欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 汽车 > 新车 > 从零开始搭建CLIP模型实现基于文本的图像检索

从零开始搭建CLIP模型实现基于文本的图像检索

2025/4/20 21:21:35 来源:https://blog.csdn.net/weixin_49295405/article/details/147296916  浏览:    关键词:从零开始搭建CLIP模型实现基于文本的图像检索

目录

  • CLIP原理简介
  • 代码实现
  • 参考链接

CLIP原理简介

论文链接,源码链接

CLIP模型由OpenAI在2021年提出,利用双Decoder(Dual Encoder)的架构来学习图像和文本之间的对应关系,是多模态大模型的开创之作,为后续许多高效的多模态模型的提出打下基础。CLIP是一个预训练模型(Pre-trained Model),在学习到图像–文本特征之间的关联后可以迁移到各种下游任务中,如图像分类,文本引导图像分割和目标检测,图像文本检索等。由于模型学习到的是文本语义和图像语义之间的关联,使得其zero-shot能力非常强大,根据论文中的描述,CLIP在很多数据集上zero-shot的结果甚至超越了许多训练好的模型的效果。CLIP的训练范式如下:

![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/1d112d364a60434bba8dd07d42d2a1c6.png

CLIP的结构非常简单,数据集包含大量的图像文本对,图像经过图像编码器得到图像特征,文本经过文本编码器得到文本特征,将图像特征和文本特征按照数据集中的对应关系进行配对,不配对的特征给予惩罚,从上图中可以看出,我们希望矩阵中蓝色的值趋近于1,其余值趋近于0,采用对比学习的方式对模型进行训练,算法的伪代码如下:

在这里插入图片描述
从损失函数中可以看出,分别对特征对比矩阵的行和列进行交叉熵损失函数计算,并取平均得到最终的loss。图像编码器一般有两种选择:ResNet和ViT;文本编码器采用Transformer Encoder,均是各自领域中优秀的特征提取网络。
CLIP的推理范式如下:

在这里插入图片描述
在推理阶段,图像编码器中输入图像获取图像特征,文本编码器中输入文本获取文本特征,将图像特征向量和文本特征向量的转置相乘得到每张图像对每个文本的特征相似度,相似度最高的文本即描述了该图像中物体所属的类别。

代码实现

Flickr8k数据集下载,提取码:fbfz
DistilBert模型文件下载

我的运行环境:
CUDA 11.8
pytorch 2.2.2
transformers 4.44.0 # 用于从HuggingFace上加载预训练模型


数据集预览:
图片示例

图片示例

在这里插入图片描述

文本示例

由于作者的显卡算力有限,选取Flickr8k数据集进行模型训练,其中包含8k个图像文本对,其中一张图像对应5条文本。图像编码器采用ResNet50,直接从timm库中导入;文本编码器采用DistilBert,即轻量化的Bert模型,从HuggingFace上下载。闲话少说,小二,上菜!

### 模型参数配置 ###
import argparse
from dataclasses import dataclassparser = argparse.ArgumentParser(description="CLIP from zero")
parser.add_argument("--image_dir", default="user/Flickr8k/Images", help='path to image folder')  # 存放图像的文件路径
parser.add_argument("--caption_dir", default="user/Flickr8k", help='path to caption folder')  # 存放文本的文件路径
parser.add_argument("--weight_dir", default='user/checkpoints', help='path to save output weight')  # 存放训练权重的文件路径
args = parser.parse_args()@dataclass
class CLIPConfig:image_path: str = args.image_dir  # 图像存放路径image_size: int = 224  # resize后的图像尺寸,便于构建Dataloadercaption_path: str = args.caption_dir  # 文本存放路径batch_size: int = 8  # 一个批次中的数据数量epochs: int = 3  # 训练世代image_encoder_model: str = "resnet50"  # 图像编码器的名称image_embedding_dim: int = 2048  # 图像特征的维度text_encoder_model: str = "distilbert-base-uncased"  # 文本编码器的名称text_embedding_dim: int = 768  # 文本特征的维度text_tokenizer: str = text_encoder_model  # 文本分词器模型的名称max_length: int = 200  # 文本编码器可输入的最长文本长度pretrained: bool = False  # 是否加载预训练好的编码器trainable: bool = True  # 在训练过程中是否更新编码器的参数temperature: float = 1.0  # 计算loss时的正则化系数proj_dim: int = 256  # 图像特征和文本特征统一后的维度dropout_rate: float = 0.1  # dropout系数,避免过拟合### 载入数据集并初始化 ###
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
import albumentations as A
import pandas as pd
import cv2class CLIPDataset(Dataset):def __init__(self, config, image_path, caption_path, transforms=True):"""图片文件名和标题的长度必须相同如果一个图片对应多个标题,该图片文件名需要重复多次"""self.image_path = image_path  # 图像路径self.caption_path = caption_path  # 文本路径self.dataframe = pd.read_csv(f"{self.caption_path}/captions.csv")  # 读取文本self.tokenizer = DistilBertTokenizer.from_pretrained(config.text_tokenizer)  # 载入分词器self.image_filenames = self.dataframe["image"].values  # 获取图像文件名self.captions = list(self.dataframe["caption"].values)   # 获取图像对应的描述文本self.encoded_captions = self.tokenizer(self.captions, padding=True, truncation=True, max_length=config.max_length)  # 文本分词self.transforms = transforms  # 对输入图像进行预处理def __getitem__(self, idx):  # 获取数据集中第idx个数据,其中包含图片名称和对应的标题(可能不止一个)item = {key: torch.tensor(values[idx]) for key, values in self.encoded_captions.items()}image = cv2.imread(f"{self.image_path}/{self.image_filenames[idx]}")  # 获取原始图像image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)if self.transforms:image = self.get_transforms(mode="train")(image=image)["image"]  # 对图像进行预处理item["image"] = torch.tensor(image).permute(2, 0, 1).float()  # 将图片转换为tensor格式,并调整为RGB顺序item["caption"] = self.captions[idx]  # 获取标题return itemdef __len__(self):return len(self.captions)  # 获取文本长度def get_transforms(self, mode="train"):if mode == "train":return A.Compose([A.Resize(config.image_size, config.image_size, always_apply=True),  # 对图像进行resizeA.Normalize(max_pixel_value=255.0, always_apply=True)  # 对像素值进行归一化])### 图像编码器 ###
import torch.nn as nn
import timmclass ImageEncoder(nn.Module):"""图像编码器,采用ResNet50"""def __init__(self, config):super().__init__()self.model = timm.create_model(config.image_encoder_model, pretrained=config.pretrained, num_classes=0, global_pool="avg")  # 创建ResNet50for p in self.model.parameters():p.requires_grad = config.trainable  # 设置参数可训练def forward(self, x):image_encoded = self.model(x)  # 获得图像特征编码,形状为[batch_size, image_embedding_dim]return image_encoded### 文本编码器 ###
class TextEncoder(nn.Module):"""文本编码器,采用DistilBERT"""def __init__(self, config):super().__init__()if config.pretrained:self.model = DistilBertModel.from_pretrained(config.text_encoder_model)  # 导入下载好的模型文件else:self.model = DistilBertModel(DistilBertConfig())for p in self.model.parameters():p.requires_grad = config.trainable  # 设置参数可训练self.target_token_idx = 0# 提取出和图像对应的文本特征向量def forward(self, input_ids, attention_mask):output = self.model(input_ids=input_ids, attention_mask=attention_mask)text_encoded = output.last_hidden_state[:, self.target_token_idx, :]  # [batch_size, text_embedding_dim]return text_encoded### 投影层 (MLP) ###
class ProjectionHead(nn.Module):"""将图像编码和文本编码映射到相同维度"""def __init__(self, config, input_embedding_dim):super().__init__()self.proj = nn.Linear(input_embedding_dim, config.proj_dim)self.act_fn = nn.GELU()self.fc = nn.Linear(config.proj_dim, config.proj_dim)self.dropout = nn.Dropout(config.dropout_rate)self.layer_norm = nn.LayerNorm(config.proj_dim)def forward(self, x):x_proj = self.proj(x)x = self.act_fn(x_proj)x = self.fc(x)x = self.dropout(x)x = x + x_projx = self.layer_norm(x)return x### 定义损失函数 ###
def cross_entropy(logits, labels, reduction='none'):log_softmax = nn.LogSoftmax(dim=-1)loss = (-labels * log_softmax(logits)).sum(dim=1)if reduction == 'mean':return loss.mean()else:return loss.sum()### 模型主体 ###
import torch.nn.functional as Fclass CLIP(nn.Module):def __init__(self, config):super().__init__()self.image_encoder = ImageEncoder(config)  # 实例化图像编码器self.text_encoder = TextEncoder(config)  # 实例化文本编码器self.image_proj = ProjectionHead(config, config.image_embedding_dim)  # 图像特征投影self.text_proj = ProjectionHead(config, config.text_embedding_dim)  # 文本特征投影self.temperature = config.temperaturedef forward(self, batch):image_features = self.image_encoder(batch["image"])  # 图像编码# 文本编码,tokenizer处理后的文本序列自带input_ids和attention_masktext_features = self.text_encoder(batch["input_ids"], batch["attention_mask"])image_embeddings = self.image_proj(image_features)  # 图像特征投影text_embeddings = self.text_proj(text_features)  # 文本特征投影logits = (text_embeddings @ image_embeddings.T) / self.temperature  # tensor形状为[batch_size, batch_size]images_similarity = image_embeddings @ image_embeddings.T  # tensor形状为[batch_size, batch_size]text_similarity = text_embeddings @ text_embeddings.T  # tensor形状为[batch_size, batch_size]# 软标签,不配对的位置设置为较小的数,而非0labels = F.softmax((images_similarity + text_similarity) / 2 * self.temperature, dim=-1)  loss_T = cross_entropy(logits, labels)  # 计算文本损失loss_I = cross_entropy(logits.T, labels.T)  # 计算图像损失total_loss = (loss_T + loss_I) / 2  # 对比学习平均损失return total_loss, logits### 训练函数 ###
def train(model, optimizer, scheduler, train_loader, device):model.train()  # 模型设置为训练模式train_loss = 0train_loader = tqdm(train_loader, total=len(train_loader))  # 显示训练进度条cnt = 0for batch in train_loader:# print(batch.keys())cnt += 1batch = {k: v.to(device) for k, v in batch.items() if k != "caption"}  # 将dataloader中一个batch的数据转换为字典形式loss, _ = model(batch)optimizer.zero_grad()loss.backward()optimizer.step()scheduler.step(metrics=loss.item())  # 根据上次训练的损失更新学习率train_loss += loss.item()# 训练100个batch显示一次lossif cnt % 100 == 0:print(f' ==> Epoch: {epoch + 1}, Batch: {cnt}, Loss: {loss.item():.4f}')return train_loss / len(train_loader)  # 平均训练损失### 测试函数 ###
def eval(model, val_loader, device):model.eval()  # 模型设置为测试模式val_loss = 0val_loader = tqdm(val_loader, total=len(val_loader))with torch.no_grad():for batch in val_loader:batch = {k: v.to(device) for k, v in batch.items() if k != "caption"}loss, _ = model(batch)val_loss += loss.item()return val_loss / len(val_loader)  # 平均测试损失if __name__ == '__main__':config = CLIPConfig()  # 实例化配置信息model = CLIP(config)  # 实例化CLIP模型device = "cuda" if torch.cuda.is_available() else "cpu"model = model.to(device)# 查看模型的总参数量total_params = sum(p.numel() for p in model.parameters())print(f"Total parameters: {total_params / 1e6} M")optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3)scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=2, factor=0.5)dataset = CLIPDataset(config, args.image_dir, args.caption_dir)  # 读取并预处理数据train_dataset, val_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])  # 80%为训练数据,20%为测试数据dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=False)train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)# 开始训练best_loss = float("inf")for epoch in range(config.epochs):print(f"Epoch: {epoch + 1}")train_loss_avg = train(model, optimizer, scheduler, train_loader, device)val_loss_avg = eval(model, val_loader, device)if val_loss_avg < best_loss:best_loss = val_loss_avgtorch.save(model.state_dict(), f'{args.weight_dir}' + f'/CLIP_{epoch}.pth')print("Best model saved!")# 图像文本检索推理并可视化# dataframe = pd.read_csv(f"{config.caption_path}/captions.csv")# tokenizer = DistilBertTokenizer.from_pretrained(config.text_tokenizer)# model.load_state_dict(torch.load(f'{args.weight_dir}' + f'/CLIP_1.pth', map_location=device))# model.eval()# # image_embeddings = []# with torch.no_grad():#     for batch in tqdm(dataloader):#         image_features = model.image_encoder(batch["image"].to(device))  # 获取图像特征#         cur_image_embeddings = model.image_proj(image_features)  # [batch_size, proj_dim]  # 图像特征投影#         image_embeddings.append(cur_image_embeddings)  # 将一个batch的图像特征保存# # image_embeddings = torch.cat(image_embeddings, dim=0)  # [image_number, proj_dim]# input_query = "two dogs sitting on the grass"  # 输入文本# image_filenames = dataframe["image"].values  # 待检索的图片# # encoded_query = tokenizer([input_query])  # 对输入文本进行分词# batch = {key: torch.tensor(values).to(device) for key, values in encoded_query.items()}# # with torch.no_grad():#     text_features = model.text_encoder(batch["input_ids"], batch["attention_mask"])  # 获取文本特征#     text_embeddings = model.text_proj(text_features)  # 文本特征投影,与图像特征维度一致# # image_embeddings_n = F.normalize(image_embeddings, dim=-1)  # [image_number, proj_dim]# text_embeddings_n = F.normalize(text_embeddings, dim=-1)  # [1, proj_dim]# dot_similarity = text_embeddings_n @ image_embeddings_n.T  # 输入文本的特征和数据集中每张图像特征之间的相似度# # values, indices = torch.topk(dot_similarity.squeeze(0), k=45)  # 获取前45个相似度最高的图像# matches = [image_filenames[idx] for idx in indices[::5]]  # 获取对应的图像文件名(9张图像)# # f, axes = plt.subplots(3, 3, figsize=(10, 10))# f.suptitle(f"Retrieving text: {input_query}")  # 设置主标题# for match, ax in zip(matches, axes.flatten()):  # 显示检索出的图像#     image = cv2.imread(f"{args.image_dir}/{match}")#     image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)#     ax.imshow(image)#     ax.axis("off")# # plt.show()

理想结果:

在这里插入图片描述

参考链接

https://towardsdatascience.com/simple-implementation-of-openai-clip-model-a-tutorial-ace6ff01d9f2/

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词