【AI】FlashAttention 详解:为什么它能让大模型注意力计算又快又省显存?

FlashAttention 是现代大模型训练与推理中最重要的底层优化之一。它并没有改变注意力机制的数学结果,也不是近似算法,却能显著降低显存占用、提升计算速度。本文从标准 Attention 的显存瓶颈讲起,用图解方式拆开 FlashAttention 的核心思想:IO-aware、分块计算、在线 Softmax、SRAM 复用、反向重算。看完这篇,你会真正明白它为什么快、为什么省、什么时候有效、什么时候不明显。

一、先说结论:FlashAttention 到底是什么?

一句话概括:

FlashAttention 是一种 IO-aware 的精确注意力计算算法,它通过分块计算和在线 Softmax,避免显式存储巨大的注意力矩阵,从而大幅减少 GPU 显存读写。

这里有三个关键词:

  1. IO-aware:它关注的不只是浮点计算量 FLOPs,而是 GPU 显存和高速缓存之间的数据搬运成本。
  2. 精确注意力:它算出来的结果和标准 Attention 数学上等价,不是近似、不剪枝、不稀疏化。
  3. 不保存完整注意力矩阵:标准 Attention 会生成一个 $N \times N$ 的注意力矩阵,FlashAttention 避免把它完整写入显存。

很多人第一次听到 FlashAttention,会以为它是“更聪明的 Attention 结构”。其实不是。它不是新的模型结构,而是同一个公式的更高效实现方式


二、先复习标准 Attention:慢在哪里?

Transformer 的核心注意力公式是:

$$ \mathrm{Attention}(Q,K,V)=\mathrm{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V $$

其中:

  • $Q$:Query,表示“我想找什么信息”
  • $K$:Key,表示“我有什么标签”
  • $V$:Value,表示“我真正提供的信息”
  • $N$:序列长度
  • $d$:每个 head 的维度

标准实现通常分三步:

1
2
3
第一步:S = QK^T              # 得到 N × N 的打分矩阵
第二步:P = softmax(S) # 得到 N × N 的注意力概率矩阵
第三步:O = P V # 得到输出

问题就出在中间的两个矩阵:SP

如果序列长度是 4096,那么每个 head 的注意力矩阵大小是:

1
4096 × 4096 = 16,777,216 个元素

FP16 每个元素 2 字节,仅一个矩阵就是:

1
16,777,216 × 2 ≈ 32 MB

注意这只是一个 batch、一个 head、一个矩阵。实际训练中还有:

  • 多个 batch
  • 多个 attention head
  • 多层 Transformer
  • 前向保存中间结果给反向传播使用

所以标准 Attention 的显存压力会迅速爆炸。


三、真正的瓶颈:不是算不动,而是搬不动

很多新手会以为 GPU 慢是因为“乘法太多”。但在 Attention 里,一个更关键的问题是:数据搬运太多

GPU 内部大致可以理解成三层存储:

存储位置 速度 容量 类比
HBM / 显存 仓库
SRAM / Shared Memory 操作台
Register / 寄存器 最快 极小 手里拿着的工具

标准 Attention 的做法像这样:

1
2
3
Q、K 从显存读入 → 计算 S → 把 S 写回显存
S 从显存读入 → softmax → 把 P 写回显存
P、V 从显存读入 → 计算 O → 把 O 写回显存

也就是说,中间的巨大矩阵 SP 被反复写入、读出显存。

这就是 FlashAttention 论文里强调的核心:

Attention 的瓶颈不是只有 FLOPs,还有 HBM 和 SRAM 之间的 IO。

标准 Attention:巨大中间矩阵反复进出显存HBM 显存Q / K / V / S / PSRAM临时计算CUDA Core矩阵乘 / softmax读 Q/K写 S/P送去计算结果返回问题:S 和 P 都是 N×N,序列越长,中间矩阵越大,显存读写成本越恐怖。

四、FlashAttention 的核心思想:不要把整张注意力表摊开

标准 Attention 像是在处理一张巨大的 Excel 表:

  • 先算完整的 S = QK^T
  • 再对整张表做 softmax
  • 再拿整张表乘以 V

FlashAttention 的想法是:

不要一次性生成整张 N×N 表,而是把 Q、K、V 切成小块,在 GPU 高速缓存里一块一块算。

可以把它想象成做饭:

  • 标准 Attention:把所有菜一次性摊满整个厨房,切完、炒完、装盘,厨房爆炸。
  • FlashAttention:每次只拿一小篮菜到操作台,处理完立刻合并到最终成品,不把半成品堆满厨房。

五、分块计算:Tiling

FlashAttention 会把矩阵切成 block。例如:

1
2
3
Q 分成 Q1, Q2, Q3, ...
K 分成 K1, K2, K3, ...
V 分成 V1, V2, V3, ...

然后每次只计算一个小块:

1
Qi × Kj^T

得到的是一个小的 attention score block,而不是完整的 N × N 矩阵。

FlashAttention:按块计算,不生成完整 N×N 注意力矩阵Q blocksQ1Q2Q3...K/V blocksK1/V1K2/V2K3/V3...小块 scoreQ1K1ᵀ每次只把一个 Q block 和一个 K/V block 放进高速缓存,算完立刻合并到输出。

这个思想本身不难,但真正困难在 softmax。

为什么?因为 softmax 的分母需要知道一整行所有元素:

$$ \mathrm{softmax}(x_i)=\frac{e^{x_i}}{\sum_j e^{x_j}} $$

如果你只看一小块,就不知道整行的最大值和分母是多少。FlashAttention 的关键创新就在这里:在线 Softmax


六、在线 Softmax:边看边更新全局结果

普通 softmax 通常要先拿到整行数据,然后做三步:

1
2
3
1. 找到这一行的最大值 m
2. 计算 exp(x - m)
3. 求和并归一化

FlashAttention 不能一次看完整行,只能一块一块看。因此它维护两个运行中的变量:

  • m:当前已经看过的所有块中的最大值
  • l:当前 softmax 分母的累计值

当新 block 到来时,更新:

$$ m_{new}=\max(m_{old},m_{block}) $$
$$ l_{new}=e^{m_{old}-m_{new}}l_{old}+e^{m_{block}-m_{new}}l_{block} $$

这一步非常重要。它解决了两个问题:

  1. 数值稳定:始终减去最大值,避免指数爆炸。
  2. 分块等价:虽然一块一块算,但最后结果和一次性 softmax 完全一致。

直觉上看,在线 softmax 就像你在分批统计全班成绩:

  • 每来一批学生,更新当前最高分。
  • 如果新的最高分变了,之前那批学生的相对权重也要重新缩放。
  • 最后得到的统计结果和一次性看完整个班一样。

七、FlashAttention 的完整流程

以一个 Q block 为例,FlashAttention 大致做下面这些事:

1
2
3
4
5
6
7
8
9
10
11
12
13
for each Q_block:
初始化输出 O_block = 0
初始化 running max m = -∞
初始化 running sum l = 0

for each K_block, V_block:
1. 把 Q_block, K_block, V_block 加载到 SRAM
2. 计算小块分数 S_block = Q_block × K_block^T
3. 根据 mask 处理 causal attention
4. 用在线 softmax 更新 m 和 l
5. 把当前 block 对输出的贡献合并进 O_block

把最终 O_block 写回 HBM

关键点:

中间的 S_blockP_block 只存在于 SRAM / register 中,用完就丢,不写回显存。

标准 Attention 需要保存完整 SP;FlashAttention 只保存最终输出 O 以及少量 softmax 统计量。


八、为什么它省显存?

标准 Attention 的中间显存主要来自:

1
2
S = QK^T          # N × N
P = softmax(S) # N × N

训练时还要保存 P 给反向传播用。

FlashAttention 不保存完整的 SP,只保存:

1
2
O       # N × d
m, l # 每一行的 softmax 统计量

所以中间显存从近似:

1
O(N²)

降到:

1
O(N)

注意:这里说的是额外中间显存,不是 Q/K/V 本身的显存,也不是模型参数显存。

这也是为什么长序列场景下 FlashAttention 特别有价值。序列越长,$N^2$ 越可怕,省下来的显存越多。


九、为什么它更快?

FlashAttention 更快,不是因为数学计算量从根本上少了。它依然要算注意力,本质 FLOPs 没有消失。

它快在:

  1. 减少 HBM 读写:不把 SP 这种巨大矩阵写回显存。
  2. 充分利用 SRAM:小块数据放在高速缓存里反复使用。
  3. Kernel Fusion:把矩阵乘、mask、softmax、dropout、乘 V 等步骤融合在一个 CUDA kernel 里,减少中间结果落地。
  4. 更好的并行调度:FlashAttention-2 进一步优化了 block 划分和 warp-level 并行。

可以这样理解:

标准 Attention 是“算一步、存一步、再读回来算下一步”;FlashAttention 是“拿到操作台上一次性处理完,只把最终成品放回仓库”。


十、反向传播:为什么还要重算?

训练时,反向传播需要用到前向过程中的一些中间结果。

标准 Attention 会直接保存完整的注意力概率矩阵 P,反向时直接读出来。

FlashAttention 不保存 P,那反向怎么办?

答案是:重算

反向传播时,它根据保存的 Oml 以及原始的 Q/K/V,重新分块计算需要的 softmax 结果。

这是一种典型的时间换空间:

方案 显存 计算
标准 Attention 少一点
FlashAttention 多一点重算

但在 GPU 上,显存 IO 往往比多做一点计算更贵。所以整体上 FlashAttention 仍然更快、更省。


十一、FlashAttention v1、v2、v3 有什么区别?

11.1 FlashAttention v1

v1 的核心贡献是提出 IO-aware exact attention

  • 分块计算
  • 在线 softmax
  • 不保存完整 attention matrix
  • 显著降低 HBM 读写

它证明了:注意力优化不能只看 FLOPs,还必须看显存 IO。

11.2 FlashAttention v2

v2 主要是工程层面的进一步优化:

  • 更好的并行化策略
  • 更少的非矩阵乘法操作
  • 更高的 Tensor Core 利用率
  • 支持更广泛的 head dimension
  • 对长序列训练更友好

如果说 v1 是“算法突破”,v2 就是“把 GPU 榨得更干净”。

11.3 FlashAttention v3

v3 面向 Hopper 架构(如 H100)进一步优化:

  • 利用 Hopper 的异步能力
  • 更好地重叠数据搬运和计算
  • 针对 FP8 等低精度计算优化
  • 进一步提升 H100 上的吞吐

简单理解:

1
2
3
FlashAttention v1:想明白怎么省 IO
FlashAttention v2:在 A100 等 GPU 上跑得更满
FlashAttention v3:为 H100 / Hopper 继续深度压榨硬件

十二、它和 Sparse Attention、Linear Attention 有什么区别?

这是非常容易混淆的点。

方法 是否精确 核心思路 结果是否等价标准 Attention
FlashAttention 精确 改变计算顺序,减少 IO 等价
Sparse Attention 近似 / 结构限制 只看部分 token 不一定等价
Linear Attention 近似 / 核技巧 避免显式 $QK^T$ 通常不等价
Sliding Window Attention 结构限制 只看局部窗口 不等价完整 Attention

FlashAttention 最重要的特点是:

它不是换了 Attention,而是把同一个 Attention 算得更聪明。

所以使用 FlashAttention 通常不需要重新训练模型结构,也不会改变模型输出的理论含义。


十三、什么时候收益最大?

FlashAttention 的收益和场景强相关。

13.1 收益明显的场景

  1. 长上下文训练:比如 4K、8K、32K、128K tokens。
  2. 大 batch 训练:中间 attention matrix 数量巨大。
  3. Decoder-only LLM:每层都有 causal attention。
  4. 显存吃紧:原本因为 OOM 训不起来的长序列任务。
  5. A100 / H100 等现代 GPU:高速 SRAM、Tensor Core、CUDA kernel fusion 优势明显。

13.2 收益不明显的场景

  1. 序列很短:比如 128 或 256 tokens,标准 Attention 中间矩阵不大。
  2. CPU 推理:FlashAttention 是 GPU kernel 优化,CPU 上意义不大。
  3. 模型瓶颈不在 Attention:例如 FFN/MoE 或数据加载成为瓶颈。
  4. 显卡太旧:某些低版本 CUDA / 旧架构 GPU 支持不佳。

十四、PyTorch 中怎么用?

现在很多框架已经把 FlashAttention 类似的优化封装进标准接口。

14.1 PyTorch SDPA

PyTorch 2.x 提供了 scaled_dot_product_attention

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
import torch.nn.functional as F

q = torch.randn(2, 8, 1024, 64, device="cuda", dtype=torch.float16)
k = torch.randn(2, 8, 1024, 64, device="cuda", dtype=torch.float16)
v = torch.randn(2, 8, 1024, 64, device="cuda", dtype=torch.float16)

out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=0.0,
is_causal=True
)

在满足条件时,PyTorch 会自动选择更高效的 kernel(可能是 FlashAttention、memory-efficient attention 或 math fallback)。

14.2 Transformers 中的使用

Hugging Face Transformers 中,很多模型可以通过参数启用:

1
2
3
4
5
6
7
8
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype="auto",
attn_implementation="flash_attention_2",
device_map="auto"
)

实际能否启用取决于:

  • GPU 架构
  • CUDA 版本
  • PyTorch 版本
  • flash-attn 包是否安装
  • head dimension 是否支持
  • attention mask 形式是否兼容

十五、常见误区

误区 1:FlashAttention 是近似算法

错误。FlashAttention 是精确 attention,数学结果与标准 attention 等价,差异主要来自浮点精度和计算顺序。

误区 2:FlashAttention 会减少理论 FLOPs

不准确。它主要减少的是显存 IO,不是从公式上消灭矩阵乘法。实际速度提升来自更少的数据搬运和更好的 kernel fusion。

误区 3:用了 FlashAttention 就一定更快

不一定。短序列、小 batch、旧 GPU、非 attention 瓶颈场景下,收益可能不明显。

误区 4:FlashAttention 能解决所有长上下文问题

不能。它降低了注意力计算的显存和 IO 压力,但完整 attention 的计算量仍然与 $N^2$ 有关。超长上下文还需要 RoPE 扩展、稀疏注意力、滑动窗口、KV Cache 管理等技术配合。


十六、面试回答模板

如果面试官问:“FlashAttention 为什么快?”可以这样回答:

FlashAttention 的核心不是改变 Attention 公式,而是改变计算方式。标准 Attention 会显式生成并保存 $N \times N$ 的 score 和 probability 矩阵,导致大量 HBM 显存读写。FlashAttention 把 Q/K/V 分块加载到 SRAM 中,使用在线 softmax 一边计算一边更新归一化统计量,避免把完整注意力矩阵写回显存。这样额外中间显存从 $O(N^2)$ 降到 $O(N)$,并通过 kernel fusion 减少读写和调度开销,所以在长序列训练和大模型推理中速度更快、显存更省。


十七、终极总结

最后用 5 句话总结 FlashAttention:

  1. FlashAttention 是精确 Attention,不是近似算法。
  2. 它的核心优化对象是显存 IO,而不只是 FLOPs。
  3. 它通过分块计算和在线 Softmax,避免保存完整 $N \times N$ 注意力矩阵。
  4. 它把中间显存从 $O(N^2)$ 降到 $O(N)$,长序列收益尤其明显。
  5. FlashAttention v2/v3 继续围绕 GPU 并行、Tensor Core、Hopper 架构做深度优化,是现代 LLM 训练和推理的重要基础设施。

如果把标准 Attention 比作“把整张地图摊开在桌上找路”,那么 FlashAttention 就是“边走边看局部地图,同时记住全局方向”。它没有改变目的地,却极大减少了你搬地图、摊地图、收地图的时间。