欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 房产 > 建筑 > AF3 BaseTriangleMultiplicativeUpdate类解读

AF3 BaseTriangleMultiplicativeUpdate类解读

2025/2/21 3:25:07 来源:https://blog.csdn.net/qq_27390023/article/details/145133990  浏览:    关键词:AF3 BaseTriangleMultiplicativeUpdate类解读

BaseTriangleMultiplicativeUpdate 类是一个抽象基类 (ABC),用于实现 AlphaFold 相关算法(具体为算法 11 和 12)。它的主要功能是通过三角形乘法更新成对表示张量(pairwise representation tensor)。

源代码:


from functools import partialmethod
from typing import Optional
from abc import ABC, abstractmethodimport torch
import torch.nn as nn
from torch.nn import LayerNorm
from src.models.components.primitives import Linear
from src.utils.chunk_utils import chunk_layer
from src.utils.tensor_utils import add, permute_final_dimsclass BaseTriangleMultiplicativeUpdate(nn.Module, ABC):"""Implements Algorithms 11 and 12."""@abstractmethoddef __init__(self, c_z, c_hidden, _outgoing):"""Args:c_z:Input channel dimensionc:Hidden channel dimension"""super(BaseTriangleMultiplicativeUpdate, self).__init__()self.c_z = c_zself.c_hidden = c_hiddenself._outgoing = _outgoingself.linear_g = Linear(self.c_z, self.c_z, init="gating")self.linear_z = Linear(self.c_hidden, self.c_z, init="final")self.layer_norm_in = LayerNorm(self.c_z)self.layer_norm_out = LayerNorm(self.c_hidden)self.sigmoid = nn.Sigmoid()def _combine_projections(self,a: torch.Tensor,b: torch.Tensor,_inplace_chunk_size: Optional[int] = None) -> torch.Tensor:if self._outgoing:a = permute_final_dims(a, (2, 0, 1))b = permute_final_dims(b, (2, 1, 0))else:a = permute_final_dims(a, (2, 1, 0))b = permute_final_dims(b, (2, 0, 1))if _inplace_chunk_size is not None:# To be replaced by torch vmapfor i in range(0, a.shape[-3], _inplace_chunk_size):a_chunk = a[..., i: i + _inplace_chunk_size, :, :]b_chunk = b[..., i: i + _inplace_chunk_size, :, :]a[..., i: i + _inplace_chunk_size, :, :] = (torch.matmul(a_chunk,b_chunk,))p = aelse:p = torch.matmul(a, b)return permute_final_dims(p, (1, 2, 0))@abstractmethoddef forward(self,z: torch.Tensor,mask: Optional[torch.Tensor] = None,inplace_safe: bool = False,_add_with_inplace: bool = False) -> torch.Tensor:"""Args:x:[*, N_res, N_res, C_z] x tensormask:[*, N_res, N_res] x maskReturns:[*, N_res, N_res, C_z] output tensor"""pass

代码解读:

1. 构造方法(__init__
def __init__(self, c_z, c_hidden, _outgoing):super(BaseTriangleMultiplicativeUpdate, self).__init__()self.c_z = c_zself.c_hidden = c_hiddenself._outgoing = _outgoingself.linear_g = Linear(self.c_z, self.c_z, init="gating")self.linear_z = Linear(self.c_hidden, self.c_z, init="final")self.layer_norm_in = LayerNorm(self.c_z)self.layer_norm_out = LayerNorm(self.c_hidden)self.sigmoid = nn.Sigmoid()
  • c_z: 输入的通道维度。
  • c_hidden: 隐藏层的通道维度。
  • _outgoing: 一个布尔标志,用于指示操作的方向(用于三角形乘法:出发或抵达方向)。
  • 主要组件:
    • linear_g 和 linear_z:线性层,用于特征变换。
    • layer_norm_in 和 layer_norm_out:层归一化,用于提高训练稳定性。
    • sigmoid:对门控张量进行非线性变换。

2. _combine_projections 方法
def _combine_projections(self, a, b, _inplace_chunk_size=None):if self._outgoing:a = permute_final_dims(a, (2, 0, 1))b = permute_final_dims(b, (2, 1, 0))else:a = permute_final_dims(a, (2, 1, 0))b = permute_final_dims(b, (2, 0, 1))if _inplace_chunk_size is not None:for i in range(0, a.shape[-3], _inplace_chunk_size):a_chunk = a[..., i: i + _inplace_chunk_size, :, :]b_chunk = b[..., i: i + _inplace_chunk_size, :, :]a[..., i: i + _inplace_chunk_size, :, :] = (torch.matmul(a_chunk, b_chunk))p = aelse:p = torch.matmul(a, b)return permute_final_dims(p, (1, 2, 0))
  • permute_final_dims:交换张量的最后几个维度以适配矩阵乘法。
  • torch.matmul:计算 a 和 b 的矩阵乘法,用于更新表示。
  • _inplace_chunk_size:支持按块处理较大的张量以节省内存。
  • 返回值:重新排列维度后的结果张量 p

3. forward 方法(抽象)
@abstractmethod
def forward(self, z, mask=None, inplace_safe=False, _add_with_inplace=False):pass
  • 参数说明
    • z: [*, N_res, N_res, C_z],输入的成对表示张量。
    • mask: 可选的掩码张量,用于处理不需要更新的部分。
    • inplace_safe:是否安全地进行原地操作。
    • _add_with_inplace:是否在加法中使用原地操作。
  • 作用:具体的三角形乘法更新逻辑将在子类中实现。

作用与意义

  1. 三角形乘法更新

    • 用于更新成对表示张量,捕捉蛋白质序列中残基之间的几何和物理关系。
    • 根据 _outgoing,可以处理三角形结构中不同方向的信息流。
  2. 特征提取与整合

    • 通过线性变换(linear_g 和 linear_z)和矩阵乘法,整合隐藏层的高阶特征。
  3. 应用场景

    • 主要在 AlphaFold 中用于蛋白质结构预测,帮助建模残基之间的复杂相互作用。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词