KV Cache

KV Cache是一种针对Transformer-Decoder部分的注意力层的优化技术,其原理是通过缓存之前生成的KV值,提高模型的推理性能。

什么是 KV Cache

对于包含有Transformer-Decoder部分的模型(可以是Decoder-Encoder,Only Decoder),模型的推理过程如下:

给定一个输入文本,模型会根据输入的文本,开始一个token一个 token 的输出回答,每一个 token 的输出都依赖于之前的输出,直到输出终止 token

如下图例子所示:

img

上述过程的模型,我们也称之为自回归模型

在自回归模型中,因为模型一次只能生成一个token,而且每次新的预测都依赖于之前的上下文。这意味着,要预测第1000个token,你需要用到前999个token的信息。

Attention(Q,K,V)=Softmax(QKTdk)VAttention(Q,K,V)=Softmax(\frac{QK^T}{\sqrt{d_k}})V

这通常涉及到对这些token的表示进行一系列矩阵乘法运算,Transformer原始的计算过程是在每一次计算token的时候,都得计算重新之前token的注意力信息。如下图所示:

img

KV cache就是在这里发挥作用,通过存储之前K , V的计算结果,并在接下来的token生成时复用这些结果,从而避免重复计算,提高推理效率。如下图第二部分所示:

img

需要注意的是,KV cache只在多个token生成步骤中发生,并且仅在decoder部分进行。像BERT这样的encoder only模型,并不是生成式模型,因此不涉及KV cache。

当然,我们也可以用代码测试一下,使用了KV Cache后模型推理速度提升了多少

import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2")
model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2").to(device)

for use_cache in (True, False):
  times = []
  for _ in range(10):  # measuring 10 generations
    start = time.time()
    model.generate(**tokenizer("What is KV caching?", return_tensors="pt").to(device), use_cache=use_cache, max_new_tokens=1000)
    times.append(time.time() - start)
  print(f"{'with' if use_cache else 'without'} KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")

测试结果如下:

with KV caching: 2.088 +- 0.715 seconds
without KV caching: 18.835 +- 12.523 seconds

快了大约 9.02 倍!

为什么Q不用Cache

正常情况下,Transformer结构的Attention的计算公式步骤如下:

Attention(Q,K,V)=softmax([Q1K1TQ2K1TQ2K2T])[V1V2]=[softmaxed(Q1K1T)softmaxed()softmaxed(Q2K1T)softmaxed(Q2k2T)][V1V2]=[softmaxed(Q1K1T)×V1+0×V2softmaxed(Q2K1T)×V1+softmaxed(Q2K2T)×V2]Attention(Q,K,V)=softmax(\left[\begin {array}{c} Q_1K_1^T &-\infty \\ Q_2K_1^T &Q_2K_2^T \\ \end{array}\right])\left[\begin {array}{c} V_1 \\ V_2 \\ \end{array}\right] \\ \\ = \left[\begin {array}{c} softmaxed(Q_1K_1^T) &softmaxed(-\infty) \\ softmaxed(Q_2K_1^T) &softmaxed(Q_2k_2^T) \\ \end{array}\right]\left[\begin {array}{c} V_1 \\ V_2 \\ \end{array}\right] \\ \\ = \left[\begin {array}{c} softmaxed(Q_1K_1^T)\times V_1+ 0\times V_2 \\ softmaxed(Q_2K_1^T)\times V_1+softmaxed(Q_2K_2^T)\times V_2 \\ \end{array}\right] \\

其中-\infty表示加上了masked的注意力,softmaxedsoftmaxed表示进行了softmax计算,那么最后的结果如下:

Attention1=softmaxed(Q1K1T)×V1Attention_1=softmaxed(Q_1K_1^T)\times V_1

Attention2=softmaxed(Q2K1T)×V1+softmaxed(Q2K2T)×V2Attention_2=softmaxed(Q_2K_1^T)\times V_1+softmaxed(Q_2K_2^T)\times V_2

其中Attention1Attention_1表示Attention的第一行,Attention2Attention_2表示Attention的第二行。

因此,在Attention的每一步中,其实只需要根据QkQ_k计算 AttentionkAttention_k 就可以,之前已经计算的Attention完全不需要重新计算。但是 K 和 V 是全程参与计算的,所以这里我们需要把每一步的 K,V 缓存起来。而Q只与当前一步计算有关,所以不需要也不用去缓存Q

KV Cache存在的问题

由于KV Cache思想是以空间换时间,那么显而易见的是KV Cache所带来模型占用显存的增加。

假设输入序列的长度为ss ,输出序列的长度为nn ,transformer层数为ll,隐藏层维度$ hKVCachekv_seq_lenKVvalue,KV Cache 存储 kv\_seq\_len个 KV value,形状为 [b, head_num, kv_seq_lenkv\_seq\_len, head_dim]$,, 峰值kv_seq_len为 s+ns+n,,以float16来保存KV cache,那么KV cache的峰值显存占用大小为 b(s+n)hl22=4blh(s+n)b(s+n)h∗l∗2∗2=4blh(s+n) 。这里第一个 2 表示 K/V cache,第二个2表示float16占 2 个 bytes。

以GPT3-175B为例,对比KV cache与模型参数占用显存的大小。模型配置如下:

模型名 参数量 层数 隐藏层维度 注意力头数
GPT-3 175B 96 12288 96

GPT3 模型占用显存大小为350GB。假设批次大小b=64b=64 ,输入序列长度$s=512 ,输出序列长度n=32 $,则KV cache 峰值占用显存为 4blh(s+n)=164,282,499,072bytes164GB4blh(s+n) = 164,282,499,072 bytes ≈ 164 GB,大约是模型参数显存的0.5倍。

优化KV Cache 的几种方法

MHA

MHA(Multi-Head Attention)也就是Transformer原始的多头注意力机制,其思想是将多个独立的注意力机制进行拼接,假设输入的行向量为[x1,x2,...,xn][x_1,x_2,...,x_n],其中xnx_ndd维。那么MHA的表现形式如下:

On=[o1,o2,...,oh]os=Attention(Qns,Kns,Vns)O_n=[o_1,o_2,...,o_h] \\ o_s=Attention(Q_n^s,K_{\leq n}^s,V_{\leq n}^s)

其中,OnO_n 表示xnx_n多头注意力后的结果,分为了hh个注意力头,每个头独立计算然后进行拼接,shs\in h,这里省略了Attention矩阵的缩放因子,实践上,每个头里的QKVQ、K、V维度一般为d/hd/h,LLAMA2-7b有d=4096d=4096h=32h=32,每个头里QKVQ、K、V的维度为128,LLAMA2-70b则是d=8192d=8192h=64h=64,每个头里QKVQ、K、V的维度为128。

MQA

MQA,即“Multi-Query Attention”,是减少KV Cache的一次非常朴素的尝试,首次提出自《Fast Transformer Decoding: One Write-Head is All You Need》,这已经是2019年的论文了,这也意味着早在LLM火热之前,减少KV Cache就已经是研究人员非常关注的一个课题了。

MQA的思路很简单,直接让所有Attention Head共享同一个K、V。

On=[o1,o2,...,oh]os=Attention(Qns,Kn,Vn)O_n=[o_1,o_2,...,o_h] \\ o_s=Attention(Q_n^s,K_{\leq n},V_{\leq n})

使用 MQA 的模型包括 PaLMGemini 等。很明显,MQA 直接将 KV Cache 减少到了原来的 1/h1/h

效果方面,目前看来大部分任务的损失都比较有限。

GQA

有人担心 MQA 对 KV Cache 的压缩太严重,以至于会影响模型的学习效率以及最终效果。为此,一个 MHA 与 MQA 之间的过渡版本 GQA(Grouped-Query Attention)应运而生,出自 2023 年 Google 的论文 GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

GQA的思想也很朴素,它就是将所有Head分为gg个组(gg可以整除hh),每组共享同一对K、V用数学公式表示为

On=[o1,o2,...,oh]os=Attention(Qns,Knsg/h,Vnsg/h)O_n=[o_1,o_2,...,o_h] \\ o_s=Attention(Q_n^s,K_{\leq n}^{⌈sg/h⌉},V_{\leq n}^{⌈sg/h⌉})

其中这里的是⌈⋅⌉是向上取整符号。GQA 提供了 MHA 到 MQA 的自然过渡。当g=hg=h时就是MHA,g=1g=1时就是MQA,当1<g<h1<g<h时,它只将KV Cache压缩到g/hg/h,压缩率不如MQA,但同时也提供了更大的自由度,效果上更有保证。GQA最知名的使用者,大概是Meta开源的LLAMA2-70B,以及LLAMA3全系列,此外使用GQA的模型还有TigerBotDeepSeek-V1YiChatGLM2ChatGLM3等,相比使用MQA的模型更多

MLA

随着今年DeepSeek的爆火,DeepSeek所提出的MLA,也正是该模型便宜的关键技术之一。

MLA(Multi-head Latent Attention)是对GQA的改进。其核心是对键(Keys)和值(Values)进行低秩联合压缩,生成压缩的潜在向量,以降低推理阶段的KV缓存。并结合了RoPE,提升了模型的外推能力,进一步提升了模型的效果。

  • MLA 通过低秩压缩显著降低了 KV 缓存的大小,在 DeepSeek-V2 中,KV 缓存减少了 93.3%,推理吞吐量提升了 576%。
  • 通过部分 RoPE 保留和低秩近似,MLA 能够高效地将任意基于 MHA/GQA 的大语言模型迁移到 MLA 架构,仅需使用少量预训练数据进行微调。

参考文献