欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 健康 > 美食 > 【深度学习】计算机视觉(CV)-图像生成-生成对抗网络(GANs, Generative Adversarial Networks)

【深度学习】计算机视觉(CV)-图像生成-生成对抗网络(GANs, Generative Adversarial Networks)

2025/2/23 6:36:58 来源:https://blog.csdn.net/IT_ORACLE/article/details/145719949  浏览:    关键词:【深度学习】计算机视觉(CV)-图像生成-生成对抗网络(GANs, Generative Adversarial Networks)

生成对抗网络(GANs)Ian Goodfellow 在 2014 年提出的一种深度生成模型,主要用于生成逼真的数据,如图像、音乐、文本等。GANs 采用博弈论思想,让两个神经网络(生成器 G判别器 D)相互对抗,在不断竞争中提高数据的生成质量。


1. GANs 的核心思想

GANs 由 两个神经网络 组成:

  • 生成器(Generator, G)

    • 输入随机噪声 z,生成与真实数据类似的样本 G(z)
    • 目标:欺骗判别器,让它认为生成的样本是真实数据
  • 判别器(Discriminator, D)

    • 输入数据(真实数据 x 或生成数据 G(z))
    • 目标:判断输入数据是真实的(1)还是生成的(0)

两者进行博弈(Adversarial Training)

  • G 尽力欺骗 D
  • D 试图正确分类

最终,G 生成的数据会越来越逼真,D 也会变得更强。


2. GANs 的训练过程

Step 1: 随机生成噪声

  • 生成器 G 以 随机噪声 z(通常是正态分布) 作为输入,生成假数据:

                                                    G(z)→生成假数据G(z)

Step 2: 判别器判定真假

  • 判别器 D 接收 真实数据 x 和假数据 G(z),并计算它们的真假概率:
                                                    D(x)→1(真实数据)
                                                    D(G(z))→0(假数据)

Step 3: 计算损失

  • 判别器损失(Binary Cross Entropy Loss)

    L_D = - \mathbb{E}_{x \sim P_{data}} [\log D(x)] - \mathbb{E}_{z \sim P_z} [\log (1 - D(G(z)))]

    目的是让 D(x) 预测 1,D(G(z)) 预测 0。

  • 生成器损失(让 G 欺骗 D)

    L_G = - \mathbb{E}_{z \sim P_z} [\log D(G(z))]

    目的是让 D(G(z)) 预测 1(以为是假数据)。

Step 4: 交替优化

  1. 更新 D:固定 G,训练 D,使其能正确分类真假数据。
  2. 更新 G:固定 D,训练 G,使 D 无法区分真假数据。

最终,G 生成的数据将会越来越接近真实数据。


3. GANs 的数学原理

GANs 本质上是在优化一个极小极大问题(Minimax Game)

\min_G \max_D V(D, G)

其中目标函数为:

V(D, G) = \mathbb{E}_{x \sim P_{data}}[\log D(x)] + \mathbb{E}_{z \sim P_z}[\log (1 - D(G(z)))]

  • 判别器 D:最大化 V(D,G) 以最准确区分真假数据。
  • 生成器 G:最小化 V(D,G) 以欺骗 D,使得假数据尽可能像真实数据。

4. 经典 GANs 变体

DCGAN(深度卷积 GAN)

  • 替换全连接层使用 CNN 提高图像生成质量。
  • 采用 Leaky ReLU 代替 ReLU,提高梯度流动。

WGAN(Wasserstein GAN)

  • 解决**模式崩溃(Mode Collapse)**问题。
  • 采用 Wasserstein 距离 代替 JS 散度,提高训练稳定性。

WGAN-GP(带梯度惩罚的 WGAN)

  • 解决 WGAN 训练中的梯度消失问题,提高稳定性。

Conditional GAN(cGAN)

  • 让 GAN 按类别生成特定类型的图片(例如手写数字、动漫头像)。
  • 在输入中添加 类别标签,使 G 生成特定类别数据。

CycleGAN(循环一致性 GAN)

  • 无需成对数据,即可进行风格转换(如 照片风格转换黑白图像上色)。

5. 代码实现(PyTorch)

生成手写数字(DCGAN)

import torch.nn as nn
import torch.optim as optim# 生成器类定义
# 该类用于生成图像,继承自nn.Module
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()# 生成器模型定义,使用反卷积(转置卷积)层逐步上采样,最终生成与真实图像大小相同的输出self.model = nn.Sequential(nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),  # 输入维度为100,输出维度为512,卷积核大小为4,步长为1,padding为0nn.BatchNorm2d(512),  # 应用Batch Normalizationnn.ReLU(True),  # 应用ReLU激活函数nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),nn.BatchNorm2d(256),nn.ReLU(True),nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),nn.BatchNorm2d(128),nn.ReLU(True),nn.ConvTranspose2d(128, 1, 4, 2, 1, bias=False),nn.Tanh()  # 应用Tanh激活函数# 输出维度为1,表示生成的图像像素值范围在-1到1之间)# 前向传播函数def forward(self, x):return self.model(x)# 判别器类定义
# 该类用于判别图像真假,继承自nn.Module
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()# 判别器模型定义,使用卷积层逐步下采样,最终输出一个标量概率值self.model = nn.Sequential(nn.Conv2d(1, 128, 4, 2, 1, bias=False),  # 输入维度为1,输出维度为128,卷积核大小为4,步长为2,padding为1nn.LeakyReLU(0.2, inplace=True),  # 应用Leaky ReLU激活函数nn.Conv2d(128, 256, 4, 2, 1, bias=False),  # 输入维度为128,输出维度为256,卷积核大小为4,步长为2,padding为1nn.BatchNorm2d(256),  # 应用Batch Normalizationnn.LeakyReLU(0.2, inplace=True),nn.Conv2d(256, 512, 4, 2, 1, bias=False),nn.BatchNorm2d(512),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(512, 1, 4, 1, 0, bias=False),nn.Sigmoid()  # 应用Sigmoid激活函数)# 前向传播函数def forward(self, x):return self.model(x)# 初始化生成器和判别器模型
G = Generator()
D = Discriminator()# 初始化生成器和判别器的优化器,使用Adam优化算法
optimizer_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))  # 生成器优化器
optimizer_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))  # 判别器优化器# 打印生成器模型结构
print(G)

 运行结果

Generator((model): Sequential((0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True)(6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(8): ReLU(inplace=True)(9): ConvTranspose2d(128, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(10): Tanh())
)

使用 DCGAN 训练后,可以生成逼真的手写数字!


6. GANs 的应用

  • 图像生成:动漫头像、3D 人脸建模
  • 风格转换:黑白照片上色、照片 → 画风转换(如 CycleGAN)
  • 医学影像:生成 MRI、CT 图像,提高医疗影像质量
  • 文本生成:ChatGPT、文本补全
  • 数据增强:生成样本数据,提高模型鲁棒性

7. 总结

GANs 通过 G 和 D 互相对抗,提高生成数据的质量
训练 GANs 可能存在模式崩溃、梯度消失等问题
多个 GAN 变体(DCGAN、WGAN、cGAN)解决不同任务需求
广泛应用于图像生成、风格转换、数据增强等领域

GANs 仍然是 AI 生成模型的重要技术之一,未来可能结合 Transformer 进行更多创新!

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词