欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 新闻 > 资讯 > 从零开始训练Codebook:基于ViT的图像重建实践

从零开始训练Codebook:基于ViT的图像重建实践

2025/4/4 16:56:42 来源:https://blog.csdn.net/muyouhang/article/details/146986420  浏览:    关键词:从零开始训练Codebook:基于ViT的图像重建实践

完整代码在文末,可以一键运行。

在这里插入图片描述

1. 核心原理

Codebook是一种离散表征学习方法,其核心思想是将连续特征空间映射到离散的码本空间。我们的实现方案包含三个关键组件:

1.1 ViT编码器

class ViTEncoder(nn.Module):def __init__(self, codebook_dim=512):super().__init__()self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")self.proj = nn.Linear(768, codebook_dim)def forward(self, x):outputs = self.vit(x).last_hidden_statepatch_embeddings = outputs[:, 1:, :]  # 移除CLS tokenreturn self.proj(patch_embeddings)
  • 使用预训练的ViT-Base模型提取图像特征
  • 移除CLS token,保留196个图像块特征
  • 线性投影调整特征维度适配Codebook

1.2 Codebook量化层

class Codebook(nn.Module):def __init__(self, num_embeddings=1024, embedding_dim=512):super().__init__()self.codebook = nn.Embedding(num_embeddings, embedding_dim)def quantize(self, z):# 计算L2距离distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)# 最近邻查找indices = torch.argmin(distances, dim=1)return indices, self.codebook(indices)
  • 使用可学习的Embedding层存储离散码本
  • 通过L2距离计算实现最近邻查找
  • 支持EMA更新(代码中已注释部分)

1.3 ViT解码器

class ViTDecoder(nn.Module):def __init__(self):self.head = nn.Sequential(nn.ConvTranspose2d(768, 384, 4, 2, 1),nn.ReLU(),... # 更多上采样层nn.Conv2d(48, 3, 1))
  • 使用转置卷积逐步上采样
  • 最终输出224x224分辨率图像
  • 与编码器形成对称结构

2. 训练策略

2.1 多目标损失函数

total_loss = mse_loss + 0.1*percep_loss + codebook_loss + commitment_loss
  • MSE Loss: 像素级重建误差
  • Perceptual Loss: VGG16特征匹配
  • Codebook Loss: 码本向量优化
  • Commitment Loss: 编码器输出稳定性

2.2 优化技巧

opt = torch.optim.Adam([{'params': encoder.parameters()},{'params': decoder.parameters()},{'params': codebook.parameters(), 'lr': 1e-4}
], lr=3e-4)
  • 分层学习率设置
  • EMA指数平滑更新
  • 混合精度训练支持
  • 动态学习率调整

3. 完整训练流程

3.1 数据准备

transform_train = transforms.Compose([transforms.Resize(224),transforms.RandomCrop(224, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(...)
])
  • CIFAR-10数据集
  • 随机裁剪+翻转增强
  • Batch Size=4适配显存

3.2 训练监控

# TensorBoard记录
writer.add_scalar('Loss/total', total_loss.item(), global_step)
writer.add_image('Reconstruction', grid, global_step)# 控制台日志
print(f"[Epoch {epoch+1:03d}] Loss: {total_loss.item():.4f}")

完整代码

from transformers import ViTModel, ViTConfig
import torch.nn as nn
import torch
import time
from tqdm import tqdm
class ViTEncoder(nn.Module):def __init__(self, codebook_dim=512):super().__init__()# 加载预训练ViT-Base模型self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")# 调整输出维度匹配Codebookself.proj = nn.Linear(768, codebook_dim)  # 网页2/6中的线性嵌入策略def forward(self, x):outputs = self.vit(x).last_hidden_state  # [batch, num_patches+1, 768]patch_embeddings = outputs[:, 1:, :]     # 移除CLS tokenreturn self.proj(patch_embeddings)       # [batch, 196, 512]class Codebook(nn.Module):def __init__(self, num_embeddings=16384, embedding_dim=512):super().__init__()self.codebook = nn.Embedding(num_embeddings, embedding_dim)nn.init.normal_(self.codebook.weight)  # 网页1的EMA更新可在此扩展def quantize(self, z):"""量化输入特征向量参数:z: 输入特征 [batch, num_patches, embedding_dim]返回:indices: 最近邻码本索引 [batch, num_patches]quantized: 量化后的特征 [batch, num_patches, embedding_dim]"""# 重塑输入为二维矩阵 [batch*num_patches, embedding_dim]batch, num_patches, dim = z.shapez_flat = z.reshape(-1, dim)  # [batch*num_patches, dim]# 计算L2距离 ||z - e||^2 = ||z||^2 - 2<z,e> + ||e||^2z_norm = torch.sum(z_flat ** 2, dim=1, keepdim=True)  # [batch*num_patches, 1]e_norm = torch.sum(self.codebook.weight ** 2, dim=1)  # [num_embeddings]dot_product = torch.matmul(z_flat, self.codebook.weight.t())  # [batch*num_patches, num_embeddings]distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)# 找到最近邻indices = torch.argmin(distances, dim=1)  # [batch*num_patches]indices = indices.reshape(batch, num_patches)  # 恢复原始形状quantized = self.codebook(indices)  # [batch, num_patches, dim]return indices, quantized
class ViTDecoder(nn.Module):def __init__(self, in_dim=512):super().__init__()# 反向映射ViT的patch嵌入self.proj = nn.Linear(in_dim, 768)config = ViTConfig()config.is_decoder = True  # 网页7中的解码器模式self.transformer = ViTModel(config).encoder  self.head = nn.Sequential(# 14x14 -> 28x28nn.ConvTranspose2d(768, 384, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 28x28 -> 56x56nn.ConvTranspose2d(384, 192, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 56x56 -> 112x112 nn.ConvTranspose2d(192, 96, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 112x112 -> 224x224nn.ConvTranspose2d(96, 48, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 最终调整到3通道nn.Conv2d(48, 3, kernel_size=1))def forward(self, x):x = self.proj(x)  # [batch, 196, 768]x = self.transformer(x).last_hidden_statex = x.permute(0, 2, 1).view(-1, 768, 14, 14)  # 恢复空间布局return self.head(x)  # 输出[1, 3, 224, 224]
# encoder = ViTEncoder()
# codebooker = Codebook()
# decoder = ViTDecoder()# data = torch.randn(1, 3, 224, 224)
# output = encoder(data)
# print(output.shape)
# indices, quantized = codebooker.quantize(output)
# print(indices.shape, quantized.shape)
# reconstructed = decoder(quantized)
# print(reconstructed.shape)from torchvision import transforms
import torchvision
import torch.nn.functional as F
# 数据增强和预处理
transform_train = transforms.Compose([transforms.Resize(224),  # 调整图像尺寸适配模型transforms.RandomCrop(224, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])transform_test = transforms.Compose([transforms.Resize(224),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# trainloader = torch.DataLoader(trainset, batch_size=64, shuffle=True)
# 加载CIFAR-10数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)batch_size = 4  # 增大batch size加速训练
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import vgg16# 初始化TensorBoard
writer = SummaryWriter('runs/codebook_experiment')# 改进的Codebook类(增加EMA更新)
class Codebook(nn.Module):def __init__(self, num_embeddings=1024, embedding_dim=512, commitment_cost=0.25, decay=0.99):super().__init__()self.codebook = nn.Embedding(num_embeddings, embedding_dim)nn.init.normal_(self.codebook.weight)self.commitment_cost = commitment_costself.decay = decayself.register_buffer('ema_cluster_size', torch.zeros(num_embeddings))self.ema_w = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))nn.init.normal_(self.ema_w)def quantize(self, z):# 重塑输入为二维矩阵 [batch*num_patches, embedding_dim]batch, num_patches, dim = z.shapez_flat = z.reshape(-1, dim)  # [batch*num_patches, dim]# 计算L2距离 ||z - e||^2 = ||z||^2 - 2<z,e> + ||e||^2z_norm = torch.sum(z_flat ** 2, dim=1, keepdim=True)  # [batch*num_patches, 1]e_norm = torch.sum(self.codebook.weight ** 2, dim=1)  # [num_embeddings]dot_product = torch.matmul(z_flat, self.codebook.weight.t())  # [batch*num_patches, num_embeddings]distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)# 找到最近邻indices = torch.argmin(distances, dim=1)  # [batch*num_patches]indices = indices.reshape(batch, num_patches)  # 恢复原始形状quantized = self.codebook(indices)  # [batch, num_patches, dim]# 新增EMA更新# if self.training:#     with torch.no_grad():#         encodings = F.one_hot(indices, self.codebook.num_embeddings).float()#         self.ema_cluster_size = self.decay * self.ema_cluster_size + (1 - self.decay) * torch.sum(encodings, 0)#         n = torch.sum(self.ema_cluster_size)#         self.ema_cluster_size = ((self.ema_cluster_size + 1e-5) / (n + self.codebook.num_embeddings * 1e-5) * n)#         dw = torch.matmul(encodings.t(), z_flat)#         self.ema_w = nn.Parameter(self.ema_w * self.decay + (1 - self.decay) * dw)#         self.codebook.weight.data = self.ema_w / self.ema_cluster_size.unsqueeze(1)return indices, quantized
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 初始化组件
encoder = ViTEncoder().to(device)
codebook = Codebook(commitment_cost=0.25, decay=0.95).to(device)
decoder = ViTDecoder().to(device)
vgg = vgg16(pretrained=True).features[:16].eval().to(device)  # 用于感知损失# 优化器分开设置
opt = torch.optim.Adam([{'params': encoder.parameters()},{'params': decoder.parameters()},{'params': codebook.parameters(), 'lr': 1e-4}  # 更小的学习率
], lr=3e-4)# 训练循环
for epoch in range(100):avg_loss = 0start_time = time.time()  # 记录epoch开始时间for batch_idx, (images, _) in enumerate(tqdm(trainloader, desc=f"Epoch {epoch+1}", ncols=80)):images = images.to(device)# 前向传播z = encoder(images)indices, quantized = codebook.quantize(z)recon = decoder(quantized)# 多目标损失计算mse_loss = F.mse_loss(recon, images)# 感知损失(VGG特征匹配)with torch.no_grad():real_features = vgg(images)recon_features = vgg(recon)percep_loss = F.mse_loss(recon_features, real_features)# Codebook相关损失commitment_loss = codebook.commitment_cost * F.mse_loss(z.detach(), quantized)codebook_loss = F.mse_loss(z, quantized.detach())# 总损失total_loss = mse_loss + 0.1*percep_loss + codebook_loss + commitment_loss# 反向传播opt.zero_grad()total_loss.backward()opt.step()# 记录数据avg_loss += total_loss.item()if batch_idx % 50 == 0:# 记录TensorBoard数据writer.add_scalar('Loss/total', total_loss.item(), epoch*len(trainloader)+batch_idx)writer.add_scalars('Loss/components', {'mse': mse_loss.item(),'perceptual': percep_loss.item(),'codebook': codebook_loss.item(),'commitment': commitment_loss.item()}, epoch*len(trainloader)+batch_idx)# 保存重建样本comparison = torch.cat([images[:4], recon[:4]])grid = vutils.make_grid(comparison.cpu(), nrow=4, normalize=True)writer.add_image('Reconstruction', grid, epoch*len(trainloader)+batch_idx)# 打印epoch统计信息avg_loss /= len(trainloader)print(f"Epoch {epoch+1}: Avg Loss {avg_loss:.4f}")# 保存模型检查点if (epoch+1) % 10 == 0:torch.save({'encoder': encoder.state_dict(),'codebook': codebook.state_dict(),'decoder': decoder.state_dict(),'opt': opt.state_dict()}, f'checkpoint_epoch{epoch+1}.pth')writer.close()

通过本实践,我们实现了从特征提取到离散表征学习的完整流程。Codebook技术可广泛应用于图像压缩、生成模型等领域,期待读者在此基础上探索更多可能性。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词