欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 财经 > 金融 > 完整的模型验证套路 pytorch

完整的模型验证套路 pytorch

2024/10/23 13:28:29 来源:https://blog.csdn.net/2302_79795489/article/details/143076069  浏览:    关键词:完整的模型验证套路 pytorch

利用已经训练好的模型,给它提供输入

**代码:

一、准备验证数据:

例如:准备一张图片

1、复制图片,粘贴到项目目录的test_img文件夹下

2、通过path来打开图片,生成图像对象

3、各种转换

convert("RGB"):仅保留RGB三通道

Resize:调像素大小(W、H)

ToTensor:变成张量    reshape:调整张量的形状(变成四维)

img_path="./test_img/img.png"
image=Image.open(img_path)image=image.convert("RGB") #png格式是四通道(RGB+透明度),此操作可只保留RGB三通道,防止可能出现的错误trans=torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)), #先要Resize,才能使用reshapetorchvision.transforms.ToTensor()
])
image=trans(image)image=torch.reshape(image,(1,3,32,32)) #要添加batch_size
print(image.shape) #torch.Size([3, 32, 32])

 

二、引入网络模型:

torch.save(xigua1, f"better_model_gpu{i + 1}.pth")

如果以这种方式保存网络模型,则这样加载:

from model import *
xigua2=torch.load("better_model_gpu20.pth",map_location=torch.device('cpu'))
# print(xigua2)

注意:要引入定义模型类的文件(import时不写.py,只写文件名就可以了)

可以在colab网站上使用GPU训练后,把模型文件下载下来

(从gpu到cpu,要写map_location=torch.device("gpu"))

 

三、验证:

记得写上

.eval()

with torch.no_grad():

xigua2.eval()
with torch.no_grad(): #不计算梯度,节约内存、使性能更好output=xigua2(image)
print(output) #tensor([[ -2.9265, -10.1790,   0.7105,   5.0713,   2.9425,   7.8542,   3.6901, 2.0290,  -4.0338,  -5.6903]])
print(output.argmax(1)) #tensor([5])test_set=torchvision.datasets.CIFAR10(root="../dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)
print(test_set.classes) #['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com