欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 财经 > 创投人物 > PyTorch快速入门教程【小土堆】之网络模型的保存和读取

PyTorch快速入门教程【小土堆】之网络模型的保存和读取

2025/2/23 10:27:11 来源:https://blog.csdn.net/EnochChen_/article/details/144849747  浏览:    关键词:PyTorch快速入门教程【小土堆】之网络模型的保存和读取

视频地址网络模型的保存与读取_哔哩哔哩_bilibili

模型的保存

import torch
import torchvision
from torch import nnvgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1,模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")# #保存方式2,模型参数(官方推荐)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")# 陷阱
class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()self.conv1 = nn.Conv2d(3, 64, 3)def forward(self, x):x = self.conv1(x)return xtudui = Tudui()
torch.save(tudui, "tudui_method1.pthl")

模型的读取

import torch
import torchvision
from torch import nn# 方式1-》保存方式1,加载模型
model = torch.load("vgg16_method1.pth")
print(model)# 方式2,加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
# model = torch.load("vgg16_method2.pth")
print(vgg16)# 陷阱1
class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3)def forward(self, x):x = self.conv1(x)return x# 必须写出模型才能读取,但不需要实现这个模型
model = torch.load('tudui_method1.pth')
print(model)

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词