欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 房产 > 家装 > CycleGAN深度学习项目

CycleGAN深度学习项目

2024/11/30 9:53:16 来源:https://blog.csdn.net/wn030416/article/details/140403675  浏览:    关键词:CycleGAN深度学习项目

远程仓库

leftthomas/CycleGAN: A PyTorch implementation of CycleGAN based on ICCV 2017 paper "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks" (github.com)

运行准备

Anaconda

安装需要的库

指令

pip install pandas -i https://pypi.tuna.tsinghua.edu.cn/simple

pip install torch==1.11.0 -i Simple Index

pip install torchvision==0.12.0 -i Simple Index

pip install dominate==2.4.0 -i Simple Index

pip install visdom==0.1.8.8 -i Simple Index

pip install tqdm -i https://pypi.tuna.tsinghua.edu.cn/simple

运行结果

数据集

我当前使用的数据集

leftthomas/CycleGAN: A PyTorch implementation of CycleGAN based on ICCV 2017 paper "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks" (github.com)

因为数据集太大,训练时间慢所以删掉了很多图片

A-副本和B-副本里面是原始的数据集

A B是我自己删了图片的数据集

如果使用其他数据集也可以训练,例如:从网上随便下载图片

运行结果

程序解读

从main.py的if __name__ == '__main__':开始看

因为程序从这里开始执行

parser = argparse.ArgumentParser(description='Train Model')
# common args
parser.add_argument('--data_root', default='horse2zebra', type=str, help='Dataset root path')
# 文件放的位置
parser.add_argument('--batch_size', default=1, type=int, help='Number of images in each mini-batch')
#每个小批量中的图像数量
parser.add_argument('--epochs', default=2, type=int, help='Number of epochs over the data to train')
# 多少轮训练
parser.add_argument('--lr', default=0.0002, type=float, help='Initial learning rate')
# 开始时学习率
parser.add_argument('--decay', default=2, type=int, help='Epoch to start linearly decaying lr to 0')
# 从第几轮开始学习率逐渐减为0
parser.add_argument('--save_root', default='result', type=str, help='Result saved root path')
# 训练出来的保存在哪里
# args parse
args = parser.parse_args()
data_root, batch_size, epochs, lr = args.data_root, args.batch_size, args.epochs, args.lr
decay, save_root = args.decay, args.save_root# data prepare
train_data = ImageDataset(data_root, 'train')
# 训练集
print("数据")
print(train_data.__len__())
# 打印出数据集的长度
test_data = ImageDataset(data_root, 'test')
# 验证集
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False, num_workers=8)
test_loader = DataLoader(test_data, batch_size=1, shuffle=False, num_workers=8)

使用通义灵码解释

# optimizer setup
optimizer_G = Adam(itertools.chain(G_A.parameters(), G_B.parameters()), lr=lr, betas=(0.5, 0.999))
optimizer_DA = Adam(D_A.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_DB = Adam(D_B.parameters(), lr=lr, betas=(0.5, 0.999))
lr_scheduler_G = LambdaLR(optimizer_G, lr_lambda=lambda eiter: 1.0 - max(0, eiter - decay) / float(decay))
lr_scheduler_DA = LambdaLR(optimizer_DA, lr_lambda=lambda eiter: 1.0 - max(0, eiter - decay) / float(decay))
lr_scheduler_DB = LambdaLR(optimizer_DB, lr_lambda=lambda eiter: 1.0 - max(0, eiter - decay) / float(decay))

这段代码是用于设置优化器和学习率调度器的。

首先,使用Adam优化器来初始化G_A和G_B的参数以及D_A和D_B的参数。Adam优化器是一种基于梯度的优化算法,它利用了动量和自适应学习率的特性。itertools.chain函数用于将G_A和G_B的参数组合在一起。

然后,使用LambdaLR学习率调度器来设置学习率的衰减。LambdaLR调度器使用给定的函数来计算每个迭代步骤的学习率。这里使用了一个lambda函数,它在迭代次数eiter超过decay后开始衰减学习率,衰减的速度由decay参数控制。

这些优化器和学习率调度器将用于训练生成器和判别器模型。

# training loop
results = {'train_g_loss': [], 'train_da_loss': [], 'train_db_loss': []}
if not os.path.exists(save_root):os.makedirs(save_root)
for epoch in range(1, epochs + 1):g_loss, da_loss, db_loss = train(G_A, G_B, D_A, D_B, train_loader, optimizer_G, optimizer_DA, optimizer_DB)results['train_g_loss'].append(g_loss)results['train_da_loss'].append(da_loss)results['train_db_loss'].append(db_loss)val(G_A, G_B, test_loader)lr_scheduler_G.step()lr_scheduler_DA.step()lr_scheduler_DB.step()# save statisticsdata_frame = pd.DataFrame(data=results, index=range(1, epoch + 1))data_frame.to_csv('{}/results.csv'.format(save_root), index_label='epoch')torch.save(G_A.state_dict(), '{}/GA.pth'.format(save_root))torch.save(G_B.state_dict(), '{}/GB.pth'.format(save_root))torch.save(D_A.state_dict(), '{}/DA.pth'.format(save_root))torch.save(D_B.state_dict(), '{}/DB.pth'.format(save_root))

这段代码是一个训练循环,用于训练深度学习模型。以下是代码的详细解释:

首先,定义一个字典results,用于存储训练过程中的损失值。

检查保存模型和结果的目录save_root是否存在,如果不存在则创建该目录。

使用for循环遍历epochs次,每次迭代都会进行一次训练和验证。

在每次迭代中,调用train函数训练生成器G_A、G_B和判别器D_A、D_B,并更新损失值。

将训练过程中的损失值分别添加到results字典中对应的列表中。

调用val函数对模型进行验证。

更新生成器和判别器的学习率。

将results字典转换为DataFrame,并将其保存为CSV文件。

保存生成器和判别器的模型参数。

这个训练循环的主要目的是在给定的训练数据集上训练生成对抗网络(GAN),并保存训练过程中的损失值和模型参数。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com