欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 教育 > 高考 > 基于Pytorch深度学习图像处理基础流程框架(以ResNetGenerator为例)

基于Pytorch深度学习图像处理基础流程框架(以ResNetGenerator为例)

2024/10/25 2:30:55 来源:https://blog.csdn.net/qq_52964132/article/details/141101897  浏览:    关键词:基于Pytorch深度学习图像处理基础流程框架(以ResNetGenerator为例)

文章目录

  • - 模型搭建
    • 1. 搭建ResNetGenerator
    • 2. 网络实例化
    • 3.加载预训练模型权重文件
    • 4. 神经网络设置为评估模式
  • 预测处理
    • 1. 定义图片的预处理方法
    • 2. 导入图片
    • 3. 预处理图片
    • 4. 调用模型
    • 5. 输出结果


- 模型搭建

1. 搭建ResNetGenerator

import torch
import torch.nn as nnclass ResNetBlock(nn.Module): # <1>def __init__(self, dim):super(ResNetBlock, self).__init__()self.conv_block = self.build_conv_block(dim)def build_conv_block(self, dim):conv_block = []conv_block += [nn.ReflectionPad2d(1)]conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),nn.InstanceNorm2d(dim),nn.ReLU(True)]conv_block += [nn.ReflectionPad2d(1)]conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),nn.InstanceNorm2d(dim)]return nn.Sequential(*conv_block)def forward(self, x):out = x + self.conv_block(x) # <2>return outclass ResNetGenerator(nn.Module):def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9): # <3> assert(n_blocks >= 0)super(ResNetGenerator, self).__init__()self.input_nc = input_ncself.output_nc = output_ncself.ngf = ngfmodel = [nn.ReflectionPad2d(3),nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),nn.InstanceNorm2d(ngf),nn.ReLU(True)]n_downsampling = 2for i in range(n_downsampling):mult = 2**imodel += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,stride=2, padding=1, bias=True),nn.InstanceNorm2d(ngf * mult * 2),nn.ReLU(True)]mult = 2**n_downsamplingfor i in range(n_blocks):model += [ResNetBlock(ngf * mult)]for i in range(n_downsampling):mult = 2**(n_downsampling - i)model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),kernel_size=3, stride=2,padding=1, output_padding=1,bias=True),nn.InstanceNorm2d(int(ngf * mult / 2)),nn.ReLU(True)]model += [nn.ReflectionPad2d(3)]model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]model += [nn.Tanh()]self.model = nn.Sequential(*model)def forward(self, input): # <3>return self.model(input)

2. 网络实例化

netG = ResNetGenerator()

3.加载预训练模型权重文件

model_path = '../data/p1ch2/horse2zebra_0.4.0.pth'
model_data = torch.load(model_path)
netG.load_state_dict(model_data)

在这里插入图片描述


4. 神经网络设置为评估模式

netG.eval()

netG.eval() 是 PyTorch 中的一个方法,用于将神经网络模型设置为评估(evaluation)模式。

  1. 关闭 Dropout 和 Batch Normalization

    • 在训练过程中,Dropout 层会随机丢弃一些神经元,以防止过拟合。Batch Normalization 层会根据每个批次的数据计算均值和方差,以稳定训练过程。
    • 在评估模式下,Dropout 层会关闭,所有神经元都会参与计算。Batch Normalization 层会使用训练过程中计算的均值和方差,而不是当前批次的数据。
  2. 确保一致性

    • 在评估模式下,模型的行为会更加一致和可预测,因为不会受到随机丢弃神经元或批次数据统计特性的影响。
  3. 推理和测试

    • 在进行模型推理或测试时,应该始终将模型设置为评估模式,以确保得到准确和稳定的结果。

预测处理

1. 定义图片的预处理方法

from PIL import Image
from torchvision import transforms
preprocess = transforms.Compose([transforms.Resize((262, 461)),  # 调整图像大小transforms.ToTensor(),          # 转换为张量transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 归一化
])

2. 导入图片

img = Image.open("../data/p1ch2/horse.jpg")

在这里插入图片描述

3. 预处理图片

# 确保图像有3个通道
if img.mode != 'RGB':img = img.convert('RGB')img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, 0)

4. 调用模型

out_t = (batch_out.data.squeeze() + 1.0) /2

5. 输出结果


out_t = (batch_out.data.squeeze() + 1.0) /2
out_img = transforms.ToPILImage()(out_t)
# out_img.save('../data/p1ch2/zebra.jpg')
out_img

在这里插入图片描述

【注*:该模型的作用是将图片中的马,生成为斑马】


(完)

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com