欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 健康 > 美食 > DGL库之HGTConv的使用

DGL库之HGTConv的使用

2025/2/1 2:48:29 来源:https://blog.csdn.net/m0_56878426/article/details/142830565  浏览:    关键词:DGL库之HGTConv的使用

DGL库之HGTConv的使用

  • 论文地址和异构图构建教程
  • HGTConv语法格式
  • HGTConv的使用

论文地址和异构图构建教程

论文地址:https://arxiv.org/pdf/2003.01332
异构图构建教程:异构图构建
异构图转同构图:异构图转同构图

HGTConv语法格式

dgl.nn.pytorch.conv.HGTConv(in_size, head_size, num_heads, num_ntypes, num_etypes, dropout=0.2, use_norm=False)

参数说明:

  • in_size (int): 输入节点特征的大小。
  • head_size (int): 输出头的大小。输出节点特征的大小为 head_size * num_heads。
  • num_heads (int): 头的数量。输出节点特征的大小为 head_size * num_heads。
  • num_ntypes (int): 节点类型的数量。
  • num_etypes (int): 边类型的数量。
  • dropout (可选, float): dropout 比率,用于防止过拟合。
  • use_norm (可选, bool): 如果为 True,则在输出节点特征上应用层归一化。
forward(g, x, ntype, etype, *, presorted=False)

参数说明:

  • g (DGLGraph): 输入的图对象。

  • x (torch.Tensor): 一个 2D 张量,表示节点特征。其形状应为 (num_nodes, in_size),num_nodes 是节点数量,in_size 是输入特征的维度。

  • ntype (torch.Tensor): 一个 1D 整数张量,表示节点类型。其形状应为 (num_nodes,),对应每个节点的类型索引。

  • etype (torch.Tensor): 一个 1D 整数张量,表示边类型。其形状应为 (num_edges,),对应每条边的类型索引。

  • presorted (bool, 可选): 指示输入图的节点和边是否已经按照类型排序。如果输入图是预排序的,则前向传播可能会更快。通过调用 to_homogeneous()创建的图会自动满足此条件。也可以使用 reorder_graph() 方法手动重新排序节点和边。

返回值:

  • 返回的新节点特征: 返回的特征是一个 2D 张量,其形状为 (num_nodes, head_size * num_heads),表示经过HGTConv 处理后的新节点特征,返回的张量类型为 torch.Tensor。

HGTConv的使用

使用的异构图如下:
在这里插入图片描述
在使用HGTConv时,一定要使用dgl.to_homogeneous将异构图转为同构图,否则不能使用,代码如下:

import dgl
import torch
import torch.nn as nn
import dgl.nn.pytorch# 定义一个简单的异构图
def create_hetero_graph():# 定义两个类型的节点:drug(药物)和 disease(疾病)data_dict = {('drug', 'd_interacts', 'drug'): (torch.tensor([0, 1]), torch.tensor([1, 2])),  # 药物间的相互作用('drug', 'g_interacts', 'gene'): (torch.tensor([0, 1]), torch.tensor([2, 3])),  # 药物与基因间的相互作用('drug', 'treats', 'disease'): (torch.tensor([1]), torch.tensor([2]))           # 药物与疾病的关系}# 创建一个异构图hetero_graph = dgl.heterograph(data_dict)# 设置节点和边的特征hetero_graph.nodes['drug'].data['h'] = torch.ones(3, 320)  # 假设药物特征是320维的hetero_graph.nodes['disease'].data['h'] = torch.zeros(3, 320)  # 假设疾病特征是320维的hetero_graph.nodes['gene'].data['h'] = torch.ones(4, 320)  # 假设基因特征是320维的return hetero_graph# 定义一个HGT模型类
class HGTModel(nn.Module):def __init__(self, in_dim, out_dim, num_heads, num_layers, num_node_types, num_edge_types, dropout=0.2):super(HGTModel, self).__init__()# 使用 dgl.nn.pytorch.conv.HGTConv 初始化 HGT 卷积层self.layers = nn.ModuleList()  # 创建一个空的层列表for _ in range(num_layers):layer = dgl.nn.pytorch.conv.HGTConv(in_dim,  # 输入维度out_dim,  # 输出维度num_heads,  # 注意力头的数量num_node_types,  # 节点类型数量num_edge_types,  # 边类型数量dropout=dropout  # dropout比率)self.layers.append(layer)  # 将层添加到列表中def forward(self, g):with g.local_scope():  # 创建一个局部作用域,‌确保对图的操作不会影响原始图。‌for layer in self.layers:# 使用HGTConv层进行卷积操作h = layer(g, g.ndata['h'], g.ndata['_TYPE'], g.edata['_TYPE'], presorted=True)g.ndata['h'] = h  # 更新节点特征return g.ndata['h']  # 返回最后一层的节点特征# 创建一个异构图
hetero_graph = create_hetero_graph()print('异构图为:\n',hetero_graph)  # 输出异构图的信息
# 将异构图转换为同构图
homogeneous_graph = dgl.to_homogeneous(hetero_graph, ndata=['h'])
print(f"节点特征矩阵为:\n{homogeneous_graph.ndata['h']}")  # 打印节点特征的类型# 创建模型并移动到 CPU 设备
hgt_model = HGTModel(in_dim=320, out_dim=80, num_heads=4, num_layers=2,num_node_types=3, num_edge_types=3, dropout=0.3).to(torch.device('cpu'))# 前向传播
output_features = hgt_model(homogeneous_graph)print("更新后的特征:\n", output_features)  # 输出特征的形状

结果如下:
在这里插入图片描述

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com