欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 文旅 > 游戏 > RootNeighboursDataset(helpers.dataset_classes文件中的root_neighbours_dataset.py)

RootNeighboursDataset(helpers.dataset_classes文件中的root_neighbours_dataset.py)

2024/10/24 3:18:19 来源:https://blog.csdn.net/sinat_41942180/article/details/143018935  浏览:    关键词:RootNeighboursDataset(helpers.dataset_classes文件中的root_neighbours_dataset.py)

任务类型:回归
用途:在 `RootNeighboursDataset` 中,任务是给定一棵根树,预测根节点度数为6的邻居的特征平均值。因此,模型需要基于根节点的结构,找到度为6的邻居,并计算其特征的平均值。这属于回归问题,因为目标是预测连续值(特征的平均值)

from helpers.dataset_classes.root_neighbours_dataset import RootNeighboursDataset

import torch
from torch_geometric.data import Data, Batch
from typing import Dict, Tuple, List
from torch import Tensorclass RootNeighboursDataset(object):def __init__(self, seed: int, print_flag: bool = False):super().__init__()self.seed = seedself.plot_flag = print_flagself.generator = torch.Generator().manual_seed(seed)self.constants_dict = self.initialize_constants()self._data = self.create_data()def get(self) -> Data:return self._datadef create_data(self) -> Data:# train, val, testdata_list = []for num in range(self.constants_dict['NUM_COMPONENTS']):data_list.append(self.generate_component())return Batch.from_data_list(data_list)def mask_task(self, num_nodes_per_fold: List[int]) -> Tuple[Tensor, Tensor, Tensor]:num_nodes = sum(num_nodes_per_fold)train_mask = torch.zeros(size=(num_nodes,), dtype=torch.bool)val_mask = torch.zeros(size=(num_nodes,), dtype=torch.bool)test_mask = torch.zeros(size=(num_nodes,), dtype=torch.bool)train_mask[0] = Trueval_mask[num_nodes_per_fold[0]] = Truetest_mask[num_nodes_per_fold[0] + num_nodes_per_fold[1]] = Truereturn train_mask, val_mask, test_maskdef generate_component(self) -> Data:data_per_fold, num_nodes_per_fold = [], []for fold_idx in range(3):data = self.generate_fold(eval=(fold_idx != 0))num_nodes_per_fold.append(data.x.shape[0])data_per_fold.append(data)train_mask, val_mask, test_mask = self.mask_task(num_nodes_per_fold=num_nodes_per_fold)batch = Batch.from_data_list(data_per_fold)return Data(x=batch.x, edge_index=batch.edge_index, y=batch.y, train_mask=train_mask, val_mask=val_mask,test_mask=test_mask)def initialize_constants(self) -> Dict[str, int]:return {'NUM_COMPONENTS': 1000, 'MAX_HUBS': 3, 'MAX_1HOP_NEIGHBORS': 10, 'ADD_HUBS': 2, 'HUB_NEIGHBORS': 5,'MAX_2HOP_NEIGHBORS': 3, 'NUM_FEATURES': 5}def generate_fold(self, eval: bool) -> Data:constant_dict = self.initialize_constants()MAX_HUBS, MAX_1HOP_NEIGHBORS, ADD_HUBS, HUB_NEIGHBORS, MAX_2HOP_NEIGHBORS, NUM_FEATURES =\[constant_dict[key] for key in ['MAX_HUBS', 'MAX_1HOP_NEIGHBORS', 'ADD_HUBS', 'HUB_NEIGHBORS','MAX_2HOP_NEIGHBORS', 'NUM_FEATURES']]assert MAX_HUBS + ADD_HUBS <= MAX_1HOP_NEIGHBORSadd_hubs = ADD_HUBS if eval else 0num_hubs = torch.randint(1, MAX_HUBS + 1, size=(1,), generator=self.generator).item() + add_hubsnum_1hop_neighbors = torch.randint(MAX_HUBS + add_hubs, MAX_1HOP_NEIGHBORS + 1, size=(1,),generator=self.generator).item()assert num_hubs <= num_1hop_neighborslist_num_2hop_neighbors = torch.randint(1, MAX_2HOP_NEIGHBORS, size=(num_1hop_neighbors - num_hubs,),generator=self.generator).tolist()list_num_2hop_neighbors = [HUB_NEIGHBORS] * num_hubs + list_num_2hop_neighbors# 2 hop edge indexnum_nodes = 1  # root node is 0idx_1hop_neighbors = []list_edge_index = []for num_2hop_neighbors in list_num_2hop_neighbors:idx_1hop_neighbors.append(num_nodes)if num_2hop_neighbors > 0:clique_edge_index = torch.tensor([[0] * num_2hop_neighbors, list(range(1, num_2hop_neighbors + 1))])# clique_edge_index = torch.combinations(torch.arange(num_2hop_neighbors), r=2).Tlist_edge_index.append(clique_edge_index + num_nodes)num_nodes += num_2hop_neighbors + 1# 1 hop edge indexidx_0hop = torch.tensor([0] * num_1hop_neighbors)idx_1hop_neighbors = torch.tensor(idx_1hop_neighbors)hubs = idx_1hop_neighbors[:num_hubs]list_edge_index.append(torch.stack((idx_0hop, idx_1hop_neighbors), dim=0))edge_index = torch.cat(list_edge_index, dim=1)# undirectedge_index_other_direction = torch.stack((edge_index[1], edge_index[0]), dim=0)edge_index = torch.cat((edge_index_other_direction, edge_index), dim=1)# featuresx = 4 * torch.rand(size=(num_nodes, NUM_FEATURES), generator=self.generator) - 2# labelsy = torch.zeros_like(x)y[0] = torch.mean(x[hubs], dim=0)return Data(x=x, edge_index=edge_index, y=y)if __name__ == '__main__':data = RootNeighboursDataset(seed=0, print_flag=True)

这个 RootNeighboursDataset通过随机生成的树状图数据来模拟一种节点关系,并基于图结构生成特征和标签。代码使用了 PyTorchPyTorch Geometric 的功能来处理图数据。下面逐块详细解释该代码实现:

1. RootNeighboursDataset 类构造器

import torch
from torch_geometric.data import Data, Batch
from typing import Dict, Tuple, List
from torch import Tensorclass RootNeighboursDataset(object):def __init__(self, seed: int, print_flag: bool = False):super().__init__()self.seed &#

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com