PairStack
是 AlphaFold 的核心模块之一,用于对残基对(residue-residue pair)的特征张量 z
进行迭代更新。这个模块结合几何操作(如三角形乘法)和注意力机制,逐步建模蛋白质序列中残基之间的复杂关系。
源代码:
class PairStack(nn.Module):def __init__(self,c_z: int,c_hidden_tri_mul: int = 128,c_hidden_pair_attn: int = 32,no_heads_tri_attn: int = 4,transition_n: int = 4,pair_dropout: float = 0.25,fuse_projection_weights: bool = False,inf: float = 1e8,):super(PairStack, self).__init__()if fuse_projection_weights:self.tri_mul_out = FusedTriangleMultiplicationOutgoing(c_z,c_hidden_tri_mul,)self.tri_mul_in = FusedTriangleMultiplicationIncoming(c_z,c_hidden_tri_mul,)else:self.tri_mul_out = TriangleMultiplicationOutgoing(c_z,c_hidden_tri_mul,)self.tri_mul_in = TriangleMultiplicationIncoming(c_z,c_hidden_tri_mul,)self.tri_att_start = TriangleAttentionStartingNode(c_z,c_hidden_pair_attn,no_heads_tri_attn,inf=inf,)self.tri_att_end = TriangleAttentionEndingNode(c_z,c_hidden_pair_attn,no_heads_tri_attn,inf=inf,)self.transition = Transition(c_z,transition_n,)self.dropout_row_layer = DropoutRowwise(pair_dropout)self.dropout_col_layer = DropoutColumnwise(pair_dropout)def forward(