高效推理(Efficient Inference)
通过量化、缓存优化、解码加速等手段,降低大模型推理的显存占用和延迟
通过分块计算、KV 共享、稀疏模式等手段优化 Transformer 注意力的速度和内存瓶颈。
内容摘要
注意力优化是一组围绕 Transformer 自注意力(Self-Attention)机制做加速和省内存的技术集合。它不改变模型的基本结构,而是在"怎样算注意力"这一步做文章——用更聪明的内存访问方式、更少的 KV 缓存(KV Cache)、更稀疏的计算模式,让大模型在同样的硬件上跑得更快、吃的显存更少。
注意力优化是一组围绕 Transformer 自注意力(Self-Attention)机制做加速和省内存的技术集合。它不改变模型的基本结构,而是在"怎样算注意力"这一步做文章——用更聪明的内存访问方式、更少的 KV 缓存(KV Cache)、更稀疏的计算模式,让大模型在同样的硬件上跑得更快、吃的显存更少。
标准自注意力的计算复杂度是 $O(N^2)$(N 为序列长度)。当序列长度从 2K 增长到 128K,注意力矩阵的大小从 400 万膨胀到 160 亿——既装不进显存,也算不过来。注意力优化就是要打破这个二次方的枷锁。
目前主流的优化路线有三条:第一条是 IO 感知优化(以 Flash Attention 为代表),不改变计算量但大幅减少显存读写;第二条是 KV 缓存压缩(以 MQA、GQA、MLA 为代表),减少需要存储的 Key/Value 数量;第三条是 稀疏/局部注意力(以滑动窗口、Longformer 为代表),让每个 Token 只关注一部分邻居而不是全序列。这三条路线互不排斥,可以组合使用。
| 优化路线 | 代表技术 | 核心思路 |
|---|---|---|
| IO 感知优化 | Flash Attention 系列 | 把大矩阵分块,在 GPU 片上高速缓存(SRAM)中完成计算,减少对慢速显存(HBM)的读写 |
| KV 缓存压缩 | MQA / GQA / MLA | 让多个查询头共享同一组 Key-Value,或用低秩压缩把 KV 缓存变小 |
| 稀疏注意力 | 滑动窗口 / Longformer / BigBird | 每个 Token 只关注局部窗口或少量全局 Token,把 $O(N^2)$ 降到 $O(N)$ |
| 推理内存管理 | PagedAttention(vLLM) | 把 KV 缓存像操作系统管理内存页一样按需分配,消除碎片浪费 |
标准实现会在 GPU 全局显存(HBM,High Bandwidth Memory,高带宽内存)中生成一整个 $N \times N$ 的注意力矩阵,然后再做 softmax 和矩阵乘法。HBM 虽然带宽大(约 2-3 TB/s),但相比 GPU 片上的 SRAM(约 19 TB/s 带宽,但容量只有几十 MB)还是慢很多。Flash Attention 的核心见解:瓶颈不是计算量,而是内存搬运量。它把 Q、K、V 分成小块,每块都在 SRAM 里完成全部计算,算完直接写回结果,不生成完整的 $N \times N$ 中间矩阵。
在自回归推理中,模型每生成一个 Token 都要读取之前所有 Token 的 Key 和 Value(即 KV Cache)。当序列很长、batch size 很大时,KV Cache 占用的显存比模型参数本身还多。
让每个 Token 不再关注全部 Token,而是只关注一个"子集":
vLLM 团队在 2023 年提出 PagedAttention,借鉴操作系统虚拟内存的分页机制管理 KV Cache。传统方式为每个请求预分配一大块连续显存,但因生成长度未知而大量浪费(实测利用率仅 20-40%)。PagedAttention 把 KV Cache 切成固定大小的"页"(block),按需分配,内存浪费从 60-80% 降到不足 4%,吞吐量提升 2-4 倍。
以 Flash Attention 为例,说明"IO 感知优化"是如何工作的。
标准注意力的计算公式:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$
直接算需要在显存中构造 $N \times N$ 的矩阵 $S = QK^T$,对 N=4096、d=64 而言,这个矩阵占 64 MB(FP16),对 N=128K 则需要 64 GB。
Flash Attention 的做法是分块 + 在线 softmax:
结果:计算量不变(仍然是 $O(N^2 d)$ 浮点运算),但 HBM 读写量从 $O(N^2 + Nd)$ 降到 $O(N^2 d^2 / M)$(M 为 SRAM 大小)。由于 GPU 是"带宽受限"的,减少读写量直接转化为速度提升。
GQA 的原理更直观:标准 MHA 中每个注意力头有独立的 $K_h$ 和 $V_h$,GQA 则把 h 个头分成 G 组,同组内的头共享同一份 K 和 V。推理时 KV Cache 大小从 $2 \times h \times d \times N$ 降到 $2 \times G \times d \times N$,当 G 远小于 h 时节省显著。
图中标红节点是标准注意力的瓶颈——在 HBM 中构造和读写 $N \times N$ 矩阵。标绿节点是 Flash Attention 的关键改进——在 SRAM 中完成计算,避免中间矩阵落地。下方展示了 KV 缓存从 MHA 到 GQA/MQA/MLA 的逐步压缩路径。
# 伪代码:Flash Attention 分块计算核心逻辑
# 说明 IO 感知优化的关键步骤,非完整实现
def flash_attention_pseudocode(Q, K, V, block_size):
"""
Q, K, V: 形状 (N, d)
block_size: 分块大小 Br/Bc
"""
N, d = Q.shape
O = zeros(N, d) # 输出矩阵
l = zeros(N) # softmax 分母累积
m = full(N, -inf) # 当前最大值(用于数值稳定)
# 外层循环:遍历 K, V 的块
for j in range(0, N, block_size):
Kj = K[j:j+block_size] # 从 HBM 加载一块 K
Vj = V[j:j+block_size] # 从 HBM 加载一块 V
# 内层循环:遍历 Q 的块
for i in range(0, N, block_size):
Qi = Q[i:i+block_size] # 从 HBM 加载一块 Q
# ---- 以下全部在 SRAM 中完成 ----
Sij = Qi @ Kj.T / sqrt(d) # 小块注意力分数
m_new = max(m[i:i+block_size], Sij.max(axis=1))
P = exp(Sij - m_new[:, None]) # 数值稳定的 softmax 分子
# 在线更新 softmax 累积量
l_new = exp(m[i:i+block_size] - m_new) * l[i:i+block_size] + P.sum(axis=1)
O[i:i+block_size] = (exp(m[i:i+block_size] - m_new)[:, None]
* O[i:i+block_size] + P @ Vj)
m[i:i+block_size] = m_new
l[i:i+block_size] = l_new
# 最终归一化
O = O / l[:, None]
return O
伪代码展示了 Flash Attention 的两层循环结构和在线 softmax 更新逻辑。m 维护当前最大值用于数值稳定,l 累积 softmax 分母。整个过程不生成完整的 $N \times N$ 矩阵,每个小块算完即合并到输出。
# GQA 的最小示意:分组共享 K, V
# 基于 PyTorch 语法(torch >= 2.0)
import torch
batch, seq_len, num_heads, head_dim = 1, 256, 8, 64
num_kv_groups = 2 # GQA: 8 个查询头分成 2 组,每组 4 个头共享一份 K,V
Q = torch.randn(batch, seq_len, num_heads, head_dim) # 8 个独立的查询头
K = torch.randn(batch, seq_len, num_kv_groups, head_dim) # 只有 2 组 K
V = torch.randn(batch, seq_len, num_kv_groups, head_dim) # 只有 2 组 V
# 将 K, V 扩展(repeat)到与 Q 相同的头数
heads_per_group = num_heads // num_kv_groups # = 4
K_expanded = K.repeat_interleave(heads_per_group, dim=2) # (1, 256, 8, 64)
V_expanded = V.repeat_interleave(heads_per_group, dim=2)
# 标准注意力计算(简化版)
scores = torch.einsum("bshd,bthd->bhst", Q, K_expanded) / (head_dim ** 0.5)
attn = scores.softmax(dim=-1)
output = torch.einsum("bhst,bthd->bshd", attn, V_expanded)
# KV Cache 对比
mha_cache = 2 * num_heads * head_dim * seq_len # MHA: 2×8×64×256 = 262144
gqa_cache = 2 * num_kv_groups * head_dim * seq_len # GQA: 2×2×64×256 = 65536
print(f"MHA KV Cache: {mha_cache} 个参数")
print(f"GQA KV Cache: {gqa_cache} 个参数")
print(f"GQA 节省: {(1 - gqa_cache/mha_cache)*100:.0f}%")
# 输出: GQA 节省: 75%
GQA 示例中,8 个查询头只需要 2 组 K-V,推理时 KV Cache 缩小 75%。repeat_interleave 操作在计算时将 K-V 广播到每个查询头,不额外占用缓存空间。
| 概念 | 与注意力优化的区别 | 更适合关注的重点 |
|---|---|---|
| 模型量化(Quantization) | 量化压缩的是模型权重和激活值的精度(如 FP16→INT8),注意力优化改的是注意力计算本身的流程和内存模式 | 量化关注"用更少的位数表示权重",注意力优化关注"用更少的内存搬运和更少的 KV 存储完成注意力计算" |
| 模型剪枝(Pruning) | 剪枝删掉不重要的权重或结构(如整层、整头),注意力优化不删除结构,而是改变注意力的计算和存储方式 | 剪枝是"让模型变小",注意力优化是"让同样大的模型跑得更快" |
| KV Cache | KV Cache 是推理时缓存历史 Key-Value 的机制本身,注意力优化(如 GQA、MLA、PagedAttention)是对 KV Cache 做压缩或管理的技术 | KV Cache 是"问题背景",注意力优化是"解决方案" |
| 模型并行(Model Parallelism) | 模型并行把模型切分到多张 GPU 上,注意力优化在单张 GPU 内部改善注意力计算效率 | 两者互补——注意力优化减少单卡开销,模型并行解决单卡放不下的问题 |
核心区别:
| 常见误区 | 正确理解 |
|---|---|
| Flash Attention 改变了注意力的计算结果 | Flash Attention 是精确实现(exact attention),输出与标准注意力在数学上完全一致。它只改变计算的执行顺序和内存访问模式,不做任何近似 |
| GQA 和 MQA 只是"缩水版"的 MHA,效果一定更差 | GQA 在多数基准测试上接近 MHA 的质量,LLaMA 2 70B 采用 GQA 后推理速度大幅提升且质量损失极小。MLA 甚至声称质量优于 MHA |
| 稀疏注意力一定会丢失远距离信息 | BigBird 的随机连接在概率上保证了远距离信息流通。此外层间堆叠效应使滑动窗口的有效感受野远大于单层窗口大小 |
| PagedAttention 只是 vLLM 的内部实现细节 | PagedAttention 的分页思想已被 TensorRT-LLM、HuggingFace TGI 等主流推理框架采纳,是当前 LLM 推理服务的标准技术 |
参考答案:
Flash Attention 的加速来自减少 GPU 显存(HBM)的读写次数,而非减少浮点运算量。标准实现需要把 $N \times N$ 的注意力矩阵在 HBM 中来回读写多次,而 Flash Attention 通过分块计算让每个小块在高速的 SRAM 中完成全部运算,只在最终写回结果时访问 HBM。由于 GPU 的性能瓶颈通常在内存带宽而非算力,减少读写就直接转化为时间缩短。
参考答案:
设注意力头数为 h、头维度为 d、序列长度为 N:
选型建议:对质量要求高且 KV Cache 不是瓶颈时用 MHA;追求推理速度且可接受少量质量损失时用 GQA(目前最主流);极端长序列场景用 MLA(需 DeepSeek 生态);MQA 更多出现在早期模型中,新模型一般优先选 GQA。
参考答案:
70B 模型典型配置:80 层、64 个注意力头、头维度 128。
MHA 的 KV Cache = $2 \times 80 \times 64 \times 128 \times 32768 \times 2$(FP16 每参数 2 字节)= 约 84 GB,单卡根本放不下(还没算模型参数本身约 140 GB)。
必须组合使用:(1) 模型量化(INT4 量化后模型约 35GB);(2) GQA(LLaMA 2 70B 用 8 组 GQA,KV Cache 降到 $2 \times 80 \times 8 \times 128 \times 32768 \times 2 \approx$ 10.5 GB);(3) Flash Attention(避免 $32K \times 32K$ 的中间矩阵占满显存);(4) PagedAttention(消除 KV Cache 的碎片浪费)。四者组合后模型参数(35GB)+ KV Cache(约 10GB)+ 激活等开销 ≈ 约 55GB,可以在 80GB A100 上运行。
优先展示同分类且标签更接近的内容,方便继续串联学习。
通过量化、缓存优化、解码加速等手段,降低大模型推理的显存占用和延迟
从 RNN 到 Transformer 再到 MoE,理解现代大模型架构的演进脉络
Transformer 中让模型理解词序的关键机制,从正弦编码到 RoPE、ALiBi 的演进与对比。