精读笔记 · Ring Attention with Blockwise Transformers for Near-Infinite Context

论文信息
- 标题:Ring Attention with Blockwise Transformers for Near-Infinite Context
- arXiv:2310.01889
- 机构:UC Berkeley
- 发表:ICLR 2024
- 作者:Hao Liu, Matei Zaharia, Pieter Abbeel
- 原文链接:https://arxiv.org/abs/2310.01889


阅读地图

本文解决一个非常实际的工程问题:单张 GPU 显存不够,装不下超长序列的注意力计算,怎么办?

阅读顺序建议:
1. 先看「背景问题」,建立直觉
2. 再看「Abstract 精译」,了解论文在说什么
3. 再看「Introduction 精译」,理解动机
4. 重点看「方法精译」,这是核心贡献
5. 最后看「实验精译」,验证效果


背景问题:先用大白话说清楚

问题 1:注意力机制的显存杀手

Transformer 的核心是自注意力(Self-Attention)。给定一段序列,每个 token 要跟所有其他 token 计算"相关性"。序列长度是 s,这个计算量和显存占用是 O(s²)——也就是说,序列翻倍,显存要翻四倍。

类比:假设你在开一场圆桌会议。10 个人时,每两人之间都要握手,共 45 次。100 个人时,握手次数变成 4950 次。人数翻 10 倍,握手次数翻 100 倍。这就是注意力的二次方代价。

问题 2:一张 GPU 不够用

现代 GPU(如 A100)的高带宽显存(HBM)通常在 80GB 左右。论文中提到:

"To put the memory demand in perspective, even when dealing with a batch size of 1, processing 100 million tokens requires over 1000GB of memory for a modest model with a hidden size of 1024."

处理 1 亿个 token,哪怕模型很小(隐藏层维度仅 1024),也需要超过 1000GB 显存——比一张 A100 大 12 倍以上。

问题 3:多卡方案的困境

能不能用多张 GPU 分担?理论上可以,但以前的办法都有缺陷:
- 张量并行(Tensor Parallelism):只能减少部分显存,不能按需扩展
- 序列并行(Sequence Parallelism):引入大量通信开销,通信和计算无法同时进行,白白等待

Ring Attention 的解法(核心直觉)

把序列切成块,每块放一张 GPU;GPU 排成一个圆环,像击鼓传花一样轮流传递 KV 块;传数据和算注意力同时进行,通信时间被"藏起来"了。


Abstract 精译

原文

"Transformers have emerged as the architecture of choice for many state-of-the-art AI models, showcasing exceptional performance across a wide range of AI applications. However, the memory demands imposed by Transformers limit their ability to handle long sequences, thereby posing challenges in utilizing videos, actions, and other long-form sequences and modalities in complex environments. We present a novel approach, Ring Attention with Blockwise Transformers (Ring Attention), which leverages blockwise computation of self-attention and feedforward to distribute long sequences across multiple devices while fully overlapping the communication of key-value blocks with the computation of blockwise attention. Our approach enables training and inference of sequences that are up to device count times longer than those achievable by prior memory-efficient Transformers, without resorting to approximations or incurring additional communication and computation overheads. Extensive experiments on language modeling and reinforcement learning tasks demonstrate the effectiveness of our approach in allowing millions of tokens context size and improving performance."

翻译

Transformer 已成为许多最先进 AI 模型的首选架构,在各类 AI 应用中展现出卓越性能。然而,Transformer 对显存的高需求限制了其处理长序列的能力,使得在复杂环境中利用视频、动作及其他长序列和多模态数据变得困难。我们提出了一种新方法——Ring Attention with Blockwise Transformers(Ring Attention,环形注意力),它利用自注意力和前馈网络的分块计算,将长序列分布到多个设备上,同时将 key-value 块的通信与分块注意力计算完全重叠。我们的方法使得序列训练和推理的长度可达先前内存高效 Transformer 的设备数倍,无需任何近似,也不引入额外的通信和计算开销。在语言建模和强化学习任务上的大量实验证明了该方法的有效性,实现了数百万 token 的上下文规模并提升了性能。

新手讲解

这段话有几个关键词,逐一拆解:


Introduction 精译

核心段落 1:问题的规模

原文(关键句):

"The self-attention has memory cost quadratic in the input sequence length, which makes it challenging to scale to longer input sequences."

翻译:

自注意力的显存开销与输入序列长度成二次方关系,这使得扩展到更长输入序列极具挑战。

新手讲解:
二次方(quadratic)意味着:序列从 1K token 变成 4K token(增加 4 倍),显存需求变成 16 倍。这不是线性增长,而是爆炸式增长。这就是为什么 GPT-2 只能处理 1K token,而让模型处理 100K token 非常困难。


核心段落 2:量化问题的严峻性

原文(关键句):

"To put the memory demand in perspective, even when dealing with a batch size of 1, processing 100 million tokens requires over 1000GB of memory for a modest model with a hidden size of 1024."

翻译:

为了直观感受显存需求:即使批大小为 1,对于一个隐藏层维度仅为 1024 的小模型,处理 1 亿个 token 也需要超过 1000GB 显存。

新手讲解:
典型的高端 GPU(A100)显存是 80GB。1000GB ÷ 80GB ≈ 12.5 张 A100。而实际的大语言模型(如 LLaMA-70B)的隐藏层远大于 1024,需求更高。100 万 token 就已经是个工程难题,更别说 1 亿 token 了。


核心段落 3:为什么已有方案不够好

原文(关键句):

"While BPT significantly reduces memory demand in Transformers, it still presents a major challenge for scaling up context length because it requires storing the output of each layer."

"Tensor parallelism can only reduce parts of activations memory and sequence parallelism introduces a significant communication overhead that cannot be fully overlapped with computation."

翻译:

尽管 BPT(分块并行 Transformer)大幅降低了 Transformer 的显存需求,但它仍然面临扩展上下文长度的重大挑战,因为它需要存储每一层的输出。

张量并行只能减少部分激活值显存,而序列并行引入了大量通信开销,这些开销无法与计算完全重叠。

新手讲解:
- BPT(Blockwise Parallel Transformer):Ring Attention 的前身,把注意力计算分块处理,节省了很多显存。但问题是,每层的输出(激活值)仍然要保存,而这个大小和序列长度成正比。所以序列越长,BPT 也撑不住。
- 张量并行:把模型的参数矩阵拆开放到不同 GPU,但这只减少了模型参数的存储,不能有效减少长序列带来的激活值开销。
- 序列并行:把序列切开放到不同 GPU,但各 GPU 之间需要大量通信,而且通信和计算是串行的(先通信,再计算),效率低。


核心段落 4:Ring Attention 的承诺

原文(关键句):

"Ring Attention enables training and inference of sequences that are up to device count times longer than those achievable by prior memory-efficient Transformers, without resorting to approximations or incurring additional communication and computation overheads."

翻译:

Ring Attention 使序列训练和推理的长度可达先前内存高效 Transformer 的设备数倍,无需任何近似,也不引入额外的通信和计算开销。

新手讲解:
三个关键承诺:
1. 设备数倍长度:32 张 GPU → 序列长度扩大 32 倍,线性扩展
2. 无近似:不是用稀疏注意力、局部注意力等"偷工减料"的方式,而是精确计算完整的全局注意力
3. 零额外开销:通信的时间被计算"盖住了",不额外增加耗时


方法精译(核心,一段不漏)

2. 背景:分块并行 Transformer(BPT)

2.1 显存瓶颈

原文(关键句):

"Prior state-of-the-arts have led to substantial reductions in memory utilization, achieved through innovative techniques that enable attention computation without full materialization by computing attention in a block by block manner. These advancements lowered the memory overhead of attention to 2bsh bytes per layer, where b represents the batch size, s denotes the sequence length, and h stands for the hidden size of the model."

翻译:

先前的工作通过不完整实例化(不一次性把完整注意力矩阵写入显存),以逐块方式计算注意力,大幅降低了显存占用。这些进展将注意力层的显存开销降低到每层 2bsh 字节(b=批大小,s=序列长度,h=隐藏维度)。

新手讲解:
这里提到的"逐块计算注意力"就是 FlashAttention 的思想:把 Q、K、V 矩阵切成小块,在 GPU 的 SRAM(速度快但容量小的片上缓存)里逐块计算,避免把巨大的注意力矩阵(s×s)写入 HBM(速度慢但容量大的显存)。

关键术语解释:
- HBM(High Bandwidth Memory):GPU 上的主显存,容量大(80GB),但比片上 SRAM 慢
- SRAM(Static RAM):GPU 片上高速缓存,速度极快,但容量极小(通常几十 MB)
- Materialization(实例化):把中间结果写入显存。"不完整实例化"就是不把注意力矩阵完整写出来,而是边算边丢弃


2.2 BPT 的进一步优化

原文(关键句):

"To further reduce memory usage, blockwise parallel transformer (BPT) introduced a strategy where the feedforward network associated with each self-attention sub-layer is computed in a block-wise fashion. This approach effectively limits the maximum activation size of feedforward network from 8bsh to 2bsh."

翻译:

为进一步降低显存,BPT 提出了一种策略:将与每个自注意力子层对应的前馈网络也以分块方式计算。这将前馈网络的最大激活值大小从 8bsh 压缩到 2bsh。

新手讲解:
BPT 不光把注意力分块,连后面的前馈层(FFN,就是 Transformer 里的 MLP)也分块算。这样每次只在内存里保留一小块激活值,极大地减少了峰值显存。

但正如 Introduction 里说的,BPT 有一个致命缺陷:每层的输出仍然要存下来(因为反向传播需要),这个存储量依然随序列长度线性增长。序列超长,BPT 也撑不住。


3. Ring Attention:核心方法

这是论文最重要的部分。

3.1 核心思想:注意力计算的顺序无关性

原文(关键句):

"We leverage this property by conceptualizing all hosts as forming a ring structure... the self-attention between a query block and a group of key-value blocks can be computed in any order, as long as the statistics of each block are combined correctly for rescaling."

翻译:

我们利用这一性质,将所有主机(设备)概念化为构成一个环形结构……一个查询块与一组 key-value 块之间的自注意力,可以以任意顺序计算,只要对每个块的统计量(用于缩放)正确合并即可。

新手讲解:
这是 Ring Attention 的理论基础。正常的注意力计算需要把所有 K、V 都看完,才能算出每个 Q 的输出。但 FlashAttention 的 online softmax 技术证明:可以先看一部分 KV,算出"局部结果",再看另一部分 KV,把两个局部结果"合并"成全局结果,而且这个合并可以以任意顺序进行。

类比:你要评选"今天最好的菜"。不需要所有菜同时摆在桌上——可以先比较 A 组的菜选出冠军,再拿这个冠军和 B 组的菜比,最终得到总冠军。顺序不影响最终结果(只要合并逻辑正确)。

这个"合并逻辑"就是 FlashAttention 的 online softmax 重缩放(rescaling)。Ring Attention 正是站在 FlashAttention 的肩膀上,把这一性质扩展到多设备场景。


3.2 环形拓扑结构

原文(关键句):

"Host devices form a conceptual ring, where during the inner loop, each device sends a copy of its key-value blocks being used for blockwise computation to the next device in the ring, while simultaneously receiving key-value blocks from the previous one."

翻译:

主机(计算设备)形成一个概念上的环形结构。在内层循环中,每台设备将当前用于分块计算的 key-value 块复制并发送给环上的下一台设备,同时从上一台设备接收 key-value 块。

新手讲解:

这就是"击鼓传花"机制,用"一圈厨师炒菜"来类比:

想象一个圆形厨房,8 位厨师(对应 8 张 GPU)围成一圈,每人手边有一批食材(对应一个序列块的 Q)。

此外有 8 份"酱料"(对应 KV 块)在厨师之间传递。

每位厨师:
1. 用手边的食材(Q)和当前拿到的酱料(KV)炒一道菜(计算注意力)
2. 同时把这份酱料传给右边的厨师
3. 从左边厨师那里接过下一份酱料

转一圈之后,每位厨师都用 8 份酱料各炒了一次,把结果合并,就得到了完整的注意力输出。

关键是:炒菜(计算)和传酱料(通信)是同时进行的,没有人需要等待!


3.3 通信与计算重叠的条件

原文(关键句):

"To achieve an overlap between communication and computation, the following condition must hold: 4dc²/F ≥ 4cd/B. This implies that the block size, denoted as c, should be greater than or equal to F/B."

翻译:

要实现通信与计算的重叠,必须满足以下条件:4dc²/F ≥ 4cd/B。这意味着块大小 c 必须大于或等于 F/B(设备浮点运算速度 F 与互联带宽 B 之比)。

新手讲解:
这个不等式的本质是:计算 KV 块的时间 ≥ 传输 KV 块的时间

只要块大小 c 足够大(c ≥ F/B),算的时间就比传的时间长,传输就能"躲在"计算的阴影里,真正实现零通信开销。

类比:厨师炒菜(计算)需要 10 分钟,把酱料传给下一位(通信)只需要 1 分钟。那么在厨师炒菜的 10 分钟里,酱料已经传好了,完全不影响流程。

在实际硬件(如 A100 + NVLink 或 TPU + ICI)上,这个条件通常很容易满足,因为现代 GPU 算力极高,只要块不是太小,计算时间总是远大于传输时间。


3.4 显存分析:终于与序列长度无关

原文(关键句):

"A host needs to store multiple blocks...Therefore, a total of six blocks are required, which translates to 6bch bytes of memory."

翻译:

每台主机需要存储多个块……因此共需要 6 个块,即 6bch 字节显存。

新手讲解:

这是 Ring Attention 最重要的理论结果之一:

方法 显存占用(每层) 随序列长度变化?
原始 Transformer O(bs²) 是(二次方!)
FlashAttention / BPT O(bsh)(需存每层输出) 是(线性)
Ring Attention O(bc·h)(只有 6 个块) 否!

其中 c 是块大小,是固定的超参数,不随序列总长度 s 增长。

含义:序列从 1K 变到 100K,每台 GPU 的显存需求几乎不变——多出来的序列分摊到更多 GPU 上了。


3.5 算法描述(Algorithm 1)

原文(关键描述):

"Split input sequence into Nh blocks that each host has one input block. Compute query, key, and value for its input block on each host. For each transformer layer... Compute memory efficient attention incrementally using local query, key, value blocks. Send key and value blocks to next host and receive key and value blocks from previous host."

翻译:

将输入序列分成 Nh 块,每台主机持有一个输入块。每台主机在其输入块上计算 query、key、value。对于每个 Transformer 层:……用本地 Q、K、V 块增量地计算内存高效的注意力。将 key、value 块发送给下一台主机,同时从上一台主机接收 key、value 块。

步骤拆解(以 4 台 GPU 为例,序列被切成 4 块):

初始状态:
GPU-0 持有 Q0, K0, V0
GPU-1 持有 Q1, K1, V1
GPU-2 持有 Q2, K2, V2
GPU-3 持有 Q3, K3, V3

第 1 轮(每台 GPU 算自己的 KV 块):
GPU-0:用 Q0 和 K0V0 算注意力 → 同时把 K0V0 传给 GPU-1
GPU-1:用 Q1 和 K1V1 算注意力 → 同时把 K1V1 传给 GPU-2
GPU-2:用 Q2 和 K2V2 算注意力 → 同时把 K2V2 传给 GPU-3
GPU-3:用 Q3 和 K3V3 算注意力 → 同时把 K3V3 传给 GPU-0

第 2 轮(每台 GPU 拿到上一台的 KV 块):
GPU-0:用 Q0 和 K3V3 算注意力(合并到第 1 轮结果)→ 同时继续传
...

第 4 轮结束后:
每台 GPU 都用自己的 Q0/Q1/Q2/Q3 看过了 K0V0、K1V1、K2V2、K3V3 全部 4 块
→ 得到完整的注意力输出

这就是"环转一圈"的完整流程。


3.6 序列长度扩展公式

原文(关键句):

"If a model can be trained with context size s on n GPUs using the blockwise attention and feedforward, with our Ring Attention approach, it becomes possible to train a model with a context size of n·s."

翻译:

如果一个模型使用分块注意力和前馈在 n 个 GPU 上能训练上下文长度为 s 的序列,那么使用 Ring Attention,就可以训练上下文长度为 n·s 的模型。

新手讲解:
这是最直观的效果描述。单卡能跑 32K token,8 卡就能跑 256K,32 卡就能跑 1M,512 卡理论上能跑 16M+。这就是"近乎无限上下文"名称的由来。


实验精译(关键结果)

4.1 上下文长度随设备数线性增长

原文(关键句):

"Ring Attention enables training up to device count times longer sequence than baselines."

翻译:

Ring Attention 使训练序列长度最高可达基线方法的设备数倍。

关键实验数据(Table 3):

硬件配置 模型规模 可训练上下文长度 相比基线提升
8× A100 NVLink 7B 256K token
32× A100 InfiniBand 7B 4,096K(约 400 万)token 32×
TPUv3-512(512 个 TPU) 7B 2,048K(约 200 万)token 256×
TPUv4-1024(1024 个 TPU) 7B 8,192K(约 800 万)token 512×

新手讲解:
设备翻倍,上下文长度几乎精确翻倍。这是真正的线性扩展。论文还提到,在足够多的 TPU 上,可以突破 1 亿 token——这在之前是不可想象的。


4.2 MFU:硬件利用率不打折

原文(关键句):

"Ring Attention trains much longer context sizes for self-attention, resulting in higher self-attention FLOPs compared to baseline models. We can train very large context models without compromising MFU or throughput."

翻译:

Ring Attention 的自注意力训练了远更长的上下文,导致更高的自注意力 FLOPs(浮点运算量)。我们可以训练超大上下文模型,而不损失 MFU(模型浮点利用率)或吞吐量。

新手讲解:
MFU(Model FLOPS Utilization)衡量 GPU 的实际算力利用率。如果通信开销大,GPU 就在等待数据,MFU 就会很低。Ring Attention 由于通信完全被计算覆盖,MFU 与基线相比没有明显下降,说明硬件没有被浪费在等待上。


4.3 语言模型效果:长上下文确实有用

原文(关键句):

"Ring Attention-13B-512K stands out as it maintains high accuracy levels even with long contexts, outperforming GPT3.5-turbo-16K and Vicuna-16B-16K at extended sequence lengths."

翻译:

Ring Attention-13B-512K 表现突出,即使在长上下文下也能保持高准确率,在长序列任务上超越了 GPT3.5-turbo-16K 和 Vicuna-16B-16K。

新手讲解:
论文用 Ring Attention 把 LLaMA-13B 模型的上下文从 4K 微调扩展到 512K。在"行检索任务"(needle-in-a-haystack,在长文本中找特定信息)上,512K 版本明显优于上下文只有 16K 的 GPT-3.5 和 Vicuna。这证明了更长的上下文不只是工程上的成就,在实际任务中也带来真实的性能提升。


4.4 强化学习:长轨迹带来更好的决策

原文(关键句):

"AT + Ring Attention consistently outperforms original AT with BPT across all six tasks, achieving a total average return of 113.66 compared to the AT with BPT model's total average return of 111.13."

翻译:

AT + Ring Attention 在全部 6 个任务上均优于原始的 AT + BPT 方法,总平均回报为 113.66,而 AT + BPT 的总平均回报为 111.13。

新手讲解:
论文还把 Ring Attention 用在强化学习(RL)上。强化学习的"序列"是智能体的历史行动轨迹,轨迹越长,智能体能看到的历史越多,决策越好。Ring Attention 让 RL 智能体能处理更长的历史,从而做出更好的决策,在 6 个连续控制任务上均有提升。


方法总结:体系结构图

长序列(如 1M token)
         |
         v
   切分成 N 块(每块 s/N 个 token)
         |
    分配到 N 张 GPU
         |
    ┌────────────────────────────┐
    │   GPU-0    GPU-1    GPU-2    GPU-3   │
    │   Q0,K0,V0  Q1,K1,V1  Q2,K2,V2  Q3,K3,V3  │
    │        \↗        \↗        \↗        \↗     │
    │         环形传递 KV 块(共 N 轮)           │
    │   计算注意力 ←→ 传输 KV(同时进行!)      │
    └────────────────────────────┘
         |
    合并每轮的局部注意力结果(online softmax)
         |
    完整的注意力输出(等价于全局注意力,无近似)

关键创新点总结:

创新点 说明
序列分块 把超长序列切成小块,每块放一张 GPU
环形拓扑 GPU 排成圆环,KV 块沿环传递一圈
通信计算重叠 传 KV 块和算注意力同时进行,零额外开销
顺序无关合并 基于 online softmax,KV 块到达顺序不影响结果正确性
线性扩展 上下文长度随 GPU 数量线性增长,理论上接近无限

与相关工作的关系

Ring Attention 站在谁的肩膀上?

FlashAttention(Dao et al., 2022)
    ↓
    分块计算注意力,避免 O(s²) 显存
    在单 GPU 内,用 SRAM 分块读写 KV

Blockwise Parallel Transformer / BPT(Liu et al., 2023)
    ↓
    把分块思想扩展到前馈层
    进一步降低单 GPU 峰值显存
    但整个序列仍在一张卡上,无法突破单卡极限

Ring Attention(本文,2023)
    ↓
    把 BPT 的分块思想扩展到多 GPU
    KV 块沿环形传递,每张卡只需持有 1 块
    通信与计算重叠,近乎无限上下文

Ring Attention 的本质是:把 FlashAttention/BPT 的"分块"思想从单卡内部扩展到多卡之间,同时通过环形拓扑和通信计算重叠解决了多卡通信的效率问题。


局限性与后续工作


一句话概括

Ring Attention 通过把长序列切块分配给多 GPU、GPU 排成环形边算边传 KV 块、让通信时间完全被计算覆盖,实现了上下文长度随 GPU 数量线性扩展,理论上可达近乎无限长度,且精确等价于全局注意力,无任何近似。


精读范围:Abstract(全部)、Introduction(核心段落)、Section 2 背景(全部)、Section 3 方法(全部)、Section 4 实验(关键结果)。约 5200 字。