【AI】FlashAttention 详解:为什么它能让大模型注意力计算又快又省显存?
一、先说结论:FlashAttention 到底是什么?
一句话概括:
FlashAttention 是一种 IO-aware 的精确注意力计算算法,它通过分块计算和在线 Softmax,避免显式存储巨大的注意力矩阵,从而大幅减少 GPU 显存读写。
这里有三个关键词:
- IO-aware:它关注的不只是浮点计算量 FLOPs,而是 GPU 显存和高速缓存之间的数据搬运成本。
- 精确注意力:它算出来的结果和标准 Attention 数学上等价,不是近似、不剪枝、不稀疏化。
- 不保存完整注意力矩阵:标准 Attention 会生成一个 $N \times N$ 的注意力矩阵,FlashAttention 避免把它完整写入显存。
很多人第一次听到 FlashAttention,会以为它是“更聪明的 Attention 结构”。其实不是。它不是新的模型结构,而是同一个公式的更高效实现方式。
二、先复习标准 Attention:慢在哪里?
Transformer 的核心注意力公式是:
其中:
- $Q$:Query,表示“我想找什么信息”
- $K$:Key,表示“我有什么标签”
- $V$:Value,表示“我真正提供的信息”
- $N$:序列长度
- $d$:每个 head 的维度
标准实现通常分三步:
1 | 第一步:S = QK^T # 得到 N × N 的打分矩阵 |
问题就出在中间的两个矩阵:S 和 P。
如果序列长度是 4096,那么每个 head 的注意力矩阵大小是:
1 | 4096 × 4096 = 16,777,216 个元素 |
FP16 每个元素 2 字节,仅一个矩阵就是:
1 | 16,777,216 × 2 ≈ 32 MB |
注意这只是一个 batch、一个 head、一个矩阵。实际训练中还有:
- 多个 batch
- 多个 attention head
- 多层 Transformer
- 前向保存中间结果给反向传播使用
所以标准 Attention 的显存压力会迅速爆炸。
三、真正的瓶颈:不是算不动,而是搬不动
很多新手会以为 GPU 慢是因为“乘法太多”。但在 Attention 里,一个更关键的问题是:数据搬运太多。
GPU 内部大致可以理解成三层存储:
| 存储位置 | 速度 | 容量 | 类比 |
|---|---|---|---|
| HBM / 显存 | 慢 | 大 | 仓库 |
| SRAM / Shared Memory | 快 | 小 | 操作台 |
| Register / 寄存器 | 最快 | 极小 | 手里拿着的工具 |
标准 Attention 的做法像这样:
1 | Q、K 从显存读入 → 计算 S → 把 S 写回显存 |
也就是说,中间的巨大矩阵 S 和 P 被反复写入、读出显存。
这就是 FlashAttention 论文里强调的核心:
Attention 的瓶颈不是只有 FLOPs,还有 HBM 和 SRAM 之间的 IO。
四、FlashAttention 的核心思想:不要把整张注意力表摊开
标准 Attention 像是在处理一张巨大的 Excel 表:
- 先算完整的
S = QK^T - 再对整张表做 softmax
- 再拿整张表乘以
V
FlashAttention 的想法是:
不要一次性生成整张 N×N 表,而是把 Q、K、V 切成小块,在 GPU 高速缓存里一块一块算。
可以把它想象成做饭:
- 标准 Attention:把所有菜一次性摊满整个厨房,切完、炒完、装盘,厨房爆炸。
- FlashAttention:每次只拿一小篮菜到操作台,处理完立刻合并到最终成品,不把半成品堆满厨房。
五、分块计算:Tiling
FlashAttention 会把矩阵切成 block。例如:
1 | Q 分成 Q1, Q2, Q3, ... |
然后每次只计算一个小块:
1 | Qi × Kj^T |
得到的是一个小的 attention score block,而不是完整的 N × N 矩阵。
这个思想本身不难,但真正困难在 softmax。
为什么?因为 softmax 的分母需要知道一整行所有元素:
如果你只看一小块,就不知道整行的最大值和分母是多少。FlashAttention 的关键创新就在这里:在线 Softmax。
六、在线 Softmax:边看边更新全局结果
普通 softmax 通常要先拿到整行数据,然后做三步:
1 | 1. 找到这一行的最大值 m |
FlashAttention 不能一次看完整行,只能一块一块看。因此它维护两个运行中的变量:
m:当前已经看过的所有块中的最大值l:当前 softmax 分母的累计值
当新 block 到来时,更新:
这一步非常重要。它解决了两个问题:
- 数值稳定:始终减去最大值,避免指数爆炸。
- 分块等价:虽然一块一块算,但最后结果和一次性 softmax 完全一致。
直觉上看,在线 softmax 就像你在分批统计全班成绩:
- 每来一批学生,更新当前最高分。
- 如果新的最高分变了,之前那批学生的相对权重也要重新缩放。
- 最后得到的统计结果和一次性看完整个班一样。
七、FlashAttention 的完整流程
以一个 Q block 为例,FlashAttention 大致做下面这些事:
1 | for each Q_block: |
关键点:
中间的
S_block和P_block只存在于 SRAM / register 中,用完就丢,不写回显存。
标准 Attention 需要保存完整 S 和 P;FlashAttention 只保存最终输出 O 以及少量 softmax 统计量。
八、为什么它省显存?
标准 Attention 的中间显存主要来自:
1 | S = QK^T # N × N |
训练时还要保存 P 给反向传播用。
FlashAttention 不保存完整的 S 和 P,只保存:
1 | O # N × d |
所以中间显存从近似:
1 | O(N²) |
降到:
1 | O(N) |
注意:这里说的是额外中间显存,不是 Q/K/V 本身的显存,也不是模型参数显存。
这也是为什么长序列场景下 FlashAttention 特别有价值。序列越长,$N^2$ 越可怕,省下来的显存越多。
九、为什么它更快?
FlashAttention 更快,不是因为数学计算量从根本上少了。它依然要算注意力,本质 FLOPs 没有消失。
它快在:
- 减少 HBM 读写:不把
S和P这种巨大矩阵写回显存。 - 充分利用 SRAM:小块数据放在高速缓存里反复使用。
- Kernel Fusion:把矩阵乘、mask、softmax、dropout、乘 V 等步骤融合在一个 CUDA kernel 里,减少中间结果落地。
- 更好的并行调度:FlashAttention-2 进一步优化了 block 划分和 warp-level 并行。
可以这样理解:
标准 Attention 是“算一步、存一步、再读回来算下一步”;FlashAttention 是“拿到操作台上一次性处理完,只把最终成品放回仓库”。
十、反向传播:为什么还要重算?
训练时,反向传播需要用到前向过程中的一些中间结果。
标准 Attention 会直接保存完整的注意力概率矩阵 P,反向时直接读出来。
FlashAttention 不保存 P,那反向怎么办?
答案是:重算。
反向传播时,它根据保存的 O、m、l 以及原始的 Q/K/V,重新分块计算需要的 softmax 结果。
这是一种典型的时间换空间:
| 方案 | 显存 | 计算 |
|---|---|---|
| 标准 Attention | 高 | 少一点 |
| FlashAttention | 低 | 多一点重算 |
但在 GPU 上,显存 IO 往往比多做一点计算更贵。所以整体上 FlashAttention 仍然更快、更省。
十一、FlashAttention v1、v2、v3 有什么区别?
11.1 FlashAttention v1
v1 的核心贡献是提出 IO-aware exact attention:
- 分块计算
- 在线 softmax
- 不保存完整 attention matrix
- 显著降低 HBM 读写
它证明了:注意力优化不能只看 FLOPs,还必须看显存 IO。
11.2 FlashAttention v2
v2 主要是工程层面的进一步优化:
- 更好的并行化策略
- 更少的非矩阵乘法操作
- 更高的 Tensor Core 利用率
- 支持更广泛的 head dimension
- 对长序列训练更友好
如果说 v1 是“算法突破”,v2 就是“把 GPU 榨得更干净”。
11.3 FlashAttention v3
v3 面向 Hopper 架构(如 H100)进一步优化:
- 利用 Hopper 的异步能力
- 更好地重叠数据搬运和计算
- 针对 FP8 等低精度计算优化
- 进一步提升 H100 上的吞吐
简单理解:
1 | FlashAttention v1:想明白怎么省 IO |
十二、它和 Sparse Attention、Linear Attention 有什么区别?
这是非常容易混淆的点。
| 方法 | 是否精确 | 核心思路 | 结果是否等价标准 Attention |
|---|---|---|---|
| FlashAttention | 精确 | 改变计算顺序,减少 IO | 等价 |
| Sparse Attention | 近似 / 结构限制 | 只看部分 token | 不一定等价 |
| Linear Attention | 近似 / 核技巧 | 避免显式 $QK^T$ | 通常不等价 |
| Sliding Window Attention | 结构限制 | 只看局部窗口 | 不等价完整 Attention |
FlashAttention 最重要的特点是:
它不是换了 Attention,而是把同一个 Attention 算得更聪明。
所以使用 FlashAttention 通常不需要重新训练模型结构,也不会改变模型输出的理论含义。
十三、什么时候收益最大?
FlashAttention 的收益和场景强相关。
13.1 收益明显的场景
- 长上下文训练:比如 4K、8K、32K、128K tokens。
- 大 batch 训练:中间 attention matrix 数量巨大。
- Decoder-only LLM:每层都有 causal attention。
- 显存吃紧:原本因为 OOM 训不起来的长序列任务。
- A100 / H100 等现代 GPU:高速 SRAM、Tensor Core、CUDA kernel fusion 优势明显。
13.2 收益不明显的场景
- 序列很短:比如 128 或 256 tokens,标准 Attention 中间矩阵不大。
- CPU 推理:FlashAttention 是 GPU kernel 优化,CPU 上意义不大。
- 模型瓶颈不在 Attention:例如 FFN/MoE 或数据加载成为瓶颈。
- 显卡太旧:某些低版本 CUDA / 旧架构 GPU 支持不佳。
十四、PyTorch 中怎么用?
现在很多框架已经把 FlashAttention 类似的优化封装进标准接口。
14.1 PyTorch SDPA
PyTorch 2.x 提供了 scaled_dot_product_attention:
1 | import torch |
在满足条件时,PyTorch 会自动选择更高效的 kernel(可能是 FlashAttention、memory-efficient attention 或 math fallback)。
14.2 Transformers 中的使用
Hugging Face Transformers 中,很多模型可以通过参数启用:
1 | from transformers import AutoModelForCausalLM |
实际能否启用取决于:
- GPU 架构
- CUDA 版本
- PyTorch 版本
- flash-attn 包是否安装
- head dimension 是否支持
- attention mask 形式是否兼容
十五、常见误区
误区 1:FlashAttention 是近似算法
错误。FlashAttention 是精确 attention,数学结果与标准 attention 等价,差异主要来自浮点精度和计算顺序。
误区 2:FlashAttention 会减少理论 FLOPs
不准确。它主要减少的是显存 IO,不是从公式上消灭矩阵乘法。实际速度提升来自更少的数据搬运和更好的 kernel fusion。
误区 3:用了 FlashAttention 就一定更快
不一定。短序列、小 batch、旧 GPU、非 attention 瓶颈场景下,收益可能不明显。
误区 4:FlashAttention 能解决所有长上下文问题
不能。它降低了注意力计算的显存和 IO 压力,但完整 attention 的计算量仍然与 $N^2$ 有关。超长上下文还需要 RoPE 扩展、稀疏注意力、滑动窗口、KV Cache 管理等技术配合。
十六、面试回答模板
如果面试官问:“FlashAttention 为什么快?”可以这样回答:
FlashAttention 的核心不是改变 Attention 公式,而是改变计算方式。标准 Attention 会显式生成并保存 $N \times N$ 的 score 和 probability 矩阵,导致大量 HBM 显存读写。FlashAttention 把 Q/K/V 分块加载到 SRAM 中,使用在线 softmax 一边计算一边更新归一化统计量,避免把完整注意力矩阵写回显存。这样额外中间显存从 $O(N^2)$ 降到 $O(N)$,并通过 kernel fusion 减少读写和调度开销,所以在长序列训练和大模型推理中速度更快、显存更省。
十七、终极总结
最后用 5 句话总结 FlashAttention:
- FlashAttention 是精确 Attention,不是近似算法。
- 它的核心优化对象是显存 IO,而不只是 FLOPs。
- 它通过分块计算和在线 Softmax,避免保存完整 $N \times N$ 注意力矩阵。
- 它把中间显存从 $O(N^2)$ 降到 $O(N)$,长序列收益尤其明显。
- FlashAttention v2/v3 继续围绕 GPU 并行、Tensor Core、Hopper 架构做深度优化,是现代 LLM 训练和推理的重要基础设施。
如果把标准 Attention 比作“把整张地图摊开在桌上找路”,那么 FlashAttention 就是“边走边看局部地图,同时记住全局方向”。它没有改变目的地,却极大减少了你搬地图、摊地图、收地图的时间。