前言:个人之见,一个神经网络网络源码出现,你先看数据集的输入和输出,而这数据集肯定要包括数据增加和制作数据集,第二 看模型的输入和输出(至于模型内部可以自己看论文 无非就是加了几个组件),然后根据输出选择的损失函数。至于学习率和优化器 差不多都是余弦退火和admw的优化器
1.数据集
直接实战,首先你看它的readme,它一般由标注文件的格式(一般都是 文件路径 + 对应的标签数字)(要求自己制作)
输入一般都是这个标注文件,输出一般都是元组或者字典。
数据增强一般包含在数据集的制作当中
actionclip
数据增强(空间剪裁)
数据增强源码
from datasets.transforms_ss import *
from RandAugment import RandAugmentclass GroupTransform(object):def __init__(self, transform):self.worker = transformdef __call__(self, img_group):return [self.worker(img) for img in img_group]def get_augmentation(training, config):input_mean = [0.48145466, 0.4578275, 0.40821073]input_std = [0.26862954, 0.26130258, 0.27577711]scale_size = config.data.input_size * 256 // 224if training:unique = torchvision.transforms.Compose([GroupMultiScaleCrop(config.data.input_size, [1, .875, .75, .66]),GroupRandomHorizontalFlip(is_sth='some' in config.data.dataset),GroupRandomColorJitter(p=0.8, brightness=0.4, contrast=0.4,saturation=0.2, hue=0.1),GroupRandomGrayscale(p=0.2),GroupGaussianBlur(p=0.0),GroupSolarization(p=0.0)])else:unique = torchvision.transforms.Compose([GroupScale(scale_size),GroupCenterCrop(config.data.input_size)])common = torchvision.transforms.Compose([Stack(roll=False),ToTorchFormatTensor(div=True),GroupNormalize(input_mean,input_std)])return torchvision.transforms.Compose([unique, common])def randAugment(transform_train,config):print('Using RandAugment!')transform_train.transforms.insert(0, GroupTransform(RandAugment(config.data.randaug.N, config.data.randaug.M)))return transform_train
这个数据增强 你可以直接 参考()
一般直接蕴含在数据集
def __init__(self, list_file, labels_file,num_segments=1, new_length=1,image_tmpl='img_{:05d}.jpg', transform=None,random_shift=True, test_mode=False, index_bias=1):def get(self, record, indices):images = list()for i, seg_ind in enumerate(indices):p = int(seg_ind)try:seg_imgs = self._load_image(record.path, p)except OSError:print('ERROR: Could not read image "{}"'.format(record.path))print('invalid indices: {}'.format(indices))raiseimages.extend(seg_imgs)process_data = self.transform(images)return process_data, record.label
- 空间剪裁 无疑就是进行多少词crop 你得了解一手 ranaugment函数
数据集的制作(时间剪裁以及帧数实现)
- 输入
actionclip的标注文件为:
/public/datasets/kinetics400/data2/extracted_train_frames/bowling/HfI4vN2vbHU_000000_000010 289 31
/public/datasets/kinetics400/data2/extracted_train_frames/bookbinding/B8FXlmO5zk4_000079_000089 240 29
/public/datasets/kinetics400/data2/extracted_train_frames/abseiling/XsEw1vd32l8_000052_000062 300 0
/public/datasets/kinetics400/data2/extracted_train_frames/belly_dancing/r61D2lDCHsM_000268_000278 240 18
/public/datasets/kinetics400/data2/extracted_train_frames/abseiling/4sCQ-EX6cIg_000021_000031 300 0
/public/datasets/kinetics400/data2/extracted_train_frames/bowling/N9mQC7MeZCk_000008_000018 300 31
/public/datasets/kinetics400/data2/extracted_train_frames/air_drumming/fzVhIrMnY-E_000322_000332 250 1
/public/datasets/kinetics400/data2/extracted_train_frames/blasting_sand/6dLNI2BPTY0_000057_000067 250 23
/public/datasets/kinetics400/data2/extracted_train_frames/bookbinding/othYtMhFdOU_000020_000030 250 29
/public/datasets/kinetics400/data2/extracted_train_frames/belly_dancing/JVSxlojnBYk_000047_000057 300 18
/public/datasets/kinetics400/data2/extracted_train_frames/air_drumming/8jO9DeYLruU_000003_000013 300 1
/public/datasets/kinetics400/data2/extracted_train_frames/belly_dancing/pU12_c-XvU_000045_000055 300 18
/public/datasets/kinetics400/data2/extracted_train_frames/belly_dancing/x6rP9b1V7sQ_000060_000070 250 18
/public/datasets/kinetics400/data2/extracted_train_frames/blasting_sand/jqC2SnFAvoM_000092_000102 300 23
/public/datasets/kinetics400/data2/extracted_train_frames/bowling/ri6AwOp59yA_000009_000019 250 31
/public/datasets/kinetics400/data2/extracted_train_frames/air_drumming/wRaacvxMoc8_000014_000024 150 1
/public/datasets/kinetics400/data2/extracted_train_frames/abseiling/7kbO0v4hag_000107_000117 300 0
/public/datasets/kinetics400/data2/extracted_train_frames/bookbinding/GjtR9KZbV3Y_000494_000504 300 29
/public/datasets/kinetics400/data2/extracted_train_frames/abseiling/hwUQqFadvE_000048_000058 250 0
/public/datasets/kinetics400/data2/extracted_train_frames/bookbinding/vXmgE41UnBk_000844_000854 300 29
/public/datasets/kinetics400/data2/extracted_train_frames/air_drumming/dglCzcubsw_000246_000256 159 1
/public/datasets/kinetics400/data2/extracted_train_frames/bowling/ri1H0ygN3Us_000768_000778 300 31
/public/datasets/kinetics400/data2/extracted_train_frames/belly_dancing/n24zV9OtorU_000257_000267 300 18
/public/datasets/kinetics400/data2/extracted_train_frames/abseiling/nKoqxSJcZn8_000071_000081 250 0
/public/datasets/kinetics400/data2/extracted_train_frames/air_drumming/pT2byS0qiZM_000001_000011 150 1
/public/datasets/kinetics400/data2/extracted_train_frames/bookbinding/CMo6AJhtZo_000075_000085 250 29
视频提起帧 视频总帧数 对应的标签数字
- 输出
一般看__getitem_
def __getitem__(self, index):record = self.video_list[index]segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record)return self.get(record, segment_indices)def __call__(self, img_group):return [self.worker(img) for img in img_group]def get(self, record, indices):images = list()for i, seg_ind in enumerate(indices):p = int(seg_ind)try:seg_imgs = self._load_image(record.path, p)except OSError:print('ERROR: Could not read image "{}"'.format(record.path))print('invalid indices: {}'.format(indices))raiseimages.extend(seg_imgs)process_data = self.transform(images)return process_data, record.label
返回元组 (images,labes)
- 帧数 一般num_segment由这个决定 为什么?
因为我看顶刊 基本上 一个片段抽一政数,这个无疑由片段决定 - 时间剪裁
时间剪裁指的是从视频的时间维度上选取特定的帧(验证数据集)
def _get_val_indices(self, record):if self.num_segments == 1:return np.array([record.num_frames //2], dtype=np.int) + self.index_biasif record.num_frames <= self.total_length:if self.loop:return np.mod(np.arange(self.total_length), record.num_frames) + self.index_biasreturn np.array([i * record.num_frames // self.total_lengthfor i in range(self.total_length)], dtype=np.int) + self.index_biasoffset = (record.num_frames / self.num_segments - self.seg_length) / 2.0return np.array([i * record.num_frames / self.num_segments + offset + jfor i in range(self.num_segments)for j in range(self.seg_length)], dtype=np.int) + self.index_bias
帧数不足时
当 self.loop 为 True 时,通过 np.mod(np.arange(self.total_length), record.num_frames) 循环选取视频帧,确保选取的帧数达到 self.total_length,这是一种时间剪裁方式,通过循环利用现有帧来满足所需的帧数。
当 self.loop 为 False 时,使用 i * record.num_frames // self.total_length 均匀地从视频中选取 self.total_length 帧,同样实现了时间维度上的剪裁。
在视频帧数充足的情况下,先根据 self.num_segments 划分片段,然后在每个片段内选取连续的 self.seg_length 帧。offset 确保每个片段内选取的帧在片段中处于相对居中的位置,通过这种方式实现了在每个片段内的时间剪裁。
x-clip
数据集
1.参考一下这一篇 关于数据集的输入输出
2 讲一下时间剪裁
val_pipeline = [dict(type='DecordInit'),dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=config.DATA.NUM_FRAMES, test_mode=True),dict(type='DecordDecode'),dict(type='Resize', scale=(-1, scale_resize)),dict(type='CenterCrop', crop_size=config.DATA.INPUT_SIZE),dict(type='Normalize', **img_norm_cfg),dict(type='FormatShape', input_format='NCHW'),dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),dict(type='ToTensor', keys=['imgs'])]if config.TEST.NUM_CROP == 3:val_pipeline[3] = dict(type='Resize', scale=(-1, config.DATA.INPUT_SIZE))val_pipeline[4] = dict(type='ThreeCrop', crop_size=config.DATA.INPUT_SIZE)if config.TEST.NUM_CLIP > 1:val_pipeline[1] = dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=config.DATA.NUM_FRAMES, multiview=config.TEST.NUM_CLIP)
multiview=config.TEST.NUM_CLIP)
无疑是控制为时间剪裁的数量
3 空间剪裁
val_pipeline[4] = dict(type='ThreeCrop', crop_size=config.DATA.INPUT_SIZE)
这个更加直观了直接剪了三次 所以为3
2 模型
action-clip
从输入而言:
- 文本
classes, num_text_aug, text_dict = text_prompt(train_data)
class为( num_text_augxnum_class,context)
text_dict为(num_class,context)
num_text_aug为填充内容长度
text_id = numpy.random.randint(num_text_aug,size=len(list_id))texts = torch.stack([text_dict[j][i,:] for i,j in zip(list_id,text_id)])
分为了(B,context)
- 图片
images = images.view((-1,config.data.num_segments,3)+images.size()[-2:])b,t,c,h,w = images.size()images= images.to(device).view(-1,c,h,w )
这个论文严格意义上 是借用 clip的编码器 所以它压缩了
输出也简单
- 文件
text_embedding = model_text(texts)(b,d) - 图片
image_embedding = model_image(images)image_embedding = image_embedding.view(b,t,-1)image_embedding = fusion_model(image_embedding)
关于这个fusion输出x.mean(dim=1, keepdim=False)
会把t压缩 x 变成了 (b,d)
x-clip
- 文本
text_labels = generate_text(train_data)
这个为(num_class(k),77)
(和上面同理),但是它没有转为样本数 - 图片
images = images.view((-1, config.DATA.NUM_FRAMES, 3) + images.size()[-2:])
它内部实现了一个编码器
def encode_video(self, image):b,t,c,h,w = image.size()image = image.reshape(-1,c,h,w)cls_features, img_features = self.encode_image(image)img_features = self.prompts_visual_ln(img_features)img_features = img_features @ self.prompts_visual_projcls_features = cls_features.view(b, t, -1)img_features = img_features.view(b,t,-1,cls_features.shape[-1])video_features = self.mit(cls_features)return video_features, img_features
image = image.reshape(-1,c,h,w) 内部化了
输出:
logit_scale = self.logit_scale.exp()logits = torch.einsum("bd,bkd->bk", video_features, logit_scale * text_features)return logits
返回了一个b k 相似度得分