精读笔记 08:FlashAttention

Fast and Memory-Efficient Exact Attention with IO-Awareness


论文基本信息

项目 内容
论文标题 FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
arXiv 编号 2205.14135
发表会议 NeurIPS 2022
作者机构 Stanford University(斯坦福大学)
主要作者 Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
提交时间 2022 年 5 月 27 日
核心贡献 提出 IO 感知的精确注意力算法,大幅减少 GPU 显存读写,实现显著加速

阅读地图

本笔记按以下顺序精读论文:

[背景知识铺垫]
  └── GPU 内存层次(必读,理解一切的基础)

[论文核心结构]
  ├── Abstract(摘要):一句话知道论文做了什么
  ├── Introduction(引言):为什么要做,做到了什么
  ├── Section 2(背景):标准注意力有什么问题
  ├── Section 3(方法):FlashAttention 如何解决——核心算法
  │     ├── 3.1 在线 softmax(Online Softmax)
  │     ├── 3.2 分块 Tiling 算法
  │     ├── 3.3 重计算 Recomputation
  │     ├── 3.4 核融合 Kernel Fusion
  │     └── 3.5 IO 复杂度分析
  └── Section 4(实验):速度、显存、模型质量对比

建议阅读顺序:先读「背景知识铺垫」,再顺序阅读各节。


必读背景:GPU 内存层次(新手必懂)

在读论文之前,必须先理解 GPU 的内存结构,否则 FlashAttention 的核心思想完全无从理解。

GPU 里有两种内存,速度差异巨大

内存类型 全称 容量 带宽(读写速度) 位置
HBM High Bandwidth Memory(高带宽显存) 40–80 GB ~1.5–2.0 TB/s GPU 芯片旁边的独立内存
SRAM Static Random Access Memory(片上静态随机存储) ~192 KB/流多处理器 ~19 TB/s GPU 芯片内部

关键对比:
- SRAM 的读写速度是 HBM 的约 10 倍
- 但 SRAM 的容量是 HBM 的约 1/200,000(HBM 40GB vs SRAM 几百KB)

工厂搬运类比(帮助理解)

想象一家汽车工厂:

标准注意力的做法(低效):
每算一个步骤,都把半成品送回大仓库(HBM),下一步再从大仓库取出来。工人不停地来回奔跑,大量时间浪费在搬运上,而不是真正的生产(计算)。

FlashAttention 的做法(高效):
提前把原料切成小批,一批一批地搬到工位旁的小料架(SRAM)上,在小料架上把所有工序一次做完,最后只把成品送回仓库。大幅减少了来回奔跑的次数。

为什么注意力是"内存瓶颈"而非"计算瓶颈"?

现代 GPU 的浮点计算能力增长速度,远快于内存带宽增长速度。
这意味着:GPU 的算术单元大部分时间在等待数据从 HBM 传来,而不是真正在算。
这类操作叫做 memory-bound(内存受限),相对地,如果瓶颈在算术单元,叫 compute-bound(计算受限)

标准注意力就是典型的 memory-bound 操作——算得不多,但搬运次数极多。


Abstract(摘要)

原文关键句

"Transformers are slow and memory-hungry on long sequences, since the time and memory complexity of self-attention are quadratic in sequence length. Approximate attention methods have attempted to address this problem, but often do not achieve wall-clock speedup."

"We propose FlashAttention, an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM."

"We train Transformers faster than existing baselines: 15% end-to-end wall-clock speedup on BERT-large (seq. length 512) compared to the MLPerf 1.1 training speed record, 3× speedup on GPT-2 (seq. length 1K), and 2.4× speedup on long-range arena (seq. length 1K-4K)."

翻译

Transformer 在处理长序列时速度慢、显存占用大,原因是自注意力机制的时间和空间复杂度都与序列长度成平方关系。已有的近似注意力方法试图解决这一问题,但往往没能带来实际的速度提升(wall-clock speedup,即真实挂钟时间加速)。

我们提出 FlashAttention——一种 IO 感知的精确注意力算法。它通过分块(tiling)技术,减少 GPU 高带宽显存(HBM)和片上 SRAM 之间的内存读写次数。

我们的方法比现有基线更快地训练 Transformer:在 BERT-large(序列长度 512)上比 MLPerf 1.1 训练速度记录提升 15%;在 GPT-2(序列长度 1K)上实现 3 倍加速;在长程任务(序列长度 1K-4K)上实现 2.4 倍加速。

新手讲解

摘要揭示了三个关键点:

  1. 问题:标准注意力是 O(N²) 的,N 是序列长度。序列翻倍,显存和时间变成 4 倍。

  2. 为什么近似方法不够好:许多论文提出"近似注意力"来降低理论复杂度,但在实际 GPU 上反而没有更快——因为它们没有解决真正的瓶颈(HBM 读写),只是减少了浮点运算数,而浮点运算本来就不是瓶颈。

  3. FlashAttention 的突破:它是精确的(exact),不牺牲精度;速度快是因为减少了 HBM 读写,而不是减少了计算量。

"wall-clock speedup"(挂钟加速)这个词很重要:它指现实中用秒表测量的真实加速,而非理论 FLOPs 减少。这正是论文强调 IO 感知的原因——以前的方法理论上减少了计算,但实际并不更快。


Introduction(引言)

第一段:问题背景

原文关键句:

"Transformers have emerged as the most widely used architecture in applications such as natural language processing and image classification. Transformers are large and slow, primarily because of self-attention... As a result, there is significant interest in making attention faster and more memory-efficient."

翻译:
Transformer 已成为自然语言处理和图像分类等应用中最广泛使用的架构。Transformer 体积大、速度慢,主要原因是自注意力机制……因此,让注意力更快、更节省显存成为了研究热点。

新手讲解:
这一段在说:注意力机制是 Transformer 的核心,也是它的速度瓶颈。每次前向/反向传播都要花大量时间在注意力计算上,尤其是长序列。


第二段:近似方法的局限

原文关键句:

"Many approximate attention methods have proposed to reduce the compute and memory requirements of attention... However, none of these methods achieve practical wall-clock speedup versus standard attention on common hardware for typical sequence lengths (up to 2K)."

翻译:
许多近似注意力方法提出减少注意力的计算量和显存需求……然而,这些方法中没有一种在常见硬件上、对典型序列长度(最多 2K)实现了相对标准注意力的实际速度提升。

新手讲解:
这是论文的"立论依据":近似方法(如 Linformer、Performer 等)理论上减少了 FLOPs(浮点运算次数),但实际上没有变快。为什么?因为它们减少的不是瓶颈,真正的瓶颈是 HBM 读写次数,这些方法并没有减少甚至可能增加了 IO 操作。


第三段:论文核心洞见

原文关键句:

"We propose to make attention algorithms IO-aware—accounting for reads and writes between levels of GPU memory. We propose FlashAttention, an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM."

翻译:
我们提出让注意力算法具备 IO 感知能力——即在算法设计中考虑 GPU 不同层级内存之间的读写代价。我们提出 FlashAttention,一种 IO 感知的精确注意力算法,通过分块技术减少 GPU 高带宽显存(HBM)和片上 SRAM 之间的读写次数。

新手讲解:

"IO-aware"(IO 感知) 是本文最核心的概念。它的含义是:
算法设计不仅要考虑"要做多少次加减乘除",还要考虑"要从哪里取数据、取几次"。

就像优化工厂,你不仅要优化工人的操作步骤,还要优化物料的搬运路线。

"exact"(精确) 也很重要:FlashAttention 计算的是与标准注意力完全相同的数值结果(在数值精度范围内),不是近似。这意味着:
- 不会改变模型的数学行为
- 可以直接替换现有代码中的注意力计算
- 所有使用标准注意力训练出的模型质量保证不变


第四段:主要贡献总结

原文关键句:

"Our main goal is to avoid reading and writing the attention matrix to and from HBM... This requires us to (i) compute the softmax reduction without access to the whole input and (ii) not store the large intermediate attention matrix for the backward pass."

翻译:
我们的主要目标是避免将注意力矩阵写入 HBM 再从 HBM 读出……这要求我们:(i) 在无法访问完整输入的情况下计算 softmax 归一化;(ii) 在反向传播中不存储大型中间注意力矩阵。

新手讲解:
这段话点出了实现 FlashAttention 需要克服的两个技术难题:

难题一:分块计算 softmax
标准 softmax 公式 softmax(x_i) = exp(x_i) / Σ exp(x_j) 需要知道所有 j 的值才能算分母。
如果我们只看一块数据,怎么算完整的 softmax?
→ 解决方案:在线 softmax(Online Softmax),后面详讲。

难题二:反向传播不存储中间矩阵
标准实现在前向传播时把注意力矩阵 P(N×N 大小)存起来,反向传播时用。
但 N×N 矩阵正是显存爆炸的来源(N=1K 时就是 100万个浮点数)。
→ 解决方案:重计算(Recomputation),反向传播时临时重算,不提前存。


Section 2:背景(标准注意力的问题)

2.1 注意力机制的标准算法

原文关键句:

"Given input sequences Q, K, V ∈ R^(N×d), we want to compute the attention output O ∈ R^(N×d): S = QK^T ∈ R^(N×N), P = softmax(S) ∈ R^(N×N), O = PV ∈ R^(N×d)."

翻译与解析:

标准注意力的三步计算:

输入:Q(查询矩阵)、K(键矩阵)、V(值矩阵),形状均为 N×d
  N = 序列长度
  d = 每个注意力头的维度(通常 64 或 128)

第一步:S = Q × Kᵀ         → 形状 N×N(注意力得分矩阵)
第二步:P = softmax(S)      → 形状 N×N(注意力权重矩阵)
第三步:O = P × V           → 形状 N×d(最终输出)

新手讲解:

N×N 矩阵是问题所在。假设序列长度 N=2048(仅 2K):
- S 矩阵大小:2048 × 2048 × 4 字节(float32)= 16 MB
- P 矩阵大小:同上 = 16 MB

这两个矩阵在标准实现中都要写到 HBM,然后再读回来。而且每一层都要重复一次,完整模型有几十层。

16 MB 听起来不大,但关键是:每次计算都要完整地读写这个矩阵,反复的 HBM IO 才是慢的根源。


2.2 HBM IO 才是瓶颈:数字说话

原文关键句:

"Standard implementations require O(N²) memory... As an example of memory-bandwidth bottleneck, consider the element-wise masking or dropout applied after computing scores S. Such an operation is memory-bound."

翻译:
标准实现需要 O(N²) 的显存……以显存带宽瓶颈为例:在计算得分 S 之后应用的逐元素掩码(masking)或 dropout 操作,这类操作是内存受限的。

新手讲解:

论文这里做了一个实际的测量:对矩阵做 softmax、dropout、masking 这些"简单操作",实际上大部分时间都花在了从 HBM 读写数据上,而不是计算本身。

这就是为什么即使减少 FLOPs,如果还是要反复读写 HBM,速度就不会提升。


Section 3:FlashAttention 算法(核心)

3.1 在线 Softmax(Online Softmax)——分块计算的数学基础

术语解释Online Softmax 指不需要一次性看到所有数据、可以逐块递增地计算 softmax 的技术。

问题:为什么 softmax 不能直接分块?

标准 softmax 公式:

softmax(x_i) = exp(x_i) / [exp(x_1) + exp(x_2) + ... + exp(x_N)]

分母需要所有 N 个元素。如果你只看其中一块(比如前 128 个元素),你不知道完整的分母是多少,算出来的值是错的。

解决方案:维护"统计量",增量合并

FlashAttention 借助了一个数学技巧:对于任意两段数据拼接,可以用以下递推公式合并 softmax 统计量:

符号定义:
- m(x) = 向量 x 中的最大值(用于数值稳定)
- ℓ(x) = Σ exp(x_i - m(x))(归一化后的指数和)

合并公式(针对 x = [x⁽¹⁾, x⁽²⁾] 两段拼接):

m([x⁽¹⁾, x⁽²⁾]) = max(m(x⁽¹⁾), m(x⁽²⁾))

ℓ([x⁽¹⁾, x⁽²⁾]) = exp(m(x⁽¹⁾) - m_new) × ℓ(x⁽¹⁾)
                  + exp(m(x⁽²⁾) - m_new) × ℓ(x⁽²⁾)

其中 m_new = m([x⁽¹⁾, x⁽²⁾])

新手讲解:

把它想象成这样一个场景:你要找全班同学中成绩最高的,并计算所有人成绩的"归一化指数和"。但成绩表是分成多页的,你每次只能看一页。

关键是:只需要记住每页的两个数(最大值和指数和),就能精确合并,不需要把所有数据都存起来。

这个公式保证:分块计算出来的 softmax 与一次性算整个序列的结果完全相同(精确)


3.2 分块 Tiling 算法(FlashAttention 前向传播)

术语解释Tiling(分块/瓦片化) 是将大矩阵切成小块,让每块能放进 SRAM 的技术。

块大小的选择

块大小 B_c = ceil(M / (4d))
块大小 B_r = min(ceil(M / (4d)), d)

其中:M = SRAM 容量,d = 注意力头维度

这个大小保证每个 Q、K、V 块能放入 SRAM。

完整前向传播算法(伪代码)

输入:Q, K, V ∈ R^(N×d)(存在 HBM 中)
SRAM 容量:M

// 步骤 1:初始化输出
O = 全零矩阵 (N×d),存于 HBM
l = 全零向量 (N),存于 HBM(记录指数和)
m = -∞向量 (N),存于 HBM(记录最大值)

// 步骤 2:把 K、V 切成 Tc 块
将 K 切成 K₁, K₂, ..., K_Tc,每块大小 B_c × d
将 V 切成 V₁, V₂, ..., V_Tc,每块大小 B_c × d

// 步骤 3:把 Q 切成 Tr 块
将 Q 切成 Q₁, Q₂, ..., Q_Tr,每块大小 B_r × d
(同样地,把 O, l, m 也对应切块)

// 步骤 4:双重循环(外层遍历 K/V 块,内层遍历 Q 块)
for j = 1 to Tc:
    从 HBM 加载 K_j, V_j 到 SRAM    // 一次 HBM 读

    for i = 1 to Tr:
        从 HBM 加载 Q_i, O_i, l_i, m_i 到 SRAM    // 一次 HBM 读

        // 在 SRAM 内计算(不离开快速内存)
        S_ij = Q_i × K_jᵀ                           // 当前块的得分
        m̃_ij = row_max(S_ij)                        // 本块最大值
        P̃_ij = exp(S_ij - m̃_ij)                   // 本块指数
        ℓ̃_ij = row_sum(P̃_ij)                       // 本块指数和

        // 合并旧统计量和新统计量(在线 softmax 更新)
        m_new = max(m_i, m̃_ij)
        ℓ_new = exp(m_i - m_new) × l_i + exp(m̃_ij - m_new) × ℓ̃_ij

        // 更新输出(用修正系数调整权重)
        O_i = diag(ℓ_new)⁻¹ × [diag(l_i) × exp(m_i - m_new) × O_i
                               + exp(m̃_ij - m_new) × P̃_ij × V_j]

        // 更新统计量
        l_i = ℓ_new
        m_i = m_new

        // 把更新后的 O_i, l_i, m_i 写回 HBM    // 一次 HBM 写

// 步骤 5:最终 O 已在 HBM 中,返回
return O

新手讲解:双重循环的意义

这个算法乍看复杂,但核心逻辑非常直观:

外层循环遍历 K/V 的每一块(把"键值对"分批搬到工位旁)
内层循环遍历 Q 的每一块(把"查询"分批处理)

对于每个 (Q块, K/V块) 组合:
1. 算出这个小块的得分
2. 用在线 softmax 公式,将这块的结果合并到已有的累计输出中
3. 用修正系数保证最终结果等价于全局 softmax

关键思想:我们从不把完整的 N×N 注意力矩阵写到 HBM!
- 标准做法:先写 S(N×N)到 HBM,再读出来算 softmax,写 P(N×N)到 HBM,再读出来算 PV
- FlashAttention:S_ij 只在 SRAM 里存一瞬间,算完就丢弃,从不写到 HBM

最终结果与标准注意力完全一致,因为在线 softmax 公式在数学上等价于全局 softmax。


3.3 重计算(Recomputation)——反向传播的节显存技巧

术语解释Recomputation(重计算/梯度检查点) 指反向传播时不依赖已存储的中间结果,而是重新计算,以节省显存。

标准反向传播的问题

神经网络训练需要反向传播(backpropagation)。
反向传播需要用到前向传播中计算的中间结果。

标准注意力前向传播中产生的中间矩阵:
- S = QKᵀ:大小 N×N
- P = softmax(S):大小 N×N

如果把它们都存起来(标准做法):
- 显存占用:2 × N² × 4 字节(float32)
- N=2K 时:约 32 MB(看起来还好)
- N=8K 时:约 512 MB
- N=64K 时:约 32 GB(超出大多数 GPU 的总显存!)

FlashAttention 的解决方案

原文关键句:

"We do not store the intermediate attention matrix S and P for the backward pass. Instead, we recompute them in the backward pass by keeping in HBM the output O and the softmax normalization statistics (m, ℓ)."

翻译:
我们不在反向传播时存储中间注意力矩阵 S 和 P,而是在反向传播时重新计算它们。我们只在 HBM 中保留输出 O 和 softmax 归一化统计量(m 和 ℓ)。

新手讲解:

这个思路的关键洞察是:

  1. 存 m 和 ℓ 代价极小:它们的大小是 N(而非 N²),存储代价可忽略不计
  2. 重计算不慢:重计算 S_ij 和 P_ij 块需要一定计算量,但计算发生在 SRAM 里,很快;而如果要从 HBM 读出存好的 N×N 矩阵,IO 时间反而更长
  3. 净效益是正的:用少量额外计算,换取大量 HBM IO 的节省,总体更快且显存更省

类比
想象你做了一道复杂菜肴,本来打算把做菜的"半成品"全部装盒存冰箱(以备检查)。但冰箱放不下,且取出来还要解冻(HBM 慢)。
FlashAttention 的做法是:只记下几个关键参数(m 和 ℓ),需要时重新快速做一次(SRAM 内重算),比从冰箱取出更高效。

实际显存节省

方法 前向传播存储 反向传播需求
标准注意力 O(N²)——存 S 和 P O(N²)
FlashAttention O(N)——只存 m 和 ℓ O(N)——重计算

这使得 FlashAttention 的额外显存(除输入输出之外)仅为 O(N),与序列长度线性相关,而非平方相关。


3.4 核融合(Kernel Fusion)

术语解释Kernel Fusion(核融合) 是将多个 GPU 计算步骤合并为单个 CUDA kernel(核函数)的优化技术。

标准做法的问题

标准注意力通常用多个独立的 GPU 核函数实现,每个函数对应一个操作:

Kernel 1: S = Q × Kᵀ        → 写 S 到 HBM
Kernel 2: masking(S)         → 读 S,写 masked_S 到 HBM
Kernel 3: P = softmax(S)     → 读 masked_S,写 P 到 HBM
Kernel 4: dropout(P)         → 读 P,写 P' 到 HBM
Kernel 5: O = P' × V         → 读 P',写 O 到 HBM

每个 Kernel 之间,数据都要经过 HBM,产生 4 次多余的 HBM 读写。

FlashAttention 的做法:一个 Kernel 搞定

原文关键句:

"FlashAttention fuses all the attention operations (matrix multiply, softmax, optional masking and dropout, matrix multiply) into one CUDA kernel."

翻译:
FlashAttention 将所有注意力操作(矩阵乘法、softmax、可选的掩码和 dropout、矩阵乘法)融合进一个 CUDA 核函数。

实现效果:

Fused Kernel: 
  从 HBM 读 Q, K, V(一次)
  在 SRAM 内完成:QKᵀ计算 → masking → softmax → dropout → PV计算
  把 O 写到 HBM(一次)

新手讲解:

想象流水线生产:
- 标准做法:每道工序在不同的厂房,产品在厂房间用卡车运输(HBM)
- 核融合:所有工序在同一个厂房内流水线完成,原料进去,成品出来,中间无需运输

核融合的好处不只是减少 IO,还避免了多个 Kernel 启动的额外开销(每次启动 CUDA Kernel 本身就有时间成本)。


3.5 IO 复杂度分析(定理)

Theorem 2(原文)

"FlashAttention requires Θ(N²d²M⁻¹) HBM accesses, where N is the sequence length, d is the head dimension, and M is the size of SRAM."

对比:

算法 HBM 访问次数
标准注意力 Θ(Nd + N²)
FlashAttention Θ(N²d²M⁻¹)

代入实际数值(A100 GPU):
- N = 1024,d = 64,M = 100KB = 100 × 1024 字节 ≈ 25,600 个 float32

标准注意力:Θ(1024×64 + 1024²) ≈ 1,114,112 次访问
FlashAttention:Θ(1024² × 64² / 25600) ≈ 167,772 次访问
比率 ≈ 6.6 倍减少

论文报告实际可达到约 9 倍的 HBM 访问减少。

Proposition 3(下界定理)

"No exact attention algorithm can achieve asymptotically better HBM accesses across all SRAM sizes M."

翻译:
不存在任何精确注意力算法能在所有 SRAM 大小 M 上渐近地优于 FlashAttention 的 HBM 访问复杂度。

新手讲解:
这是一个最优性定理:FlashAttention 的 IO 复杂度在渐近意义下是最优的——没有人能做得(渐近意义上)更好。这是一个强有力的理论保证。


3.6 完整算法设计总结

至此可以将 FlashAttention 的所有设计思路串联起来:

核心问题:标准注意力频繁读写 HBM,而 HBM 速度慢

解决思路链:
  ① 分块(Tiling):把大矩阵切小,让小块能放入 SRAM
         ↓ 但是...
  ② softmax 需要看全部数据,分块后怎么算?
         ↓ 解决方案:
  ③ 在线 softmax:用递推公式增量合并 softmax 统计量
         ↓ 但反向传播还需要中间矩阵...
  ④ 重计算(Recomputation):只存 (m, ℓ) 统计量,反向传播时重算
         ↓ 多步操作还是有 Kernel 切换开销...
  ⑤ 核融合(Kernel Fusion):所有操作合入一个 CUDA Kernel

最终效果:
  - HBM 访问减少最多 9×
  - 显存从 O(N²) 降至 O(N)
  - 计算结果与标准注意力完全相同(精确)

Section 4:实验结果

4.1 速度对比

原文关键句:

"FlashAttention achieves 7.6× speedup over the PyTorch implementation of attention, 3× over Megatron, and 2.4× over Triton. FlashAttention runs in 1.77ms vs. PyTorch's 13.5ms (sequence length 2048, head dimension 64)."

翻译:
FlashAttention 相比 PyTorch 注意力实现实现了 7.6 倍加速,相比 Megatron 3 倍,相比 Triton 2.4 倍。在序列长度 2048、头维度 64 时,FlashAttention 运行 1.77ms,PyTorch 标准实现需 13.5ms。

新手讲解:

这里说的是注意力计算本身的加速(不是整个模型),因此数字更显著。对于端到端的模型训练,加速幅度会有所降低(因为注意力只是模型的一部分),但依然可观:
- BERT-large 端到端训练:+15%
- GPT-2 端到端训练:3× 加速


4.2 显存对比

关键数据:

序列长度 标准注意力显存 FlashAttention 显存 节省比例
512 ~1 GB ~0.5 GB ~2×
1K ~4 GB ~0.5 GB ~8×
2K ~16 GB ~0.5 GB ~32×
4K OOM(显存溢出) ~0.5 GB 使得长序列成为可能

FlashAttention 的显存是 O(N) 线性增长,而标准注意力是 O(N²),因此序列越长优势越明显。

新手讲解:
这张表格最直观地说明了 FlashAttention 的价值:在序列长度 4K 时,标准注意力已经 OOM(out of memory,显存不足),而 FlashAttention 还游刃有余。这正是为什么 FlashAttention 能够训练更长序列的原因。


4.3 长序列模型质量提升

原文关键句:

"Using FlashAttention, we train GPT-2 with context length 4K on OpenWebText, achieving 0.7 better perplexity than GPT-2 (context length 1K) trained with the same number of tokens."

"FlashAttention is the first transformer to achieve better-than-chance performance on Path-X (seq. length 16K) and Path-256 (seq. length 64K)."

翻译:
使用 FlashAttention,我们在 OpenWebText 上训练了上下文长度为 4K 的 GPT-2,比相同 token 数训练的 1K 上下文 GPT-2 困惑度(perplexity)低 0.7。

FlashAttention 是第一个在 Path-X(序列长度 16K)和 Path-256(序列长度 64K)上实现好于随机水平性能的 Transformer。

新手讲解:

困惑度(Perplexity):语言模型的评测指标,数值越低越好,代表模型对文本的预测越准确。
困惑度降低 0.7 表示:仅通过增大上下文长度(1K → 4K),不改变任何模型结构或训练数据,模型质量就得到了提升。这是 FlashAttention 的间接贡献——让更长上下文的训练变得可行。

Path-X 和 Path-256:这是极长序列的基准测试任务,标准 Transformer 因无法处理如此长的序列而表现不佳。FlashAttention 使这些任务首次对 Transformer 可行,并取得 61.4% 的准确率(随机猜测为 50%)。


4.4 块稀疏 FlashAttention(Block-Sparse FlashAttention)

原文关键句:

"Block-sparse FlashAttention achieves 2-4× speedup over even FlashAttention on long sequences of length 4K-64K."

翻译:
块稀疏 FlashAttention 在长度 4K-64K 的长序列上,相比 FlashAttention 本身再实现 2-4 倍加速。

新手讲解:
FlashAttention 还可以扩展到稀疏注意力模式:不是每个位置都需要关注所有其他位置,可以跳过大部分块。这在保持 FlashAttention IO 优势的同时,进一步减少了计算量。但这已是近似方法,不再是精确注意力。


核心思想总结

一句话总结

FlashAttention 的本质洞察是:标准注意力慢不是因为算得太多,而是因为频繁地把大矩阵写进、读出 GPU 的慢速显存(HBM)。 通过分块计算、在线 softmax 和重计算三项技术,FlashAttention 将注意力计算的数据全程保留在快速的片上 SRAM 中,大幅减少 HBM IO,同时保证结果精确。

三项核心技术的关系

目标:减少 HBM IO
  ↓
技术一:分块(Tiling)
  - 将 Q、K、V 切成小块,装入 SRAM
  - 问题:分块后 softmax 无法直接算
  ↓
技术二:在线 softmax(Online Softmax)  
  - 递推公式增量合并,分块精确计算 softmax
  - 问题:反向传播还需要 N×N 矩阵
  ↓
技术三:重计算(Recomputation)
  - 只存 O(N) 的统计量,反向传播时重算
  - 全程无需写完整 N×N 矩阵到 HBM
  ↓
加分项:核融合(Kernel Fusion)
  - 所有操作在一个 CUDA Kernel 内完成,消除 Kernel 切换开销

为什么"精确"是重要贡献

很多加速方法通过"近似"来提速(如稀疏注意力、低秩近似),这些方法:
- 可能改变模型的数学行为
- 需要在速度与精度之间做权衡
- 不能直接替换标准注意力

FlashAttention 是精确的(数值上等价于标准注意力),这意味着:
- 可以直接作为即插即用的替代品
- 不影响任何基于标准注意力的理论分析
- 已有的模型不需要修改结构即可受益

影响与意义

FlashAttention 发布后迅速成为业界标准:
- PyTorch 2.0 将其纳入官方 F.scaled_dot_product_attention
- Hugging Face Transformers 库默认支持
- GPT-4、Llama 2 等大模型训练中广泛采用
- FlashAttention-2(2023)、FlashAttention-3(2024)相继发布,进一步提升性能

它的核心贡献不只是加速,更重要的是把长序列建模变成了可行的工程任务,直接推动了 16K、32K 乃至更长上下文大模型的发展。


术语速查表

术语 英文 含义
HBM High Bandwidth Memory GPU 芯片外的高带宽显存,大(40-80GB)但相对慢(1.5-2TB/s)
SRAM Static Random Access Memory GPU 芯片内的静态随机存储,小(~192KB)但极快(~19TB/s)
IO-aware IO-Aware 在算法设计时考虑内存读写代价的设计理念
Tiling 分块/瓦片化 将大矩阵切成小块,使每块能放入 SRAM 的技术
Online Softmax 在线 softmax 无需一次看到全部数据、可增量计算 softmax 的算法
Recomputation 重计算 反向传播时不依赖存储的中间结果,而是重新计算,以节省显存
Kernel Fusion 核融合 将多个 CUDA 操作合并为一个 Kernel,消除中间 HBM 读写
Memory-bound 内存受限 操作速度受限于内存带宽而非计算单元的状态
Compute-bound 计算受限 操作速度受限于算术计算单元的状态
Wall-clock speedup 挂钟加速 用实际时钟测量的真实加速,对应理论加速(FLOPs 减少)
Perplexity 困惑度 语言模型评估指标,越低越好
FLOPs 浮点运算次数 衡量计算量的单位,FlashAttention 不减少 FLOPs
Exact attention 精确注意力 与标准注意力数学等价的实现,相对于近似注意力方法

本笔记基于 FlashAttention 原论文(arXiv:2205.14135)整理,适合零基础读者入门。
建议配合 FlashAttention GitHub 仓库(github.com/Dao-AILab/flash-attention)的代码对照阅读。