文章目录
- 前言
- 一、掩蔽 Softmax 操作
- 1.1 sequence_mask
- 1.2 masked_softmax
- 1.3 测试代码
- 二、加性注意力 (Additive Attention)
- 2.1 实现解析
- 2.2 测试代码
- 三、点积注意力 (Dot Product Attention)
- 3.1 实现解析
- 3.2 测试代码
- 四、可视化注意力权重
- 4.1 可视化点积注意力的权重
- 总结
前言
在深度学习领域,注意力机制(Attention Mechanism)已经成为许多模型的核心组件,尤其是在自然语言处理(NLP)和计算机视觉任务中。注意力机制的核心思想是通过计算查询(Query)与键(Key)之间的相关性,动态地为值(Value)分配权重,从而聚焦于最重要的信息。本篇博客将通过 PyTorch 代码,深入探讨注意力汇聚(Attention Pooling)的两种常见评分函数:加性注意力(Additive Attention)和点积注意力(Dot Product Attention)。我们将从代码实现入手,逐步解析其原理,并通过可视化工具展示注意力权重的分布。
本文的目标读者是对深度学习有一定基础、希望通过代码理解注意力机制的实现细节的开发者。所有代码均基于 PyTorch,并在 Jupyter Notebook 中运行和测试。让我们开始吧!
一、掩蔽 Softmax 操作
在注意力机制中,掩蔽 Softmax(Masked Softmax)是一个关键步骤,用于确保模型只关注序列中的有效部分,避免对填充(padding)数据产生影响。我们先来看两个核心函数的实现:sequence_mask
和 masked_softmax
。
1.1 sequence_mask
sequence_mask
函数用于在序列中屏蔽不相关的项。它接收输入序列张量 X
、有效长度张量 valid_len
,并将无效位置替换为指定值(默认值为 0)。
import torch
import torch.nn as nndef sequence_mask(X, valid_len, value=0):"""在序列中屏蔽不相关的项参数:X: 输入序列张量,维度 [batch_size, maxlen]valid_len: 有效长度张量,维度 [batch_size]value: 填充值,标量,默认为0返回:X: 屏蔽后的序列张量,维度 [batch_size, maxlen]Defined in :numref:`sec_seq2seq_decoder`"""# 获取序列的最大长度,维度为标量maxlen = X.size(1)# 创建掩码矩阵# torch.arange(maxlen): 生成 [0, 1, ..., maxlen-1] 的序列,维度 [maxlen]# [None, :] 将其扩展为 [1, maxlen]# valid_len[:, None] 将 [batch_size] 扩展为 [batch_size, 1]# 比较结果 mask 维度为 [batch_size, maxlen]mask = torch.arange((maxlen), dtype=torch.float32,device=X.device)[None, :] < valid_len[:, None]# 使用掩码将 X 中无效位置设为 value# ~mask 为反向掩码,选择需要填充的位置X[~mask] = valuereturn X
这个函数的工作原理是:
- 通过
torch.arange(maxlen)
生成一个从 0 到maxlen-1
的序列,并扩展为与批量大小匹配的形状。 - 使用广播机制,将
valid_len
与生成的序列比较,生成布尔掩码mask
。 - 根据掩码,将无效位置(即超出有效长度的部分)替换为
value
。
1.2 masked_softmax
masked_softmax
函数在 Softmax 操作中加入掩蔽机制,确保无效位置的注意力权重为 0。
def masked_softmax(X, valid_lens):"""通过在最后一个轴上掩蔽元素来执行softmax操作参数:X: 三维张量 (batch_size, seq_len, feature_dim)valid_lens: 一维张量 (batch_size,) 或二维张量 (batch_size, seq_len),表示有效长度返回:经过masked softmax处理的张量 (batch_size, seq_len, feature_dim)"""if valid_lens is None:# 当没有指定有效长度时,直接执行标准softmaxreturn nn.functional.softmax(X, dim=-1)else:shape = X.shape # shape: (batch_size, seq_len, feature_dim)if valid_lens.dim() == 1:# 将一维的valid_lens重复扩展到与X的第二维匹配valid_lens = torch.repeat_interleave(valid_lens, shape[1])else:# 将二维的valid_lens展平为一维valid_lens = valid_lens.reshape(-1)# 在最后一轴上对被掩蔽的元素使用非常大的负值替换,使其softmax输出为0X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens,value=-1e6)# 执