欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 房产 > 家装 > 图神经网络实战(12)——图同构网络(Graph Isomorphism Network, GIN)

图神经网络实战(12)——图同构网络(Graph Isomorphism Network, GIN)

2025/2/25 1:38:44 来源:https://blog.csdn.net/LOVEmy134611/article/details/139487511  浏览:    关键词:图神经网络实战(12)——图同构网络(Graph Isomorphism Network, GIN)

图神经网络实战(12)——图同构网络

    • 0. 前言
    • 1. 图同构网络原理
    • 2. 构建 GIN 模型执行图分类
      • 2.1 图分类任务
      • 2.2 PROTEINS 数据集分析
      • 2.3 构建 GIN 实现图分类
      • 2.4 GCN 与 GIN 性能差异分析
    • 3. 提升模型性能
    • 小结
    • 系列链接

0. 前言

Weisfeiler-Leman (WL) 测试提供了一个理解图神经网络 (Graph Neural Networks, GNN) 表达能力的框架,利用该框架我们比较了不同的 GNN 层,在本节中,我们将利用 WL 测试结果尝试设计比图卷积网络 (Graph Convolutional Network, GCN)、图注意力网络 (Graph Attention Networks,GAT) 和 GraphSAGE 更强大的 GNN 架构——图同构网络 (Graph Isomorphism Network, GIN)。然后,使用 PyTorch Geometric 实现 GIN 架构,并执行图分类任务。我们将在 PROTEINS 数据集上实现 GIN 架构,将比较不同的图分类模型并分析结果。

为了验证这一想法,在下一节中,我们将根据这一思想构建。

1. 图同构网络原理

在 Weisfeiler-Leman (WL) 测试一节中,我们看到之前所介绍的图神经网络 (Graph Neural Networks, GNN) (包括图卷积网络 (Graph Convolutional Network, GCN)、图注意力网络 (Graph Attention Networks,GAT) 和 GraphSAGE 等)的表达能力不如 WL 测试,这暴露出一个问题,因为区分更多图结构的能力与最终嵌入的质量有关。在本节中,我们将把理论框架转化为一种新的 GNN 架构——图同构网络 (Graph Isomorphism Network, GIN)。
GIN2018 年由 Xu 等人提出,旨在具有与 WL 测试相同的表达能力。作者将对聚合的观察归纳为两个函数:

  • 聚合 (Aggregate): 函数 f f f 选择 GNN 考虑的邻居节点
  • 组合 (Combine): 函数 ϕ ϕ ϕ 将所选节点的嵌入结合起来,生成目标节点的新嵌入

节点嵌入可以表达为以下形式:
h i ′ = ϕ ( h i , f ( h j : j ∈ N i ) ) h'_i=ϕ(h_i,f({h_j:j\in \mathcal N_i})) hi=ϕ(hi,f(hj:jNi))
GCN 中,函数 f f f 会聚合节点 i i i 的每个邻居,而 ϕ ϕ ϕ 则使用均值聚合算子。在 GraphSAGE 中,邻居采样就是函数 f f f ϕ ϕ ϕ 具有三个不同选项,包括均值聚合算子、长短期记忆 (long short-term memory, LSTM) 聚合算子和池化聚合算子。
而在图同构网络 (Graph Isomorphism Network, GIN) 中,这些函数是必须是单射的。如下图所示,单射函数 (injective function) 将不同的输入映射到不同的输出,这正是我们想要区分图结构的原因。如果函数不是单射的,那么不同的输入将得到相同的输出。在这种情况下,嵌入就变得不那么有价值了,因为它们包含的信息会更少。

单射函数

GIN 在设计这两个函数时,只是对这两个函数进行了近似。在 GAT 层中,我们学习了自注意力权重。在 GIN 中,我们可以利用通用近似定理,用一个多层感知机 (Multilayer Perceptron, MLP) 学习这两个函数:
h i ′ = M L P ( ( 1 + ɛ ) ⋅ h i + ∑ j ∈ N i h j ) h'_i=MLP((1+ɛ)\cdot h_i+\sum_{j\in \mathcal N_i}h_j) hi=MLP((1+ɛ)hi+jNihj)
其中, ɛ ɛ ɛ 是一个可学习的参数或固定标量,表示目标节点的嵌入与其邻居的嵌入相比的重要性。同时,MLP 必须具有多个层来区分特定的图结构。
现在,我们已经介绍了一个与 WL 测试具有相同表达能力的 GNN,但在此基础上,我们还可以进一步改进,将 WL 测试推广为一系列更高级别的测试,称为 k-WL 测试。与考虑单个节点不同,k-WL 测试考虑的是 k 元组节点。这意味着它们是非局部的,因为它们可以查看相距更远的节点,这也是 (k + 1) -WL 测试比 k-WL 测试(其中 k ≥ 2 )能区分更多图结构的原因。
目前已经提出了几种基于 k-WL 测试的架构,如 Morris 等人提出的 k-GNN。虽然这些架构有助于我们更好地理解 GNN 的工作原理,但与 GNNGAT 等表达能力较弱的模型相比,它们在实际应用中往往表现不佳,但它们也有各自合适的应用场景,接下来,我们将 GIN 应用于图分类以发挥其性能。

2. 构建 GIN 模型执行图分类

我们可以直接实现用于节点分类的图同构网络 (Graph Isomorphism Network, GIN)模型,但 GIN 架构对于执行图分类任务更加有效。在本节中,我们将了解如何使用全局池化技术将节点嵌入转化为图嵌入。然后,将这些技术应用于 PROTEINS 数据集,并对比 GIN 和图卷积网络 (Graph Convolutional Network, GCN)在图分类任务中的性能差异。

2.1 图分类任务

图分类是基于图神经网络 (Graph Neural Networks, GNN) 生成的节点嵌入进行的,但与节点层面的任务不同,图分类需要关注图数据的全局信息,需要对全局的信息进行融合学习,在图中通常采用全局池化( global pooling,也称图读出机制,graph-level readout function)来提取全局信息。三种简单的实现方法如下:

  • 全局均值池化 (Mean global pooling): 通过对图中每个节点的嵌入取平均值,得到图嵌入 h G h_G hG
    h G = 1 N ∑ i = 0 N h i h_G=\frac 1N\sum_{i=0}^Nh_i hG=N1i=0Nhi
  • 全局最大池化 (Max global pooling): 通过选择每个节点维度的最高值,得到图嵌入 h G h_G hG
    h G = m a x i = 0 N ( h i ) h_G=max_{i=0}^N(h_i) hG=maxi=0N(hi)
  • 全局求和池化 (Sum global pooling):通过对图中每个节点的嵌入求和,得到图嵌入 h G h_G hG
    h G = ∑ i = 0 N h i h_G=\sum_{i=0}^Nh_i hG=i=0Nhi

根据 Weisfeiler-Leman (WL) 测试可知,求和全局池化严格来说比其他两种池化技术更具表达能力。同时,要考虑所有结构信息,就必须考虑 GNN 每一层产生的嵌入,将 GNN k k k 个层中每层产生的节点嵌入求和后串联起来:
h G = ∑ i = 0 N h i 0 ∣ ∣ ⋯ ∣ ∣ ∑ i = 0 N h i k h_G=\sum_{i=0}^Nh_i^0||\cdots ||\sum_{i=0}^Nh_i^k hG=i=0Nhi0∣∣∣∣i=0Nhik
这种方法通过串联将求和运算符的表达能力与每层中存储的信息优雅的结合在一起。

2.2 PROTEINS 数据集分析

接下来,在 PROTEINS 数据集上使用图读出机制实现 GIN 模型。PROTEINS 数据集包含 1,113 个表示蛋白质的图,其中每个节点都是一个氨基酸。当两个节点之间的距离小于 0.6 纳米时,它们之间会有一条边相连。该数据集的目标是将每个蛋白质分类为酶或非酶,即二分类问题,酶可作为催化剂加速细胞内的化学反应,例如,脂肪酶可以帮助消化食物等,蛋白质的三维结构示例如下:

蛋白质三维结构

接下来,使用 PyTorch Geometric (PyG) 在 PROTEINS 数据集上构建 GIN 模型。

(1) 首先,使用 PyTorch GeometricTUDataset 类导入 PROTEINS 数据集并打印相关信息:

from torch_geometric.datasets import TUDatasetdataset = TUDataset(root='.', name='PROTEINS').shuffle()# Print information about the dataset
print(f'Dataset: {dataset}')
print('-----------------------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of nodes: {dataset[0].x.shape[0]}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
'''
Dataset: PROTEINS(1113)
-----------------------
Number of graphs: 1113
Number of nodes: 28
Number of features: 3
Number of classes: 2
'''

(2)8:1:1 的比例将数据集拆分为训练集、验证集和测试集:

from torch_geometric.loader import DataLoader# Create training, validation, and test sets
train_dataset = dataset[:int(len(dataset)*0.8)]
val_dataset   = dataset[int(len(dataset)*0.8):int(len(dataset)*0.9)]
test_dataset  = dataset[int(len(dataset)*0.9):]print(f'Training set   = {len(train_dataset)} graphs')
print(f'Validation set = {len(val_dataset)} graphs')
print(f'Test set       = {len(test_dataset)} graphs')

训练集、验证集和测试集中图的数量输出如下:

Training set   = 890 graphs
Validation set = 111 graphs
Test set       = 112 graphs

(3) 使用批大小为 64DataLoader 对象将这些数据集合转换为批数据,即每批数据最多包含 64 个图:

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=64, shuffle=True)

(4) 打印每批数据的相关信息:

print('\nTrain loader:')
for i, batch in enumerate(train_loader):print(f' - Batch {i}: {batch}')print('\nValidation loader:')
for i, batch in enumerate(val_loader):print(f' - Batch {i}: {batch}')print('\nTest loader:')
for i, batch in enumerate(test_loader):print(f' - Batch {i}: {batch}')

批数据信息

2.3 构建 GIN 实现图分类

构建训练数据集后,开始实施 GIN 模型。首先需要考虑 GIN 层的架构,使用一个至少有两层的多层感知机 (Multilayer Perceptron, MLP) ,引入批归一化来标准化每个隐藏层的输入,用于稳定并加快训练速度。总体而言, GIN 架构如下所示:

GIN 架构

(1) 使用 PyTorch Geometric (PyG) 实现以上架构,作为对比,我们同时实现了 GCN 模型:

import torch
import torch.nn.functional as F
from torch.nn import Linear, Sequential, BatchNorm1d, ReLU, Dropout
from torch_geometric.nn import GCNConv, GINConv
from torch_geometric.nn import global_mean_pool, global_add_poolclass GCN(torch.nn.Module):"""GCN"""def __init__(self, dim_h):super(GCN, self).__init__()self.conv1 = GCNConv(dataset.num_node_features, dim_h)self.conv2 = GCNConv(dim_h, dim_h)self.conv3 = GCNConv(dim_h, dim_h)self.lin = Linear(dim_h, dataset.num_classes)def forward(self, x, edge_index, batch):# Node embeddings h = self.conv1(x, edge_index)h = h.relu()h = self.conv2(h, edge_index)h = h.relu()h = self.conv3(h, edge_index)# Graph-level readouthG = global_mean_pool(h, batch)# Classifierh = F.dropout(hG, p=0.5, training=self.training)h = self.lin(h)return F.log_softmax(h, dim=1)class GIN(torch.nn.Module):"""GIN"""def __init__(self, dim_h):super(GIN, self).__init__()self.conv1 = GINConv(Sequential(Linear(dataset.num_node_features, dim_h),BatchNorm1d(dim_h), ReLU(),Linear(dim_h, dim_h), ReLU()))self.conv2 = GINConv(Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),Linear(dim_h, dim_h), ReLU()))self.conv3 = GINConv(Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),Linear(dim_h, dim_h), ReLU()))

PyTorch Geometric 还内置了 GINE 层,GINEGIN 层的一种改进。与 GIN 相比,GINE 的主要改进在于能够在聚合过程中考虑边特征。由于 PROTEINS 数据集没有边特征,因此本节采用经典的 GIN 模型。

(2) 要进行图分类,还需要对每一层中图上每个节点嵌入进行求和。换句话说,我们需要为每一层存储一个大小为 dim_h 的向量,本节中 dim_h3。在最终的线性层之前,添加一个大小为 3 * dim_h 的线性层,用于二分类 (data.num_classes = 2):

        self.lin1 = Linear(dim_h*3, dim_h*3)self.lin2 = Linear(dim_h*3, dataset.num_classes)

(3) 接下来,实现连接初始化层的逻辑。每一层都会产生不同的嵌入张量——h1h2h3。使用 global_add_pool() 函数对它们进行求和,然后使用 torch.cat() 将它们串联起来。这样,就得到了分类器的输入,类似一个带有 Dropout 层的普通神经网络:

    def forward(self, x, edge_index, batch):# Node embeddings h1 = self.conv1(x, edge_index)h2 = self.conv2(h1, edge_index)h3 = self.conv3(h2, edge_index)# Graph-level readouth1 = global_add_pool(h1, batch)h2 = global_add_pool(h2, batch)h3 = global_add_pool(h3, batch)# Concatenate graph embeddingsh = torch.cat((h1, h2, h3), dim=1)# Classifierh = self.lin1(h)h = h.relu()h = F.dropout(h, p=0.5, training=self.training)h = self.lin2(h)return F.log_softmax(h, dim=1)

(4) 使用批数据实现一个常规的训练循环,共训练 100epoch

def train(model, loader):criterion = torch.nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=0.01)epochs = 100model.train()for epoch in range(epochs+1):total_loss = 0acc = 0val_loss = 0val_acc = 0# Train on batchesfor data in loader:optimizer.zero_grad()out = model(data.x, data.edge_index, data.batch)loss = criterion(out, data.y)total_loss += loss / len(loader)acc += accuracy(out.argmax(dim=1), data.y) / len(loader)loss.backward()optimizer.step()# Validationval_loss, val_acc = test(model, val_loader)

(5)20epoch 打印一次训练和验证准确率,并返回训练后的模型:

        # Print metrics every 20 epochsif(epoch % 20 == 0):print(f'Epoch {epoch:>3} | Train Loss: {total_loss:.2f} | Train Acc: {acc*100:>5.2f}% | Val Loss: {val_loss:.2f} | Val Acc: {val_acc*100:.2f}%')return model

(6)test() 方法中也必须使用批处理,因为验证和测试加载器同样包含多个批数据:

@torch.no_grad()
def test(model, loader):criterion = torch.nn.CrossEntropyLoss()model.eval()loss = 0acc = 0for data in loader:out = model(data.x, data.edge_index, data.batch)loss += criterion(out, data.y) / len(loader)acc += accuracy(out.argmax(dim=1), data.y) / len(loader)return loss, acc

(7) 定义用于计算准确率的函数:

def accuracy(pred_y, y):"""Calculate accuracy."""return ((pred_y == y).sum() / len(y)).item()

(8) 实例化并训练 GCNGIN 模型:

print('GCN training')
gcn = GCN(dim_h=32)
gcn = train(gcn, train_loader)
print('GIN training')
gin = GIN(dim_h=32)
gin = train(gin, train_loader)

模型训练过程

(9) 使用测试加载器测试训练后的模型:

test_loss, test_acc = test(gcn, test_loader)
print(f'GCN test Loss: {test_loss:.2f} | GCN test Acc: {test_acc*100:.2f}%')test_loss, test_acc = test(gin, test_loader)
print(f'Gin test Loss: {test_loss:.2f} | Gin test Acc: {test_acc*100:.2f}%')# GCN test Loss: 0.56 | GCN test Acc: 65.70%
# GIN test Loss: 0.46 | GIN test Acc: 78.12%

2.4 GCN 与 GIN 性能差异分析

根据上一小节的结果可以看出,用简单的全局均值池( PyTorch Geometric 中使用 global_mean_pool() 实现)实现 GCN 执行图分类,在完全相同的设定下,进行 100 次实验的平均准确率为 72.72%(±0.73%)。这远低于 GIN 模型的平均准确率 77.57% (±1.77%)
据此,可以得出结论,GIN 架构比 GCN 更适合图分类任务。根据 WL 测试理论框架,这是因为 GCN不如 GIN 的表达能力强。换句话说,GINGCN 能区分更多的图结构,这也是它更准确的原因。可以通过可视化两种模型错误分类的图来验证这一假设。

(1) 导入 matplotlibnetworkx 库,用于绘制蛋白质结构:

import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from torch_geometric.utils import to_networkxfig, ax = plt.subplots(4, 4)
fig.suptitle('GIN - Graph classification')

(2) 对于每个蛋白质,使用训练后的 GIN 获取最终分类结果。如果预测是正确的,将其绘制为绿色(否则绘制为红色):

for i, data in enumerate(dataset[-16:]):# Calculate color (green if correct, red otherwise)out = gin(data.x, data.edge_index, data.batch)color = "green" if out.argmax(dim=1) == data.y else "red"

(3) 为了方便起见,将蛋白质转换成 networkx 图,然后使用 nx.draw_networkx() 函数进行绘制:

    ix = np.unravel_index(i, ax.shape)ax[ix].axis('off')G = to_networkx(dataset[i], to_undirected=True)nx.draw_networkx(G,pos=nx.spring_layout(G),with_labels=False,node_size=10,node_color=color,width=0.8,ax=ax[ix])
plt.show()

GIN 模型分类结果的准确性如下所示:

GIN 模型性能

(4)GCN 模型重复以上过程:

fig, ax = plt.subplots(4, 4)
fig.suptitle('GCN - Graph classification')for i, data in enumerate(dataset[-16:]):# Calculate color (green if correct, red otherwise)out = gcn(data.x, data.edge_index, data.batch)color = "green" if out.argmax(dim=1) == data.y else "red"# Plot graphix = np.unravel_index(i, ax.shape)ax[ix].axis('off')G = to_networkx(dataset[i], to_undirected=True)nx.draw_networkx(G,pos=nx.spring_layout(G),with_labels=False,node_size=10,node_color=color,width=0.8,ax=ax[ix])
plt.show()

GCN分类结果

可以看到,GCN 模型将更多的图错误分类。要了解哪些图结构没有被充分捕捉,需要对 GIN 正确分类的每个蛋白质进行大量分析。但可以看到 GIN 也出现了不同的错误,这证明了这些模型可以互补。

3. 提升模型性能

在机器学习中,将出现不同错误的模型集成为一个更优秀的模型是一种常见技术。可以采用不同的方法,例如在最终分类的基础上训练第三个模型。为了简单起见,本节中我们将实现一种简单的模型平均技术:

(1) 首先,我们将模型设置为评估模式 (eval()),并定义变量用于存储准确率:

gcn.eval()
gin.eval()
acc_gcn = 0
acc_gin = 0
acc_ens = 0

(2) 得到每个模型的最终分类结果,然后将它们组合起来,作为集成模型的预测结果:

for data in test_loader:# Get classificationsout_gcn = gcn(data.x, data.edge_index, data.batch)out_gin = gin(data.x, data.edge_index, data.batch)out_ens = (out_gcn + out_gin)/2

(3) 计算三个模型预测的准确率:

    # Calculate accuracy scoresacc_gcn += accuracy(out_gcn.argmax(dim=1), data.y) / len(test_loader)acc_gin += accuracy(out_gin.argmax(dim=1), data.y) / len(test_loader)acc_ens += accuracy(out_ens.argmax(dim=1), data.y) / len(test_loader)

(4) 最后,打印模型的准确率:

# Print results
print(f'GCN accuracy:     {acc_gcn*100:.2f}%')
print(f'GIN accuracy:     {acc_gin*100:.2f}%')
print(f'GCN+GIN accuracy: {acc_ens*100:.2f}%')
'''
GCN accuracy:     73.70%
GIN accuracy:     78.91%
GCN+GIN accuracy: 79.43%
'''

在本节示例中,集成模型的预测结果优于其它两个模型,准确率为 79.43% (GCN73.70%GIN78.91%)。模型集成技术的准确率提升相当显著,为构建高性能模型提供了更多可能性。然而,这并不一定是普遍情况,即使在本例中,集成模型的表现也并不总是优于 GIN。可以用其他架构(如 Node2Vec )的嵌入来丰富集成模型,并观察是否能提高最终的准确率。

小结

图同构网络 (Graph Isomorphism Network, GIN) 架构受 WL 测试启发而设计的,其表达能力与 WL 测试相近,因此在严格意义上比 GCNGATGraphSAGE 更具表达能力。在本节中,将这一架构用于图分类任务,介绍了将节点嵌入融合到图嵌入中的不同方法,GIN 通过连接求和运算符和每个 GIN 层产生图嵌入,其性能明显优于通过 GCN 层获得的经典全局均值池化。最后,我们将两个模型的预测结果进行简单的集成,从而进一步提高了准确率。

系列链接

图神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(2)——图论基础
图神经网络实战(3)——基于DeepWalk创建节点表示
图神经网络实战(4)——基于Node2Vec改进嵌入质量
图神经网络实战(5)——常用图数据集
图神经网络实战(6)——使用PyTorch构建图神经网络
图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)——图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(9)——GraphSAGE详解与实现
图神经网络实战(10)——归纳学习
图神经网络实战(11)——Weisfeiler-Leman测试

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词