欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 健康 > 养生 > 【深度学习】PyTorch :调用残差网络(ResNet)

【深度学习】PyTorch :调用残差网络(ResNet)

2025/2/23 0:57:03 来源:https://blog.csdn.net/2303_80346267/article/details/145125574  浏览:    关键词:【深度学习】PyTorch :调用残差网络(ResNet)

ResNet (Residual Network) 是由 Microsoft Research 的 Kaiming He 等人在 2015 年提出的一种深度学习模型结构。它解决了随着网络深度增加而导致的梯度消失和退化问题。传统的深层网络可能由于信息难以有效传递,导致模型性能下降,而 ResNet 通过引入残差连接(skip connections),使信息可以跨层直接传递,从而缓解了这一问题。

基本原理

ResNet 的核心思想是学习残差函数而不是直接学习期望的映射函数。具体来说,假设希望学习的目标映射为 H(x) ,ResNet 让每个模块学习一个残差函数 F(x)=H(x)−x ,这样原始映射变成 H(x)=F(x)+x 。这种设计使得梯度更容易反向传播,有助于训练更深层的网络。

常见的 ResNet 结构包括 ResNet-18、ResNet-34、ResNet-50、ResNet-101 等,它们通过不同的层数适应从简单到复杂的任务需求。

导入必要的包

确保安装 PyTorch 和 torchvision:

pip install torch torchvision

在代码中导入相关模块:

import torch
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

实例化预训练 ResNet 模型

通过 torchvision.models 获取预训练的 ResNet 模型:

# 实例化 ResNet-50 模型,并使用预训练权重
model = models.resnet50(pretrained=True)# 切换模型到计算设备(GPU 或 CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

修改输出层适应新任务

如果要将 ResNet 应用于自定义分类任务,需要修改其最后的全连接层:

# 假设新任务有 10 个类别
num_classes = 10
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

数据预处理与加载

使用 torchvision.transforms 对图像数据进行预处理:

# 定义数据变换
transform = transforms.Compose([transforms.Resize(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

模型训练

定义损失函数和优化器,并进行模型训练:

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 5
for epoch in range(num_epochs):model.train()running_loss = 0.0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)# 清零梯度optimizer.zero_grad()# 前向传播outputs = model(inputs)loss = criterion(outputs, labels)# 反向传播和优化loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')print('训练完成!')

测试模型性能

在测试集上评估模型的分类准确率:

# 加载测试数据
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)# 测试模型
model.eval()
correct = 0
total = 0with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / total
print(f'测试准确率: {accuracy:.2f}%')

总结

通过上述步骤,您可以在 PyTorch 中快速使用预训练的 ResNet 模型,并根据不同任务需求进行定制和优化。ResNet 强大的残差学习能力使其成为许多计算机视觉任务的首选模型。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词