背景:
softmax操作用于输出一个概率分布作为注意力权重。
在某些情况下,并非所有值都应该被纳入到注意力汇聚中。为了仅将有意义的词元作为值来获取注意力汇聚,可以指定一个有效序列长度,以便在计算softmax时过滤掉超出指定范围的位置(任何超出有效长度的位置都被掩蔽置为0),代码如下:
def masked_softmax(X, valid_lens):if valid_lens is None:return nn.dunctional.softmax(X, dim=-1)else:shape = X.shapeif valid_lens.dim() == 1:valid_lens = torch.repeat_interleave(valid_len, shape[1])else:valid_lens = valid_lens.reshape(-1)X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)return nn.functional.softmax(X.reshape(shape), dim=-1)
-
代码解释:
X: 输入张量,通常是一个形状为 (batch_size, sequence_length, num_features) 的张量。
valid_lens: 有效长度,表示每个序列中有效的元素数量。
它可以是一维 (batch_size,)的张量(表示每个样本的有效长度不同)。
例如tensor([2, 3, 1])是一个一维张量,形状为 (3,)
代码中的valid_lens = torch.repeat_interleave(valid_len, shape[1])
将 valid_lens 重复 shape[1] 次,使其形状变为 (batch_size * sequence_length,),这样做是为了将每个样本的有效长度扩展到每个位置,以便后续进行掩码操作。
假如valid_lens=tensor([2, 3, 1])shape[1]为3(即第一个样本每一行的有效长度为2,第二个样本每一行有效长度为3,第三个样本每一行有效长度为1),通过代码valid_lens将会变为tensor([2, 2, 2, 3, 3, 3, 1, 1, 1])即有效长度成功扩展到每个位置。
它也可以是二维张量(即形状为 (batch_size, sequence_length)),即为每个样本的每一行指定有效长度。
valid_lens = valid_lens.reshape(-1): 将 valid_lens 展平为一维张量,形状为 (batch_size * sequence_length,)X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
主要目的是对输入张量 X 进行掩码操作,将无效位置的值替换为 -1e6,以便在 softmax 计算中忽略这些位置 -
示例