游乐游手机版
首页/AI教程/文章详情

大模型入门从MHA到GQA一次讲清KV Cache省显存原理

时间:2026-06-02 07:25
大模型推理中,键值缓存随序列长度和批次增长消耗大量显存。分组查询注意力通过减少键值头数量直接压缩缓存体积。例如同样配置下,多头注意力需4吉字节,分组查询注意力(键值头为8)仅需1吉字节,在效果与效率间取得平衡。

大模型入门:从 MHA 到 GQA,一次讲清 KV Cache 为什么能省显存

许多初次部署本地大模型的开发者,通常会以为显存主要被模型参数占用。这固然没错——一个 7B 模型即便使用 FP16 精度,参数部分也需要十几 GB 级别的显存。

但进入真实推理阶段后,你会发现另一项资源消耗增长更快。提示词越长,KV Cache 占用的空间越大;批次越大,KV Cache 膨胀得越厉害;上下文窗口越长,KV Cache 需求量也越大;并发请求越多,KV Cache 的管理也越复杂。

模型参数在加载时就基本固定了,而 KV Cache 是在生成过程中随请求数量、序列长度和批次大小持续动态增长的。这也正是服务端推理框架认真管理 KV Cache 的原因——vLLM 的 PagedAttention、Hugging Face 的 DynamicCache / StaticCache / QuantizedCache,本质上都在解决同一个核心问题:如何让历史 K/V 既能被快速读取,又不至于撑爆显存。

GQA 恰好处于这个问题的中心。

一句话概括:GQA 通过减少 KV Head 的数量,直接压缩了 KV Cache 的体积。

1. 先回忆:KV Cache 到底缓存了什么

Decoder-only 大模型推理通常分为两个阶段:

阶段输入主要动作
Prefill完整 prompt一次性计算 prompt 各层的 K/V,并写入缓存
Decode当前新 token只计算新 token 的 Q/K/V,并用新 Q 查询历史 K/V

Hugging Face 的缓存文档也强调过:自回归生成是一个 token 一个 token 逐步预测的,KV Cache 会保存过去 token 在注意力层中的 K/V,后续 token 可以复用这些缓存,避免重复计算。

上一篇文章中,我们使用的 MHA 张量形状为:

q.shape == [batch, num_heads, seq_len, head_dim]
k.shape == [batch, num_heads, seq_len, head_dim]
v.shape == [batch, num_heads, seq_len, head_dim]

每一层需要缓存历史 token 的 kv

past_k.shape == [batch, num_heads, past_len, head_dim]
past_v.shape == [batch, num_heads, past_len, head_dim]

注意,这里缓存的是每一层的 K/V。一个 32 层的模型,就会有 32 份这样的缓存。因此 KV Cache 的显存可以大致估算为:

KV Cache bytes = batch_size * seq_len * num_layers * 2 * num_kv_heads * head_dim * bytes_per_element

公式中的 2 代表 K 和 V 两份数据。最容易被忽略的参数是 num_kv_heads

在 MHA 中:num_kv_heads = num_query_heads。而在 GQA 中:num_kv_heads < num_query_heads。这正是 GQA 能够节省显存的关键入口。

2. 用一组数字算清楚

假设一个简化配置:

batch_size = 1
seq_len = 8192
num_layers = 32
num_query_heads = 32
head_dim = 128
dtype = fp16  # 2 bytes

如果是传统 MHA:num_kv_heads = 32,KV Cache 大约为:

1 * 8192 * 32 * 2 * 32 * 128 * 2 bytes = 4 GiB

如果换成 GQA,假设 num_kv_heads = 8,KV Cache 大约为:

1 * 8192 * 32 * 2 * 8 * 128 * 2 bytes = 1 GiB

同样是 32 个 Query Head、同样的上下文长度,仅仅把 KV Head 从 32 降到 8,缓存就变成了原来的四分之一。

换成 MQA:num_kv_heads = 1,KV Cache 会进一步降低到:

128 MiB

这只是一个教学估算,实际框架还受到 allocator、block size、padding、并发调度、量化和 kernel 实现等因素的影响。但作为面试和工程理解,这个公式已经足够抓住核心要点。

3. MHA、MQA、GQA 的区别

可以用一张表先记住:

结构Query HeadKV Head直觉
MHA多个和 Query 一样多每个 Q head 独享一组 K/V
MQA多个1 个所有 Q head 共享同一组 K/V
GQA多个介于 1 和 Query Head 之间一组 Q head 共享一组 K/V

假设:num_query_heads = 32num_kv_heads = 8,则 group_size = num_query_heads // num_kv_heads = 4

那么 GQA 的含义是:前 4 个 Q head(0123)共享一个 KV Head,接下来 4 个(4567)共享下一个,以此类推。它不像 MQA 那样把所有 Query Head 都压缩到同一个 KV Head 上,也不像 MHA 那样每个 Query Head 都保留独立的 K/V。

GQA 原论文的动机也正在于此:MQA 可以显著提升 decoder 推理速度,但可能带来质量下降;GQA 使用介于 1 和 Query Head 数之间的 KV Head 数量,在效果和推理效率之间取得折中。

4. 张量形状怎么变

MHA 的投影通常是:

q_proj: hidden_dim -> num_q_heads * head_dim
k_proj: hidden_dim -> num_q_heads * head_dim
v_proj: hidden_dim -> num_q_heads * head_dim

GQA 的投影变成:

q_proj: hidden_dim -> num_q_heads * head_dim
k_proj: hidden_dim -> num_kv_heads * head_dim
v_proj: hidden_dim -> num_kv_heads * head_dim

也就是说,Q 仍然有很多头,K/V 变少了。

假设:batch = 2seq_len = 5num_q_heads = 32num_kv_heads = 8head_dim = 128,那么:

q.shape == [2, 32, 5, 128]
k.shape == [2, 8, 5, 128]
v.shape == [2, 8, 5, 128]

但在 attention 计算时,q @ k.transpose(-2, -1) 要求 head 维度能够对齐。一种教学实现是把 K/V 按组展开:

k_expanded.shape == [2, 32, 5, 128]
v_expanded.shape == [2, 32, 5, 128]

PyTorch 的 scaled_dot_product_attention(enable_gqa=True) 文档中也展示了类似逻辑:启用 GQA 时,会根据 Query Head 和 KV Head 的比例对 key/value 做 repeat_interleave。但要注意,真实的高性能实现不一定真的物理复制 K/V。服务端推理更关心 cache 布局、访存效率和 kernel 的实现方式。

5. 手写一个最小 GQA

下面这份代码只保留核心逻辑,适合面试讲解时使用:

  • Q Head 数可以大于 KV Head 数;
  • KV Head 必须能整除 Query Head;
  • K/V 先按较少的 head 存储;
  • 计算 attention 前按组展开;
  • cache 里只缓存较少的 KV Head。
import math
import torch
from torch import nn

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    # x: [B, H_kv, T, D]
    if n_rep == 1:
        return x
    batch, num_kv_heads, seq_len, head_dim = x.shape
    x = x[:, :, None, :, :]
    x = x.expand(batch, num_kv_heads, n_rep, seq_len, head_dim)
    return x.reshape(batch, num_kv_heads * n_rep, seq_len, head_dim)

class GroupedQueryAttention(nn.Module):
    def __init__(self,
                 hidden_dim: int,
                 num_q_heads: int,
                 num_kv_heads: int,
                 dropout: float = 0.0,
                 ):
        super().__init__()
        assert hidden_dim % num_q_heads == 0
        assert num_q_heads % num_kv_heads == 0
        self.hidden_dim = hidden_dim
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = hidden_dim // num_q_heads
        self.num_groups = num_q_heads // num_kv_heads

        self.q_proj = nn.Linear(hidden_dim, num_q_heads * self.head_dim)
        self.k_proj = nn.Linear(hidden_dim, num_kv_heads * self.head_dim)
        self.v_proj = nn.Linear(hidden_dim, num_kv_heads * self.head_dim)
        self.o_proj = nn.Linear(num_q_heads * self.head_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def _split_heads(self, x: torch.Tensor, num_heads: int) -> torch.Tensor:
        batch, seq_len, _ = x.shape
        x = x.view(batch, seq_len, num_heads, self.head_dim)
        return x.transpose(1, 2)  # [B, H, T, D]

    def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        batch, heads, seq_len, head_dim = x.shape
        x = x.transpose(1, 2).contiguous()
        return x.view(batch, seq_len, heads * head_dim)

    def forward(self,
                x: torch.Tensor,
                attn_mask: torch.Tensor | None = None,
                past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None,
                use_cache: bool = False,
                ):
        q = self._split_heads(self.q_proj(x), self.num_q_heads)
        k = self._split_heads(self.k_proj(x), self.num_kv_heads)
        v = self._split_heads(self.v_proj(x), self.num_kv_heads)

        if past_key_value is not None:
            past_k, past_v = past_key_value
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)

        present_key_value = (k, v) if use_cache else None

        k_for_attn = repeat_kv(k, self.num_groups)
        v_for_attn = repeat_kv(v, self.num_groups)

        scores = q @ k_for_attn.transpose(-2, -1)
        scores = scores / math.sqrt(self.head_dim)
        if attn_mask is not None:
            scores = scores.masked_fill(attn_mask, float("-inf"))
        weights = torch.softmax(scores, dim=-1)
        weights = self.dropout(weights)

        out = weights @ v_for_attn
        out = self._merge_heads(out)
        out = self.o_proj(out)
        return out, weights, present_key_value

测试一下形状:

x = torch.randn(2, 5, 4096)
gqa = GroupedQueryAttention(
    hidden_dim=4096,
    num_q_heads=32,
    num_kv_heads=8,
)
out, weights, cache = gqa(x, use_cache=True)
print(out.shape)       # [2, 5, 4096]
print(weights.shape)   # [2, 32, 5, 5]
print(cache[0].shape)  # [2, 8, 5, 128]
print(cache[1].shape)  # [2, 8, 5, 128]

关键点在最后两行。注意力权重仍然是 32 个 Query Head:weights.shape == [2, 32, 5, 5],但缓存里只有 8 个 KV Head:cache[0].shape == [2, 8, 5, 128]cache[1].shape == [2, 8, 5, 128]。这就是 GQA 在 KV Cache 上节省显存的直观体现。

6. 用 PyTorch 接口怎么写

PyTorch 的 torch.nn.functional.scaled_dot_product_attention 已经提供了 enable_gqa 参数。

一个最小示例:

import torch
import torch.nn.functional as F

query = torch.randn(2, 32, 5, 128, device="cuda", dtype=torch.float16)
key = torch.randn(2, 8, 5, 128, device="cuda", dtype=torch.float16)
value = torch.randn(2, 8, 5, 128, device="cuda", dtype=torch.float16)

out = F.scaled_dot_product_attention(
    query, key, value,
    is_causal=True,
    enable_gqa=True,
)
print(out.shape)  # [2, 32, 5, 128]

官方文档中有两个重要约束:

number_of_heads_query % number_of_heads_key_value == 0
number_of_heads_key == number_of_heads_value

也就是说:

  • Query Head 数必须能被 KV Head 数整除;
  • Key Head 数与 Value Head 数必须相同;
  • enable_gqa 目前仍是实验特性,后端支持和张量类型有限制。

还有一个容易踩坑的地方:PyTorch 这个函数中布尔 attn_mask 的语义与某些 MHA 接口中的 padding mask 语义相反。scaled_dot_product_attentionTrue 表示参与 attention,迁移代码时需要小心。

7. 为什么 GQA 主要影响推理

如果只做一次完整 forward,并且不使用 KV Cache,GQA 对峰值显存的影响不如在 KV Cache 场景下那么明显。

真正的收益集中在自回归 decode 阶段:

每一步都要读取历史 K/V
历史越长,读取量越大
并发越高,缓存越多
KV Head 越少,缓存越小

Hugging Face 的优化文档也提到,减少 KV 向量的数量,只有在使用 KV Cache 的自回归解码场景中才特别有意义,因为 decode 阶段会反复读取历史 K/V,内存带宽很容易成为瓶颈。

所以可以这样理解:

场景GQA 价值
训练全序列并行不是主要优化目标
Prefill可以减少写入 cache 的 K/V 体积
Decode最关键,减少每步读取的历史 K/V 量
长上下文服务价值更明显
高并发服务价值更明显

这也是为什么讲解 GQA 时,不能只盯着 attention 公式,要把它放回到推理服务的 KV Cache 场景中去看。

8. 和 vLLM、PagedAttention 有什么关系

GQA 解决的是:单个 token 的 K/V 体积更小。PagedAttention 解决的是:大量 token 的 K/V 如何更高效地组织和管理。二者不属于同一层优化,但会共同影响推理效率。

vLLM 的 PagedAttention 文档中提到,key/value cache 会被拆分成 block,每个 block 存储固定数量 token 的 cache。这样做的目标是用更适合服务端调度的方式来管理 KV Cache,而不是把每个请求都当作一大段连续显存处理。

可以将它们放在同一张图中:

GQA:减少每个 token 的 KV 体积
PagedAttention:管理大量 token 的 KV 存放方式
Quantized Cache:降低每个元素的字节数
Offloaded Cache:将部分 cache 放到 CPU

如果只看单次模型结构,GQA 像是 attention 结构的变化。但如果从推理系统角度看,GQA 是 KV Cache 成本控制中的一个重要环节。

9. 常见坑

坑 1:只改 num_kv_heads,忘了改投影层输出维度

GQA 中 Q/K/V 的 projection 输出维度不同:

q_proj -> num_q_heads * head_dim
k_proj -> num_kv_heads * head_dim
v_proj -> num_kv_heads * head_dim

如果仍然将 K/V 投影到 num_q_heads * head_dim,cache 就没有节省下来。

坑 2:num_q_heads 不能整除 num_kv_heads

GQA 要按组共享 K/V,所以通常要求:num_q_heads % num_kv_heads == 0,否则每组 Query Head 无法均匀映射到 KV Head。

坑 3:把 repeat 后的 K/V 当作 cache 存储

教学代码为了便于理解,会在 attention 前做 repeat_kv。但 cache 里应该保留较少的 KV Head:cache_k.shape == [B, H_kv, T, D]。如果将展开后的 K/V 存进去:cache_k.shape == [B, H_q, T, D],显存占用又回到了 MHA 级别。

坑 4:只计算 cache 容量,不关注内存带宽

KV Cache 不仅占用显存。Decode 每一步都需要读取历史 K/V,所以内存带宽也会成为瓶颈。GQA 的价值不仅仅是少存,也包括少读。

坑 5:把 GQA 当作无损替换

GQA 是效果和效率之间的折中方案。GQA 原论文的结论是,GQA 相比 MQA 更能保留 MHA 的质量,同时接近 MQA 的速度收益。但具体效果仍然取决于模型结构、训练方式、上采样策略和任务类型。工程上不要把结构变化理解成“免费优化”——它通常是在模型设计或训练阶段就确定好的。

10. 面试怎么讲

如果面试官问:“GQA 和 MHA 有什么区别?”

可以这样回答:GQA 的核心区别在于,Query Head 的数量多于 Key/Value Head 的数量,多个 Query Head 会共享同一组 K/V。而 MHA 里每个 Query Head 都有独立的 K/V。

如果继续问:“为什么能省显存?”

可以接着答:因为 KV Cache 的大小与 num_kv_heads 直接成正比。在相同的 Query Head 数量和序列长度下,GQA 只需缓存更少的 K/V Head,因此显存占用更小。

如果问:“GQA、MQA 怎么区分?”

可以答:MQA 是所有 Query Head 共享一个 KV Head,极端节省显存但可能损失效果;GQA 是折中方案,将 Query Head 分成若干组,每组共享一个 KV Head。

如果问:“代码里最容易错在哪里?”

可以答:最容易错的是投影层的输出维度改错,以及 cache 里无意中存储了展开后的 K/V。核心约束是 num_q_heads % num_kv_heads == 0

11. 一张速记表

问题关键回答
GQA 改了什么?Query Head 多,KV Head 少
为什么能省显存?KV Cache 大小与 num_kv_heads 成正比
MHA 的 KV Head 数?通常等于 Query Head 数
MQA 的 KV Head 数?1 个
GQA 的 KV Head 数?介于 1 和 Query Head 数之间
代码核心约束?num_q_heads % num_kv_heads == 0
cache 里存什么?未展开的 K/V,形状是 [B, H_kv, T, D]
attention 前做什么?把 K/V 按组映射到 Query Head
最适合讲的场景?长上下文、自回归 decode、高并发推理
PyTorch 接口?scaled_dot_product_attention(..., enable_gqa=True)

总结

GQA 可以用三句话记住:

  1. MHA 中每个 Query Head 通常有自己的 K/V,KV Cache 按 Query Head 数增长。
  2. GQA 让一组 Query Head 共享较少的 K/V Head,KV Cache 按 KV Head 数增长。
  3. 它的主要价值体现在自回归推理,尤其是长上下文和高并发服务场景中。

所以,学习 GQA 不要只记住一个缩写。真正要记住的是这条线索:

MHA 张量形状 -> KV Cache 显存公式 -> KV Head 数量 -> Decode 访存压力 -> GQA

这条线索讲清楚了,GQA、MQA、KV Cache、长上下文推理优化就能串联起来。

参考资料

  • Joshua Ainslie et al.:GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
    arxiv.org/abs/2305.13…
  • PyTorch:torch.nn.functional.scaled_dot_product_attention
    docs.pytorch.org/docs/main/g…
  • Hugging Face Transformers:Caching
    huggingface.co/docs/transf…
  • Hugging Face Transformers:KV cache strategies
    huggingface.co/docs/transf…
  • Hugging Face Transformers:Optimizing LLMs for Speed and Memory
    huggingface.co/docs/transf…
  • vLLM:Paged Attention
    docs.vllm.ai/en/latest/d…
来源:https://juejin.cn/post/7644783891102580778
上一篇AI导航网站产品榜怎么样 下一篇VocalRemover.co在线人声分离工具
本站内容用于信息整理与展示,如有侵权或内容问题请及时联系处理。

相关推荐

补充同频道和同主题内容,方便继续浏览更多相关内容。

同类最新

继续查看同栏目最近更新的文章。

更多
2026实测解析GPT-5.5模型能力详解与国内合规使用规范
AI教程 · 2026-06-03

2026实测解析GPT-5.5模型能力详解与国内合规使用规范

2026年,AI大模型迎来了又一次迭代升级。GPT-5 5凭借在多模态精细化处理能力上的跨越式突破,正逐步成为职场办公、内容创作、代码开发以及数据优化等领域的核心生产力工具。然而,对国内多数用户而言,当前仍面临不少现实难题:渠道杂乱、合规边界模糊、账号频繁被封、数据泄露风险——各类非正规镜像站、共享

分时操作系统和实时操作系统的主要区别
AI教程 · 2026-06-03

分时操作系统和实时操作系统的主要区别

分时操作系统和实时操作系统区别 ?️ 操作系统家族里,有两类系统经常被放在一起比较:分时操作系统和实时操作系统。它们虽然都叫“操作系统”,但设计哲学、工作机制和应用场景可以说是天差地别。一个追求“公平共享”,一个追求“确定性响应”。这篇文章打算从定义、核心机制、调度策略、实际应用等维度,把这两者的本

企业AI智能体从零搭建实战踩坑经验全记录
AI教程 · 2026-06-03

企业AI智能体从零搭建实战踩坑经验全记录

去年开始用腾讯云智能体开发平台(ADP)跑了几个企业项目,从最基础的客服Bot一路干到多Agent协同系统,中间踩的坑不少,但积累下来的经验价值也相当可观。这篇文章就聊聊实际落地过程里的那些关键节点和教训,给同样在腾讯云上折腾AI Agent的朋友做个参考。为什么选腾讯云ADP而不是从零搭建做第一个

Selenium自动化测试入门:从环境搭建到首个可维护用例
AI教程 · 2026-06-03

Selenium自动化测试入门:从环境搭建到首个可维护用例

Selenium 入门的核心不在于记住多少 API,而在于把三件事想清楚:环境别装错版本、等待机制别用 sleep、用例结构别写成流水账。下面按照“装环境 → 跑通第一个脚本 → 理解等待 → 选对定位器 → 拆成 Page Object”的顺序走一遍,每一步都附上代码,踩过的坑直接标出来。 Sel

专业表格魔法师 QoderWork CN 让脏数据秒变仪表盘神器
AI教程 · 2026-06-03

专业表格魔法师 QoderWork CN 让脏数据秒变仪表盘神器

使用案例 今天聊聊怎么用阿里巴巴的 QoderWork CN 桌面应用智能体,把 Excel 里那堆乱糟糟的原始数据清洗干净,再做成可视化的看板。整个过程基本不需要写代码,全靠自然语言对话就能搞定。下面就用一个实际案例,把操作步骤拆开来讲。 步骤一:安装并注册 QoderWork CN 账号 先到