WGAN-GP 原理及实现
- 一、WGAN-GP 原理
- 1.1 WGAN-GP 核心原理
- 1.2 WGAN-GP 实现步骤
- 1.3 总结
- 二、WGAN-GP 实现
- 2.1 导包
- 2.2 数据加载和处理
- 2.3 构建生成器
- 2.4 构建判别器
- 2.5 训练和保存模型
- 2.6 图片转GIF
一、WGAN-GP 原理
Wasserstein GAN with Gradient Penalty (WGAN-GP) 是对原始 WGAN 的改进,通过梯度惩罚(Gradient Penalty)
替代权重裁剪(Weight Clipping),解决了 WGAN 训练不稳定、权重裁剪导致梯度消失或爆炸的问题。
1.1 WGAN-GP 核心原理
(1) Wasserstein 距离(Earth-Mover 距离)
- 原始 GAN 的 JS 散度在分布不重叠时梯度消失,而 WGAN 使用 Wasserstein 距离衡量生成分布 P g P_g Pg 和真实分布 P r P_r Pr 的距离:
W ( P r , P g ) = inf γ ∼ Π ( P r , P g ) E ( x , y ) ∼ γ [ ∥ x − y ∥ ] W(P_r, P_g) = \inf_{\gamma \sim \Pi(P_r, P_g)} \mathbb{E}_{(x,y)\sim \gamma} [\|x-y\|] W(Pr,Pg)=infγ∼Π(Pr,Pg)E(x,y)∼γ[∥x−y∥] - 通过 Kantorovich-Rubinstein 对偶形式,转化为:
W ( P r , P g ) = sup ∥ D ∥ L ≤ 1 E x ∼ P r [ D ( x ) ] − E z ∼ P z [ D ( G ( z ) ) ] W(P_r, P_g) = \sup_{\|D\|_L \leq 1} \mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{z \sim P_z}[D(G(z))] W(Pr,Pg)=sup∥D∥L≤1Ex∼Pr[D(x)]−Ez∼Pz[D(G(z))],其中 D D D 是 1-Lipschitz 函数(梯度范数不超过 1)
(2) 梯度惩罚(Gradient Penalty)
- 原始 WGAN 的问题:通过权重裁剪强制判别器(Critic)满足 Lipschitz 约束,但会导致梯度不稳定或容量下降
- WGAN-GP 的改进:直接对判别器的梯度施加惩罚项,强制其梯度范数接近 1: λ ⋅ E x ^ ∼ P x ^ \lambda \cdot \mathbb{E}_{\hat{x} \sim P_{\hat{x}}} λ⋅Ex^∼Px^ [ ( ∥ ∇ x ^ D ( x ^ ) ∥ 2 − 1 ) 2 ] \left [(\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2 \right] [(∥∇x^D(x^)∥2−1)2]
- x ^ \hat{x} x^ 是真实数据和生成数据的随机插值点: x ^ = ϵ x + ( 1 − ϵ ) G ( z ) \hat{x} = \epsilon x + (1-\epsilon) G(z) x^=ϵx+(1−ϵ)G(z), ϵ ∼ U [ 0 , 1 ] \epsilon \sim U[0,1] ϵ∼U[0,1]
- λ \lambda λ 是惩罚系数(通常设为 10)
1.2 WGAN-GP 实现步骤
(1) 判别器(Critic)的损失函数
判别器的目标是最大化 Wasserstein 距离,同时满足梯度约束:
L D = E x ∼ P r [ D ( x ) ] − E z ∼ P z [ D ( G ( z ) ) ] ⏟ Wasserstein 距离 + λ ⋅ E x ^ ∼ P x ^ [ ( ∥ ∇ x ^ D ( x ^ ) ∥ 2 − 1 ) 2 ] ⏟ 梯度惩罚 L_D = \underbrace{\mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{z \sim P_z}[D(G(z))]}_{\text{Wasserstein 距离}} + \underbrace{\lambda \cdot \mathbb{E}_{\hat{x} \sim P_{\hat{x}}} \left[ (\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2 \right]}_{\text{梯度惩罚}} LD=Wasserstein 距离 Ex∼Pr[D(x)]−Ez∼Pz[D(G(z))]+梯度惩罚 λ⋅Ex^∼Px^[(∥∇x^D(x^)∥2−1)2]
(2) 生成器(Generator)的损失函数
生成器的目标是最小化 Wasserstein 距离: L G = − E z ∼ P z [ D ( G ( z ) ) ] L_G = -\mathbb{E}_{z \sim P_z}[D(G(z))] LG=−Ez∼Pz[D(G(z))]
(3) 训练流程
- 输入:真实数据 x x x,噪声 z ∼ N ( 0 , 1 ) z \sim \mathcal{N}(0,1) z∼N(0,1)
- 生成数据: G ( z ) G(z) G(z)
- 插值采样: x ^ = ϵ x + ( 1 − ϵ ) G ( z ) \hat{x} = \epsilon x + (1-\epsilon) G(z) x^=ϵx+(1−ϵ)G(z), ϵ ∼ U [ 0 , 1 ] \epsilon \sim U[0,1] ϵ∼U[0,1]
- 计算梯度惩罚:
- 对插值样本 x ^ \hat{x} x^ 计算判别器输出 D ( x ^ ) D(\hat{x}) D(x^)
- 求梯度 ∇ x ^ D ( x ^ ) \nabla_{\hat{x}} D(\hat{x}) ∇x^D(x^) 并计算惩罚项
- 更新判别器:最小化 L D L_D LD
- 更新生成器:最小化 L G L_G LG(每 n critic n_{\text{critic}} ncritic 次判别器更新后更新 1 次生成器)
1.3 总结
WGAN-GP 通过梯度惩罚替代权重裁剪,显著提升了 WGAN 的训练稳定性,是生成对抗网络的重要改进之一。实际应用中需注意:
- 判别器架构设计
- 梯度惩罚的正确实现
- 学习率和训练次数的调优
二、WGAN-GP 实现
2.1 导包
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as npimport os
import time
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from torchsummary import summary# 判断是否存在可用的GPU
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 指定存放日志路径
writer=SummaryWriter(log_dir="./runs/wgan_gp")os.makedirs("./img/wgan_gp_mnist", exist_ok=True) # 存放生成样本目录
os.makedirs("./model", exist_ok=True) # 模型存放目录
2.2 数据加载和处理
# 加载 MNIST 数据集
def load_data(batch_size=64,img_shape=(1,28,28)):transform = transforms.Compose([transforms.ToTensor(), # 将图像转换为张量transforms.Normalize(mean=[0.5], std=[0.5]) # 归一化到[-1,1]])# 下载训练集和测试集train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 创建 DataLoadertrain_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=2,shuffle=True)test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=2,shuffle=False)return train_loader, test_loader
2.3 构建生成器
class Generator(nn.Module):"""生成器"""def __init__(self, latent_dim=100,img_shape=(1,28,28)):super(Generator,self).__init__()# 网络块def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat))layers.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))return layersself.model = nn.Sequential(*block(latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh() # 输出归一化到[-1,1] )def forward(self,z): # 噪声z,2维[batch_size,latent_dim]gen_img=self.model(z) gen_img=gen_img.view(gen_img.shape[0],*img_shape)return gen_img # 4维[batch_size,1,H,W]
2.4 构建判别器
class Discriminator(nn.Module):"""判别器"""def __init__(self,img_shape=(1,28,28)):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),nn.LeakyReLU(negative_slope=0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(negative_slope=0.2, inplace=True),nn.Linear(256, 1))def forward(self,img): # 输入图片,4维[batc_size,1,H,W]img=img.view(img.shape[0], -1) pred = self.model(img)return pred # 2维[batch_size,1]
2.5 训练和保存模型
-
WGAN-GP 算法流程
-
定义梯度惩罚函数
def compute_gradient_penalty(critic, real, fake, device):batch_size = real.shape[0]epsilon = torch.rand(batch_size, 1, 1, 1).to(device) # 随机插值系数interpolates = (epsilon * real + (1 - epsilon) * fake).requires_grad_(True)critic_interpolates = critic(interpolates)# 计算梯度gradients = torch.autograd.grad(outputs=critic_interpolates,inputs=interpolates,grad_outputs=torch.ones_like(critic_interpolates),create_graph=True,retain_graph=True,)[0]gradients = gradients.view(gradients.shape[0], -1)gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()return gradient_penalty
- 训练和保存
# 设置超参数
batch_size = 64
epochs = 200
lr= 0.0002
latent_dim=100 # 生成器输入噪声向量的长度(维数)
sample_interval=400 #每400次迭代保存生成样本# WGAN的特别设置
num_iter_critic = 5
lambda_gp = 10# 设置图片形状1*28*28
img_shape = (1,28,28)# 加载数据
train_loader,_= load_data(batch_size=batch_size,img_shape=img_shape)# 实例化生成器G、判别器D
G=Generator().to(device)
D=Discriminator().to(device)# 设置优化器
optimizer_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))# 开始训练
batches_done=0
loader_len=len(train_loader) #训练集加载器的长度
for epoch in range(epochs):# 进入训练模式G.train()D.train()loop = tqdm(train_loader, desc=f"第{epoch+1}轮")for i, (real_imgs, _) in enumerate(loop):real_imgs=real_imgs.to(device) # [B,C,H,W]# -----------------# 训练判别器# -----------------# 获取噪声样本[B,latent_dim)z=torch.normal(0,1,size=(real_imgs.shape[0],latent_dim),device=device) #从正态分布中抽样# Step-1 计算判断器损失=判断真实图片损失+判断生成图片损失+惩罚项fake_imgs=G(z).detach()gradient_penalty=compute_gradient_penalty(D, real_imgs, fake_imgs, device)dis_loss=-torch.mean(D(real_imgs)) + torch.mean(D(fake_imgs))+lambda_gp*gradient_penalty# Step-2 更新判别器参数optimizer_D.zero_grad() # 梯度清零dis_loss.backward() #反向传播,计算梯度optimizer_D.step() #更新判别器 # -----------------# 训练生成器# -----------------# 判别器每迭代 num_iter_critic 次,生成器迭代一次if i % num_iter_critic ==0 :gen_imgs=G(z).detach()# 更新生成器参数optimizer_G.zero_grad() #梯度清零gen_loss=-torch.mean(D(gen_imgs))gen_loss.backward() #反向传播,计算梯度optimizer_G.step() #更新生成器 # 更新进度条loop.set_postfix(gen_loss=f"{gen_loss:.8f}",dis_loss=f"{dis_loss:.8f}")# 每 sample_interval 次迭代保存生成样本if batches_done % sample_interval == 0:save_image(gen_imgs.data[:25], f"./img/wgan_gp_mnist/{epoch}_{i}.png", nrow=5, normalize=True)batches_done += 1print('总共训练用时: %.2f min' % ((time.time() - start_time)/60))#仅保存模型的参数(权重和偏置),灵活性高,可以在不同的模型结构之间加载参数
torch.save(G.state_dict(), "./model/WGAN-GP_G.pth")
torch.save(D.state_dict(), "./model/WGAN-GP_D.pth")
2.6 图片转GIF
from PIL import Imagedef create_gif(img_dir="./img/wgan_gp_mnist", output_file="./img/wgan_gp_mnist/wgan_gp_figure.gif", duration=100):images = []img_paths = [f for f in os.listdir(img_dir) if f.endswith(".png")]# 自定义排序:按 "x_y.png" 的 x 和 y 排序img_paths_sorted = sorted(img_paths,key=lambda x: (int(x.split('_')[0]), # 第一个数字(如 0_400.png 的 0)int(x.split('_')[1].split('.')[0]) # 第二个数字(如 0_400.png 的 400)))for img_file in img_paths_sorted:img = Image.open(os.path.join(img_dir, img_file))images.append(img)images[0].save(output_file, save_all=True, append_images=images[1:], duration=duration, loop=0)print(f"GIF已保存至 {output_file}")
create_gif()
