torch.Tensor.masked_fill
是 PyTorch 中用于根据给定的掩码将张量中的特定元素替换为指定值的函数。这个函数可以用于在模型中屏蔽不需要的值,通常与掩码操作(如前向掩码、反向掩码等)结合使用。
函数签名
Tensor.masked_fill(mask, value)
参数
mask
:一个与输入张量相同形状的布尔型(bool
)张量或相同维度的整型张量,True
或非零的元素表示需要替换的元素位置。value
:需要替换为的值。当掩码张量中元素为True
或非零时,原张量中对应位置的元素会被替换为这个值。
返回
返回的是一个新张量,其中根据掩码 mask
对应的位置用 value
进行替换。
使用场景
- 遮蔽(屏蔽)无效值:在序列任务中,比如当处理不等长的输入时,可以使用
masked_fill
将填充的位置(比如PAD
标记)设置为一个极端的值,如负无穷大(-inf
),避免模型关注这些位置。 - 生成注意力掩码:在 Transformer 等模型中,使用掩码来确保某些位置不会被模型关注。
示例代码
1. 基本示例
import torch# 创建一个张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32)# 创建一个掩码张量,True 表示对应位置将被替换
mask = torch.tensor([[True, False, True], [False, True, False]])# 使用 masked_fill 替换为 -1
result = x.masked_fill(mask, -1)
print(result)
输出:
tensor([[-1., 2., -1.],[ 4., -1., 6.]])
解释:在 mask
中为 True
的位置,x
中的元素被替换为 -1
。
2. 避免关注填充(PAD)位置
在处理文本序列时,可能需要对填充的 PAD
标记位置进行掩码操作。
import torch# 假设有一个序列的注意力权重矩阵
attention_scores = torch.tensor([[0.2, 0.5, 0.3, 0.0],[0.1, 0.3, 0.6, 0.0]])# 对应的掩码,表示序列中的 PAD 位置
#pad_mask = torch.tensor([[False, False, False, True],
# [False, False, False, True]])pad_mask = attention_scores == 0.0
print(pad_mask)# 将 PAD 位置的分数替换为一个极端的小值 -inf
masked_attention_scores = attention_scores.masked_fill(pad_mask, float('-inf'))
print(masked_attention_scores)
输出:
tensor([[False, False, False, True],[False, False, False, True]])
tensor([[0.2000, 0.5000, 0.3000, -inf],[0.1000, 0.3000, 0.6000, -inf]])
解释:pad_mask
中为 True
的位置(即 PAD
位置),注意力分数被替换为 -inf
,这将确保在 softmax 操作中这些位置的权重接近于 0。
总结
masked_fill
在 PyTorch 中是非常有用的工具,能够根据掩码来灵活地屏蔽或替换特定张量中的元素。它广泛用于序列处理、注意力机制等场景,帮助模型忽略不需要的部分。