欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 汽车 > 时评 > 第P3周:Pytorch实现天气识别

第P3周:Pytorch实现天气识别

2025/2/22 2:23:45 来源:https://blog.csdn.net/deflag/article/details/144540546  浏览:    关键词:第P3周:Pytorch实现天气识别
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目标

  1. 读取天气图片,按文件夹分类
  2. 搭建CNN网络,保存网络模型并加载模型
  3. 使用保存的模型预测真实天气

具体实现

(一)环境

语言环境:Python 3.10
编 译 器: PyCharm
框 架: Pytorch 2.5.1

(二)具体步骤
1. 通用文件Utils.py
import torch  # 第一步:设置GPU  
def USE_GPU():  if torch.cuda.is_available():  print('CUDA is available, will use GPU')  device = torch.device("cuda")  else:  print('CUDA is not available. Will use CPU')  device = torch.device("cpu")  return device
2. 模型代码
import os  from torchinfo import summary  from Utils import USE_GPU  
import pathlib  
from PIL import Image  
import matplotlib.pyplot as plt  
import numpy as np  
import torch  
import torch.nn as nn  
import torchvision.transforms as transforms  
import torchvision  
from torchvision import datasets  device = USE_GPU()  # 导入数据  
data_dir = './data/weather_photos/'  
data_dir = pathlib.Path(data_dir)  data_paths = list(data_dir.glob('*'))  
# print(data_paths)  
classNames = [str(path).split("\\")[2] for path in data_paths]  
print(classNames)  # 查看一下图片  
image_folder = './data/weather_photos/cloudy'  
# 获取image_folder下的所有图片  
image_files = [f for f in os.listdir(image_folder) if f.endswith((".jpg", ".png", ".jpeg"))]  
#创建matplotlib图像  
fig, axes = plt.subplots(3, 8, figsize=(16, 6))  for ax, img_file in zip(axes.flat, image_files):  img_path = os.path.join(image_folder, img_file)  img = Image.open(img_path)  ax.imshow(img)  ax.axis('off')  plt.tight_layout()  
plt.title(image_folder, loc='center')  
# plt.show()  

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传
外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

train_transforms = transforms.Compose([  transforms.Resize([224, 224]),  # 将输入图片统一resize成224大小  transforms.RandomHorizontalFlip(),  transforms.RandomVerticalFlip(),  transforms.ToTensor(),  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  
])  total_data = datasets.ImageFolder(data_dir, transform=train_transforms)  
print(total_data)  # 划分数据集  
train_size = int(0.8 * len(total_data))  
test_size = len(total_data) - train_size  
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])  
print(train_size, test_size)  
print(train_dataset, test_dataset)  # 设置dataloader  
batch_size = 32  
train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)  
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)  for X, y in test_dl:  print("Shape of X [N, C, H, W]: ", X.shape)  print("Shape of y: ", y.shape, y.dtype)  break  # 构建CNN网络  
import torch.nn.functional as F  class Network_bn(nn.Module):  def __init__(self):  super(Network_bn, self).__init__()  self.conv1 = nn.Conv2d(3, 12, 5, 1, 0)  self.bn1 = nn.BatchNorm2d(12)  self.conv2 = nn.Conv2d(12, 12, 5, 1, 0)  self.bn2 = nn.BatchNorm2d(12)  self.pool1 = nn.MaxPool2d(2, 2)  self.conv4 = nn.Conv2d(12, 24, 5, 1, 0)  self.bn4 = nn.BatchNorm2d(24)  self.conv5 = nn.Conv2d(24, 24, 5, 1, 0)  self.bn5 = nn.BatchNorm2d(24)  self.pool2 = nn.MaxPool2d(2, 2)  self.fc1 = nn.Linear(24 * 50 * 50, len(classNames))  def forward(self, x):  x = F.relu(self.bn1(self.conv1(x)))  x = F.relu(self.bn2(self.conv2(x)))  x = self.pool1(x)  x = F.relu(self.bn4(self.conv4(x)))  x = F.relu(self.bn5(self.conv5(x)))  x = self.pool2(x)  x = x.view(-1, 24 * 50 * 50)  x = self.fc1(x)  return x  model = Network_bn().to(device)  
print(model)  
summary(model)  # 训练模型  
loss_fn = nn.CrossEntropyLoss()  
learn_rate = 1e-4  
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)  # 循环训练  
def train(dataloader, model, loss_fn, optimizer):  size = len(dataloader.dataset)  num_batches = len(dataloader)  train_loss, train_acc = 0, 0  for X, y in dataloader:  X, y = X.to(device), y.to(device)  pred = model(X)  loss = loss_fn(pred, y)  optimizer.zero_grad()  loss.backward()  optimizer.step()  train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()  train_loss += loss.item()  train_acc /= size  train_loss /= num_batches  return  train_acc,train_loss  def test(dataloader, model, loss_fn):  size = len(dataloader.dataset)  num_batches = len(dataloader)  test_loss, test_acc = 0, 0  with torch.no_grad():  for imgs, target in dataloader:  imgs, target = imgs.to(device), target.to(device)  target_pred = model(imgs)  loss = loss_fn(target_pred, target)  test_loss += loss.item()  test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()  test_acc /= size  test_loss /= num_batches  return test_acc, test_loss  epochs = 25  
train_loss = []  
train_acc = []  
test_loss = []  
test_acc = []  for epoch in range(epochs):  model.train()  epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)  model.eval()  epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)  train_acc.append(epoch_train_acc)  train_loss.append(epoch_train_loss)  test_acc.append(epoch_test_acc)  test_loss.append(epoch_test_loss)  template = 'Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}'  print(template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss))  
print('Done')  # 结果可视化  
import matplotlib.pyplot as plt  
#隐藏警告  
import warnings  
warnings.filterwarnings("ignore")               #忽略警告信息  
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签  
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号  
plt.rcParams['figure.dpi']         = 100        #分辨率  epochs_range = range(epochs)  plt.figure(figsize=(12, 3))  
plt.subplot(1, 2, 1)  plt.plot(epochs_range, train_acc, label='Training Accuracy')  
plt.plot(epochs_range, test_acc, label='Test Accuracy')  
plt.legend(loc='lower right')  
plt.title('Training and Validation Accuracy')  plt.subplot(1, 2, 2)  
plt.plot(epochs_range, train_loss, label='Training Loss')  
plt.plot(epochs_range, test_loss, label='Test Loss')  
plt.legend(loc='upper right')  
plt.title('Training and Validation Loss')  
plt.show()  # 保存模型  
torch.save(model, "./models/cnn-weather.pth")
3. 预测真实图片:pred.py
from pydoc import classname  from PIL import Image  
from matplotlib import pyplot as plt  
from torch import nn  from Utils import USE_GPU  
import torch  
import  torchvision.transforms as transforms  
from torchvision import datasets  
import pathlib  device = USE_GPU()  # 构建CNN网络  
import torch.nn.functional as F  class Network_bn(nn.Module):  def __init__(self):  super(Network_bn, self).__init__()  self.conv1 = nn.Conv2d(3, 12, 5, 1, 0)  self.bn1 = nn.BatchNorm2d(12)  self.conv2 = nn.Conv2d(12, 12, 5, 1, 0)  self.bn2 = nn.BatchNorm2d(12)  self.pool1 = nn.MaxPool2d(2, 2)  self.conv4 = nn.Conv2d(12, 24, 5, 1, 0)  self.bn4 = nn.BatchNorm2d(24)  self.conv5 = nn.Conv2d(24, 24, 5, 1, 0)  self.bn5 = nn.BatchNorm2d(24)  self.pool2 = nn.MaxPool2d(2, 2)  self.fc1 = nn.Linear(24 * 50 * 50, 4)  def forward(self, x):  x = F.relu(self.bn1(self.conv1(x)))  x = F.relu(self.bn2(self.conv2(x)))  x = self.pool1(x)  x = F.relu(self.bn4(self.conv4(x)))  x = F.relu(self.bn5(self.conv5(x)))  x = self.pool2(x)  x = x.view(-1, 24 * 50 * 50)  x = self.fc1(x)  return x  model = torch.load('./models/cnn-weather.pth', weights_only=False)  
model.eval()  transform = transforms.Compose([  transforms.Resize([224, 224]),  # 将输入图片统一resize成224大小  transforms.RandomHorizontalFlip(),  transforms.RandomVerticalFlip(),  transforms.ToTensor(),  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  
])  className = ['cloudy', 'rain', 'shine', 'sunshine']  # 导入数据  
weather_data_directory = './mydata/weather'  
weather_data_directory = pathlib.Path(weather_data_directory)  
print(weather_data_directory)  
image_count = len(list(weather_data_directory.glob('*.jpg')))  
print("待识别天气图片数量:", image_count)  plt.figure(figsize=(5, 3))  
i = 0  
for path in weather_data_directory.glob('*.jpg'):  print(path) # 天气图片路径  image_source = Image.open(path)    # 打开图片转换成图片数据  image = transform(image_source)  image = image.unsqueeze(0)  # 增加维度  print(image.shape)  output = model(image.to(device))  pred = className[torch.argmax(output, dim=1).item()]  print(pred)  plt.subplot(2, 5, i+1)  plt.imshow(image_source)  plt.title(pred)  plt.xticks([])  plt.yticks([])  i += 1  
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传
准确率80%.

(三)总结

下载一个大数据集训练一下,数据如下:

  • 晴天:10000张
  • 多云:10000张
  • 雨天:10000张
  • 大雪:10000张
  • 薄雾:10000张
  • 雷雨:10000张
    经历漫长的几个小时训练,结果:
    外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传
    外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词