函数 rigid_from_3_points
的作用是 根据给定的三点(N、Ca、C)计算局部刚体坐标系到全局坐标系的刚体变换。它返回一个旋转矩阵 RR 和一个平移向量(这里是点 Ca 的坐标),从而描述一个刚体变换。
源码
# More complicated version splits error in CA-N and CA-C (giving more accurate CB position)
# It returns the rigid transformation from local frame to global frame
def rigid_from_3_points(N, Ca, C, non_ideal=False, eps=1e-8):# N, Ca, C - [B,L, 3]# R - [B,L, 3, 3], det(R)=1, inv(R) = R.T, R is a rotation matrixB, L = N.shape[:2]v1 = C - Cav2 = N - Cae1 = v1 / (torch.norm(v1, dim=-1, keepdim=True) + eps)u2 = v2 - (torch.einsum("bli, bli -> bl", e1, v2)[..., None] * e1)e2 = u2 / (torch.norm(u2, dim=-1, keepdim=True) + eps)e3 = torch.cross(e1, e2, dim=-1)R = torch.cat([e1[..., None], e2[..., None], e3[..., None]], axis=-1) # [B,L,3,3] - rotation matrixif non_ideal:v2 = v2 / (torch.norm(v2, dim=-1, keepdim=True) + eps)cosref = torch.sum(e1 * v2, dim=-1) # cosine of current N-CA-C bond anglecostgt = cos_ideal_NCAC.item()cos2del = torch.clamp(cosref * costgt+ torch.sqrt((1 - cosref * cosref) * (1 - costgt * costgt) + eps),min=-1.0,max=1.0,)cosdel = torch.sqrt(0.5 * (1 + cos2del) + eps)sindel = torch.sign(costgt - cosref) * torch.sqrt(1 - 0.5 * (1 + cos2del) + eps)Rp = torch.eye(3, device=N.device).repeat(B, L, 1, 1)Rp[:, :, 0, 0] = cosdelRp[:, :, 0, 1] = -sindelRp[:, :, 1, 0] = sindelRp[:, :, 1, 1] = cosdelR = torch.einsum("blij,bljk->blik", R, Rp)return R, Ca
代码解读
输入参数
-
N
、Ca
、C
: