欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 新闻 > 资讯 > PyTorch FlexAttention技术实践:基于BlockMask实现因果注意力与变长序列处理

PyTorch FlexAttention技术实践:基于BlockMask实现因果注意力与变长序列处理

2025/1/7 19:46:50 来源:https://blog.csdn.net/m0_46510245/article/details/144923336  浏览:    关键词:PyTorch FlexAttention技术实践:基于BlockMask实现因果注意力与变长序列处理

本文介绍了如何利用torch 2.5及以上版本中新引入的FlexAttention和BlockMask功能来实现因果注意力机制与填充输入的处理。

鉴于目前网络上缺乏关于FlexAttention处理填充输入序列的完整代码示例和技术讨论,本文将详细阐述一种实现方法,该方法同时涵盖了因果注意力机制的实现。

本文不会详细讨论FlexAttention的理论基础,如需了解更多技术细节,建议参考PyTorch官方博客。

环境配置

 git clone https://github.com/pytorch-labs/attention-gym.git  cd attention-gym  pip install .  cd ../

我们通过attention-gym仓库进行安装,这样可以确保组件间的兼容性,同时获取其可视化工具的使用权限。

MultiheadFlexAttention实现

为了在transformer架构中有效地使用flex_attention,需要在多头注意力模块中进行实现。

     class MultiheadFlexAttention(nn.Module):  def __init__(self, d_in, d_out, n_heads, bias=False):  """  描述:实现基于flex_attention的多头自注意力机制的PyTorch模块参数:d_in: int, 输入张量维度d_out: int, 输出张量维度n_heads: int, 注意力头数bias: bool, 是否在query、key和value计算中使用偏置项"""  super().__init__()  assert d_out % n_heads == 0, "d_out must be divisible by n_heads"  self.n_heads = n_heads  self.d_head = d_out // n_heads  self.d_out = d_out  self.in_proj = nn.Linear(d_in, 3 * d_out, bias=bias)  self.out_proj = nn.Linear(d_out, d_out)

此处定义了模型的核心参数,包括输入输出维度及线性变换层。

 def forward(self, x, block_mask):  """  描述:多头自注意力模块的前向计算过程参数:x: torch.Tensor, 输入张量,维度为(batch_size, max_seq_len, d_in)block_mask: torch.Tensor, flex_attention使用的块状掩码"""  batch_size, max_seq_len, d_in = x.shape  # 通过线性变换生成query、key、value的组合表示qkv = self.in_proj(x)  # 将qkv分解并重组为多头形式qkv = qkv.view(batch_size, max_seq_len, 3, self.n_heads, self.d_head)  # 调整张量维度以适配flex_attention的输入要求qkv = qkv.permute(2, 0, 3, 1, 4)  # 解析得到query、key、value张量queries, keys, values = qkv   # 利用flex_attention计算注意力权重attn = flex_attention(queries, keys, values, block_mask=block_mask)  # 合并多头注意力的输出attn = attn.transpose(1, 2).contiguous().view(batch_size, max_seq_len, self.d_out)  # 执行输出映射attn = self.out_proj(attn)  return attn, queries, keys

该前向传播函数的实现与PyTorch标准的MultiheadAttention类相似,主要区别在于引入了block_mask参数并采用flex_attention函数进行注意力计算。

mask_mod函数实现

FlexAttention的核心优势在于能够高效地实现和使用自定义注意力掩码,而无需编写特定的CUDA核心代码。

要使用此功能,需要将掩码定义为布尔类型张量。首先实现一个因果掩码,这是FlexAttention开发者在其官方博客中提供的基础示例。

因果掩码

 def causal(b, h, q_idx, kv_idx):  return q_idx >= kv_idx

这里的参数说明:

  • b:批次大小
  • h:注意力头数
  • q_idx:query位置索引
  • kv_idx:key/value位置索引

例如,对于序列长度为

5

的输入,

q_idx

表示为

torch.Tensor([0,1,2,3,4])

q_idx >= kv_idx

返回一个因果布尔掩码,确保注意力计算只考虑当前位置及其之前的token。

接下来将实现填充掩码来处理变长序列的填充部分。

填充掩码实现

填充掩码与因果掩码的主要区别在于其批次依赖性,即掩码值取决于每个序列中填充token的具体位置。实现时需要通过填充标记表来识别序列中应被忽略的填充token。

 def create_padding_mask(pads):  def padding(b, h, q_idx, kv_idx):  return ~pads[b, q_idx] & ~pads[b, kv_idx]  return padding
pads

是一个形状为

(batch_size, max_seq_len)

的布尔张量,填充位置标记为True,有效token位置标记为False。此

padding

mask_mod函数生成填充掩码,仅当query和key/value位置均为非填充token时才允许注意力计算。

实验设置与数据准备

在组合掩码并应用到MultiheadFlexAttention之前,需要先设置相关参数并准备实验数据。

 # 多头注意力参数配置d_in = 64  d_out = 64  n_heads = 8  # 初始化多头注意力模块mhfa = MultiheadFlexAttention(d_in, d_out, n_heads).to(device)  # 数据维度设置batch_size = 1 # 支持任意批次大小max_seq_len = 10  # 生成随机输入数据input_data = torch.randn(batch_size, max_seq_len, d_in).to(device)

接下来,对

input_data

进行修改,添加随机的末尾零填充。

 # 添加随机零填充pad = torch.zeros(1, d_in).to(device)  pad_idxs = [(b, range(torch.randint(max_seq_len//2, max_seq_len + 1, (1,)).item(), max_seq_len)) for b in range(batch_size)]  for b, idxs in pad_idxs:  input_data[b, idxs] = pad

现在需要为

padding

mask_mod函数构建填充标记表。

 # 构建填充标记掩码collapsed_input = input_data[:, :, 0] # (batch_size, max_seq_len)  pads = torch.eq(collapsed_input, 0).to(device)

注意,mask_mod函数不需要考虑

input_data

的嵌入维度,因此在创建填充标记表(

pads

)时可以将该维度压缩。

组合因果掩码和填充掩码

此时我们已具备创建综合注意力掩码所需的全部组件。

 # 构建组合掩码causal_mask = causal  padding_mask = create_padding_mask(pads)  masks = [causal, padding_mask]  combined_mask = and_masks(*masks)  causal_padding_mask = create_block_mask(combined_mask, B=batch_size, H=None, Q_LEN=max_seq_len, KV_LEN=max_seq_len, _compile=True)

在这里,我们通过torch.flex_attention提供的

and_masks

函数将

causal

padding

mask_mod函数进行组合,从而生成统一的BlockMask。

说明:开发团队建议启用

_compile_

参数可显著提升BlockMasks的生成效率,这对于批次相关的掩码处理尤其重要。

现在可以利用MultiheadFlexAttention类对

input_data

执行注意力计算,同时应用编译后的自定义注意力掩码。

 # 执行前向计算attn_output, query, key = mhfa(input_data, causal_padding_mask)

使用attention-gym提供的可视化工具来分析注意力分布。

 # 可视化第一个序列的注意力分布visualize_attention_scores(  query,  key,  mask_mod=combined_mask,  device=device,  name="causal_padding_mask",  path=Path("./causal_padding_mask.png"),  )

上图展示了包含三个填充token的序列的掩码后因果注意力分布。

从可视化结果可以观察到,填充token和未来token的注意力权重都被有效地屏蔽,验证了实现的正确性。

https://avoid.overfit.cn/post/96d77c0f872c43dd8c752b687af7babf

作者:Lucas Gomez

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com