4. KV Cache

介绍 KV Cache 在 LLM 推理中的作用与优化

什么是 KV Cache

  • KV Cache 是大语言模型推理阶段的关键优化技术,通过存储注意力机制中生成的中间键(K)和值(V)向量,避免后续生成步骤中重复计算

  • KV Cache 用于提升文本生成效率,仅适用于推理阶段,无法在训练中使用

  • LLM 生成文本时采用 “逐 token 生成” 模式

    • 无 KV Cache 时,每生成一个新 token,模型需重新计算整个序列(含历史 token)的 K 和 V 向量,存在大量冗余计算

    • 例如生成 “Time flies fast” 时,生成 “fast” 需重新计算 “Time”“flies” 的 KV 向量

    • KV Cache 可存储历史 KV 向量,新 token 生成仅需计算当前 token 的 KV,再结合缓存复用历史数据,避免冗余计算历史 token 的 KV 向量

KV 计算分析

  • 下图展示了在注意力计算过程中如何从 token embeddings 中计算得到 KV 向量

    KV Calculate

  • 每个输入的 token(例如,“Time”和“flies”)通过学习矩阵 $W_k$ 和 $W_v$​ 进行投影,以获得其相应的 KV 向量

  • 大语言模型一次生成一个词(或者 token),假设大语言模型生成了“fast”这个词,那么下一轮的 Prompt 就变成了“Time flies fast”

    KV Calculate Duplicate

  • 可以看出,当处理 “Time flies fast” 时,“Time” 和 “flies” 两个 token 的 KV 向量是完全重复计算的

  • 因此,KV Cache 的理念是实现一种缓存机制,用于存储之前生成的 KV 向量以供重复使用,从而避免这些不必要的重新计算

KV Cache 的计算流程

KV Cache 有无对比

  • 无 KV Cache 的计算

    Generation Step
    Input Tokens
    Computed KV

    1

    "Time"

    "Time"

    2

    "Time flies"

    "Time", "flies"

    3

    "Time flies fast"

    "Time", "flies", "fast"

  • 有 KV Cache 的计算

    Generation Step
    Input Tokens
    Computed KV
    Cached KV

    1

    "Time"

    "Time"

    -

    2

    "Time flies"

    "flies"

    "Time"

    3

    "Time flies fast"

    "fast"

    "Time", "flies"

KV Cache 的代码实现

(1)注册缓存缓冲区(MultiHeadAttention类构造函数)

  • 在多头注意力类中添加 cache_kcache_v 两个缓冲区,用于存储拼接后的 KV 向量:

(2)带use_cache参数的前向传播(MultiHeadAttention.forward)

  • 扩展 forward 方法,根据 use_cache 标志决定是否使用缓存,核心逻辑为“初始化缓存→拼接新 KV →检索缓存”:

(3)缓存重置(MultiHeadAttention.reset_cache)

  • 避免不同文本生成任务间的缓存污染(新 prompt 使用旧缓存会导致输出混乱),因此要新增重置缓存的方法:

(4)全模型 use_cache 参数传播(GPTModel 类)

  • 新增 current_pos 跟踪已缓存 token 数量,确保新 token 的位置索引连续:

  • 修改前向方法,传递use_cache参数到每个 Transformer 块,并更新位置索引:

  • 模型级缓存重置(批量重置所有 Transformer 块的缓存):

(5)带缓存的文本生成函数(generate_text_simple_cached)

  • 仅向模型输入新token(而非完整序列),结合缓存生成文本:

KV Cache 的优化

  • 优点

    • 计算效率显著提升:复杂度从 $O(n^2)$ 降至 $O (n)$,序列越长收益越明显

    • 推理速度快:减少重复计算,尤其适合长文本生成场景(如对话、文档生成)

  • 缺点

    • 内存占用线性增长:每新增一个 $token$,缓存大小增加,长序列或大模型可能耗尽 $GPU$ 内存

    • 代码复杂度提高:需处理缓存初始化、重置、位置跟踪等逻辑,增加实现成本

  • CPU 环境:KV Cache 带来显著速度提升,结合 torch.compile 后效率更高

  • GPU 环境:小模型(0.6B/1B)下,KV Cache 优势不明显,因为设备的传输 / 通信成本抵消了缓存收益,在大模型场景下缓存优势会凸显

  • 优化

    • 预分配内存(解决 torch.cat 的内存碎片问题):避免频繁拼接张量(torch.cat 会反复分配/释放内存),提前按最大序列长度分配缓存空间,确保内存使用的一致性并减少开销

    • 滑动窗口缓存(解决内存线性增长问题):仅保留最近的 window_size 个token的缓存,截断早期 token,平衡内存与上下文相关性

    • 模型外缓存(配合 torch.compile 加速):将 KV Cache 模型外部,便于使用 torch.compile 编译模型,进一步提升计算效率(尤其适用于Qwen3、Llama 3等大模型)

Last updated

Was this helpful?