AlphaFold3的AtomAttentionDecoder类
旨在从每个 token 的表示扩展到每个原子的表示,同时通过交叉注意力机制对原子及其对关系进行建模。这种设计可以在生物分子建模中捕获复杂的原子级别交互。
源代码:
class AtomAttentionDecoder(nn.Module):"""AtomAttentionDecoder that broadcasts per-token activations to per-atom activations."""def __init__(self,c_token: int,c_atom: int = 128,c_atompair: int = 16,no_blocks: int = 3,no_heads: int = 8,dropout=0.0,n_queries: int = 32,n_keys: int = 128,clear_cache_between_blocks: bool = False):"""Initialize the AtomAttentionDecoder module.Args:c_token:The number of channels for the token representation.c_atom:The number of channels for the atom representation. Defaults to 128.c_atompair:The number of channels for the atom pair representation. Defaults to 16.no_blocks:Number of blocks.no_heads:Number of parallel attention heads. Note that c_atom will be split across no_heads(i.e. each head will have dimension c_atom // no_heads).dropout:Dropout probability on attn_output_weights. Default: 0.0 (no dropout).n_queries:The size of the atom window. Defaults to 32.n_keys:Number of atoms each atom attends to in local sequence space. Defaults to 128.clear_cache_between_blocks:Whether to clear CUDA's GPU memory cache between blocks of thestack. Slows down each block but can reduce fragmentation"""super().__init__()self.c_token = c_tokenself.c_atom = c_atomself.c_atompair = c_atompairself.num_blocks = no_blocksself.num_heads = no_headsself.dropout = dropoutself.n_queries = n_queriesself.n_keys = n_keysself.clear_cache_bet