get_tor_mask函数
生成一个张量 tors_mask
,用于表示蛋白质序列中各残基的二面角是否有效。tors_mask
是一个布尔张量,大小为 (B,L,10),其中 B 是批次大小,L 是序列长度,最后的维度(大小为 10)表示每个残基可能涉及的二面角状态。
源代码:
def get_tor_mask(seq, torsion_indices, mask_in=None):B, L = seq.shape[:2]tors_mask = torch.ones((B, L, 10), dtype=torch.bool, device=seq.device)tors_mask[..., 3:7] = torsion_indices[seq, :, -1] > 0tors_mask[:, 0, 1] = Falsetors_mask[:, -1, 0] = False# mask for additional anglestors_mask[:, :, 7] = seq != aa2num["GLY"]tors_mask[:, :, 8] = seq != aa2num["GLY"]tors_mask[:, :, 9] = torch.logical_and(seq != aa2num["GLY"], seq != aa2num["ALA"])tors_mask[:, :, 9] = torch.logical_and(tors_mask[:, :, 9], seq != aa2num["UNK"])tors_mask[:, :, 9] = torch.logical_and(tors_mask[:, :, 9], seq != aa2num["MAS"])if mask_in != None:# mask for missing atoms# chisti0 = torch.gather(mask_in, 2, torsion_indices[seq, :, 0])ti1 = torch.gather(mask_in, 2, torsion_indices[seq, :, 1])ti2 = torch.gather(mask_in, 2, torsion_indices[seq, :, 2])ti3 = torch.gather(mask_in, 2, torsion_indices[seq, :, 3])is_valid = torch.stack((ti0, ti1, ti2, ti3), dim=-2).all(dim=-1)tors_mask[..., 3:7] = torch.logical_and(tors_mask[..., 3:7], is_valid)tors_mask[:, :, 7] = torch.logical_and(tors_mask[:, :, 7], mask_in[:, :, 4]) # CB exist?t