欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 文旅 > 文化 > YOLOv6-4.0部分代码阅读笔记-engine.py

YOLOv6-4.0部分代码阅读笔记-engine.py

2025/1/19 5:31:42 来源:https://blog.csdn.net/m0_58169876/article/details/143533945  浏览:    关键词:YOLOv6-4.0部分代码阅读笔记-engine.py

engine.py

yolov6\core\engine.py

目录

engine.py

1.所需的库和模块

2.class Trainer: 


1.所需的库和模块

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# 主要负责模型的训练过程。
from ast import Pass
import os
import time
from copy import deepcopy
import os.path as ospfrom tqdm import tqdmimport cv2
import numpy as np
import math
import torch
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriterimport tools.eval as eval
from yolov6.data.data_load import create_dataloader
from yolov6.models.yolo import build_model
from yolov6.models.yolo_lite import build_model as build_lite_modelfrom yolov6.models.losses.loss import ComputeLoss as ComputeLoss
from yolov6.models.losses.loss_fuseab import ComputeLoss as ComputeLoss_ab
from yolov6.models.losses.loss_distill import ComputeLoss as ComputeLoss_distill
from yolov6.models.losses.loss_distill_ns import ComputeLoss as ComputeLoss_distill_nsfrom yolov6.utils.events import LOGGER, NCOLS, load_yaml, write_tblog, write_tbimg
from yolov6.utils.ema import ModelEMA, de_parallel
from yolov6.utils.checkpoint import load_state_dict, save_checkpoint, strip_optimizer
from yolov6.solver.build import build_optimizer, build_lr_scheduler
from yolov6.utils.RepOptimizer import extract_scales, RepVGGOptimizer
from yolov6.utils.nms import xywh2xyxy
from yolov6.utils.general import download_ckpt

2.class Trainer: 

class Trainer:# 1.args :这是一个包含命令行参数的对象,可能包括训练配置、数据路径、模型参数等。# 2.cfg :这是一个配置对象,通常从配置文件加载,包含模型结构、训练策略、优化器设置等详细信息。# 3.device :这是指定的训练设备,可以是CPU或GPU,用于指示模型和数据应该在哪个设备上进行训练。def __init__(self, args, cfg, device):self.args = argsself.cfg = cfgself.device = device# 检查 args 中是否有 resume 属性,这个属性通常用于指示是否从之前训练的检查点(checkpoint)恢复训练。if args.resume:# torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)# torch.load 是 PyTorch 中用于加载保存的 PyTorch 对象的函数。这个函数可以加载之前使用 torch.save 保存的文件,这些文件可以包含模型参数、优化器状态、张量等。# 参数说明 :# 1. f :要加载的文件名或文件对象。可以是一个字符串路径,也可以是一个文件对象。# 2. map_location :一个字符串或一个函数,指定如何映射存储位置。默认为  None  ,意味着使用默认存储位置。如果设置为 'cpu' ,那么所有张量都会映射到CPU上。# 如果设置为 'cuda:device_id' ,那么张量会被映射到指定的GPU上。也可以是一个函数,该函数接受一个存储位置字符串并返回一个新的存储位置字符串。# 3. pickle_module :用于序列化和反序列化的模块。默认为 Python 标准库中的 pickle 。# 4. **pickle_load_args :传递给 pickle.load 函数的其他参数。# 返回值 :# 加载的 PyTorch 对象,这可以是张量、模型、优化器状态等。# 如果 args.resume 为真,使用PyTorch的 torch.load 函数加载检查点文件。 map_location='cpu' 参数确保检查点文件被加载到CPU上,无论它原来保存在哪个设备上。self.ckpt = torch.load(args.resume, map_location='cpu')# 获取 args 中的 rank 属性,并赋值给 self.rank 。 rank 通常用于分布式训练中标识每个进程的编号。self.rank = args.rank# 获取 args 中的 local_rank 属性,并赋值给 self.local_rank 。 local_rank 标识本地进程的编号,用于GPU上的分布式训练。self.local_rank = args.local_rank# 获取 args 中的 world_size 属性,并赋值给 self.world_size 。 world_size 表示分布式训练中总的进程数。self.world_size = args.world_size# 设置 self.main_process 为布尔值,表示当前进程是否为主进程。在分布式训练中,通常只有一个主进程负责保存模型和输出日志等任务。self.main_process = self.rank in [-1, 0]# 获取 args 中的 save_dir 属性,并赋值给 self.save_dir 。 save_dir 指定了模型检查点和日志文件的保存目录。self.save_dir = args.save_dir# get data loader    获取数据加载器。# def load_yaml(file_path): -> 从 yaml 文件加载数据。函数返回从 YAML 文件中加载的数据,通常是一个字典或列表,具体取决于 YAML 文件的内容。 -> return data_dict# 这行代码调用了一个名为 load_yaml 的函数,用于加载YAML格式的数据集配置文件。 args.data_path 提供了配置文件的路径。加载后的数据被存储在 self.data_dict 中,这是一个字典,包含了数据集的所有配置信息。self.data_dict = load_yaml(args.data_path)# 从 self.data_dict 字典中获取 'nc' 键对应的值,这个值代表数据集中的类别数量( num_classes ),并将其赋值给类的实例变量 self.num_classes 。self.num_classes = self.data_dict['nc']# self.get_data_loader 方法来创建训练和验证数据的加载器。这个方法接受三个参数: args (命令行参数), cfg (配置信息),和 self.data_dict (数据集配置字典)。# 方法的返回值是两个数据加载器,分别用于训练和验证,它们被赋值给 self.train_loader 和 self.val_loader 。# def get_data_loader(args, cfg, data_dict): -> 它用于创建并返回训练和验证数据加载器( DataLoader )。 -> return train_loader, val_loaderself.train_loader, self.val_loader = self.get_data_loader(args, cfg, self.data_dict)# get model and optimizer    获取模型和优化器。# 这行代码检查是否启用了蒸馏( distill )以及模型类型是否为 YOLOv6n 或 YOLOv6s 。如果是,则设置 self.distill_ns 为 True ,表示将使用蒸馏训练小型网络。self.distill_ns = True if self.args.distill and self.cfg.model.type in ['YOLOv6n','YOLOv6s'] else False# 调用 self.get_model 方法来创建模型实例。这个方法接受命令行参数、配置、类别数量和设备作为输入,并返回一个模型实例。# def get_model(self, args, cfg, nc, device): -> 用于根据提供的参数构建和配置YOLOv6模型。返回构建和配置好的模型实例。 -> return modelmodel = self.get_model(args, cfg, self.num_classes, device)# 如果启用了蒸馏,执行以下操作。if self.args.distill:# 检查是否启用了 fuse_ab 。如果启用了蒸馏并且同时启用了 fuse_ab ,则记录错误并退出程序,因为蒸馏模型不应该启用 fuse_ab 。if self.args.fuse_ab:LOGGER.error('ERROR in: Distill models should turn off the fuse_ab.\n')    # 错误:蒸馏模型应该关闭fuse_ab。exit()# 如果启用了蒸馏,调用 self.get_teacher_model 方法来创建教师模型实例。# def get_teacher_model(self, args, cfg, nc, device): # -> 它用于创建和配置一个教师模型,这在知识蒸馏(Knowledge Distillation)的场景中是常见的。返回构建和配置好的教师模型实例。-> return modelself.teacher_model = self.get_teacher_model(args, cfg, self.num_classes, device)# 如果启用了量化,调用 self.quant_setup 方法来设置模型的量化。if self.args.quant:self.quant_setup(model, cfg, device)# 如果训练模式是 repopt (Re-Optimization),则执行以下操作。if cfg.training_mode == 'repopt':# 从预训练模型加载量化缩放因子。scales = self.load_scale_from_pretrained_models(cfg, device)# 如果配置中指定了预训练模型,则不重新初始化模型参数。reinit = False if cfg.model.pretrained is not None else True# 创建一个 RepVGGOptimizer 优化器实例,用于 repopt 训练模式。self.optimizer = RepVGGOptimizer(model, scales, args, cfg, reinit=reinit)# 如果不是 repopt 训练模式,执行以下操作。else:# 调用 self.get_optimizer 方法来创建一个优化器实例。self.optimizer = self.get_optimizer(args, cfg, model)# 调用 self.get_lr_scheduler 方法来创建学习率调度器和学习率函数。self.scheduler, self.lf = self.get_lr_scheduler(args, cfg, self.optimizer)# 如果当前进程是主进程,则创建一个模型指数移动平均(EMA)实例。EMA用于平滑模型参数,提高训练稳定性。self.ema = ModelEMA(model) if self.main_process else None# tensorboard# 这行代码检查当前进程是否为主进程( self.main_process )。如果是主进程,则创建一个TensorBoard的 SummaryWriter 实例,用于记录训练过程中的日志信息,这些信息可以用于TensorBoard可视化。# SummaryWriter 是TensorBoard的一个API,它允许用户将训练过程中的数据(如损失、准确率等)写入日志文件,以便后续使用TensorBoard进行可视化分析。# self.save_dir 是保存TensorBoard日志文件的目录。# 如果当前进程不是主进程,则 self.tblogger 被设置为 None ,这意味着非主进程不会记录TensorBoard日志。self.tblogger = SummaryWriter(self.save_dir) if self.main_process else None# 这行代码设置训练的起始周期(epoch)为0。在深度学习训练中,一个周期指的是对整个数据集进行一次前向和反向传播的过程。这个变量通常用于记录从哪个周期开始训练,特别是在从检查点(checkpoint)恢复训练时。self.start_epoch = 0#resume# 这段代码处理从检查点(checkpoint)恢复模型、优化器、调度器和指数移动平均(EMA)的状态,并更新训练的起始周期。# 检查当前实例是否有 ckpt 属性。 ckpt 通常是一个从文件加载的检查点字典,包含了模型参数、优化器状态、训练周期等信息。if hasattr(self, "ckpt"):# 从检查点中提取模型状态,并确保所有参数都是32位浮点数(FP32),然后获取模型的状态字典。resume_state_dict = self.ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32    检查点 state_dict 作为 FP32。# 将检查点中的状态字典加载到模型中, strict=True 确保状态字典中的所有键都必须与模型中的参数名称匹配。model.load_state_dict(resume_state_dict, strict=True)  # load# 从检查点中提取训练周期,并将其设置为恢复训练的起始周期,即上一周期的下一个周期。self.start_epoch = self.ckpt['epoch'] + 1# 从检查点中加载优化器的状态。self.optimizer.load_state_dict(self.ckpt['optimizer'])# 从检查点中加载调度器的状态。self.scheduler.load_state_dict(self.ckpt['scheduler'])# 检查当前是否为主进程。if self.main_process:# 为主进程加载EMA的状态字典。self.ema.ema.load_state_dict(self.ckpt['ema'].float().state_dict())# 更新EMA的更新次数。self.ema.updates = self.ckpt['updates']# 从数据字典中提取类别数量和类别名称,并更新到模型的属性中。# def parallel_model(args, model, device): -> 用于根据提供的参数配置模型以在不同的并行计算模式下运行。返回包装后的模型,无论是DP还是DDP模式。 -> return modelself.model = self.parallel_model(args, model, device)# 从数据字典中提取类别数量和类别名称,并更新到模型的属性中。self.model.nc, self.model.names = self.data_dict['nc'], self.data_dict['names']# 将命令行参数中的训练周期数( epochs )赋值给 self.max_epoch ,表示训练将要进行的最大周期数。self.max_epoch = args.epochs# 计算训练数据加载器( train_loader )的长度,即总的批次数,并赋值给 self.max_stepnum ,这通常用于确定训练的总步数。self.max_stepnum = len(self.train_loader)# 将命令行参数中的批量大小( batch_size )赋值给 self.batch_size 。self.batch_size = args.batch_size# 将命令行参数中的图像尺寸( img_size )赋值给 self.img_size 。self.img_size = args.img_size# 初始化一个空列表,用于存储用于可视化的图像。self.vis_imgs_list = []# 将命令行参数中的 write_trainbatch_tb 赋值给 self.write_trainbatch_tb ,表示是否在TensorBoard中记录训练批次的信息。self.write_trainbatch_tb = args.write_trainbatch_tb# 设置类名的颜色。# set color for classnames# 为每个类别随机生成一个颜色,并存储在 self.color 列表中。这里使用 np.random.choice 从0到255中随机选择三个数字,代表RGB颜色值, self.model.nc 是类别的数量。self.color = [tuple(np.random.choice(range(256), size=3)) for _ in range(self.model.nc)]# 初始化损失数量为3,对应于 'Epoch' 、 'iou_loss' 和 'dfl_loss' 。self.loss_num = 3# 初始化损失信息列表,包含周期、IoU损失、dfl损失和分类损失。self.loss_info = ['Epoch', 'iou_loss', 'dfl_loss', 'cls_loss']# 如果启用了蒸馏(  self.args.distill  )。if self.args.distill:# # 蒸馏会额外增加一个损失,因此损失数量加1。self.loss_num += 1# 在损失信息列表中添加 'cwd_loss' ,表示蒸馏过程中的类权重蒸馏损失。self.loss_info += ['cwd_loss']# 训练过程。# Training Processdef train(self):# 开始一个异常处理块,用于捕获训练过程中可能发生的任何异常。try:# 在训练循环开始之前执行的操作,可能是一些初始化工作,如准备数据加载器、初始化优化器等。# def train_before_loop(self): -> 它在训练循环开始之前执行一系列初始化操作。self.train_before_loop()# 从 self.start_epoch 开始,遍历到 self.max_epoch 结束, self.epoch 会在每个周期中被更新。for self.epoch in range(self.start_epoch, self.max_epoch):# 对每个周期执行的训练操作,这个方法可能包含每个周期的前向传播、损失计算、反向传播和优化器步骤。self.train_in_loop(self.epoch)# 在所有周期完成后,执行模型压缩或者简化操作,以去除不必要的参数,减少模型大小。self.strip_model()# 如果训练循环或模型保存过程中出现任何异常,捕获该异常并记录错误信息。except Exception as _:# 使用日志记录器记录错误信息。LOGGER.error('ERROR in training loop or eval/save model.')# 重新抛出捕获的异常,以便外部可以进一步处理。raise# 无论是否发生异常,都会执行的代码块。finally:# 在训练循环结束后执行的操作,可能是一些清理工作,如释放资源、保存最终模型等。self.train_after_loop()# Training loop for each epoch    每个时期的训练循环。# 它负责在一个训练周期内执行训练步骤,并在周期结束后进行评估和模型保存。# 1.epoch_num :当前训练周期的编号。def train_in_loop(self, epoch_num):try:# 在开始训练步骤之前执行的准备工作,可能包括数据加载器的初始化、设备设置等。# def prepare_for_steps(self): -> 它在每个训练周期开始之前执行一系列准备工作。self.prepare_for_steps()# 遍历进度条 self.pbar ,它是一个迭代器,提供了当前步数 self.step 和批次数据 self.batch_data 。for self.step, self.batch_data in self.pbar:# 对每个批次执行训练步骤,包括前向传播、损失计算、反向传播和优化器更新。# def train_in_steps(self, epoch_num, step_num): -> 它负责执行单个训练步骤中的操作。self.train_in_steps(epoch_num, self.step)# 打印当前步骤的训练细节,如损失值、精度等。# def print_details(self): -> 打印每个步骤后的损失。更新进度条的描述信息,显示当前周期和平均损失值。self.print_details()# 如果训练步骤中发生任何异常,记录错误信息并重新抛出异常。except Exception as _:LOGGER.error('ERROR in training steps.')    # 训练步骤中出现问题。raise# 开始另一个异常处理块,用于捕获评估和保存模型过程中可能发生的任何异常。try:# 在周期结束后执行模型评估,并根据评估结果决定是否保存模型。self.eval_and_save()except Exception as _:# 如果评估和保存模型过程中发生任何异常,记录错误信息并重新抛出异常。LOGGER.error('ERROR in evaluate and save model.')    # 评估和保存模型时出现错误。raise# 批量数据的训练循环。# Training loop for batchdata# 它负责执行单个训练步骤中的操作。# 1.epoch_num :当前训练周期的编号。# 2.step_num :当前步骤的编号。def train_in_steps(self, epoch_num, step_num):# 对批次数据进行预处理,将数据移动到指定的设备(如GPU),并将其转换为模型训练所需的格式。 self.prepro_data 方法可能包括数据增强、归一化等操作。# def prepro_data(batch_data, device): -> 它用于将输入的批次数据( batch_data )进行预处理,并将其移动到指定的设备(如GPU)。返回预处理后的图像和目标数据。 -> return images, targetsimages, targets = self.prepro_data(self.batch_data, self.device)# plot train_batch and save to tensorboard once an epoch    绘制 train_batch 并每 epoch 保存一次到 tensorboard。# 检查是否需要将训练批次写入TensorBoard,当前进程是否为主进程,以及当前步骤是否为周期的第一步。if self.write_trainbatch_tb and self.main_process and self.step == 0:# 如果条件满足,调用 self.plot_train_batch 方法绘制当前批次的图像和目标,以便可视化。# def plot_train_batch(self, images, targets, max_size=1920, max_subplots=16): -> 它用于将训练批次中的图像和对应的目标(标签)绘制成一幅马赛克图(mosaic),并保存用于可视化。self.plot_train_batch(images, targets)# 使用 write_tbimg 函数将绘制的训练批次图像写入TensorBoard。# 1.self.tblogger :是TensorBoard的 SummaryWriter 实例。# 2.self.vis_train_batch :是要写入的图像。# 3.self.step + self.max_stepnum * self.epoch :计算当前周期和步骤的全局步数。# 4.type='train' :指定图像类型为训练。# def write_tbimg(tblogger, imgs, step, type='train'): -> 将图像数据展示在 TensorBoard 中,用于可视化训练批次和验证过程中的预测结果。write_tbimg(self.tblogger, self.vis_train_batch, self.step + self.max_stepnum * self.epoch, type='train')# 模型的前向传播。# forward# 使用PyTorch的自动混合精度(AMP)上下文管理器来优化模型的前向传播。如果设备不是CPU(即在GPU上训练),则启用AMP以提高性能和减少内存使用。with amp.autocast(enabled=self.device != 'cpu'):# 将图像数据传递给模型进行前向传播,获取模型的预测结果 preds 和特征图 s_featmaps 。preds, s_featmaps = self.model(images)# 如果启用了蒸馏( self.args.distill )。if self.args.distill:# 使用PyTorch的 torch.no_grad() 上下文管理器来禁用梯度计算,这对于教师模型的前向传播是必要的,以避免不必要的计算和内存使用。with torch.no_grad():# 将图像数据传递给教师模型进行前向传播,获取教师模型的预测结果 t_preds 和特征图 t_featmaps 。t_preds, t_featmaps = self.teacher_model(images)# 从命令行参数中获取温度参数,用于调整蒸馏损失的尺度。temperature = self.args.temperature# 计算蒸馏损失,包括学生模型和教师模型的预测结果和特征图,以及目标数据。total_loss, loss_items = self.compute_loss_distill(preds, t_preds, s_featmaps, t_featmaps, targets, \epoch_num, self.max_epoch, temperature, step_num)# 如果启用了AB融合( self.args.fuse_ab )elif self.args.fuse_ab:      # 计算YOLOv6-af部分的损失。 total_loss, loss_items = self.compute_loss((preds[0],preds[3],preds[4]), targets, epoch_num, step_num) # YOLOv6_af# 计算YOLOv6-ab部分的损失。total_loss_ab, loss_items_ab = self.compute_loss_ab(preds[:3], targets, epoch_num, step_num) # YOLOv6_ab# 将两部分的损失相加,得到总损失。total_loss += total_loss_ab# 将两部分的损失项合并。loss_items += loss_items_abelse:# 否则,如果既没有启用蒸馏也没有启用AB融合:total_loss, loss_items = self.compute_loss(preds, targets, epoch_num, step_num) # YOLOv6_af# 如果不是单GPU训练( self.rank != -1 )。if self.rank != -1:# 将总损失乘以world_size(即分布式训练中的进程数),这通常用于在分布式训练中平均损失值。total_loss *= self.world_size# 负责执行反向传播和更新优化器。# backward# 使用自动混合精度(AMP)的梯度缩放器 self.scaler 来缩放总损失 total_loss ,然后调用 backward() 方法执行反向传播。这一步计算损失相对于模型参数的梯度。self.scaler.scale(total_loss).backward()# 将计算得到的损失项 loss_items 保存到实例变量 self.loss_items 中。这些损失项可以用于后续的日志记录、可视化或其他分析。self.loss_items = loss_items# 调用 update_optimizer 方法来更新优化器的状态。这通常包括调用优化器的 step() 方法来更新模型参数,并在需要时执行梯度裁剪或其他自定义的优化步骤。# def update_optimizer(self): -> 负责在训练过程中更新优化器的状态,包括学习率的调整、执行优化步骤和梯度的清零。self.update_optimizer()# 负责决定是否进行模型评估,并在适当的时候保存模型。def eval_and_save(self):# 计算从当前周期到最大周期之前还剩下多少个训练周期。这里 -1 是因为 self.epoch 是从0开始的。remaining_epochs = self.max_epoch - 1 - self.epoch # self.epoch is start from 0# 根据剩余周期数和配置参数 heavy_eval_range 确定评估间隔。如果在 heavy_eval_range 范围内,则使用配置的评估间隔 eval_interval ,否则使用一个较小的间隔(这里是3)。eval_interval = self.args.eval_interval if remaining_epochs >= self.args.heavy_eval_range else 3# 判断当前周期是否应该进行评估。如果只剩下最后一个周期,或者不是只评估最终周期且当前周期是评估间隔的倍数,则需要进行评估。is_val_epoch = (remaining_epochs == 0) or ((not self.args.eval_final_only) and ((self.epoch + 1) % eval_interval == 0))# 只有主进程执行评估和保存模型的操作,以避免在分布式训练中重复执行这些操作。if self.main_process:# 更新指数移动平均(EMA)模型的属性,确保EMA模型的这些属性与主模型保持一致。self.ema.update_attr(self.model, include=['nc', 'names', 'stride']) # update attributes for ema model# 如果当前周期是评估周期( is_val_epoch 为 True )。if is_val_epoch:# 调用 eval_model 方法对模型进行评估。self.eval_model()# 从评估结果中获取当前的平均精度(AP)值,并更新实例变量 self.ap 。self.ap = self.evaluate_results[1]# 比较当前的AP值和之前保存的最佳AP值,如果当前AP值更高,则更新最佳AP值。self.best_ap = max(self.ap, self.best_ap)# 保存检查点。# save ckpt# ckpt = {...} :创建一个字典 ckpt ,用于保存当前训练状态。ckpt = {# 保存模型的状态字典, deepcopy 用于创建模型参数的深拷贝, de_parallel 用于去除模型的并行包装(如 DataParallel 或 DistributedDataParallel ), .half() 将模型参数转换为半精度(FP16)。'model': deepcopy(de_parallel(self.model)).half(),# 保存EMA模型的状态字典,并转换为半精度。'ema': deepcopy(self.ema.ema).half(),# 保存EMA的更新次数。'updates': self.ema.updates,# 保存优化器的状态字典。'optimizer': self.optimizer.state_dict(),# 保存学习率调度器的状态字典。'scheduler': self.scheduler.state_dict(),# 保存当前的周期数。'epoch': self.epoch,# 保存评估结果。'results': self.evaluate_results,}# 使用 osp.join ( os.path.join 的简写)函数将保存根目录 self.save_dir 与子目录 'weights' 连接,形成检查点文件的保存目录。save_ckpt_dir = osp.join(self.save_dir, 'weights')# 调用 save_checkpoint 函数保存检查点。# ckpt :要保存的检查点字典,包含模型参数、优化器状态等。# (is_val_epoch) and (self.ap == self.best_ap) :一个布尔值,指示当前周期是否是评估周期且当前的平均精度(AP)是否等于最佳AP,如果是,则只保存这个检查点。# save_ckpt_dir :检查点的保存目录。# model_name='last_ckpt' :检查点的文件名前缀。# def save_checkpoint(ckpt, is_best, save_dir, model_name=""): -> 将检查点保存到磁盘。使用 shutil.copyfile 函数将当前检查点文件复制为 best_ckpt.pt ,这样可以保留最佳模型的副本。 -> shutil.copyfile(filename, best_filename)save_checkpoint(ckpt, (is_val_epoch) and (self.ap == self.best_ap), save_ckpt_dir, model_name='last_ckpt')# 检查当前周期是否在训练的最后 self.args.save_ckpt_on_last_n_epoch 个周期内。if self.epoch >= self.max_epoch - self.args.save_ckpt_on_last_n_epoch:# 如果是,则调用 save_checkpoint 函数保存检查点,使用当前周期数作为文件名前缀。save_checkpoint(ckpt, False, save_ckpt_dir, model_name=f'{self.epoch}_ckpt')# 默认保存最佳 ap ckpt 以停止强 aug 时期#default save best ap ckpt in stop strong aug epochs# 检查当前周期是否在训练的最后 self.args.stop_aug_last_n_epoch 个周期内。 self.max_epoch 是总的训练周期数, self.args.stop_aug_last_n_epoch 是在训练结束前需要停止增强的周期数。if self.epoch >= self.max_epoch - self.args.stop_aug_last_n_epoch:# 如果当前周期的AP值 self.ap 大于之前记录的最佳AP值 self.best_stop_strong_aug_ap ,则执行以下操作。if self.best_stop_strong_aug_ap < self.ap:# 更新 self.best_stop_strong_aug_ap 为当前周期的AP值和之前记录的最佳AP值中的较大者。self.best_stop_strong_aug_ap = max(self.ap, self.best_stop_strong_aug_ap)# 调用 save_checkpoint 函数保存当前周期的模型检查点。# ckpt :要保存的检查点字典,包含模型参数、优化器状态等。# False :这个参数通常用于指示是否仅在达到更好的评估结果时才保存检查点。在这里,我们忽略这个条件,始终保存检查点。# save_ckpt_dir :检查点的保存目录。# model_name='best_stop_aug_ckpt' :检查点的文件名前缀,标识这是一个在停止增强后达到最佳AP值的检查点。save_checkpoint(ckpt, False, save_ckpt_dir, model_name='best_stop_aug_ckpt')# 删除 ckpt 变量以释放内存。这通常在检查点已经保存到磁盘后执行,以避免不必要的内存占用。del ckpt# 记录学习率。# log for learning rate# 从优化器的参数组中提取当前的学习率,并将其存储在列表 lr 中。lr = [x['lr'] for x in self.optimizer.param_groups]# 将学习率列表 lr 追加到评估结果列表 self.evaluate_results 中。这样,评估结果就包含了模型性能指标和学习率信息。self.evaluate_results = list(self.evaluate_results) + lr# tensorboard 日志。# log for tensorboard# 调用 write_tblog 函数将当前周期 self.epoch 、评估结果 self.evaluate_results 和平均损失 self.mean_loss 写入TensorBoard。这有助于在训练过程中可视化性能指标和损失变化。write_tblog(self.tblogger, self.epoch, self.evaluate_results, self.mean_loss)# 将验证预测保存到 tensorboard。# save validation predictions to tensorboard# 调用 write_tbimg 函数将验证阶段的预测图像 self.vis_imgs_list 保存到TensorBoard。这些图像通常展示了模型在验证集上的预测结果,有助于直观地评估模型性能。write_tbimg(self.tblogger, self.vis_imgs_list, self.epoch, type='val')# 它负责执行模型的评估过程。def eval_model(self):# 检查配置对象 self.cfg 是否包含 eval_params 属性。如果没有,说明评估参数未被特别指定,将使用默认的评估流程。if not hasattr(self.cfg, "eval_params"):# 调用 eval.run 函数执行评估。# 1.elf.data_dict :包含数据集信息的字典。# 2.batch_size=self.batch_size // self.world_size * 2 :设置评估时的批量大小,通常是训练时批量大小的两倍,但考虑到分布式训练中的 world_size 。# 3.img_size=self.img_size :设置图像的尺寸。# 4.model=self.ema.ema if self.args.calib is False else self.model :选择使用EMA模型或原始模型进行评估。如果不在校准( calib )阶段,则使用EMA模型;否则使用原始模型。# 5.conf_thres=0.03 :设置置信度阈值,用于过滤低置信度的检测结果。# 6.dataloader=self.val_loader :使用验证数据加载器。# 7.save_dir=self.save_dir :设置保存评估结果的目录。# 8.task='train' :设置任务类型,这里为 'train' ,可能用于区分训练和测试阶段的评估。# 函数返回评估结果 results ,可视化输出 vis_outputs 和可视化路径 vis_paths 。results, vis_outputs, vis_paths = eval.run(self.data_dict,batch_size=self.batch_size // self.world_size * 2,img_size=self.img_size,model=self.ema.ema if self.args.calib is False else self.model,conf_thres=0.03,dataloader=self.val_loader,save_dir=self.save_dir,task='train')else:# 它用于从配置字典 cfg_dict 中获取指定键 value_str 对应的值。如果该键存在,则返回其值;如果不存在,则返回默认值 default_value 。# 1.cfg_dict :配置信息的字典。# 2.value_str :要获取值的键的字符串表示。# 3.default_value :如果键不存在或键对应的值为 None ,则返回的默认值。def get_cfg_value(cfg_dict, value_str, default_value):# 检查 value_str 是否作为键存在于 cfg_dict 中。if value_str in cfg_dict:# 如果键存在,进一步检查该键对应的值的类型:if isinstance(cfg_dict[value_str], list):# 如果值是列表类型,进一步检查列表中的第一个元素是否为 None 。# 如果列表的第一个元素不是 None ,则返回该元素;如果是 None ,则返回 default_value 。return cfg_dict[value_str][0] if cfg_dict[value_str][0] is not None else default_value# 否则,如果值不是列表类型。else:# 如果值不是 None ,则直接返回该值;如果是 None ,则返回 default_value 。return cfg_dict[value_str] if cfg_dict[value_str] is not None else default_valueelse:# 如果键不存在于字典中,则返回 default_value  。return default_value# 使用 get_cfg_value 函数从评估参数中获取图像尺寸 img_size ,如果未指定,则使用实例变量 self.img_size 作为默认值。eval_img_size = get_cfg_value(self.cfg.eval_params, "img_size", self.img_size)# 调用 eval.run 函数执行模型评估,传递以下参数 。# 1.self.data_dict :包含数据集信息的字典。# 2.batch_size :评估时的批量大小,从评估参数中获取,如果未指定,则使用 self.batch_size // self.world_size * 2 作为默认值。# 3.img_size :评估时的图像尺寸,使用 eval_img_size 。# 4.model :选择使用EMA模型或原始模型进行评估。# 5.conf_thres :置信度阈值,从评估参数中获取,如果未指定,则使用0.03作为默认值。# 6.dataloader :验证数据加载器。# 7.save_dir :保存评估结果的目录。# 8.task :任务类型,这里为 'train' 。# 9.shrink_size :图像缩小尺寸,从评估参数中获取,如果未指定,则使用 eval_img_size 作为默认值。# 10.infer_on_rect :是否在矩形区域进行推理,从评估参数中获取,如果未指定,则使用 False 作为默认值。# 11.verbose :是否输出详细信息,从评估参数中获取,如果未指定,则使用 False 作为默认值。# 12.do_coco_metric :是否计算COCO评估指标,从评估参数中获取,如果未指定,则使用 True 作为默认值。# 13.do_pr_metric :是否计算PR曲线评估指标,从评估参数中获取,如果未指定,则使用 False 作为默认值。# 14.plot_curve :是否绘制评估曲线,从评估参数中获取,如果未指定,则使用 False 作为默认值。# 15.plot_confusion_matrix :是否绘制混淆矩阵,从评估参数中获取,如果未指定,则使用 False 作为默认值。results, vis_outputs, vis_paths = eval.run(self.data_dict,batch_size=get_cfg_value(self.cfg.eval_params, "batch_size", self.batch_size // self.world_size * 2),img_size=eval_img_size,model=self.ema.ema if self.args.calib is False else self.model,conf_thres=get_cfg_value(self.cfg.eval_params, "conf_thres", 0.03),dataloader=self.val_loader,save_dir=self.save_dir,task='train',shrink_size=get_cfg_value(self.cfg.eval_params, "shrink_size", eval_img_size),infer_on_rect=get_cfg_value(self.cfg.eval_params, "infer_on_rect", False),verbose=get_cfg_value(self.cfg.eval_params, "verbose", False),do_coco_metric=get_cfg_value(self.cfg.eval_params, "do_coco_metric", True),do_pr_metric=get_cfg_value(self.cfg.eval_params, "do_pr_metric", False),plot_curve=get_cfg_value(self.cfg.eval_params, "plot_curve", False),plot_confusion_matrix=get_cfg_value(self.cfg.eval_params, "plot_confusion_matrix", False),)# 使用日志记录器记录当前周期和评估结果,包括 mAP@0.5 和 mAP@0.50:0.95。LOGGER.info(f"Epoch: {self.epoch} | mAP@0.5: {results[0]} | mAP@0.50:0.95: {results[1]}")# 更新实例变量 self.evaluate_results ,存储前两个评估结果。self.evaluate_results = results[:2]# plot validation predictions# 调用 plot_val_pred 方法绘制验证预测结果。# def plot_val_pred(self, vis_outputs, vis_paths, vis_conf=0.3, vis_max_box_num=5): -> 它用于绘制验证阶段的预测结果,并将绘制好的图像保存到一个列表中。self.plot_val_pred(vis_outputs, vis_paths)# 它在训练循环开始之前执行一系列初始化操作。def train_before_loop(self):# 训练开始...。LOGGER.info('Training start...')# 记录训练开始的时间戳。self.start_time = time.time()# 根据配置中的预热周期数和最大步数计算预热步数。如果启用了量化,则预热步数设置为0。self.warmup_stepnum = max(round(self.cfg.solver.warmup_epochs * self.max_stepnum), 1000) if self.args.quant is False else 0# 设置学习率调度器的初始周期,通常用于从检查点恢复时保持学习率的正确状态。self.scheduler.last_epoch = self.start_epoch - 1# 初始化最后一个优化步骤的计数器。self.last_opt_step = -1# 初始化自动混合精度(AMP)的梯度缩放器,用于FP16训练。self.scaler = amp.GradScaler(enabled=self.device != 'cpu')# 初始化最佳平均精度(AP)和当前AP值。self.best_ap, self.ap = 0.0, 0.0# 初始化最佳停止增强平均精度。self.best_stop_strong_aug_ap = 0.0# 初始化评估结果,通常包含AP50和AP50_95。self.evaluate_results = (0, 0) # AP50, AP50_95# resume results    恢复结果。#  如果存在检查点( if hasattr(self, "ckpt") ),则从检查点中恢复评估结果和最佳AP值。# 这行代码检查当前类的实例( self )是否有一个名为 ckpt 的属性。 ckpt 通常是一个字典,包含了之前训练过程中保存的状态信息,比如模型参数、优化器状态、评估结果等。if hasattr(self, "ckpt"):# 如果存在检查点,从检查点中获取 results 键对应的值,并将其赋值给 self.evaluate_results 。这里的 results 可能是一个包含多个评估指标的元组或列表。self.evaluate_results = self.ckpt['results']# 从评估结果中取出第二个元素(索引为1),并将其赋值给 self.best_ap 。这个值通常代表最佳平均精度(AP)。self.best_ap = self.evaluate_results[1]# 同样地,将第二个元素的值也赋值给 self.best_stop_strong_aug_ap 。这个值可能代表在停止训练时的最佳增强平均精度。self.best_stop_strong_aug_ap = self.evaluate_results[1]# 根据配置和数据字典初始化损失计算模块。self.compute_loss = ComputeLoss(num_classes=self.data_dict['nc'],ori_img_size=self.img_size,warmup_epoch=self.cfg.model.head.atss_warmup_epoch,use_dfl=self.cfg.model.head.use_dfl,reg_max=self.cfg.model.head.reg_max,iou_type=self.cfg.model.head.iou_type,fpn_strides=self.cfg.model.head.strides)if self.args.fuse_ab:# 如果启用了AB融合,则初始化AB融合损失计算模块。self.compute_loss_ab = ComputeLoss_ab(num_classes=self.data_dict['nc'],ori_img_size=self.img_size,warmup_epoch=0,use_dfl=False,reg_max=0,iou_type=self.cfg.model.head.iou_type,fpn_strides=self.cfg.model.head.strides)if self.args.distill :# 如果启用了蒸馏,则根据模型类型初始化蒸馏损失计算模块。if self.cfg.model.type in ['YOLOv6n','YOLOv6s']:Loss_distill_func = ComputeLoss_distill_nselse:Loss_distill_func = ComputeLoss_distillself.compute_loss_distill = Loss_distill_func(num_classes=self.data_dict['nc'],ori_img_size=self.img_size,fpn_strides=self.cfg.model.head.strides,warmup_epoch=self.cfg.model.head.atss_warmup_epoch,use_dfl=self.cfg.model.head.use_dfl,reg_max=self.cfg.model.head.reg_max,iou_type=self.cfg.model.head.iou_type,distill_weight = self.cfg.model.head.distill_weight,distill_feat = self.args.distill_feat,)# 它在每个训练周期开始之前执行一系列准备工作。def prepare_for_steps(self):# 如果当前周期 self.epoch 大于起始周期 self.start_epoch ,则调用 self.scheduler.step() 更新学习率。if self.epoch > self.start_epoch:self.scheduler.step()# 如果当前周期等于起始周期且存在检查点( self.ckpt ),则为优化器的每个参数组设置学习率,这是从检查点恢复时必要的步骤。elif  hasattr(self, "ckpt") and self.epoch == self.start_epoch: # resume first epoch, load lr    恢复第一个 epoch,加载 lr。for k, param in enumerate(self.optimizer.param_groups):param['lr'] = self.scheduler.get_lr()[k]#stop strong aug like mosaic and mixup from last n epoch by recreate dataloader    通过重新创建数据加载器来停止上个 n 个时期的马赛克和混合等强增强。# 如果当前周期是最后一个周期减去配置中指定的在最后几个周期停止增强的周期数 self.args.stop_aug_last_n_epoch ,则将数据增强的概率设置为0,并重新创建数据加载器以应用这些更改。if self.epoch == self.max_epoch - self.args.stop_aug_last_n_epoch:self.cfg.data_aug.mosaic = 0.0self.cfg.data_aug.mixup = 0.0# def get_data_loader(args, cfg, data_dict): -> 它用于创建并返回训练和验证数据加载器( DataLoader )。这个方法不依赖于类的实例,因此被定义为静态方法。# -> 这个方法返回训练和验证数据加载器,但在这段代码中,验证数据加载器只在主进程中创建,因此可能返回 None 。 -> return train_loader, val_loaderself.train_loader, self.val_loader = self.get_data_loader(self.args, self.cfg, self.data_dict)# 将模型设置为训练模式,这对于某些特定层(如Dropout和BatchNorm)的行为是必要的。self.model.train()# 如果不是单GPU训练( self.rank != -1 ),则为训练数据加载器的采样器设置当前周期,这对于确保数据在多个GPU间正确分配是必要的。if self.rank != -1:self.train_loader.sampler.set_epoch(self.epoch)# 初始化一个张量来存储平均损失值。self.mean_loss = torch.zeros(self.loss_num, device=self.device)# 清空(重置)优化器的梯度,这是每次训练步骤之前必须的操作。self.optimizer.zero_grad()# 使用日志记录器记录损失信息的标题。LOGGER.info(('\n' + '%10s' * (self.loss_num + 1)) % (*self.loss_info,))# 初始化一个进度条,用于遍历训练数据加载器。self.pbar = enumerate(self.train_loader)if self.main_process:# 如果是主进程( self.main_process ),则使用 tqdm 库来创建一个可视化的进度条,并设置总步数、列宽和进度条格式。self.pbar = tqdm(self.pbar, total=self.max_stepnum, ncols=NCOLS, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')# 打印每个步骤后的损失。# Print loss after each stepsdef print_details(self):# 这个条件检查当前进程是否为主进程。在分布式训练环境中,只有主进程负责打印日志信息,以避免重复的日志输出。if self.main_process:# 更新平均损失值。# 这里使用了一个在线更新平均值的公式,其中 self.mean_loss 是之前的平均损失值, self.step 是当前步骤数, self.loss_items 是当前步骤的损失值。# 这个公式可以有效地计算新的平均损失,而不需要存储所有历史损失值。self.mean_loss = (self.mean_loss * self.step + self.loss_items) / (self.step + 1)# 更新进度条的描述信息,显示当前周期和平均损失值。# %10s 和 %10.4g 是格式化字符串,分别用于格式化文本和浮点数, %10s 表示左对齐的字符串, %10.4g 表示保留4位小数的浮点数。# self.loss_num 是损失值的数量,用于确定需要多少个 %10.4g 格式化占位符。# f'{self.epoch}/{self.max_epoch - 1}' 表示当前周期和总周期数(不包括当前周期)。# *(self.mean_loss) 是一个解包操作,将 self.mean_loss 数组中的每个损失值作为参数传递给格式化字符串。self.pbar.set_description(('%10s' + '%10.4g' * self.loss_num) % (f'{self.epoch}/{self.max_epoch - 1}', \*(self.mean_loss)))def strip_model(self):if self.main_process:LOGGER.info(f'\nTraining completed in {(time.time() - self.start_time) / 3600:.3f} hours.')save_ckpt_dir = osp.join(self.save_dir, 'weights')strip_optimizer(save_ckpt_dir, self.epoch)  # strip optimizers for saved pt model# Empty cache if training finisheddef train_after_loop(self):if self.device != 'cpu':torch.cuda.empty_cache()# 负责在训练过程中更新优化器的状态,包括学习率的调整、执行优化步骤和梯度的清零。def update_optimizer(self):# 计算全局步骤数,这是当前周期数乘以每个周期的步数加上当前步数。curr_step = self.step + self.max_stepnum * self.epoch# 设置梯度积累的步数,这是基于批量大小来确定的。积累步数用于在小批量训练时模拟大批量训练的效果。self.accumulate = max(1, round(64 / self.batch_size))# 如果当前步骤在预热步数内( curr_step <= self.warmup_stepnum )if curr_step <= self.warmup_stepnum:# 使用 NumPy 的 np.interp 函数在预热期间线性插值计算积累步数。这个值从1开始,逐渐增加到基于批量大小计算的目标积累步数( 64 / self.batch_size ),并向上取整到最接近的整数。self.accumulate = max(1, np.interp(curr_step, [0, self.warmup_stepnum], [1, 64 / self.batch_size]).round())# 遍历优化器的每个参数组。for k, param in enumerate(self.optimizer.param_groups):# 如果当前是第三个参数组(索引为2),则使用配置中的预热偏置学习率 warmup_bias_lr ;否则,使用0.0。warmup_bias_lr = self.cfg.solver.warmup_bias_lr if k == 2 else 0.0# 使用 np.interp 函数在预热期间线性插值计算学习率。这个值从 warmup_bias_lr 开始,逐渐增加到初始学习率乘以当前周期的学习率因子 self.lf(self.epoch)  。param['lr'] = np.interp(curr_step, [0, self.warmup_stepnum], [warmup_bias_lr, param['initial_lr'] * self.lf(self.epoch)])# 如果参数组中包含动量参数。if 'momentum' in param:# 使用 np.interp 函数在预热期间线性插值计算动量值。这个值从配置中的预热动量 warmup_momentum 开始,逐渐增加到正常动量 momentum 。param['momentum'] = np.interp(curr_step, [0, self.warmup_stepnum], [self.cfg.solver.warmup_momentum, self.cfg.solver.momentum])# 检查自上次优化器更新以来是否已经累积了足够的步骤。 curr_step 是当前的全局步骤数, self.last_opt_step 是上次优化器更新时的步骤数, self.accumulate 是梯度积累的步数。if curr_step - self.last_opt_step >= self.accumulate:# 如果条件满足,使用自动混合精度(AMP)的梯度缩放器 self.scaler 来执行优化器的 step 方法,这会应用累积的梯度来更新模型的参数。self.scaler.step(self.optimizer)# 更新梯度缩放器的状态,准备下一次梯度缩放。self.scaler.update()# 清空优化器的梯度,为下一次梯度积累做准备。self.optimizer.zero_grad()# 如果启用了EMA。if self.ema:# 更新模型参数的EMA。EMA是一种技术,用于平滑模型参数,通常可以提高模型的泛化能力。self.ema.update(self.model)# 更新最后优化步骤的计数器,记录当前步骤数。self.last_opt_step = curr_step# 它用于创建并返回训练和验证数据加载器( DataLoader )。这个方法不依赖于类的实例,因此被定义为静态方法。@staticmethod# 1.args :包含命令行参数的对象,可能包含图像尺寸、批量大小、工作进程数等。# 2.cfg :配置信息,可能包含数据增强设置等。# 3.data_dict :包含数据集信息的字典,如训练路径、验证路径、类别数量等。def get_data_loader(args, cfg, data_dict):# 从 data_dict 中提取训练和验证数据的路径。train_path, val_path = data_dict['train'], data_dict['val']# check data    检查数据。# 提取类别数量并转换为整数。nc = int(data_dict['nc'])# 提取类别名称列表。class_names = data_dict['names']# 确保类别名称的数量与定义的类别数量相匹配。assert len(class_names) == nc, f'the length of class names does not match the number of classes defined'    # 类名的长度与定义的类数不匹配。# 计算网格尺寸,这是基于模型头部的步长和32的较大值。grid_size = max(int(max(cfg.model.head.strides)), 32)# create train dataloader    创建训练数据加载器。# create_dataloader 函数用于创建数据加载器(Dataloader),它负责在训练或验证过程中提供图像和标签数据。# 1. path : 图像数据的加载路径,通常是训练或验证数据集的根目录。# 2. img_size : 图像的输入尺寸,即图像在送入模型之前被缩放的尺寸。# 3. batch_size : 每个批次中的图像数量。# 4. stride : 模型下采样的步长,用于确定特征图的尺寸。# 5. hyp (默认值为 None) : 超参数字典,包含了训练过程中的各种参数,如学习率、迭代次数等。# 6. augment (默认值为 False) : 是否启用数据增强。# 7. check_images (默认值为 False) : 是否检查图像文件的完整性。# 8. check_labels (默认值为 False) : 是否检查标签文件的完整性和格式。# 9. pad (默认值为 0.0) : 在进行矩形训练时,用于调整图像尺寸的填充比例。# 10. rect (默认值为 False) : 是否使用矩形训练,即图像的宽高比可以不同。# 11. rank (默认值为 -1) : 在分布式训练中,指定当前进程的编号。如果是 -1,则不使用分布式训练。# 12. workers (默认值为 8) : 加载数据时使用的子进程数量。# 13. shuffle (默认值为 False) : 是否在每个epoch开始时打乱数据。# 14. data_dict (默认值为 None) : 包含数据集信息的字典,如类别名称、类别数量等。# 15. task (默认值为 "Train") : 指定数据集用于训练("train")还是验证("val")。# 调用 create_dataloader 函数创建训练数据加载器。# 返回值是一个包含数据加载器和其他可能信息的元组,这里只取第一个元素(即数据加载器)。train_loader = create_dataloader(train_path, args.img_size, args.batch_size // args.world_size, grid_size,hyp=dict(cfg.data_aug), augment=True, rect=False, rank=args.local_rank,workers=args.workers, shuffle=True, check_images=args.check_images,check_labels=args.check_labels, data_dict=data_dict, task='train')[0]# create val dataloader    创建验证数据加载器# 初始化验证数据加载器为 None 。val_loader = None# 检查当前进程是否为主进程(在分布式训练中,只有主进程创建验证数据加载器)。if args.rank in [-1, 0]:# 如果当前进程为主进程,调用 create_dataloader 函数创建验证数据加载器。参数与训练数据加载器类似,但批量大小可能不同,且不进行数据增强。val_loader = create_dataloader(val_path, args.img_size, args.batch_size // args.world_size * 2, grid_size,hyp=dict(cfg.data_aug), rect=True, rank=-1, pad=0.5,workers=args.workers, check_images=args.check_images,check_labels=args.check_labels, data_dict=data_dict, task='val')[0]# 这个方法返回训练和验证数据加载器,但在这段代码中,验证数据加载器只在主进程中创建,因此可能返回 None 。return train_loader, val_loader# 它用于将输入的批次数据( batch_data )进行预处理,并将其移动到指定的设备(如GPU)。@staticmethod# 1.batch_data :包含图像和目标(标签)的批次数据。通常, batch_data[0] 是图像张量, batch_data[1] 是对应的目标张量。# 2.device :数据需要被移动到的设备,例如 torch.device("cuda:0") 或 torch.device("cpu") 。def prepro_data(batch_data, device):# 将图像数据移动到指定的设备上,并转换为浮点类型。这里使用 non_blocking=True 参数来异步执行数据传输,这可以提高数据加载的效率。然后,将图像数据除以255,将其值从[0, 255]范围归一化到[0, 1]范围。images = batch_data[0].to(device, non_blocking=True).float() / 255# 将目标数据(标签)移动到指定的设备上。targets = batch_data[1].to(device)# 返回预处理后的图像和目标数据。return images, targets# 用于根据提供的参数构建和配置YOLOv6模型。# 1.self :类的实例。# 2.args :包含命令行参数的对象。# 3.cfg :配置信息,可能包含模型类型、预训练权重等。# 4.nc :类别数量( num_classes )。# 5.device :模型运行的设备,比如CPU或GPU。def get_model(self, args, cfg, nc, device):# 如果配置中的模型类型包含 'YOLOv6-lite' ,则执行以下操作:if 'YOLOv6-lite' in cfg.model.type:# 断言不启用 fuse_ab 和 distill 模式,因为 YOLOv6-lite 模型不支持这些模式。assert not self.args.fuse_ab, 'ERROR in: YOLOv6-lite models not support fuse_ab mode.'assert not self.args.distill, 'ERROR in: YOLOv6-lite models not support distill mode.'# 调用 build_lite_model 函数来构建轻量级模型。model = build_lite_model(cfg, nc, device)else:# 如果不是 YOLOv6-lite 模型,则调用 build_model 函数来构建常规模型,并传递 fuse_ab 和 distill_ns 参数。model = build_model(cfg, nc, device, fuse_ab=self.args.fuse_ab, distill_ns=self.distill_ns)# 检查 cfg.model.pretrained 是否设置了预训练权重路径。weights = cfg.model.pretrained# 如果设置了预训练权重路径:if weights:  # finetune if pretrained model is set    如果设置了预训练模型则进行微调。# 检查权重文件是否存在,如果不存在则下载权重文件。if not os.path.exists(weights):# def download_ckpt(path): -> 下载预训练模型的检查点。download_ckpt(weights)# 从 {weights} 加载 state_dict 以进行微调......LOGGER.info(f'Loading state_dict from {weights} for fine-tuning...')# def load_state_dict(weights, model, map_location=None): -> # 从检查点文件加载权重,仅为那些层的名称和形状匹配的分配权重。 -> return modelmodel = load_state_dict(weights, model, map_location=device)#  使用 LOGGER.info 记录模型信息。LOGGER.info('Model: {}'.format(model))# 返回构建和配置好的模型实例。return model# 它用于创建和配置一个教师模型,这在知识蒸馏(Knowledge Distillation)的场景中是常见的。# 1.self :类的实例。# 2.args :包含命令行参数的对象。# 3.cfg :配置信息,可能包含模型类型、预训练权重等。# 4.nc :类别数量( num_classes )。# 5.device :模型运行的设备,比如CPU或GPU。def get_teacher_model(self, args, cfg, nc, device):# 这行代码检查配置中模型头部的层数是否为3,如果是,则设置 teacher_fuse_ab 为 True ,否则为 False 。teacher_fuse_ab = False if cfg.model.head.num_layers != 3 else True# 调用 build_model 函数来构建教师模型,并传递融合参数 teacher_fuse_ab 。model = build_model(cfg, nc, device, fuse_ab=teacher_fuse_ab)# 从命令行参数中获取教师模型的预训练权重路径。weights = args.teacher_model_path# 如果 weights 非空(即指定了预训练权重路径)。if weights:  # finetune if pretrained model is set    如果设置了预训练模型,则进行微调。# 使用 LOGGER.info 记录正在从指定路径加载状态字典( state_dict )作为教师模型的信息。LOGGER.info(f'Loading state_dict from {weights} for teacher')    # 从 {weights} 为教师加载 state_dict。# 调用 load_state_dict 函数加载预训练权重到教师模型中,并指定映射位置为   device  。model = load_state_dict(weights, model, map_location=device)# 使用 LOGGER.info 记录教师模型的信息。LOGGER.info('Model: {}'.format(model))# Do not update running means and running vars    不更新运行方式和运行变量。# 遍历模型中的所有模块( module )。# 如果模块是 torch.nn.BatchNorm2d 类型(即二维批量归一化层),则将 track_running_stats 设置为 False 。这样做可以防止在推理或蒸馏过程中更新批量归一化的运行均值和方差。for module in model.modules():if isinstance(module, torch.nn.BatchNorm2d):# module.track_running_stats# 在PyTorch中, module.track_running_stats 是 torch.nn.modules.batchnorm._BatchNorm 类(包括 torch.nn.BatchNorm1d , torch.nn.BatchNorm2d , 和 torch.nn.BatchNorm3d )的一个属性。# 这个属性控制着批量归一化(Batch Normalization)层是否更新其运行均值(running mean)和运行方差(running variance)。# 具体来说 :# track_running_stats=True (默认值) :当设置为 True 时,批量归一化层会在训练过程中更新其运行均值和运行方差。这些运行统计数据是基于在该层上传递的所有数据批次计算得出的,并用于推理(即模型评估)时的归一化。这是训练模式下的典型行为。# track_running_stats=False :当设置为 False 时,批量归一化层不会更新其运行均值和运行方差。这意味着在推理时,该层将只使用在训练过程中计算并保存的均值和方差,而不会根据新的输入数据进一步调整这些统计数据。# 这在某些情况下很有用,比如在知识蒸馏或者模型微调时,我们希望保持批量归一化层的统计数据不变,以确保教师模型和学生模型的行为一致。# 在推理或模型部署时,通常将 track_running_stats 设置为 False ,因为此时不需要进一步更新统计数据,而且可以减少计算量。# 此外,在某些特定的模型微调或蒸馏场景中,保持批量归一化层的统计数据不变可以确保模型的行为与训练时一致,这对于模型性能的一致性是有益的。module.track_running_stats = False# 返回构建和配置好的教师模型实例。return model# 它用于从一个预训练模型中加载量化缩放因子(scales)。@staticmethod# 1.cfg :配置信息,可能包含预训练模型的路径和量化相关的设置。# 2.device :模型运行的设备,比如CPU或GPU。def load_scale_from_pretrained_models(cfg, device):# 从配置中提取预训练模型的路径,该路径可能包含量化缩放因子。weights = cfg.model.scalesscales = None# 如果 weights 为空,则使用 LOGGER.error 记录错误信息,指出没有提供初始化 RepOptimizer 所需的缩放因子。if not weights:LOGGER.error("ERROR: No scales provided to init RepOptimizer!")# 如果 weights 不为空,则使用 torch.load 加载预训练模型的状态字典(state dictionary),并指定映射位置为 device 。else:# 加载预训练模型的状态字典。ckpt = torch.load(weights, map_location=device)# 调用 extract_scales 函数从加载的预训练模型中提取量化缩放因子。# def extract_scales(model): -> 它用于从给定的模型中提取特定模块( LinearAddBlock )的缩放因子(scales)。返回提取到的量化缩放因子 scales 。 -> return scalesscales = extract_scales(ckpt)# 返回提取到的量化缩放因子 scales 。return scales# 用于根据提供的参数配置模型以在不同的并行计算模式下运行。@staticmethod# 1.args :包含命令行参数的对象。# 2.model :要进行并行处理的模型。# 3.device :模型运行的设备,比如CPU或GPU。def parallel_model(args, model, device):# If DP mode# 如果设备不是CPU且 args.rank 为-1,则认为是DP模式。 args.rank 通常用于分布式训练中标识每个进程的编号,-1表示非分布式训练。dp_mode = device.type != 'cpu' and args.rank == -1# 如果是DP模式且GPU数量大于1,则记录警告日志,推荐使用DistributedDataParallel(DDP)代替DP,因为DP在多GPU训练时效率较低。if dp_mode and torch.cuda.device_count() > 1:LOGGER.warning('WARNING: DP not recommended, use DDP instead.\n')    # 警告:不推荐使用 DP,请改用 DDP。# 使用PyTorch的 DataParallel 包装器包装模型,使其能够在多个GPU上并行运行。model = torch.nn.DataParallel(model)# If DDP mode# 如果设备不是CPU且 args.rank 不为-1,则认为是DDP模式。ddp_mode = device.type != 'cpu' and args.rank != -1if ddp_mode:# 创建DDP实例,指定模型运行的GPU编号和输出设备。model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)# 返回包装后的模型,无论是DP还是DDP模式。return model# 责根据提供的参数和模型构建并返回一个优化器实例。# 1.args :包含命令行参数的对象,可能包含批量大小、是否使用分布式训练等信息。# 2.cfg :配置信息,包含优化器的配置,如学习率、权重衰减等。# 3.model :需要优化的模型。def get_optimizer(self, args, cfg, model):# 计算基于批量大小的积累步数。这个值用于梯度积累,以模拟较大的批量大小,通常与批量大小成反比。accumulate = max(1, round(64 / args.batch_size))# 根据批量大小和积累步数调整权重衰减系数。这种调整有助于在不同批量大小下保持相似的正则化效果。cfg.solver.weight_decay *= args.batch_size * accumulate / 64# 根据批量大小、世界大小(分布式训练中的进程数)和每个GPU的批量大小调整初始学习率。这种调整有助于在不同训练设置下保持相似的学习动态。cfg.solver.lr0 *= args.batch_size / (self.world_size * args.bs_per_gpu) # rescale lr0 related to batchsize    重新调整与批量大小相关的 lr0。# 调用 build_optimizer 函数,根据配置和模型构建优化器实例。# def build_optimizer(cfg, model): -> 从 cfg 文件构建优化器。 -> return optimizeroptimizer = build_optimizer(cfg, model)# 返回构建好的优化器实例。return optimizer# 它负责根据提供的参数、配置和优化器构建并返回一个学习率调度器实例以及一个学习率函数。# 1.args :包含命令行参数的对象,可能包含训练周期数等信息。# 2.cfg :配置信息,包含学习率调度器的配置。# 3.optimizer :用于训练模型的优化器实例。@staticmethoddef get_lr_scheduler(args, cfg, optimizer):# 从命令行参数中获取总的训练周期数。epochs = args.epochs# 调用 build_lr_scheduler 函数,根据配置、优化器和训练周期数构建学习率调度器实例 lr_scheduler 和学习率函数 lf 。学习率函数 lf 通常用于根据当前周期动态调整学习率。# def build_lr_scheduler(cfg, optimizer, epochs): -> 从 cfg 文件构建学习率调度程序。 -> return scheduler, lflr_scheduler, lf = build_lr_scheduler(cfg, optimizer, epochs)return lr_scheduler, lf# 它用于将训练批次中的图像和对应的目标(标签)绘制成一幅马赛克图(mosaic),并保存用于可视化。# 1.images :训练批次中的图像张量。# 2.targets :训练批次中的目标(标签)张量。# 3.max_size :绘制图像的最大尺寸,默认为1920。# 4.max_subplots :绘制图像的最大子图数量,默认为16。def plot_train_batch(self, images, targets, max_size=1920, max_subplots=16):# Plot train_batch with labels    绘制带标签的 train_batch。# 检查 images 是否为 PyTorch 张量。if isinstance(images, torch.Tensor):# 如果是张量,首先将其移动到 CPU( .cpu() ),然后转换为浮点类型( .float() ),最后转换为 NumPy 数组( .numpy() )。这是因为后续的绘图操作通常在 NumPy 数组上进行。images = images.cpu().float().numpy()# 检查 targets 是否为 PyTorch 张量。if isinstance(targets, torch.Tensor):# 如果是张量,将其移动到 CPU 并转换为 NumPy 数组。targets = targets.cpu().numpy()# 检查图像数据的最大值是否小于或等于1,这通常意味着图像数据是归一化到 [0, 1] 范围内的。if np.max(images[0]) <= 1:# 如果是归一化的,将图像数据乘以255进行反归一化,将其值转换回 [0, 255] 范围。这一步是可选的,取决于图像数据是否需要以原始像素值进行显示。images *= 255  # de-normalise (optional)    反归一化(可选)。# 从 NumPy 数组中获取图像的批量大小( bs )、通道数( _ ,通常为3)、高度( h )和宽度( w )。bs, _, h, w = images.shape  # batch size, _, height, width# 将批量大小限制为 max_subplots ,以确保不会尝试绘制超过指定数量的图像。bs = min(bs, max_subplots)  # limit plot images    限制绘图图像。# 计算需要多少行和列来容纳 bs 个图像,使得每行每列尽可能均匀。这里使用平方根来确定子图的行数和列数, np.ceil 用于向上取整,确保有足够的空间容纳所有图像。ns = np.ceil(bs ** 0.5)  # number of subplots (square)    子图数量(平方)。# 从 self.batch_data 中获取图像路径,这通常用于在绘制的图像上标注每个图像的文件名。paths = self.batch_data[2]  # image paths    图像路径。# Build Image    构建图像。# 创建一个填充为255(白色)的NumPy数组,其形状为 (ns * h, ns * w, 3) ,其中 ns 是子图数量的平方根(向上取整), h 和 w 分别是单个图像的高度和宽度。这个数组将用来存放所有图像的马赛克。mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8)  # init    初始化。# 遍历 images 数组, i 是索引, im 是当前处理的图像。for i, im in enumerate(images):# 如果已经处理了 max_subplots 个图像,则退出循环。这确保即使最后一批图像的数量少于预期,代码也不会尝试访问不存在的图像。if i == max_subplots:  # if last batch has fewer images than we expect    如果最后一批图像少于我们预期。break# 计算当前图像在马赛克中的位置。 i // ns 计算当前图像所在的行, i % ns 计算当前图像所在的列。 x 和 y 分别是图像左上角在马赛克中的坐标。x, y = int(w * (i // ns)), int(h * (i % ns))  # block origin    块起源。# 将图像数据从 (C, H, W) 格式(通道、高度、宽度)转换为 (H, W, C) 格式,以适配后续的图像操作。im = im.transpose(1, 2, 0)# 将当前图像 im 放置到马赛克 mosaic 中的对应位置。这里 y:y + h 和 x:x + w 定义了马赛克中的一个区域,该区域的大小与单个图像的大小相同。mosaic[y:y + h, x:x + w, :] = im# Resize (optional)    调整大小(可选)。# 计算用于调整马赛克图像大小的缩放比例。 max_size 是马赛克图像的最大期望尺寸, ns 是子图数量的平方根(向上取整), h 和 w 分别是单个图像的高度和宽度。 max(h, w) 确保在计算缩放比例时考虑到图像的较大维度。scale = max_size / ns / max(h, w)# 如果计算出的缩放比例小于1,即马赛克图像的当前尺寸超过了最大尺寸限制,则需要进行缩放。if scale < 1:# 将每个子图的高度按比例缩小,使用 math.ceil 向上取整以确保尺寸为整数。h = math.ceil(scale * h)w = math.ceil(scale * w)# 使用 OpenCV 的 cv2.resize 函数调整马赛克图像的大小。新的大小为 (ns * w, ns * h) ,即每个子图的缩放尺寸乘以子图的数量。这里使用 int(x * ns) 确保目标尺寸为整数。mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))# 遍历批量中的图像数量 bs 。for i in range(bs):# 计算每个图像块在马赛克中的起始坐标。x, y = int(w * (i // ns)), int(h * (i % ns))  # block origin    块起源。# 在马赛克图像上绘制白色边界框,以区分不同的图像块。cv2.rectangle(mosaic, (x, y), (x + w, y + h), (255, 255, 255), thickness=2)  # borders    边界。# 在每个图像块的左上角绘制文件名,文件名截取前40个字符。cv2.putText(mosaic, f"{os.path.basename(paths[i])[:40]}", (x + 5, y + 15),cv2.FONT_HERSHEY_COMPLEX, 0.5, color=(220, 220, 220), thickness=1)  # filename    文件名。# 如果目标数据不为空,则处理目标数据。if len(targets) > 0:# 从目标数据中提取当前图像的目标。ti = targets[targets[:, 0] == i]  # image targets    图像目标。# 将目标数据中的边界框坐标从 xywh 格式转换为 xyxy 格式,并转置矩阵以便于处理。boxes = xywh2xyxy(ti[:, 2:6]).T# 提取目标数据中的类别。classes = ti[:, 1].astype('int')# 判断目标数据中是否包含置信度列,如果没有,则只显示标签。labels = ti.shape[1] == 6  # labels if no conf column    如果没有 conf 列则显示标签。if boxes.shape[1]:# 如果边界框坐标是归一化的,则将其转换为像素坐标。if boxes.max() <= 1.01:  # if normalized with tolerance 0.01    如果以 0.01 的公差进行归一化。# 将边界框的坐标缩放到像素值。boxes[[0, 2]] *= w  # scale to pixels    缩放至像素。boxes[[1, 3]] *= h# 如果图像缩放,则将边界框坐标按比例缩放。elif scale < 1:  # absolute coords need scale if image scales    如果图像缩放,绝对坐标需要缩放。# 按比例缩放边界框坐标。boxes *= scale# 将边界框坐标偏移到马赛克中的对应位置。boxes[[0, 2]] += xboxes[[1, 3]] += y# 遍历转置后的边界框列表 boxes , j 是索引, box 是当前处理的边界框。for j, box in enumerate(boxes.T.tolist()):# 将边界框的坐标从浮点数转换为整数,因为像素坐标必须是整数。box = [int(k) for k in box]# 获取当前边界框对应的类别索引。cls = classes[j]# 根据类别索引从 self.color 中获取对应的颜色,并将其转换为整数元组。color = tuple([int(x) for x in self.color[cls]])# 如果 self.data_dict 中包含类别名称列表,则使用类别索引从该列表中获取类别名称;如果没有,则使用类别索引作为类别名称。cls = self.data_dict['names'][cls] if self.data_dict['names'] else cls# 如果 labels 为 True ,则表示目标数据中没有置信度列,只显示类别标签。if labels:label = f'{cls}'# 在马赛克图像 mosaic 上绘制边界框, box[0] 和 box[1] 是左上角坐标, box[2] 和 box[3] 是右下角坐标, color 是边界框颜色。cv2.rectangle(mosaic, (box[0], box[1]), (box[2], box[3]), color, thickness=1)# 在边界框上方绘制类别标签, label 是类别名称, (box[0], box[1] - 5) 是文本的左上角坐标, cv2.FONT_HERSHEY_COMPLEX 是字体类型, 0.5 是字体缩放比例, color 是文本颜色。cv2.putText(mosaic, label, (box[0], box[1] - 5), cv2.FONT_HERSHEY_COMPLEX, 0.5, color, thickness=1)self.vis_train_batch = mosaic.copy()# 它用于绘制验证阶段的预测结果,并将绘制好的图像保存到一个列表中。# vis_outputs :预测结果的输出列表,其中包含每个图像的边界框信息。# vis_paths :对应于 vis_outputs 中预测结果的图像路径列表。# vis_conf :可视化时使用的置信度阈值,默认为0.3。# vis_max_box_num :每张图像可视化的最大边界框数量,默认为5。def plot_val_pred(self, vis_outputs, vis_paths, vis_conf=0.3, vis_max_box_num=5):# 绘制验证预测。# plot validation predictions# 初始化一个空列表,用于存储绘制好的图像。self.vis_imgs_list = []# 遍历预测结果和对应的图像路径。for (vis_output, vis_path) in zip(vis_outputs, vis_paths):# 将预测结果从GPU(如果需要)转移到CPU,并转换为NumPy数组,以便进行图像绘制操作。数组格式为 xyxy ,即每个边界框的坐标为 (x1, y1, x2, y2, score, cls_id) 。vis_output_array = vis_output.cpu().numpy()     # xyxy# 使用OpenCV读取对应的原始图像。ori_img = cv2.imread(vis_path)# 遍历每个边界框。for bbox_idx, vis_bbox in enumerate(vis_output_array):# x_tl, y_tl, x_br, y_br, box_score, cls_id :从 vis_bbox 中提取边界框的坐标、得分和类别ID。x_tl = int(vis_bbox[0])y_tl = int(vis_bbox[1])x_br = int(vis_bbox[2])y_br = int(vis_bbox[3])box_score = vis_bbox[4]cls_id = int(vis_bbox[5])# 绘制前 n 个 预测框。# draw top n bbox# 如果边界框的得分低于置信度阈值 vis_conf 或边界框索引超过最大数量 vis_max_box_num ,则停止绘制。if box_score < vis_conf or bbox_idx > vis_max_box_num:break# 使用OpenCV的 cv2.rectangle 在原始图像上绘制边界框。cv2.rectangle(ori_img, (x_tl, y_tl), (x_br, y_br), tuple([int(x) for x in self.color[cls_id]]), thickness=1)# 使用OpenCV的 cv2.putText 在边界框上方绘制类别名称和得分。cv2.putText(ori_img, f"{self.data_dict['names'][cls_id]}: {box_score:.2f}", (x_tl, y_tl - 10), cv2.FONT_HERSHEY_COMPLEX, 0.5, tuple([int(x) for x in self.color[cls_id]]), thickness=1)# 将绘制好的图像从BGR格式转换为RGB格式,并转换为PyTorch张量,然后添加到 self.vis_imgs_list 列表中。self.vis_imgs_list.append(torch.from_numpy(ori_img[:, :, ::-1].copy()))# PTQ# 它用于执行模型的量化校准过程,并保存校准后的模型。def calibrate(self, cfg):# 它负责将校准后的模型保存到指定的路径。# 1.model :经过校准的模型实例。# 2.cfg :包含配置信息的对象,其中包含模型保存路径和校准方法等信息。def save_calib_model(model, cfg):# 保存校准检查点。# Save calibrated checkpoint# 构建保存校准模型的完整路径。# cfg.ptq.calib_output_path :配置中指定的输出路径。# os.path.splitext(os.path.basename(cfg.model.pretrained))[0] :从预训练模型路径中提取文件名,并去除扩展名。# cfg.ptq.calib_method :使用的校准方法。output_model_path = os.path.join(cfg.ptq.calib_output_path, '{}_calib_{}.pt'.format(os.path.splitext(os.path.basename(cfg.model.pretrained))[0], cfg.ptq.calib_method))# 如果配置中指定跳过敏感层的量化。if cfg.ptq.sensitive_layers_skip is True:# 在输出文件名中添加 _partial 后缀,以标识该模型包含部分量化的层。output_model_path = output_model_path.replace('.pt', '_partial.pt')# 使用日志记录器记录保存模型的路径。LOGGER.info('Saving calibrated model to {}... '.format(output_model_path))# 检查配置中指定的输出路径是否存在。if not os.path.exists(cfg.ptq.calib_output_path):# 如果路径不存在,则创建该路径。os.mkdir(cfg.ptq.calib_output_path)# 使用 PyTorch 的 torch.save 函数保存模型。# deepcopy(de_parallel(model)) :去除模型的并行包装(如 DataParallel 或 DistributedDataParallel ),并创建模型参数的深拷贝。# .half() :将模型参数转换为半精度(FP16),以减少模型文件的大小。torch.save({'model': deepcopy(de_parallel(model)).half()}, output_model_path)# 断言确保 self.args.quant 和 self.args.calib 都为 True  。这个断言用来检查是否同时启用了 量化 和 校准 。assert self.args.quant is True and self.args.calib is True# 只有主进程执行校准流程,以避免在分布式训练中重复执行。if self.main_process:# 从 tools.qat.qat_utils 模块导入 ptq_calibrate 函数,该函数用于执行量化校准。from tools.qat.qat_utils import ptq_calibrate# 调用 ptq_calibrate 函数对模型进行量化校准。这个函数需要模型、训练数据加载器和配置信息作为输入。ptq_calibrate(self.model, self.train_loader, cfg)# 在校准完成后,将训练周期重置为0,以便重新开始训练流程。self.epoch = 0# 调用 eval_model 方法评估校准后的模型性能。self.eval_model()# 调用 save_calib_model 函数保存校准后的模型。这个函数负责将校准后的模型保存到磁盘。save_calib_model(self.model, cfg)# 它用于设置模型的量化,特别是在量化感知训练(Quantization Aware Training, QAT)的上下文中。# QAT# self :类的实例。# model :要进行量化的模型。# cfg :配置信息,可能包含量化相关的设置。# device :模型运行的设备,比如CPU或GPU。def quant_setup(self, model, cfg, device):# 如果 self.args.quant 为 True ,则执行量化设置。if self.args.quant:# qat_init_model_manu(model, cfg, self.args) -> 手动初始化模型量化的函数。# kip_sensitive_layers(model, cfg.qat.sensitive_layers_list) -> 跳过敏感层量化的函数。from tools.qat.qat_utils import qat_init_model_manu, skip_sensitive_layers# 调用 qat_init_model_manu 函数来初始化模型的量化,这个函数可能会对模型的权重和激活进行量化。qat_init_model_manu(model, cfg, self.args)# workaround# 调用 model.neck.upsample_enable_quant 方法来启用或禁用上采样层的量化,根据配置中的位数和校准方法。model.neck.upsample_enable_quant(cfg.ptq.num_bits, cfg.ptq.calib_method)# if self.main_process:#     print(model)# QAT# 如果 self.args.calib 为 False (即不在校准阶段),并且配置中指定了要跳过的敏感层,则调用 skip_sensitive_layers 函数来跳过这些层的量化。if self.args.calib is False:if cfg.qat.sensitive_layers_skip:skip_sensitive_layers(model, cfg.qat.sensitive_layers_list)# QAT flow load calibrated model# 断言配置中提供了校准后的模型路径( cfg.qat.calib_pt ),然后使用 torch.load 加载校准后的模型状态字典,并将其加载到当前模型中。assert cfg.qat.calib_pt is not None, 'Please provide calibrated model'# load_state_dict(state_dict, strict=True)# .load_state_dict() 是 PyTorch 中模型( nn.Module )的一个方法,用于将一个状态字典(state dictionary)加载到模型中。状态字典包含了模型的参数(权重和偏置),可以用来初始化或更新模型的参数。# 参数说明 :# 1. state_dict :要加载的状态字典。状态字典是一个从层名称映射到参数张量的字典对象。# 2. strict (可选):一个布尔值,默认为 True 。如果设置为 True ,则会检查状态字典中的每个键是否与模型中的参数名称匹配。如果设置为 False ,则会忽略不匹配的键,并且只更新状态字典中存在的参数。# 返回值 :# 无返回值,因为这个方法会直接修改模型的状态。model.load_state_dict(torch.load(cfg.qat.calib_pt)['model'].float().state_dict())# 使用 model.to(device) 将模型移动到指定的设备(如CPU或GPU)。model.to(device)

 

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com