欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 汽车 > 维修 > 利用pytorch对加噪堆叠自编码器在MNIST数据集进行训练和验证

利用pytorch对加噪堆叠自编码器在MNIST数据集进行训练和验证

2025/4/17 13:55:59 来源:https://blog.csdn.net/qq_60985893/article/details/147038439  浏览:    关键词:利用pytorch对加噪堆叠自编码器在MNIST数据集进行训练和验证

实现背景:

最近在复现关于使用深度学习提取特征来进行聚类的论文,其中使用到了加噪堆叠自编码器,具体实现细节请参考论文:Improved Deep Embedded Clustering with Local Structure Preservation

其中加噪堆叠自编码器涉及两个过程:

预训练:预训练过程对原始数据加噪,贪婪式地对每一层encoder和decoder进行训练,其中训练新的AE时冻结前面训练好的AE。详见:堆栈自编码器 Stacked AutoEncoder-CSDN博客

微调:在预训练完成之后使用所有AE和原始数据对整体模型进行微调。

我在网上找到了一个SAE的示范样例:python-pytorch 利用pytorch对堆叠自编码器进行训练和验证_pytoch把训练和验证写一起的代码-CSDN博客

但是这篇博客的数据集很小,如果应用到MNIST数据集时显存很容易溢出,因此我在原始的基础上进行了改进,直接上代码:

初始化数据集:

import torch
from torchvision import datasets, transforms
from torch.utils.data import Dataset, random_split
# 定义数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])# 下载并加载 MNIST 训练数据集
original_dataset = datasets.MNIST(root='./data', train=True,download=False, transform=transform)class NoLabelDataset(Dataset):def __init__(self, original_dataset):self.original_dataset = original_datasetdef __getitem__(self, index):image, _ = self.original_dataset[index]return imagedef __len__(self):return len(self.original_dataset)# 创建不包含标签的数据集
no_label_dataset = NoLabelDataset(original_dataset)# 划分训练集和验证集
train_size = int(0.8 * len(no_label_dataset))
val_size = len(no_label_dataset) - train_size
train_dataset, val_dataset = random_split(no_label_dataset, [train_size, val_size])# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)print(f"训练集样本数量: {len(train_dataset)}")
print(f"验证集样本数量: {len(val_dataset)}")    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

 定义模型和训练函数:

import torch.nn as nnclass Autoencoder(nn.Module):def __init__(self, input_size, hidden_size):super(Autoencoder, self).__init__()self.encoder = nn.Sequential(nn.Conv2d(input_size, hidden_size, kernel_size=3, stride=1, padding=1),  # 输入通道1,输出通道16nn.ReLU())self.decoder = nn.Sequential(nn.ConvTranspose2d(hidden_size, input_size, kernel_size=3, stride=1, padding=1),nn.ReLU())def forward(self, x):x = self.encoder(x)x = self.decoder(x)return xdef train_ae(models, train_loader, val_loader, num_epochs, criterion, optimizer, noise_factor, finetune):device = torch.device("cuda" if torch.cuda.is_available() else "cpu")for epoch in range(num_epochs):# Trainingmodels[-1].train()train_loss = 0for batch_data in train_loader:optimizer.zero_grad()if len(models) != 1:batch_data = batch_data.to(device)for model in models[:-1]:with torch.no_grad():batch_data = model.encoder(batch_data)batch_data = batch_data.detach()if finetune == True:batch_data = batch_data.to(device)outputs = models[-1](batch_data)loss = criterion(outputs, batch_data)else:noisy_image = batch_data + noise_factor * torch.randn_like(batch_data)noisy_image = torch.clamp(noisy_image, 0., 1.).to(device)outputs = models[-1](noisy_image)batch_data = batch_data.to(device)loss = criterion(outputs, batch_data)loss.backward()optimizer.step()train_loss += loss.item()train_loss /= len(train_loader)print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss:.4f}")# Validationmodels[-1].eval()val_loss = 0with torch.no_grad():for batch_data in val_loader:if len(models) != 1:batch_data = batch_data.to(device)for model in models[:-1]:batch_data = model.encoder(batch_data)batch_data = batch_data.detach()if finetune == True:batch_data = batch_data.to(device)outputs = models[-1](batch_data)loss = criterion(outputs, batch_data)else:noisy_image = batch_data + noise_factor * torch.randn_like(batch_data)noisy_image = torch.clamp(noisy_image, 0., 1.).to(device)outputs = models[-1](noisy_image)batch_data = batch_data.to(device)loss = criterion(outputs, batch_data)val_loss += loss.item()val_loss /= len(val_loader)print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {val_loss:.4f}")

 模型训练以及微调:

batch_size = 16
noise_factor = 0.4ae1 = Autoencoder(input_size=1, hidden_size=16).to(device)
optimizer = torch.optim.Adam(ae1.parameters(), lr=0.001)
criterion = nn.MSELoss()
train_ae([ae1], train_loader, val_loader, 10, criterion, optimizer, noise_factor, finetune = False)ae2 = Autoencoder(input_size=16, hidden_size=64).to(device)
optimizer = torch.optim.Adam(ae2.parameters(), lr=0.001)
train_ae([ae1, ae2], train_loader, val_loader, 10, criterion, optimizer, noise_factor, finetune = False)ae3 = Autoencoder(input_size=64, hidden_size=128).to(device)
optimizer = torch.optim.Adam(ae3.parameters(), lr=0.001)
train_ae([ae1, ae2, ae3], train_loader, val_loader, 10, criterion, optimizer, noise_factor, finetune = False)class StackedAutoencoder(nn.Module):def __init__(self, ae1, ae2, ae3):super(StackedAutoencoder, self).__init__()self.encoder = nn.Sequential(ae1.encoder, ae2.encoder, ae3.encoder)self.decoder = nn.Sequential(ae3.decoder, ae2.decoder, ae1.decoder)def forward(self, x):x = self.encoder(x)x = self.decoder(x)return xsae = StackedAutoencoder(ae1, ae2, ae3)optimizer = torch.optim.Adam(ae1.parameters(), lr=0.001)
criterion = nn.MSELoss()
train_ae([sae], train_loader, val_loader, 10, criterion, optimizer, noise_factor, finetune = True)

 结果可视化:

import matplotlib.pyplot as plt
import numpy as np
dataiter = iter(val_loader)
image = next(dataiter)[1]
print(image.shape)
image = image.to(device)# 通过自编码器模型进行前向传播
with torch.no_grad():output = sae(image)noise_factor = 0.4
noisy_image = image + noise_factor * torch.randn_like(image)
noisy_image = torch.clamp(noisy_image, 0., 1.).cpu().numpy()# 将张量转换为 numpy 数组以便可视化
image = image.cpu().numpy()
output = output.cpu().numpy()# 定义一个函数来显示图片
def imshow(img):img = img * 0.3081 + 0.1307  # 反归一化npimg = img.squeeze()  # 去除单维度plt.imshow(npimg, cmap='gray')# 可视化输入和输出图片
plt.figure(figsize=(10, 5))# 显示输入图像
plt.subplot(1, 3, 1)
imshow(torch.from_numpy(image))
plt.title('Input Image')
plt.axis('off')# 显示加噪图像
plt.subplot(1, 3, 2)
imshow(torch.from_numpy(noisy_image))
plt.title('Noisy Image')
plt.axis('off')# 显示输出图像
plt.subplot(1, 3, 3)
imshow(torch.from_numpy(output))
plt.title('Output Image')
plt.axis('off')plt.savefig("预训练图.svg", dpi=300,format="svg")

 下面附上训练好的可视化图:

    版权声明:

    本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

    我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

    热搜词