InputEmbedder
是 AlphaFold3 中的输入嵌入器模块,用于初始化 单体表示(single representation, s_init
) 和 对表示(pair representation, z_init
)。
源代码:
class InputEmbedder(nn.Module):"""Input embedder for AlphaFold3 that initializes the single and pair representations."""def __init__(self,c_token: int = 384,c_atom: int = 128,c_atompair: int = 16,c_trunk_pair: int = 128,):super(InputEmbedder, self).__init__()# InputFeatureEmbedder for the s_inputs representationself.input_feature_embedder = InputFeatureEmbedder(c_token=c_token,c_atom=c_atom,c_atompair=c_atompair,c_trunk_pair=c_trunk_pair)# Projectionsself.linear_single = LinearNoBias(c_token, c_token)self.linear_proj_i = LinearNoBias(c_token, c_trunk_pair)self.linear_proj_j = LinearNoBias(c_token, c_trunk_pair)# self.linear_bonds = LinearNoBias(1, c_trunk_pair)# Relative position encodingself.relpos = RelativePositionEncoding(c_pair=c_trunk_pair)def forward(self,features: Dict[str, Tensor],inplace_safe: bool = False,) -> Tuple[Tensor, Tensor, Tensor]:"""Args:features:Dictionary containing the following input features:"ref_pos" ([*, N_atoms, 3]):atom positions in the reference conformers, witha random rotation and translation applied. Atom positions in Angstroms."ref_charge" ([*, N_atoms]):Charge for each atom in the reference conformer."ref_mask" ([*, N_atoms]):Mask indicating which atom slots are used in the reference