AlphaFold3 _attention
函数位于 src.models.components.primitives模块,是一个标准的注意力机制的实现,主要用于计算输入的查询 (query
)、键 (key
) 和值 (value
) 张量之间的注意力权重,并将其应用于值张量。_attention
函数被Attention类调用,实现定制化的多头注意力机制。
源代码:
def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor:"""A stock PyTorch implementation of the attention mechanism.Args:query:[*, H, Q, C_hidden] query tensorkey:[*, H, K/V, C_hidden] key tensorvalue:[*, H, K/V, C_value] value tensorbiases:a list of biases that broadcast to [*, H, Q, K]Returns:the resultant tensor [*, H, Q, C_value]"""# [*, H, C_hidden, K]key = permute_final_dims(key, (1, 0))# [*, H, Q, K]a = torch.matmul(query,