欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 房产 > 家装 > 昇思MindSpore学习笔记4-04生成式--GAN图像生成

昇思MindSpore学习笔记4-04生成式--GAN图像生成

2025/2/23 7:18:46 来源:https://blog.csdn.net/muren/article/details/140219595  浏览:    关键词:昇思MindSpore学习笔记4-04生成式--GAN图像生成

摘要:

        记录昇思MindSpore AI框架使用GAN生成式对抗网络模型生成图像的原理和实际使用方法、步骤。包括环境准备、数据集下载、数据加载、隐码、构建生成器与判别器、模型训练、模型推理等。

一、概念

GAN生成式对抗网络模型

(Generative Adversarial Networks)

生成式机器学习模型

复杂分布上无监督学习

        两个模型共同组成

                生成器Generative Model

                        输出模拟训练图像

                判别器Discriminative Model

                        判断生成器输出图像的真实性

        生成器和判别器互相博弈学习

        核心

                通过对抗同时训练生成模型和判别模型

        博弈平衡点

                生成的模拟图像和训练数据图像分布完全一致时

                判别器拥有50%的真假判断置信度

x                  1×28×28图像数据

 D(x)            判别器判定的真实概率

                   二分类

                   x ϵ 训练数据 D(x)-->1 

                   x ϵ 生成器 D(x)-->0 

z                 提取标准正态分布隐码(隐向量)

G(z)           生成器函数

                        隐码(隐向量)z映射到数据空间。

                 目标

                        高斯分布的随机噪声z --> 真实数据分布P_{data}(x)

网络参数θ        训练网络找寻θPG(x;\theta )--> P_{data}(x)

D(G(z))           判定生成器G 生成模拟图像的真实概率

logD(x)           生成器参数

log(1−D(G(z)))         判别器参数

GAN损失函数:

\underset{G}{min}\underset{D}{max}V(D,G)=E_{x\sim P_{data}}(x)[logD(x)]+E_{z\sim p_z(z)}[log(1-D(G(z)))]\underset{.}{}

博弈平衡点PG(x;\theta )=P_{data}(x),随机猜测

生成器和判别器博弈过程:

1.训练开始生成器随机生成数据分布

2.判别器通过求取梯度和损失函数优化网络

判定靠近真实数据分布的为1

判定靠近生成器生成数据分布的为0

3.生成器通过优化,生成更贴近真实数据分布的数据

4.生成器生成数据和真实数据达到相同分布

判别器输出为1/2

上图中

        蓝色虚线表示判别器

        黑色虚线表示真实数据分布

        绿色实线表示生成器生成的模拟数据分布

        z表示隐码

        G(z)表示生成的模拟图像

二、环境准备

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore

输出:

Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by: 

三、数据集

1.数据集简介

NIST数据集

        MNIST手写数字数据集

                70000张手写数字图片

                        60000张训练样本

                        10000张测试样本

                尺寸归一化: 图片大小28*28

                单通道

                中心化处理

2.数据集下载

安装download包

        pip install download

download接口下载

自动解压到当前目录

数据集目录结构:

./MNIST_Data/
├─ train
│ ├─ train-images-idx3-ubyte
│ └─ train-labels-idx1-ubyte
└─ test├─ t10k-images-idx3-ubyte└─ t10k-labels-idx1-ubyte

数据下载代码:

# 数据下载
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
download(url, ".", kind="zip", replace=True)

输出:

Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip (10.3 MB)file_sizes: 100%|███████████████████████████| 10.8M/10.8M [00:00<00:00, 116MB/s]
Extracting zip file...
Successfully downloaded / unzipped to .

3.数据加载

MindSpore.dataset.MnistDatase接口

读取和解析MNIST数据集的源文件构建数据集。然后对数据进行一些前处理。

import numpy as np
import mindspore.dataset as ds
​
batch_size = 64
latent_size = 100  # 隐码的长度
​
train_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/train')
test_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/test')
​
def data_load(dataset):
dataset1 = ds.GeneratorDataset(dataset, ["image", "label"], shuffle=True, 
python_multiprocessing=False,num_samples=10000)# 数据增强mnist_ds = dataset1.map(operations=lambda x: (x.astype("float32"), 
np.random.normal(size=latent_size).astype("float32")),output_columns=["image", "latent_code"])mnist_ds = mnist_ds.project(["image", "latent_code"])
​# 批量操作mnist_ds = mnist_ds.batch(batch_size, True)
​return mnist_ds
​
mnist_ds = data_load(train_dataset)
​
iter_size = mnist_ds.get_dataset_size()
print('Iter size: %d' % iter_size)

输出:

Iter size: 156

4.数据集可视化

create_dict_iterator函数

        数据转换成字典迭代器

matplotlib模块

        可视化

import matplotlib.pyplot as plt
​
data_iter = next(mnist_ds.create_dict_iterator(output_numpy=True))
figure = plt.figure(figsize=(3, 3))
cols, rows = 5, 5
for idx in range(1, cols * rows + 1):image = data_iter['image'][idx]figure.add_subplot(rows, cols, idx)plt.axis("off")plt.imshow(image.squeeze(), cmap="gray")
plt.show()

输出:

5.隐码构造

每轮训练迭代

输入固定高斯分布隐码test_noise到生成器

评估生成器生成图像的效果

import random
import numpy as np
from mindspore import Tensor
from mindspore.common import dtype
​
# 利用随机种子创建一批隐码
np.random.seed(2323)
test_noise = Tensor(np.random.normal(size=(25, 100)), dtype.float32)
random.shuffle(test_noise)

四、模型构建

判别器和生成器中采用

全连接网络架构

 ReLU 激活函数

1.生成器

生成器 

Generator

        功能

                将隐码映射到数据空间

                图像数据映射到与真实图像大小相同的灰度图像(或 RGB 彩色图像)

五层 Dense 全连接层

        BatchNorm1d 批归一化层

         ReLU 激活层配对

        Tanh 函数

        输出 [-1,1]范围内的数据。

注意静态图模式下实例化生成器之后修改参数名称

from mindspore import nn
import mindspore.ops as ops
​
img_size = 28  # 训练图像长(宽)
​
class Generator(nn.Cell):def __init__(self, latent_size, auto_prefix=True):super(Generator, self).__init__(auto_prefix=auto_prefix)self.model = nn.SequentialCell()# [N, 100] -> [N, 128]# 输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维self.model.append(nn.Dense(latent_size, 128))self.model.append(nn.ReLU())# [N, 128] -> [N, 256]self.model.append(nn.Dense(128, 256))self.model.append(nn.BatchNorm1d(256))self.model.append(nn.ReLU())# [N, 256] -> [N, 512]self.model.append(nn.Dense(256, 512))self.model.append(nn.BatchNorm1d(512))self.model.append(nn.ReLU())# [N, 512] -> [N, 1024]self.model.append(nn.Dense(512, 1024))self.model.append(nn.BatchNorm1d(1024))self.model.append(nn.ReLU())# [N, 1024] -> [N, 784]# 经过线性变换将其变成784维self.model.append(nn.Dense(1024, img_size * img_size))# 经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间self.model.append(nn.Tanh())
​def construct(self, x):img = self.model(x)return ops.reshape(img, (-1, 1, 28, 28))
​
net_g = Generator(latent_size)
net_g.update_parameters_name('generator')

2.判别器

判别器

Discriminator 

        二分类网络模型

        输出判定图像的真实概

系列 Dense 层

LeakyReLU 层

Sigmoid 激活函数

输出 [0, 1]范围内的数据,

        真实概率

注意静态图模式下实例化判别器之后修改参数名称

 # 判别器
class Discriminator(nn.Cell):def __init__(self, auto_prefix=True):super().__init__(auto_prefix=auto_prefix)self.model = nn.SequentialCell()# [N, 784] -> [N, 512]self.model.append(nn.Dense(img_size * img_size, 512))  # 输入特征数为784,输出为512self.model.append(nn.LeakyReLU())  # 默认斜率为0.2的非线性映射激活函数# [N, 512] -> [N, 256]self.model.append(nn.Dense(512, 256))  # 进行一个线性映射self.model.append(nn.LeakyReLU())# [N, 256] -> [N, 1]self.model.append(nn.Dense(256, 1))self.model.append(nn.Sigmoid())  # 二分类激活函数,将实数映射到[0,1]
​def construct(self, x):x_flat = ops.reshape(x, (-1, img_size * img_size))return self.model(x_flat)
​
net_d = Discriminator()
net_d.update_parameters_name('discriminator')

3.损失函数和优化器

损失函数

        MindSpore.nn.BCELoss二进制交叉熵损失函数

生成器优化器

        Adam优化器

判别器优化器

        Adam优化器

lr = 0.0002  # 学习率
​
# 损失函数
adversarial_loss = nn.BCELoss(reduction='mean')
​
# 优化器
optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g.update_parameters_name('optim_g')
optimizer_d.update_parameters_name('optim_d')

五、模型训练

训练判别器

        提高判别图像真伪的概率

        提高随机梯度来更新判别器

        最大化logD(x)+log(1-D(G(z)))的值

训练生成器

        产生更好的模拟图像

        最小化 log(1-D(G(z))) 

分别获取训练损失

每轮迭代测试

        批量推送隐码到生成器

        跟踪生成器 Generator 的训练效果

import os
import time
import matplotlib.pyplot as plt
import mindspore as ms
from mindspore import Tensor, save_checkpoint
​
total_epoch = 12  # 训练周期数
batch_size = 64  # 用于训练的训练集批量大小
​
# 加载预训练模型的参数
pred_trained = False
pred_trained_g = './result/checkpoints/Generator99.ckpt'
pred_trained_d = './result/checkpoints/Discriminator99.ckpt'
​
checkpoints_path = "./result/checkpoints"  # 结果保存路径
image_path = "./result/images"  # 测试结果保存路径
[10]:# 生成器计算损失过程
def generator_forward(test_noises):fake_data = net_g(test_noises)fake_out = net_d(fake_data)loss_g = adversarial_loss(fake_out, ops.ones_like(fake_out))return loss_g
​
# 判别器计算损失过程
def discriminator_forward(real_data, test_noises):fake_data = net_g(test_noises)fake_out = net_d(fake_data)real_out = net_d(real_data)real_loss = adversarial_loss(real_out, ops.ones_like(real_out))fake_loss = adversarial_loss(fake_out, ops.zeros_like(fake_out))loss_d = real_loss + fake_lossreturn loss_d
​
# 梯度方法
grad_g = ms.value_and_grad(generator_forward, None, net_g.trainable_params())
grad_d = ms.value_and_grad(discriminator_forward, None, net_d.trainable_params())
​
def train_step(real_data, latent_code):# 计算判别器损失和梯度loss_d, grads_d = grad_d(real_data, latent_code)optimizer_d(grads_d)loss_g, grads_g = grad_g(latent_code)optimizer_g(grads_g)
​return loss_d, loss_g
​
# 保存生成的test图像
def save_imgs(gen_imgs1, idx):for i3 in range(gen_imgs1.shape[0]):plt.subplot(5, 5, i3 + 1)plt.imshow(gen_imgs1[i3, 0, :, :] / 2 + 0.5, cmap="gray")plt.axis("off")plt.savefig(image_path + "/test_{}.png".format(idx))
​
# 设置参数保存路径
os.makedirs(checkpoints_path, exist_ok=True)
# 设置中间过程生成图片保存路径
os.makedirs(image_path, exist_ok=True)
​
net_g.set_train()
net_d.set_train()
​
# 储存生成器和判别器loss
losses_g, losses_d = [], []
​
for epoch in range(total_epoch):start = time.time()for (iter, data) in enumerate(mnist_ds):start1 = time.time()image, latent_code = dataimage = (image - 127.5) / 127.5  # [0, 255] -> [-1, 1]image = image.reshape(image.shape[0], 1, image.shape[1], image.shape[2])d_loss, g_loss = train_step(image, latent_code)end1 = time.time()if iter % 10 == 10:print(f"Epoch:[{int(epoch):>3d}/{int(total_epoch):>3d}], "f"step:[{int(iter):>4d}/{int(iter_size):>4d}], "f"loss_d:{d_loss.asnumpy():>4f} , "f"loss_g:{g_loss.asnumpy():>4f} , "f"time:{(end1 - start1):>3f}s, "f"lr:{lr:>6f}")
​end = time.time()print("time of epoch {} is {:.2f}s".format(epoch + 1, end - start))
​losses_d.append(d_loss.asnumpy())losses_g.append(g_loss.asnumpy())
​# 每个epoch结束后,使用生成器生成一组图片gen_imgs = net_g(test_noise)save_imgs(gen_imgs.asnumpy(), epoch)
​# 根据epoch保存模型权重文件if epoch % 1 == 0:save_checkpoint(net_g, checkpoints_path + "/Generator%d.ckpt" % (epoch))save_checkpoint(net_d, checkpoints_path + "/Discriminator%d.ckpt" % (epoch))

输出:

time of epoch 1 is 86.37s
time of epoch 2 is 7.11s
time of epoch 3 is 6.94s
time of epoch 4 is 7.22s
time of epoch 5 is 7.28s
time of epoch 6 is 6.98s
time of epoch 7 is 7.33s
time of epoch 8 is 7.06s
time of epoch 9 is 7.06s
time of epoch 10 is 7.25s
time of epoch 11 is 7.03s
time of epoch 12 is 7.29s

六、效果展示

描绘D和G损失与训练迭代的关系图:

plt.figure(figsize=(6, 4))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(losses_g, label="G", color='blue')
plt.plot(losses_d, label="D", color='orange')
plt.xlim(-5,15)
plt.ylim(0, 3.5)
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

输出:

显示训练过程中通过隐向量生成的图像。

import cv2
import matplotlib.animation as animation
​
# 将训练过程中生成的测试图转为动态图
image_list = []
for i in range(total_epoch):image_list.append(cv2.imread(image_path + "/test_{}.png".format(i), cv2.IMREAD_GRAYSCALE))
show_list = []
fig = plt.figure(dpi=70)
for epoch in range(0, len(image_list), 5):plt.axis("off")show_list.append([plt.imshow(image_list[epoch], cmap='gray')])
​
ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)
ani.save('train_test.gif', writer='pillow', fps=1)

输出:

训练次数增多,图像质量越好

 Epoch>100,生成的手写数字图片数据集

七、模型推理

加载生成器网络模型参数文件来生成图像:

import mindspore as ms
​
# test_ckpt = './result/checkpoints/Generator199.ckpt'
​
# parameter = ms.load_checkpoint(test_ckpt)
# ms.load_param_into_net(net_g, parameter)
# 模型生成结果
test_data = Tensor(np.random.normal(0, 1, (25, 100)).astype(np.float32))
images = net_g(test_data).transpose(0, 2, 3, 1).asnumpy()
# 结果展示
fig = plt.figure(figsize=(3, 3), dpi=120)
for i in range(25):fig.add_subplot(5, 5, i + 1)plt.axis("off")plt.imshow(images[i].squeeze(), cmap="gray")
plt.show()

输出:

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词