欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 文旅 > 文化 > 【DGL系列】dgl中为graph指定CSR/COO/CSC矩阵格式

【DGL系列】dgl中为graph指定CSR/COO/CSC矩阵格式

2025/3/16 23:31:38 来源:https://blog.csdn.net/sxf1061700625/article/details/145310388  浏览:    关键词:【DGL系列】dgl中为graph指定CSR/COO/CSC矩阵格式

转载请注明出处:小锋学长生活大爆炸[xfxuezhagn.cn]

如果本文帮助到了你,欢迎[点赞、收藏、关注]哦~

只讲几个注意事项:

1、graph.formats() 函数可以查看graph格式,也可以指定graph格式。

g = dgl.graph(([0, 0, 1], [2, 3, 2]))
g.ndata['h'] = torch.ones(4, 1)# 查看格式
g.formats()
# => {'created': ['coo'], 'not created': ['csr', 'csc']}# 指定一种格式
csr_g = g.formats('csr')
csr_g.formats()
# => {'created': ['csr'], 'not created': []}# 指定多种格式
new_g = g.formats(['coo', 'csr'])
new_g .formats()
# => {'created': ['coo', 'csr'], 'not created': []}

2、在调用 formats(['coo', 'csr']) 时,如果当前图的格式与指定格式没有交集,DGL 会按照 coo -> csr -> csc 的顺序选择一种格式创建。因此,如果图在反序列化后没有 CSR 格式,调用 formats(['coo', 'csr']) 可能只会创建 COO 格式。

g = dgl.graph(([0, 0, 1], [2, 3, 2]))
g.ndata['h'] = torch.ones(4, 1)# 假设只有一种格式
g.formats()
# => {'created': ['coo'], 'not created': ['csc']}# 交集没有csr,就不会设置成功
new_g = g.formats(['coo', 'csr'])
new_g .formats()
# => {'created': ['coo'], 'not created': []}

3、上述第2点,虽然没有指定格式,但是可以通过graph.create_formats_来显式创建。

g = dgl.graph(([0, 0, 1], [2, 3, 2]))
g.ndata['h'] = torch.ones(4, 1)# 假设只有一种coo格式
g.formats()
# => {'created': ['coo'], 'not created': ['csc']}# 交集没有csr,就不会设置成功
new_g = g.formats(['coo', 'csr'])
new_g .formats()
# => {'created': ['coo'], 'not created': ['csr']}# 显式创建格式
new_g.create_formats_()
print(new_g.formats())
# => {'created': ['coo', 'csr'], 'not created': []}

4、使用 pickle 对 DGL 图对象进行序列化和反序列化后,图的存储格式可能会丢失或被重置为 COO 格式。

import dgl
import pickle# 创建一个图并设置多种格式
g = dgl.graph(([0, 1, 2], [1, 2, 3]))
g = g.formats(['coo', 'csr', 'csc'])# 使用 pickle 保存
with open('graph.pkl', 'wb') as f:pickle.dump(g, f)# 使用 pickle 加载
with open('graph.pkl', 'rb') as f:loaded_g = pickle.load(f)# 检查加载后的格式
print(loaded_g.formats())  # 可能会丢失某些格式

5、可以考虑使用 DGL 提供的保存dgl.save_graphs和加载dgl.load_graphs方法,这些方法能够更好地处理图的内部状态,包括稀疏格式。

# 保存图
dgl.save_graphs("graph.bin", [graph])# 加载图
loaded_graphs, _ = dgl.load_graphs("graph.bin")
graph = loaded_graphs[0]

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词