欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 新闻 > 会展 > PyTorch(七)模型的保存与加载

PyTorch(七)模型的保存与加载

2025/2/24 13:01:09 来源:https://blog.csdn.net/qq_45031509/article/details/140134381  浏览:    关键词:PyTorch(七)模型的保存与加载

#d 两种保存方式比较

仅保存模型参数
优点:

  • 更加灵活,只保存模型的参数,不保存模型的结构,可以在不同的模型结构中加载参数(只要参数匹配)。
  • 文件大小通常比保存整个模型小。
  • 安全性更高,因为不直接执行pickle内容。

缺点:

  • 加载模型前需要先定义模型的结构,增加了代码量。

保存整个模型
优点:

  • 保存简单,一行代码完成。
  • 加载模型时不需要再定义模型的结构。

缺点:

  • 保存的模型依赖于具体的类定义,如果模型的结构有所改变(例如类名、层的结构等),加载时可能会出现问题。
  • 文件通常比仅保存状态字典的方式大。
  • 可能存在安全风险,因为torch.load会加载任何pickle内容。

总结:

仅保存模型的参数(状态字典)是更加推荐的方式,因为它更加灵活和安全。但是,如果你想要快速保存和加载整个模型,不担心模型结构变化或安全问题,保存整个模型也是一个可行的选择。

1 仅保存模型参数

#c 说明 保存加载方式

PyTorch保存模型的「学习参数」是通过state_dict的一个内部状态字典,使用torch.save来保存模型的学习参数。

#e 模型保存方式一

model = models.vgg16(weights='IMAGENET1K_V1')
'''
vgg16是一个非常流行的卷积神经网络,经过了大量的训练,可以识别1000个不同的对象。
weights='IMAGENET1K_V1'表示加载了在ImageNet数据集上预训练的权重。
'''
torch.save(model.state_dict(), 'model_weights.pth')#状态字典与保存路径

#e 模型加载方式一

加载模型权重,首先需要创建一个与「原始模型相同的模型实例」,然后使用load_state_dict方法加载参数。

注意:需要使用model.eval()方法将模型设置为评估模式,这将关闭Dropout和BatchNorm层。否则将会导致不一致的推理结果。

model = models.vgg16()#加载模型
model.load_state_dict(torch.load('model_weights.pth'))#加载模型权重
model.eval()#设置模型为评估模式

2 保存整个模型

#c 说明 保存整个模型

在加载模型权重时,需要首先实例化模型类,因为模型类定义了网络的结构。如果希望将模型类的架构与模型一起保存,那么可以传递模型本身(而不是模型的状态字典model.state_dict())给保存函数。

#e 模型保存方式二

torch.save(model, 'model.pth')#保存模型

#e 模型加载方式二

model = torch.load('model.pth')#加载模型

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词