【AI】FlashAttention 详解:为什么它能让大模型注意力计算又快又省显存?
一、先说结论:FlashAttention 到底是什么?
一句话概括:
FlashAttention 是一种 IO-aware 的精确注意力计算算法,它通过分块计算和在线 Softmax,避免显式存储巨大的注意力矩阵,从而大幅减少 GPU 显存读写。
这里有三个关键词:
- IO-aware:它关注的不只是浮点计算量 FLOPs,而是 GPU 显存和高速缓存之间的数据搬运成本。
- 精确注意力:它算出来的结果和标准 Attention 数学上等价,不是近似、不剪枝、不稀疏化。
- 不保存完整注意力矩阵:标准 Attention 会生成一个 \( N \times N \) 的注意力矩阵,FlashAttention 避免把它完整写入显存。
很多人第一次听到 FlashAttention,会以为它是“更聪明的 Attention 结构”。其实不是。它不是新的模型结构,而是同一个公式的更高效实现方式。
二、先复习标准 Attention:慢在哪里?
Transformer 的核心注意力公式是:
其中:
- \( Q \):Query,表示“我想找什么信息”
- \( K \):Key,表示“我有什么标签”
- \( V \):Value,表示“我真正提供的信息”
- \( N \):序列长度
- \( d \):每个 head 的维度
标准实现通常分三步:
1 | 第一步:S = QK^T # 得到 N × N 的打分矩阵 |
问题就出在中间的两个矩阵:S 和 P。
如果序列长度是 4096,那么每个 head 的注意力矩阵大小是:
1 | 4096 × 4096 = 16,777,216 个元素 |
FP16 每个元素 2 字节,仅一个矩阵就是:
1 | 16,777,216 × 2 ≈ 32 MB |
注意这只是一个 batch、一个 head、一个矩阵。实际训练中还有:
- 多个 batch
- 多个 attention head
- 多层 Transformer
- 前向保存中间结果给反向传播使用
所以标准 Attention 的显存压力会迅速爆炸。
三、真正的瓶颈:不是算不动,而是搬不动
很多新手会以为 GPU 慢是因为“乘法太多”。但在 Attention 里,一个更关键的问题是:数据搬运太多。
GPU 内部大致可以理解成三层存储:
| 存储位置 | 典型容量 | 典型带宽 | 相对速度 | 类比 |
|---|---|---|---|---|
| HBM / 显存 | 数十 GB | ~1.5–3 TB/s | 慢 | 仓库 |
| SRAM / Shared Memory(每 SM) | ~100 KB 量级 | ~10+ TB/s | 快 | 操作台 |
| Register / 寄存器 | 几 KB | 极快 | 最快 | 手里拿着的工具 |
注意:HBM 容量大但带宽相对慢,SRAM 容量极小但带宽快一个数量级。这就是 FlashAttention 设计的物理基础:尽量让数据呆在 SRAM 里多算几次,少进出 HBM。
标准 Attention 的做法像这样:
1 | Q、K 从显存读入 → 计算 S → 把 S 写回显存 |
也就是说,中间的巨大矩阵 S 和 P 被反复写入、读出显存。
这就是 FlashAttention 论文里强调的核心:
Attention 的瓶颈不是只有 FLOPs,还有 HBM 和 SRAM 之间的 IO。
3.1 标准 Attention vs FlashAttention 数据流对比
下面这张图直接把两者放在一起对比,最能体现”是否反复进出 HBM”这一关键差别:
四、FlashAttention 的核心思想:不要把整张注意力表摊开
标准 Attention 像是在处理一张巨大的 Excel 表:
- 先算完整的
S = QK^T - 再对整张表做 softmax
- 再拿整张表乘以
V
FlashAttention 的想法是:
不要一次性生成整张 N×N 表,而是把 Q、K、V 切成小块,在 GPU 高速缓存里一块一块算。
可以把它想象成做饭:
- 标准 Attention:把所有菜一次性摊满整个厨房,切完、炒完、装盘,厨房爆炸。
- FlashAttention:每次只拿一小篮菜到操作台,处理完立刻合并到最终成品,不把半成品堆满厨房。
五、分块计算:Tiling
FlashAttention 会把矩阵切成 block。例如:
1 | Q 分成 Q1, Q2, Q3, ... |
然后每次只计算一个小块:
1 | Qi × Kj^T |
得到的是一个小的 attention score block,而不是完整的 N × N 矩阵。
这个思想本身不难,但真正困难在 softmax。
为什么?因为 softmax 的分母需要知道一整行所有元素:
如果你只看一小块,就不知道整行的最大值和分母是多少。FlashAttention 的关键创新就在这里:在线 Softmax。
六、在线 Softmax:边看边更新全局结果
普通 softmax 通常要先拿到整行数据,然后做三步:
1 | 1. 找到这一行的最大值 m |
FlashAttention 不能一次看完整行,只能一块一块看。因此对每一行 Query,它维护三个运行中的变量:
- \( m \):当前已经看过的所有块中的最大值(per-row)
- \( \ell \):当前 softmax 未归一化 的分母累计值
- \( O \):当前未归一化的输出累计值(行向量)
每当一个新的 K/V 块到来,FlashAttention 做如下更新:
第 1 步,算这一块的局部分数与局部最大值:
第 2 步,更新全局最大值,并把”旧累积”按新最大值重新缩放:
第 3 步,对输出 O 做同样的 rescale,然后把当前块的贡献加进来:
最后所有块都处理完后,再除以 \( \ell \) 完成归一化:
这一套公式解决了两个问题:
- 数值稳定:始终减去当前最大值 \( m_{new} \),避免 \( e^{x} \) 上溢。
- 分块等价:虽然一块一块累加,但最终结果与一次性对整行做 softmax 严格相等(不是近似)。
关键直觉:每当出现新最大值时,所有”旧的累积”都按 \( e^{m_{old}-m_{new}} \) 收缩一次,这样大家始终处在同一个”减最大值”的尺度下,可以放心相加。
下面这张图把”新块到来 → rescale 旧累积 → 合并新贡献”的过程画出来:
直觉上看,在线 softmax 就像你在分批统计全班成绩:
- 每来一批学生,更新当前最高分。
- 如果新的最高分变了,之前那批学生的相对权重也要按比例重新缩放。
- 最后得到的统计结果和一次性看完整个班一样。
七、FlashAttention 的完整流程
下面给出 FlashAttention v2 风格的伪代码(外层是 Q,内层是 K/V,对推理更友好;v1 的循环顺序与之相反,详见第十一节):
1 | for each Q_block: # 外循环:固定一段 Query |
关键点:
中间的
S和P只存在于 SRAM / register 中,用完就丢,不写回 HBM。 写回 HBM 的只有最终输出O和每一行的 softmax 统计量L = m + log(l)(反向时用得上,但只是 \( O(N) \) 大小)。
可以把 v2 的”双层循环”画成下面这张矩阵图:
八、为什么它省显存?
标准 Attention 的中间显存主要来自:
1 | S = QK^T # N × N |
训练时还要保存 P 给反向传播用。
FlashAttention 不保存完整的 S 和 P,只保存:
1 | O # N × d |
所以中间显存从近似:
1 | O(N²) |
降到:
1 | O(N) |
注意:这里说的是额外中间显存,不是 Q/K/V 本身的显存,也不是模型参数显存。
这也是为什么长序列场景下 FlashAttention 特别有价值。序列越长,\( N^2 \) 越可怕,省下来的显存越多。
九、为什么它更快?
FlashAttention 更快,不是因为数学计算量从根本上少了。它依然要算注意力,本质 FLOPs 没有消失。
它快在:
- 减少 HBM 读写:不把
S和P这种巨大矩阵写回显存。 - 充分利用 SRAM:小块数据放在高速缓存里反复使用。
- Kernel Fusion:把矩阵乘、mask、softmax、dropout、乘 V 等步骤融合在一个 CUDA kernel 里,减少中间结果落地。
- 更好的并行调度:FlashAttention-2 进一步优化了 block 划分和 warp-level 并行。
可以这样理解:
标准 Attention 是“算一步、存一步、再读回来算下一步”;FlashAttention 是“拿到操作台上一次性处理完,只把最终成品放回仓库”。
十、反向传播:为什么还要重算?
训练时,反向传播需要用到前向过程中的一些中间结果。
标准 Attention 会直接保存完整的注意力概率矩阵 P,反向时直接读出来。
FlashAttention 不保存 P,那反向怎么办?
答案是:重算(recomputation)。
前向阶段,FlashAttention 写回 HBM 的”反向所需”的中间量只有:
- 输出
O(\( N \times d \)) - 每行的 softmax 统计量
L = m + log(l)(\( O(N) \))
反向阶段,对每个 (Q_block, K_block) 对,重新加载 Q、K、V,重新算一次:
1 | S = Q · Kᵀ / sqrt(d) |
这里有个非常优雅的细节:因为 L = m + log(l),所以 \( \exp(S - L) = \exp(S - m)/l \),也就是真正的 softmax 概率。换句话说:只要保存一行一个 L 标量,就足以在反向时一次性还原整行的概率,不需要保存 N×N 的 P,也不需要保存中间 m、l。
这是一种典型的时间换空间:
| 方案 | 显存 | 计算 |
|---|---|---|
| 标准 Attention | 高 | 少一点 |
| FlashAttention | 低 | 多一点重算 |
但在 GPU 上,显存 IO 往往比多做一点计算更贵。所以整体上 FlashAttention 仍然更快、更省。
十一、FlashAttention v1、v2、v3 有什么区别?
11.1 FlashAttention v1(2022)
v1 的核心贡献是提出 IO-aware exact attention:
- 分块计算(tiling)
- 在线 softmax
- 不保存完整 attention matrix
- 显著降低 HBM 读写
它证明了:注意力优化不能只看 FLOPs,还必须看显存 IO。
特点:v1 的循环顺序是 外循环 K/V、内循环 Q,这种顺序在反向上比较自然,但前向时 Q 的 rescale 频繁、并行度受限。
11.2 FlashAttention v2(2023)
v2 在工程实现上做了系统性的重构:
- 互换循环顺序:把 Q 放到外层、K/V 放到内层,每个 Q 行只需一次 rescale,前向显著加速。
- 减少非矩阵乘法操作(更少的 rescale、除法、log),把更多 FLOPs 留给 Tensor Core 这种”会算大块矩阵乘”的硬件单元。
- 更好的并行化:在 batch、head、序列维度上同时切分,提升 GPU 占用率。
- 更广 head dimension 与 mask 支持(支持 causal、local、ALiBi 等)。
- 端到端在 A100 上比 v1 快约 2× 左右(论文给出的典型数字)。
如果说 v1 是”算法突破”,v2 就是”把 GPU 榨得更干净”。
11.3 FlashAttention v3(2024)
v3 面向 NVIDIA Hopper 架构(H100)做硬件深度适配:
- Warp-specialized producer/consumer 调度:一部分 warp 专门负责异步搬运数据(TMA),另一部分 warp 专门做矩阵乘 + softmax,让数据搬运和计算真正重叠起来。
- GEMM 与 Softmax 的指令级流水:在 register 层面交错排布,进一步隐藏 softmax 的非矩阵乘开销。
- FP8 支持:在低精度路径下做了块级缩放(block scaling),尽量保留数值精度,但 FP8 仍可能有可观的精度下降,需要业务方做评估。
- 在 H100 上 FP16 比 v2 快约 1.5–2×,FP8 接近 1.2 PFLOPs/s 量级(论文数据)。
简单理解:
1 | FlashAttention v1:想明白怎么省 IO(A100 上 ~2× 提升) |
十二、它和 Sparse Attention、Linear Attention 有什么区别?
这是非常容易混淆的点。
| 方法 | 是否精确 | 核心思路 | 结果是否等价标准 Attention |
|---|---|---|---|
| FlashAttention | 精确 | 改变计算顺序,减少 IO | 等价 |
| Sparse Attention | 近似 / 结构限制 | 只看部分 token | 不一定等价 |
| Linear Attention | 近似 / 核技巧 | 避免显式 \( QK^T \) | 通常不等价 |
| Sliding Window Attention | 结构限制 | 只看局部窗口 | 不等价完整 Attention |
FlashAttention 最重要的特点是:
它不是换了 Attention,而是把同一个 Attention 算得更聪明。
所以使用 FlashAttention 通常不需要重新训练模型结构,也不会改变模型输出的理论含义。
十三、训练、Prefill、Decode 三种场景下的收益不同
很多人只笼统地说 “FlashAttention 提速很多”,但实际上它在三种典型场景里的收益差别很大,搞清这一点能避免很多”为什么我用了没快多少”的困惑。
| 场景 | 序列长度 | Q 长度 | KV 长度 | FlashAttention 收益 |
|---|---|---|---|---|
| 训练 | 长 | 全长 | 全长 | 极大:N×N 中间矩阵全省下来 |
| 推理 prefill | 长 | 全 prompt 长度 | 全 prompt 长度 | 大:和训练前向类似 |
| 推理 decode | 长 | 1(只算新 token) | 全 KV cache | 小:瓶颈不在 N×N 而在 KV 加载 |
为什么 decode 阶段收益小?因为这时 Q 只有 1 个 token,注意力矩阵从 N×N 退化成 1×N,不再有 O(N²) 中间矩阵可省。这时真正的瓶颈是把整段 KV cache 从 HBM 读进来——属于 memory-bound 而非 compute-bound 问题。
针对这一痛点,社区提出了 FlashDecoding(2023)和 FlashDecoding++(2024):
- 把 KV 维度也切成多个 chunk,沿 KV 维度并行多个 SM 同时算
- 最后用一次 reduction 合并 partial softmax 结果
- 在长上下文 decode 上能再带来 2–4× 加速
十四、与 PagedAttention、KV Cache 的关系
经常有同学问:”vLLM 用的 PagedAttention 和 FlashAttention 是什么关系?是替代品吗?”
答案是:不是替代,而是叠加。
| 优化 | 解决的问题 | 工作层 |
|---|---|---|
| FlashAttention | 注意力计算本身的 IO/显存 | 单次 attention kernel 内部 |
| PagedAttention | KV cache 在多请求间的碎片化与浪费 | KV cache 的内存管理 |
| KV Cache | decode 时不重复算历史 token | 前向流程层面 |
| Continuous Batching | 请求级调度,提升吞吐 | 调度器层 |
你完全可以同时用 FlashAttention 内核 + PagedAttention 内存管理 + Continuous Batching 调度,三者互补,事实上 vLLM 的高性能内核就是这么组合的。
十三、什么时候收益最大?
FlashAttention 的收益和场景强相关。
13.1 收益明显的场景
- 长上下文训练:比如 4K、8K、32K、128K tokens。
- 大 batch 训练:中间 attention matrix 数量巨大。
- Decoder-only LLM:每层都有 causal attention。
- 显存吃紧:原本因为 OOM 训不起来的长序列任务。
- A100 / H100 等现代 GPU:高速 SRAM、Tensor Core、CUDA kernel fusion 优势明显。
13.2 收益不明显的场景
- 序列很短:比如 128 或 256 tokens,标准 Attention 中间矩阵不大。
- CPU 推理:FlashAttention 是 GPU kernel 优化,CPU 上意义不大。
- 模型瓶颈不在 Attention:例如 FFN/MoE 或数据加载成为瓶颈。
- 显卡太旧:某些低版本 CUDA / 旧架构 GPU 支持不佳。
十四、PyTorch 中怎么用?
现在很多框架已经把 FlashAttention 类似的优化封装进标准接口。
14.1 PyTorch SDPA
PyTorch 2.x 提供了 scaled_dot_product_attention(简称 SDPA):
1 | import torch |
PyTorch 在底层会从三种后端中自动挑选:
flash—— FlashAttention(v2/v3,看 PyTorch 版本与硬件)mem_efficient—— xFormers 风格的 memory-efficient attentionmath—— 数学回退实现(保底,会显式生成 N×N 矩阵)
要让 SDPA 真正命中 flash 后端,下面这些条件最容易踩坑:
q/k/v必须是float16或bfloat16(FP32 通常走不到 flash)- 在 CUDA 设备上、CC ≥ 8.0(Ampere 及以上)
head_dim是受限集合(v2 一般支持 ≤ 256,v1 时代仅 64/128)attn_mask必须是None或可转换成 causal/标准的形式(任意 dense mask 会回退)dropout_p必须是 0(部分版本)- 张量必须连续
可以用上下文管理器强制只允许 flash 后端,从而验证是否真的走到:
1 | from torch.nn.attention import sdpa_kernel, SDPBackend |
如果不满足条件,PyTorch 会直接报错,从而暴露你”以为命中了 flash,实际走了 math fallback”的隐形性能 bug。
14.2 Hugging Face Transformers 中的使用
很多模型可以通过参数显式启用:
1 | from transformers import AutoModelForCausalLM |
attn_implementation 常见取值:
"eager"—— 原始实现(保底)"sdpa"—— 走 PyTorch SDPA(自动挑后端)"flash_attention_2"—— 强制走 FlashAttention-2 内核(需要flash-attn包)
实际能否启用取决于:
- GPU 架构(Ampere/Hopper 通常 OK,Turing/Volta 一般不行)
- CUDA 版本与 PyTorch 版本
flash-attn包是否安装、版本是否匹配- head dimension 是否在受支持集合
- attention mask 是否兼容(自定义 dense mask 通常不支持)
14.3 直接调 flash-attn 包
最底层的用法是直接用 flash-attn 包提供的函数:
1 | from flash_attn import flash_attn_func |
直接调更可控,也能用上 v3 的特性(如 FP8、不同的 mask 类型),但需要自己处理张量布局、mask、batched variable-length 等细节,业务代码一般用 SDPA 或 Transformers 即可。
十六、常见误区
误区 1:FlashAttention 是近似算法
错误。FlashAttention 是精确 attention,数学结果与标准 attention 等价,差异主要来自浮点精度和计算顺序。
误区 2:FlashAttention 会减少理论 FLOPs
不准确。它主要减少的是显存 IO,不是从公式上消灭矩阵乘法。实际速度提升来自更少的数据搬运和更好的 kernel fusion。
误区 3:用了 FlashAttention 就一定更快
不一定。短序列、小 batch、旧 GPU、非 attention 瓶颈场景下,收益可能不明显;推理 decode 阶段尤其需要叠加 FlashDecoding 才有大收益。
误区 4:FlashAttention 能解决所有长上下文问题
不能。它降低了注意力计算的显存和 IO 压力,但完整 attention 的计算量仍然与 \( N^2 \) 有关。超长上下文还需要 RoPE 扩展、稀疏注意力、滑动窗口、KV Cache 管理等技术配合。
误区 5:FlashAttention 算出来和标准 Attention 比特一致
不是的。FlashAttention 的数学定义和标准 attention 等价,但浮点结果几乎一定有微小差异,原因有三:
- 加法顺序不同。浮点加法不满足结合律,分块累加和一次性求和会有不同的舍入。
- 指数 / 除法在 FP16/BF16 下精度本来就有限。
- v3 的 FP8 路径精度更低,需要业务做评估。
但这些误差通常远小于训练本身的浮点噪声,对模型质量没有可观察影响。
十七、面试回答模板
如果面试官问:”FlashAttention 为什么快?”可以这样回答:
FlashAttention 的核心不是改变 Attention 公式,而是改变计算方式。标准 Attention 会显式生成并保存 \( N \times N \) 的 score 和 probability 矩阵,导致大量 HBM 显存读写。FlashAttention 把 Q/K/V 分块加载到 SRAM 中,使用在线 softmax 一边计算一边更新归一化统计量,避免把完整注意力矩阵写回显存。这样额外中间显存从 \( O(N^2) \) 降到 \( O(N) \),并通过 kernel fusion 减少读写和调度开销,所以在长序列训练和大模型推理中速度更快、显存更省。
进一步可以补充:v2 把外循环换成 Q,提升前向并行度并减少非 matmul 操作;v3 在 Hopper 上做 warp specialization、把数据搬运和计算异步重叠,并支持 FP8。同时要区分场景:训练 / prefill 收益最大,decode 由于 Q 长度只有 1,需要叠加 FlashDecoding 才能有显著加速。
十八、终极总结
最后用 7 句话总结 FlashAttention:
- FlashAttention 是精确 Attention,不是近似算法。
- 它的核心优化对象是显存 IO,而不只是 FLOPs。
- 它通过分块计算和在线 Softmax,避免保存完整 \( N \times N \) 注意力矩阵。
- 它把额外中间显存从 \( O(N^2) \) 降到 \( O(N) \),长序列收益尤其明显。
- 训练 / prefill 收益最大;decode 由于 Q 只有 1,需要 FlashDecoding 才能继续加速。
- v1 → v2 → v3 是”算法 → 调度 → 硬件”的三层优化路径,分别对应 IO-aware、并行重排、Hopper 异步与 FP8。
- 它和 PagedAttention、KV Cache、Continuous Batching 是叠加关系,不是替代关系。
如果把标准 Attention 比作”把整张地图摊开在桌上找路”,那么 FlashAttention 就是”边走边看局部地图,同时记住全局方向”。它没有改变目的地,却极大减少了你搬地图、摊地图、收地图的时间。