欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 新闻 > 国际 > 【模型学习之路】TopK池化,全局池化

【模型学习之路】TopK池化,全局池化

2025/2/6 23:35:27 来源:https://blog.csdn.net/wwl412095144/article/details/144066527  浏览:    关键词:【模型学习之路】TopK池化,全局池化

来学学图卷积中的池化操作

目录

DataBatch

Dense Batching

Dynamic Batching

DataBatch

存取操作

TopKPooling

GAP/GMP

一个例子

后话


DataBatch

当进行图级别的任务时,首先的任务是把多个图合成一个batch。

在Transformer中,一个句子的维度是【单词数,词向量长度】。在一个batch内,batch_size个长度相同的句子(长度短了就做padding)的维度是【句子数,单词数,词向量长度】。

这里,在图任务中得到batch有两种策略。

Dense Batching

一个batch有batch_size个图,第i个图的x的特征维度为m_{i}f,那么先:

m = max(m_{1}, m_{2}, ..., m_{batchsize})

把所有的图做padding,然后合到一起,那么最后数据的维度就是【batch_size, m, f】。

这种方式通常用于需要固定大小输入的场景,例如某些图神经网络的实现或者特定的并行计算框架。

Dynamic Batching

这是PyG默认的批处理方式,它不要求所有图具有相同数量的节点。在这种模式下,每个图的节点特征被拼接在一起,形成一个大的特征矩阵【M,f】,其中:

M = \sum_{i=1}^{batchsize}m_{i}

同时,会有一个batch向量,它是一个长度为M的一维Tensor,记录每个节点属于哪个图。

DataBatch

前面提到过,Data对象是PyG数据的基本单元。我们先生成一个一个Data对象的list:

import torch
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoaderdata_lst = [Data(x=torch.randint(0, 2, (5, 3)), edge_index=torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4]]),y=torch.randint(0, 1, (5,)))for _ in range(1000)]

重写Dataset,然后将list[Data]转化为Dataset:

class MyDataset(Dataset):def __init__(self, data_lst):super(MyDataset, self).__init__()self.data_lst = data_lstdef __len__(self):return len(self.data_lst)def __getitem__(self, idx):return self.data_lst[idx]dataset = MyDataset(data_lst)
dataset# output
MyDataset(1000)

进一步做成Dataloader:

dataloader = DataLoader(dataset, batch_size=32, follow_batch=['x'], shuffle=True)
first_batch = list(dataloader)[0]
first_batch# output
DataBatch(x=[160, 3], x_batch=[160], x_ptr=[33], edge_index=[2, 128], y=[160], batch=[160], ptr=[33])

x,y,edge_index都是由多个图拼接而成。x_batch就是用来记录每个节点属于哪个图。ptr用于记录每个图的位置信息(不用过多关注),大小正好是batch_size + 1,记录每个图的终点和起点。

不指定follow_batch=['x'],就没有了ptr,模型就会认为这是一个由很多图拼起来的一个大图,而不是视为很多图。这里不必深究,指定一下follow_batch就好了。

存取操作

可以继承重写PyG中一些与数据相关的类,做到存取的效果,不过有些难度可以看看这个:【图神经网络工具】PyTorch Geometric Tutorial 之Data Handling - 知乎

也可以看看这个的15~19集:5-数据集创建函数介绍_哔哩哔哩_bilibili

我们实现一个简单的存取方法:

from torch_geometric.data import Batch
batch = Batch.from_data_list(data_lst)
batch# output
DataBatch(x=[5000, 3], edge_index=[2, 4000], y=[5000], batch=[5000], ptr=[1001])

可以看到,和我们Dataloader取出来的东西一样,都是DataBatch对象。然后我们把它存起来:

torch.save(batch, 'batch.pt')loaded_batch = torch.load('batch.pt', map_location='cpu', weights_only=False)
data_lst = loaded_batch.to_data_list()

TopKPooling

先端上官方文档:

torch_geometric.nn.pool.TopKPooling — pytorch_geometric documentation

再端上一张网上随便一找就能看到的图:

p是要学习的参数。y的维度是(M, 1),计算出每一个点的“重要性”。除以二范数是为了标准化。

然后选取M个点中k个最重要的

根据这个topk,在X以及A中挑出对应的k个,得到,相应的邻接矩阵也只保留剩下的边之间的关系。

最后,由于y’本身记录了“重要性”的信息,那就把重要性加权到X中:

  

仅发表一下个人意见,出于归一化的想法,感觉用softmax挺合适:

 

好,搞定。

一个小问题,在做这个pool操作时,会不会导致某一个图的所有节点全部消失?

并不会,因为TopK是独立地在每个图中做topk操作。

GAP/GMP

global_mean_pool(GAP)和global_max_pool(GMP)是两种常用的全局池化(global pooling)操作,它们用于将整个图的信息聚合为一个固定大小的向量。

全局平均池化(GAP)操作将图中所有节点的特征向量求平均。简单说来就是,每一个图表示为自己所有节点求平均得到的向量。

全局最大池化(GMP)操作将图中所有节点的特征向量进行逐元素的最大值操作。简单来说就是,对于每一个图,拿出自己所有的节点,拿到每个特征的最大值,组成一个向量。

So,在维度上,都会有这样的特征:【M, f】-> 【batch_size,f】

这俩是两种常用的全局池化操作,它们用于将图中所有节点的特征聚合为一个全局特征向量。这两种操作通常在图神经网络的最后阶段使用,以便将图级别的表示用于图分类或其他下游任务。

一个例子

用PyG写个一个神经网络模型。

import torch
import torch.nn as nn
from torch_geometric.nn import TopKPooling, SAGEConv
from torch_geometric.nn import global_mean_pool as gap
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()torch.manual_seed(114514)self.conv1 = SAGEConv(128, 128)self.pool1 = TopKPooling(128, ratio=0.8)self.conv2 = SAGEConv(128, 128)self.pool2 = TopKPooling(128, ratio=0.8)self.conv3 = SAGEConv(128, 128)self.pool3 = TopKPooling(128, ratio=0.8)self.embed = nn.Embedding(100, 128)self.lin = nn.Sequential(nn.Linear(128, 128),nn.ReLU(),nn.Dropout(0.5),nn.Linear(128, 64),nn.ReLU(),nn.Dropout(0.5),nn.Linear(64, 1), )self.bn = nn.BatchNorm1d(128)self.bn2 = nn.BatchNorm1d(64)def forward(self, data):"""x: [M, 1]edge_index: [2, e]batch: [M]"""x, edge_index, batch = data.x, data.edge_index, data.batchx = x.squeeze(1)  # [M, 1] -> [M]  # 这里是大坑!在github评论区逛了一圈,还好一个老外和我一样的错误x = self.embed(x)  # [M] -> [M, 128]  x = self.conv1(x, edge_index)  # [M, 128]x = F.relu(x)x, edge_index, _, batch, *_ = self.pool1(x, edge_index, None, batch)  # [0.8*M, 128]x1 = gap(x, batch)  # [batch, 128]x = self.conv2(x, edge_index)  # [0.8*M, 128]x = F.relu(x)x, edge_index, _, batch, *_ = self.pool2(x, edge_index, None, batch)  # [0.8*0.8*M, 128]x2 = gap(x, batch)  # [batch, 128]x = self.conv3(x, edge_index)  # [0.8*0.8*M, 128]x = F.relu(x)x, edge_index, _, batch, *_ = self.pool3(x, edge_index, None, batch)  # [0.8*0.8*0.8*M, 128]x3 = gap(x, batch)  # [batch, 128]out = x1 + x2 + x3  # [batch, 128]out = self.lin(out)  # [batch, 1]out = out.squeeze(1)  # [batch]out = F.sigmoid(out)return out

这个网络架构的设计意图是利用图卷积层提取局部图结构特征,通过池化层进行降采样以捕捉更全局的信息,然后通过全连接层和激活函数进行特征融合和分类。这种架构在图分类、节点分类等任务中很常见。

后话

代码中的SAGEConv是什么?它是众多卷积方式的一种。

PyG文档上有大量卷积层、池化层的类。确实,路漫漫其修远兮!

这个文章上有很多的卷积层和池化层的讲解,看看能不能在未来的时间里都弄懂它们的原理:转载 | 一文遍览GNN卷积与池化的代表模型 - 知乎

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com