欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 科技 > 能源 > 【拜读】Tensor Product Attention Is All You Need姚期智团队开源TPA兼容RoPE位置编码

【拜读】Tensor Product Attention Is All You Need姚期智团队开源TPA兼容RoPE位置编码

2025/2/22 14:58:31 来源:https://blog.csdn.net/qq_36603091/article/details/145777452  浏览:    关键词:【拜读】Tensor Product Attention Is All You Need姚期智团队开源TPA兼容RoPE位置编码

在这里插入图片描述
姚期智团队开源新型注意力:张量积注意力(Tensor Product Attention,TPA)。有点像一种「动态的LoRA」,核心思路在于利用张量分解来压缩注意力机制中的 Q、K、V 表示,同时保留上下文信息,减少内存开销。另外巧妙地兼容了RoPE,论文中还证明了流行的MHA、MQA、GQA都是TPA的特殊情况,用一个框架统一了现代注意力设计,解决MLA压缩了KV缓存但与RoPE位置编码不兼容的问题。
张量积注意力(Tensor Product Attention, TPA),用于解决语言模型在处理长序列时的内存开销问题。通过上下文张量分解来表示查询、键和值,从而在推理时显著减少了KV缓存的大小。实验结果表明,TPA在保持模型性能的同时,显著降低了内存开销,能够处理更长的序列上下文。此外,TPA与旋转位置嵌入(RoPE)兼容,便于在现代大型语言模型架构中应用。总体而言,TPA提供了一种灵活且内存高效的替代方案,推动了现代语言模型的可扩展性。

核心机制总结

具体来说,

  1. 张量分解:首先,TPA使用张量分解来表示查询(Q)、键(K)和值(V),从而在推理时显著减少KV缓存的大小。通过将表示分解为上下文低秩分量(contextual factorization),TPA实现了比标准多头注意力(MHA)低一个数量级的内存开销,同时降低了预训练验证损失(困惑度)并提高了下游性能。

    • 张量积投影:通过两组线性层(A 投影和 B 投影)对输入进行变换,这是张量积注意力的核心操作之一。
    • 缓存机制:使用缓存来存储 K 和 V 的值,以便在后续计算中使用,这有助于提高计算效率。
    • 旋转位置嵌入:应用旋转位置嵌入来处理序列中的位置信息。
    • 注意力计算:通过矩阵乘法计算注意力分数,并使用 softmax 进行归一化,最后得到输出。
      在这里插入图片描述
  2. 与RoPE的兼容性:TPA与旋转位置嵌入(RoPE)天然兼容,可以直接替代多头注意力(MHA)层,便于在现代大型语言模型架构(如LLaMA和Gemma)中应用。

  3. 公式描述:具体来说,TPA的查询、键和值的分解公式如下:

Q t = 1 R Q ∑ r = 1 R Q a r Q ( x t ) ⊗ b r Q ( x t ) Q_{t}=\frac{1}{R_{Q}}\sum_{r=1}^{R_{Q}} a_{r}^{Q}\left(x_{t}\right)\otimes b_{r}^{Q}\left(x_{t}\right) Qt=RQ1r=1RQarQ(xt)brQ(xt)

K t = 1 R K ∑ r = 1 R K a r K ( x t ) ⊗ b r K ( x t ) K_{t}=\frac{1}{R_{K}}\sum_{r=1}^{R_{K}} a_{r}^{K}\left(x_{t}\right)\otimes b_{r}^{K}\left(x_{t}\right) Kt=RK1r=1RKarK(xt)brK(xt)

V t = 1 R V ∑ r = 1 R V a r V ( x t ) ⊗ b r V ( x t ) V_{t}=\frac{1}{R_{V}}\sum_{r=1}^{R_{V}} a_{r}^{V}\left(x_{t}\right)\otimes b_{r}^{V}\left(x_{t}\right) Vt=RV1r=1RVarV(xt)brV(xt)

其中, a r Q ( x t ) a_{r}^{Q}(x_{t}) arQ(xt), b r Q ( x t ) b_{r}^{Q}(x_{t}) brQ(xt)是查询的因子, a r K ( x t ) a_{r}^{K}(x_{t}) arK(xt) b r K ( x t ) b_{r}^{K}(x_{t}) brK(xt) 是键的因子, a r V ( x t ) a_{r}^{V}(x_{t}) arV(xt) b r V ( x t ) b_{r}^{V}(x_{t}) brV(xt) 是值的因子。

实验设计

  1. 数据集:实验在FineWeb-Edu 100B数据集上进行,该数据集包含1000亿个训练令牌和10亿个验证令牌。
  2. 模型对比:实验中将T6与基线Llama架构(使用SwiGLU激活和RoPE嵌入)以及Llama变体(将多头注意力替换为多查询注意力MQA、分组查询注意力GQA或多头潜在注意力MLA)进行对比。
  3. 训练设置:实验采用nanoGPT训练配置,使用AdamW优化器,学习率由余弦退火调度器管理,训练阶段分别为2000步预热和全局批量大小为480。

结果与分析

  1. 训练和验证曲线:中等规模(353M)、大规模(773M)和超大规模(1.5B)模型的训练和验证损失曲线显示,TPA及其简化变体TPA-KVonly的收敛速度与基线MHA、MQA、GQA和MLA相当或更快,并且在整个训练过程中保持了较低的验证损失。在这里插入图片描述
    he training loss, validation loss, and validation perplexity of medium-size (353M) models(learning rate 3 × 10−4) and different attention mechanisms on the FineWeb-Edu 100B datase

  2. 验证困惑度:中等规模和大规模模型的验证困惑度曲线显示,TPA和TPA-KVonly在大多数配置下在整个训练过程中保持了较低的困惑度。预训练结束时,TPA基线的困惑度最低。在这里插入图片描述

  3. 下游评估:在标准基准上的零样本和两样本评估结果显示,中等规模模型中,TPA的平均准确率为51.41%,高于MHA的50.11%、MQA的50.44%和MLA的50.13%。大规模模型中,TPA-KVonly的平均准确率为53.52%,而超大规模模型中,TPA-KVonly的平均准确率为55.03%。

优点与创新

  1. 显著的内存效率提升:通过张量分解表示查询、键和值,显著减少了推理时的KV缓存大小,相比标准多头注意力机制(MHA)提升了10倍以上。
  2. 模型性能提升:在预训练验证损失(困惑度)和下游评估性能方面均优于现有的多头注意力、多查询注意力、分组查询注意力和多头潜在注意力等方法。
  3. 与RoPE的兼容性:TPA天然兼容旋转位置嵌入(RoPE),可以直接替代多头注意力层,便于在现代大型语言模型架构(如LLaMA和Gemma)中应用。
  4. 统一的注意力机制框架:揭示了多头注意力、多查询注意力和分组查询注意力都可以作为非上下文变体的TPA自然出现。
  5. 灵活的变体:TPA的变体包括仅分解键/值或跨标记共享基向量,展示了在平衡内存成本、计算开销和表示能力方面的多样性。

关键问题及回答

问题1:张量积注意力(TPA)是如何通过张量分解来表示查询(Q)、键(K)和值(V)的?

张量积注意力(TPA)通过将查询(Q)、键(K)和值(V)分解为多个低秩张量的和来表示。具体来说,每个头的查询、键和值被分解为多个低秩张量的和:

Q t = 1 R Q ∑ r = 1 R Q a r Q ( x t ) ⊗ b r Q ( x t ) Q_{t}=\frac{1}{R_{Q}}\sum_{r=1}^{R_{Q}} a_{r}^{Q}\left(x_{t}\right)\otimes b_{r}^{Q}\left(x_{t}\right) Qt=RQ1r=1RQarQ(xt)brQ(xt)

K t = 1 R K ∑ r = 1 R K a r K ( x t ) ⊗ b r K ( x t ) K_{t}=\frac{1}{R_{K}}\sum_{r=1}^{R_{K}} a_{r}^{K}\left(x_{t}\right)\otimes b_{r}^{K}\left(x_{t}\right) Kt=RK1r=1RKarK(xt)brK(xt)

V t = 1 R V ∑ r = 1 R V a r V ( x t ) ⊗ b r V ( x t ) V_{t}=\frac{1}{R_{V}}\sum_{r=1}^{R_{V}} a_{r}^{V}\left(x_{t}\right)\otimes b_{r}^{V}\left(x_{t}\right) Vt=RV1r=1RVarV(xt)brV(xt)

其中, a r Q a_{r}^{Q} arQ, a r K a_{r}^{K} arK, a r V a_{r}^{V} arV b r Q b_{r}^{Q} brQ, b r K b_{r}^{K} brK, b r V b_{r}^{V} brV 是可学习的参数矩阵, x t x_{t} xt 是第t个标记的隐藏状态向量。通过这种张量分解,TPA能够显著减少KV缓存的大小,同时提高表示能力。

问题2:张量积注意力(TPA)与旋转位置嵌入(RoPE)的兼容性如何?

张量积注意力(TPA)与旋转位置嵌入(RoPE)天然兼容。RoPE是一种用于编码位置信息的编码方式,能够在保持相对位置关系的同时进行旋转。TPA可以直接替换多头注意力(MHA)层,便于在现代大型语言模型(如LLaMA和Gemma)中应用。具体来说,RoPE可以通过以下公式进行预旋转:

B ~ K ( x t ) ⟵ RoPE ⁡ t ( B K ( x t ) ) \widetilde{B}_{K}\left(x_{t}\right)\longleftarrow\operatorname{RoPE}_{t}\left(B_{K}\left(x_{t}\right)\right) B K(xt)RoPEt(BK(xt))

这样,每个键在缓存之前就已经旋转,从而避免了在解码时显式进行旋转操作,加速了自回归推理过程。

问题3:张量积注意力(TPA)在实验中的性能如何?

张量积注意力(TPA)在实验中表现出色。具体来说,在FineWeb-Edu 100B数据集上的中型(353M)、大型(773M)和XL(1.5B)模型的训练和验证损失曲线显示,TPA及其简化变体TPA-KVonly收敛速度与基线MHA、MQA、GQA和MLA相当或更快,且最终损失更低。验证困惑度曲线也表明,TPA和TPA-KVonly在训练过程中始终优于MHA、MQA、GQA和MLA,并在预训练结束时达到最低的困惑度。

在下游评估中,TPA和TPA-KVonly在中型和大型模型上均表现出色。例如,中型模型在零样本情况下的平均准确率达到51.41%,在两样本情况下的平均准确率达到53.12%。这些结果表明,TPA在各种基准测试中均优于现有的多头注意力、多查询注意力和分组查询注意力机制,解决了语言模型在处理长序列时的内存开销问题。

代码

张量积注意力(TPA)机制的核心代码主要实现在TPA类中,下面对其核心代码进行详细解读。

类定义与初始化

class TPA(nn.Module):def __init__(self, args: ModelArgs):super().__init__()# 若未指定 n_kv_heads,则使用 n_heads 的值self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_headsself.n_heads = args.n_heads# 若 head_dim 大于 0 则使用其值,否则通过计算得到self.head_dim = args.head_dim if args.head_dim > 0 else args.dim // args.n_headsself.n_head = args.n_headsself.q_rank = args.q_rankself.rank = args.rankself.dim = args.dimself.using_groupnorm = args.using_groupnorm# 定义 A 投影的线性层,用于 Q、K、Vself.W_A_q = nn.Linear(args.dim, self.n_head * self.q_rank, bias=False)self.W_A_k = nn.Linear(args.dim, self.n_head * self.rank, bias=False)self.W_A_v = nn.Linear(args.dim, self.n_head * self.rank, bias=False)# 定义 B 投影的线性层,用于 Q、K、Vself.W_B_q = nn.Linear(args.dim, self.q_rank * self.head_dim, bias=False)self.W_B_k = nn.Linear(args.dim, self.rank * self.head_dim, bias=False)self.W_B_v = nn.Linear(args.dim, self.rank * self.head_dim, bias=False)# 初始化缓存,用于存储 K 和 V 的值self.cache_kA = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_heads, self.rank,)).cuda()self.cache_vA = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_heads, self.rank,)).cuda()self.cache_kB = torch.zeros((args.max_batch_size, args.max_seq_len, self.rank, self.head_dim,)).cuda()self.cache_vB = torch.zeros((args.max_batch_size, args.max_seq_len, self.rank, self.head_dim,)).cuda()self.reset_parameters()if self.using_groupnorm:self.subln = T6GroupNorm(self.head_dim, eps=1e-5, elementwise_affine=True)

在初始化函数中,首先接收一个ModelArgs类型的参数args,然后设置一些必要的超参数,如头的数量、秩等。接着定义了两组线性层,分别用于 A 投影和 B 投影。同时,还初始化了缓存用于存储 K 和 V 的值,以便在后续计算中使用。最后,调用reset_parameters方法对权重进行初始化,并根据using_groupnorm参数决定是否使用组归一化。

权重初始化

    def reset_parameters(self, args):# 将 W_A_q 的权重进行变形,然后使用 Xavier 均匀初始化W_A_q_tensor = self.W_A_q.weight.view(self.dim, self.n_head, self.q_rank)nn.init.xavier_uniform_(W_A_q_tensor)self.W_A_q.weight.data = W_A_q_tensor.view_as(self.W_A_q.weight)# 对 W_A_k 和 W_A_v 做同样的操作W_A_k_tensor = self.W_A_k.weight.view(self.dim, self.n_head, self.rank)nn.init.xavier_uniform_(W_A_k_tensor)self.W_A_k.weight.data = W_A_k_tensor.view_as(self.W_A_k.weight)W_A_v_tensor = self.W_A_v.weight.view(self.dim, self.n_head, self.rank)nn.init.xavier_uniform_(W_A_v_tensor)self.W_A_v.weight.data = W_A_v_tensor.view_as(self.W_A_v.weight)# 对 B 投影的权重做同样的操作W_B_q_tensor = self.W_B_q.weight.view(self.dim, self.q_rank, self.head_dim)nn.init.xavier_uniform_(W_B_q_tensor)self.W_B_q.weight.data = W_B_q_tensor.view_as(self.W_B_q.weight)W_B_k_tensor = self.W_B_k.weight.view(self.dim, self.rank, self.head_dim)nn.init.xavier_uniform_(W_B_k_tensor)self.W_B_k.weight.data = W_B_k_tensor.view_as(self.W_B_k.weight)W_B_v_tensor = self.W_B_v.weight.view(self.dim, self.rank, self.head_dim)nn.init.xavier_uniform_(W_B_v_tensor)self.W_B_v.weight.data = W_B_v_tensor.view_as(self.W_B_v.weight)

reset_parameters方法用于对线性层的权重进行初始化,采用 Xavier 均匀初始化方法,这有助于提高模型的训练稳定性。

前向传播

    def forward(self,x: torch.Tensor,start_pos: int,freqs_cis: torch.Tensor,mask: Optional[torch.Tensor],):bsz, seqlen, _ = x.shape# 计算 A 投影的 Q、K、VA_q = self.W_A_q(x).view(bsz, seqlen, self.n_head, self.q_rank)A_k = self.W_A_k(x).view(bsz, seqlen, self.n_head, self.rank)A_v = self.W_A_v(x).view(bsz, seqlen, self.n_head, self.rank)# 计算 B 投影的 Q、K、VB_q = self.W_B_q(x).view(bsz, seqlen, self.q_rank, self.head_dim)B_k = self.W_B_k(x).view(bsz, seqlen, self.rank, self.head_dim)B_v = self.W_B_v(x).view(bsz, seqlen, self.rank, self.head_dim)# 缓存 A_k 和 A_vself.cache_kA = self.cache_kA.to(A_k)self.cache_vA = self.cache_vA.to(A_v)self.cache_kA[:bsz, start_pos : start_pos + seqlen] = A_kself.cache_vA[:bsz, start_pos : start_pos + seqlen] = A_vA_k = self.cache_kA[:bsz, : start_pos + seqlen]A_v = self.cache_vA[:bsz, : start_pos + seqlen]# 缓存 B_k 和 B_vself.cache_kB = self.cache_kB.to(B_k)self.cache_vB = self.cache_vB.to(B_v)self.cache_kB[:bsz, start_pos : start_pos + seqlen] = B_kself.cache_vB[:bsz, start_pos : start_pos + seqlen] = B_vB_k = self.cache_kB[:bsz, : start_pos + seqlen]B_v = self.cache_vB[:bsz, : start_pos + seqlen]# 重塑 A_q、A_k、A_vA_q = A_q.view(bsz * seqlen, self.n_head, self.q_rank)A_k = A_k.view(bsz * seqlen, self.n_head, self.rank)A_v = A_v.view(bsz * seqlen, self.n_head, self.rank)# 重塑 B_q、B_k、B_vB_q = B_q.view(bsz * seqlen, self.q_rank, self.head_dim)B_k = B_k.view(bsz * seqlen, self.rank, self.head_dim)B_v = B_v.view(bsz * seqlen, self.rank, self.head_dim)# 计算 q、k、vq = torch.bmm(A_q, B_q).div_(self.q_rank).view(bsz, seqlen, self.n_head, self.head_dim)k = torch.bmm(A_k, B_k).div_(self.rank).view(bsz, seqlen, self.n_head, self.head_dim)v = torch.bmm(A_v, B_v).div_(self.rank).view(bsz, seqlen, self.n_head, self.head_dim)# 应用旋转位置嵌入q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)# 计算注意力分数k = k.transpose(1, 2) scores = torch.matmul(q.transpose(1, 2), k.transpose(2, 3)) / math.sqrt(self.head_dim)if mask is not None:scores = scores + mask  scores = F.softmax(scores.float(), dim=-1).type_as(q)# 计算输出output = torch.matmul(scores, v.transpose(1, 2))  output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)return self.wo(output)

在前向传播函数中,首先获取输入的形状。然后分别计算 A 投影和 B 投影的 Q、K、V,并对其进行缓存。接着对 A 和 B 投影的结果进行重塑,通过矩阵乘法计算最终的 q、k、v。之后应用旋转位置嵌入,计算注意力分数并进行归一化,最后通过矩阵乘法得到输出。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词