欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 新闻 > 会展 > 如何计算kv cache的缓存大小

如何计算kv cache的缓存大小

2024/12/22 17:10:29 来源:https://blog.csdn.net/2301_79093491/article/details/144196901  浏览:    关键词:如何计算kv cache的缓存大小

符号定义

首先,定义一些符号:

( B ):批大小(Batch Size)
( L ):序列长度(Sequence Length),在您的问题中,( L = 1 )
( N ):Transformer 层数(Number of Transformer Layers)
( H ):注意力头数(Number of Attention Heads)
( D ):每个注意力头的维度(Dimension per Head),即 ( D = Hidden Size / H D = \text{Hidden Size} / H D=Hidden Size/H)
( S ):数据类型大小(Size of Data Type),以字节为单位。例如:
FP32(32位浮点数):( S = 4 ) 字节
FP16(16位浮点数):( S = 2 ) 字节

KV 缓存的内存计算

对于每一层的多头注意力机制,我们需要存储 键(Key)值(Value) 的缓存。对于每一层,键和值的缓存大小计算如下:

键缓存(Key Cache)大小:

Size Key = B × H × L × D × S \text{Size}_{\text{Key}} = B \times H \times L \times D \times S SizeKey=B×H×L×D×S

值缓存(Value Cache)大小:

Size Value = B × H × L × D × S \text{Size}_{\text{Value}} = B \times H \times L \times D \times S SizeValue=B×H×L×D×S

因此,每一层的 KV 缓存总大小为:

SizeKV per layer = SizeKey + Size Value = 2 × B × H × L × D × S \text{Size}{\text{KV per layer}} = \text{Size}{\text{Key}} + \text{Size}_{\text{Value}} = 2 \times B \times H \times L \times D \times S SizeKV per layer=SizeKey+SizeValue=2×B×H×L×D×S

由于模型有 ( N ) 层,因此 总的 KV 缓存大小为:

Total SizeKV = N × SizeKV per layer = 2 × B × H × L × D × N × S \text{Total Size}{\text{KV}} = N \times \text{Size}{\text{KV per layer}} = 2 \times B \times H \times L \times D \times N \times S Total SizeKV=N×SizeKV per layer=2×B×H×L×D×N×S

具体示例计算

假设以下参数:

批大小:( B = 1 )

序列长度:( L = 1 ) (即 token 数为 1)

层数:( N = 12 ) (例如,一个小型的 Transformer)

隐藏层尺寸:( Hidden Size = 768 \text{Hidden Size} = 768 Hidden Size=768 )

注意力头数:( H = 12 )

每个头的维度:

D = Hidden Size H = 768 12 = 64 D = \frac{\text{Hidden Size}}{H} = \frac{768}{12} = 64 D=HHidden Size=12768=64

数据类型:FP32,( S = 4 ) 字节

现在,我们计算每一层的 KV 缓存大小:

Size KV per layer = 2 × B × H × L × D × S = 2 × 1 × 12 × 1 × 64 × 4 = 2 × 1 × 12 × 1 × 64 × 4 = 2 × 12 × 64 × 4 = 2 × 12 × 64 × 4 = 6144 字节 \begin{align*} \text{Size}_{\text{KV per layer}} &= 2 \times B \times H \times L \times D \times S \ &= 2 \times 1 \times 12 \times 1 \times 64 \times 4 \ &= 2 \times 1 \times 12 \times 1 \times 64 \times 4 \ &= 2 \times 12 \times 64 \times 4 \ &= 2 \times 12 \times 64 \times 4 \ &= 6144\ \text{字节} \end{align*} SizeKV per layer=2×B×H×L×D×S =2×1×12×1×64×4 =2×1×12×1×64×4 =2×12×64×4 =2×12×64×4 =6144 字节

总的 KV 缓存大小:

Total SizeKV = N × SizeKV per layer  = 12 × 6144 = 73728 字节 \begin{align*} \text{Total Size}{\text{KV}} &= N \times \text{Size}{\text{KV per layer}} \ &= 12 \times 6144 \ &= 73728\ \text{字节} \end{align*} Total SizeKV=N×SizeKV per layer =12×6144 =73728 字节

即大约 72 KB。

虽然这个数字看起来不大,但在大型模型中,参数会显著增大。例如,考虑一个具有以下参数的大型模型:

层数:( N = 96 )

隐藏层尺寸:( Hidden Size = 12288 \text{Hidden Size} = 12288 Hidden Size=12288)

注意力头数:( H = 96 )

每个头的维度:

D = 12288 96 = 128 D = \frac{12288}{96} = 128 D=9612288=128

数据类型:FP16,( S = 2 ) 字节

计算每一层的 KV 缓存大小:

Size KV per layer = 2 × B × H × L × D × S = 2 × 1 × 96 × 1 × 128 × 2 = 2 × 96 × 128 × 2 = 2 × 96 × 128 × 2 = 49 , 152 字节 \begin{align*} \text{Size}_{\text{KV per layer}} &= 2 \times B \times H \times L \times D \times S \ &= 2 \times 1 \times 96 \times 1 \times 128 \times 2 \ &= 2 \times 96 \times 128 \times 2 \ &= 2 \times 96 \times 128 \times 2 \ &= 49,152\ \text{字节} \end{align*} SizeKV per layer=2×B×H×L×D×S =2×1×96×1×128×2 =2×96×128×2 =2×96×128×2 =49,152 字节

总的 KV 缓存大小:

Total SizeKV = N × SizeKV per layer  = 96 × 49 , 152 = 4 , 719 , 616 字节 \begin{align*} \text{Total Size}{\text{KV}} &= N \times \text{Size}{\text{KV per layer}} \ &= 96 \times 49,152 \ &= 4,719,616\ \text{字节} \end{align*} Total SizeKV=N×SizeKV per layer =96×49,152 =4,719,616 字节

即大约 4.5 MB。

注意事项

模型规模的影响: 可以看到,随着 层数 ( N )、隐藏层尺寸 和 注意力头数 ( H ) 的增加,KV 缓存的内存需求会显著增长。

序列长度的影响: 虽然在 ( L = 1 ) 时,序列长度对内存影响较小,但在生成长序列时,( L ) 会增加,导致 KV 缓存内存占用线性增长。

数据类型的影响: 使用 FP16 可以将内存占用减少一半,但对于大型模型,内存需求仍然很高。

总结

即使 token 数为 1,由于模型的层数、注意力头数、每个头的维度等参数较大,KV 缓存仍然需要消耗较大的内存。

通过以上公式,可以直观地看到各个参数对 KV 缓存内存占用的影响,从而理解为什么在处理单个 token 时仍需要大的内存。

优化建议

减少模型规模: 降低 ( N )、( H ) 或 ( D ) 的值。
使用半精度: 采用 FP16 或更低精度的数据类型。
批量大小优化: 确保 ( B ) 仅为需要的最小值。
序列长度控制: 在可能的情况下,限制生成序列的最大长度 ( L )。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com