欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 科技 > 能源 > cross attention交叉熵注意力机制

cross attention交叉熵注意力机制

2025/2/23 10:47:53 来源:https://blog.csdn.net/study_jiang_up/article/details/139425094  浏览:    关键词:cross attention交叉熵注意力机制

        交叉注意力(Cross-Attention)则是在两个不同序列上计算注意力,用于处理两个序列之间的语义关系。在两个不同的输入序列之间计算关联度和加权求和的机制。具体来说,给定两个输入序列,cross attention机制将一个序列中的每个元素与另一个序列中的所有元素计算关联度,并根据关联度对两个序列中的每个元素进行加权求和。这样的机制使模型能够建立不同序列之间的关联关系,并将两个序列的信息融合起来。例如,在翻译任务中,需要将源语言句子和目标语言句子进行对齐,就需要使用交叉注意力来计算两个句子之间的注意力权重。

        交叉注意力机制是一种特殊形式的多头注意力,它将输入张量拆分成两个部分 X1\epsilon R^{n*d1}  和 X2\epsilon R^{n*d2},然后将其中一个部分作为查询集合,另一个部分作为键值集合。它的输出是一个大小为n*d2 的张量,对于每个行向量,都给出了它对于所有行向量的注意力权重。

Q=X_{1} W^{Q} 和 K=V=X_{2} W^{K},则交叉注意力的计算如下:

\operatorname{CrossAttention}\left(X_{1}, X_{2}\right)=\operatorname{Softmax}\left(\frac{Q K^{T}}{\sqrt{d_{2}}}\right) V

import torch
import torch.nn as nn
import torch.nn.functional as Fclass CrossAttention(nn.Module):def __init__(self, embed_dim, hidden_dim, num_heads):super(CrossAttention, self).__init__()self.embed_dim = embed_dimself.hidden_dim = hidden_dimself.num_heads = num_headsself.query_proj = nn.Linear(embed_dim, hidden_dim * num_heads)self.key_proj = nn.Linear(embed_dim, hidden_dim * num_heads)self.value_proj = nn.Linear(embed_dim, hidden_dim * num_heads)self.out_proj = nn.Linear(hidden_dim * num_heads, embed_dim)def forward(self, query, context):"""query: (batch_size, query_len, embed_dim)context: (batch_size, context_len, embed_dim)"""batch_size, query_len, _ = query.size()context_len = context.size(1)# Project input embeddingsquery_proj = self.query_proj(query).view(batch_size, query_len, self.num_heads, self.hidden_dim)key_proj = self.key_proj(context).view(batch_size, context_len, self.num_heads, self.hidden_dim)value_proj = self.value_proj(context).view(batch_size, context_len, self.num_heads, self.hidden_dim)# Transpose to get dimensions (batch_size, num_heads, len, hidden_dim)query_proj = query_proj.permute(0, 2, 1, 3)key_proj = key_proj.permute(0, 2, 1, 3)value_proj = value_proj.permute(0, 2, 1, 3)# Compute attention scoresscores = torch.matmul(query_proj, key_proj.transpose(-2, -1)) / (self.hidden_dim ** 0.5)attn_weights = F.softmax(scores, dim=-1)# Compute weighted contextcontext = torch.matmul(attn_weights, value_proj)# Concatenate heads and project outputcontext = context.permute(0, 2, 1, 3).contiguous().view(batch_size, query_len, -1)output = self.out_proj(context)return output, attn_weights# Example usage:
embed_dim = 512
hidden_dim = 64
num_heads = 8cross_attention = CrossAttention(embed_dim, hidden_dim, num_heads)# Dummy data
batch_size = 2
query_len = 10
context_len = 20query = torch.randn(batch_size, query_len, embed_dim)
context = torch.randn(batch_size, context_len, embed_dim)output, attn_weights = cross_attention(query, context)
print(output.size())  # Should be (batch_size, query_len, embed_dim)
print(attn_weights.size())  # Should be (batch_size, num_heads, query_len, context_len)
  1. 类定义CrossAttention 类继承自 nn.Module,包含初始化函数 __init__ 和前向传播函数 forward
  2. 初始化
    • 定义了一些线性变换层:query_proj, key_proj, 和 value_proj,这些层将嵌入向量转换为多头注意力机制所需的维度。
    • 最终的输出通过 out_proj 再投影回原始的嵌入维度。
  3. 前向传播
    • 输入的 querycontext 分别通过线性变换层,并重新整形以适应多头注意力机制。
    • 计算注意力分数,并通过 softmax 得到注意力权重。
    • 利用注意力权重加权上下文向量,得到新的上下文表示。
    • 最后将多头的结果合并,并通过输出投影层得到最终的输出。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词