https://huggingface.co/HuggingFaceTB/SmolLM-135M-Instruct
继续学习SmolLM
模型概述
SmolLM 是一系列小型语言模型,有三种规模:参数数量分别为 1.35 亿、3.6 亿和 17 亿。
这些模型在 SmolLM 语料库上进行训练,该语料库是经过精心整理的高质量教育及合成数据集合,专为训练大语言模型而设计。更多详细信息,请参阅我们的博客文章。
为构建 SmolLM-Instruct,我们在公开可用的数据集上对基础模型进行了微调。
变更日志
版本发布 | 描述 |
---|---|
v0.1 | SmolLM-Instruct 的首次发布。我们在 WebInstructSub 数据集的允许使用子集上进行微调,并结合了 StarCoder2-Self-OSS-Instruct。然后,对于 1.35 亿参数和 17 亿参数的模型,在 HelpSteer 上进行了一个周期的直接偏好优化(DPO);对于 3.6 亿参数的模型,则在 argilla/dpo-mix-7k 上进行了直接偏好优化。 |
v0.2 | 我们将微调数据组合更改为更适合小型模型的数据集。我们在由 llama3.1-70B 生成的包含 2000 个简单日常对话的新数据集(everyday-conversations-llama3.1-2k)、Magpie-Pro-300K-Filtered、StarCoder2-Self-OSS-Instruct 以及 OpenHermes-2.5 的一小部分子集上进行训练。 v0.2 版本的模型在紧扣主题以及对标准提示(如问候语和关于其作为人工智能助手角色的问题)做出恰当回应方面表现更出色。在 AlpacaEval 评估中,SmolLM-360M-Instruct(v0.2)相较于 SmolLM-360M-Instruct(v0.1)的胜率为 63.3%。你可以在此处找到详细信息。 |
你可以在 transformers 代码中通过指定 revision="v0.1" 来加载 v0.1 版本的模型:
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct", revision="v0.1")
用法
本地应用
⚡ 对于本地应用,除了在这个集合中快速的浏览器演示之外(https://huggingface.co/collections/HuggingFaceTB/local-smollms-66c0f3b2a15b4eed7fb198d0),你还可以找到 MLC、GGUF 和 Transformers.js 格式的优化模型实现。
我们注意到,4 位量化会降低 1.35 亿参数和 3.6 亿参数模型的质量,因此对于 MLC,我们使用 q016 量化,对于 WebGPU 演示,则使用 ONNX/Transformers.js 检查点。我们还建议使用温度 0.2 和核采样参数 top-p 为 0.9。
Transformers
安装 transformers:
bash
pip install transformers
python
# pip install transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
checkpoint = "HuggingFaceTB/SmolLM-135M-Instruct"device = "cuda" # 使用 GPU 时设置为 "cuda",使用 CPU 时设置为 "cpu"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
# 对于多 GPU 环境,安装 accelerate 并使用 `model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto")`
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)messages = [{"role": "user", "content": "What is the capital of France."}]
input_text=tokenizer.apply_chat_template(messages, tokenize=False)
print(input_text)
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
outputs = model.generate(inputs, max_new_tokens=50, temperature=0.2, top_p=0.9, do_sample=True)
print(tokenizer.decode(outputs[0]))
使用 TRL 进行聊天
你也可以使用 TRL 命令行界面在终端中与模型进行聊天:
pip install trl
trl chat --model_name_or_path HuggingFaceTB/SmolLM-135M-Instruct --device cpu
局限性
此外,生成的内容可能并不总是在事实上准确、逻辑上一致,或者没有训练数据中存在的偏差。我们建议用户将其用作辅助工具,而不是作为确定的信息来源。我们发现,这些模型可以处理常识性知识问题、创意写作和基本的 Python 编程。但它们仅支持英语,并且在处理算术、编辑任务和复杂推理方面可能存在困难。有关这些模型能力的更多详细信息,请参阅我们的博客文章。
训练参数
我们使用对齐手册,在变更日志中提到的数据集上训练模型,v0.2 版本使用以下参数(其中大多数参数来自 Zephyr Gemma 的训练方案):
- 训练 1 个周期
- 学习率为 1e-3
- 余弦退火学习率调度
- 热身比例为 0.1
- 全局批量大小为 262k 个词元
你可以在此处找到训练方案:https://github.com/huggingface/alignment-handbook/tree/smollm/recipes/smollm
引用
plaintext
@misc{allal2024SmolLM,title={SmolLM - blazingly fast and remarkably powerful}, author={Loubna Ben Allal and Anton Lozhkov and Elie Bakouch and Leandro von Werra and Thomas Wolf},year={2024},
}
代码添加与更改
config.json
{"_name_or_path": "HuggingFaceTB/SmolLM-135M","architectures": ["LlamaForCausalLM"],"attention_bias": false,"attention_dropout": 0.0,"bos_token_id": 1,"eos_token_id": 2,"hidden_act": "silu","hidden_size": 576,"initializer_range": 0.02,"intermediate_size": 1536,"max_position_embeddings": 2048,"mlp_bias": false,"model_type": "llama","num_attention_heads": 9,"num_hidden_layers": 30,"num_key_value_heads": 3,"pad_token_id": 2,"pretraining_tp": 1,"rms_norm_eps": 1e-05,"rope_scaling": null,"rope_theta": 10000.0,"tie_word_embeddings": true,"torch_dtype": "bfloat16","transformers_version": "4.42.3","use_cache": true,"vocab_size": 49152
}
添加model代码
import torch
from llmc.utils.registry_factory import MODEL_REGISTRY
from .base_model import BaseModel
from transformers import AutoConfig, SmolVLMForConditionalGeneration
from loguru import logger
from accelerate import Accelerator, DistributedType
from typing import Optional, Union
from transformers.models.llama.modeling_llama import LlamaRMSNorm
# from .smolvlm_model import SmolVLMAutoModelForCausalLM
from llmc.compression.quantization.module_utils import (_LLMC_LINEAR_TYPES_, _LLMC_LN_TYPES_, _TRANSFORMERS_LINEAR_TYPES_,_TRANSFORMERS_LN_TYPES_, LlmcFp8Linear)@MODEL_REGISTRY
class SmolVLM2(BaseModel):def __init__(self, config, device_map=None, use_cache=False):super().__init__(config, device_map, use_cache)self.vision_prefix = "model.vision_model"self.text_prefix = "model.text_model"self._init_modality_specific_params()# 添加兼容性属性self.linear_blocks = [] # 用于兼容旧式索引访问self.block_modality_map = {} # 记录块所属模态def _init_modality_specific_params(self):"""初始化多模态专用参数"""self.blocks = {"vision": [],"text": []}self.vision_embeds = []self.text_embeds = []self.block_name_prefix = {}self.pairs = {}def build_model(self):self.model_config = AutoConfig.from_pretrained(self.model_path,trust_remote_code=True, # 必须启用model_type="smolvlm", # 显式指定类型torch_dtype=torch.bfloat16 # 强制指定配置类型)# 使用自定义加载器self.model = SmolVLMForConditionalGeneration.from_pretrained(self.model_path,config=self.model_config,device_map=self.device_map,trust_remote_code=True, # 关键参数torch_dtype=torch.bfloat16, # 统一加载类型low_cpu_mem_usage=True,)# smol_VLMForConditionalGeneration=self.model# self.model=self.model.model# 修正lm_head数据类型if self.model.lm_head.weight.dtype != torch.bfloat16:self.model.lm_head = self.model.lm_head.to(torch.bfloat16)logger.info(f"lm_head dtype: {self.model.lm_head.weight.dtype}")# 初始化组件引用self.vision_model = self.model.model.vision_modelself.text_model = self.model.model.text_modelself.connector = self.model.model.connector# 验证类型一致性text_emb = self.text_model.embed_tokensassert text_emb.weight.dtype == torch.bfloat16, "文本嵌入层类型错误"assert self.model.lm_head.weight.dtype == torch.bfloat16, "输出头类型错误"# 统一设备初始化# self._sync_device()def find_blocks(self):# 文本模型的块(LlamaDecoderLayer)作为主要处理块self.blocks = self.text_model.layers# 视觉模型的块单独存储(可选,根据需求)self.vision_blocks = self.vision_model.encoder.layersdef find_embed_layers(self):# 视觉嵌入层:patch embedding( Conv2d)和位置嵌入(Embedding)self.vision_patch_embed = self.vision_model.embeddings.patch_embeddingself.vision_pos_embed = self.vision_model.embeddings.position_embedding# 文本嵌入层self.text_embed_tokens = self.text_model.embed_tokensdef get_embed_layers(self):# 返回所有嵌入层(视觉和文本)return [self.vision_patch_embed, self.vision_pos_embed, self.text_embed_tokens]def get_head_layers(self):# 生成头return [self.model.lm_head]def get_pre_head_layernorm_layers(self):# 文本模型的最终层归一化return [self.text_model.norm]def get_layers_except_blocks(self):# 除块外的层:视觉嵌入、视觉后归一化、文本嵌入、文本最终归一化、生成头return [self.vision_patch_embed,self.vision_pos_embed,self.vision_model.post_layernorm,self.text_embed_tokens,self.text_model.norm,self.model.lm_head]def skip_layer_name(self):# 跳过生成头(与原始LLaMA逻辑一致)return ['lm_head']def has_bias(self):# 视觉模块的线性层有偏置(q_proj/k_proj/v_proj/out_proj均bias=True),文本模块无偏置return Truedef get_layernorms_in_block(self, block):# 处理文本块的层归一化(与LLaMA一致)return {'input_layernorm': block.input_layernorm,'post_attention_layernorm': block.post_attention_layernorm,}def get_subsets_in_block(self, block):# 文本块的子集结构(与LLaMA一致)return [{'layers': {'self_attn.q_proj': block.self_attn.q_proj,'self_attn.k_proj': block.self_attn.k_proj,'self_attn.v_proj': block.self_attn.v_proj,},'prev_op': [block.input_layernorm],'input': ['self_attn.q_proj'],'inspect': block.self_attn,'has_kwargs': True,},{'layers': {'self_attn.o_proj': block.self_attn.o_proj},'prev_op': [block.self_attn.v_proj],'input': ['self_attn.o_proj'],'inspect': block.self_attn.o_proj,'has_kwargs': False,},{'layers': {'mlp.gate_proj': block.mlp.gate_proj,'mlp.up_proj': block.mlp.up_proj,},'prev_op': [block.post_attention_layernorm],'input': ['mlp.gate_proj'],'inspect': block.mlp,'has_kwargs': False,'is_mlp': True,},{'layers': {'mlp.down_proj': block.mlp.down_proj},'prev_op': [block.mlp.up_proj],'input': ['mlp.down_proj'],'inspect': block.mlp.down_proj,'has_kwargs': False,'is_mlp': True,},]# 以下为可选扩展(若需处理视觉块,可添加额外方法,但BaseModel未强制要求)def find_block_name(self):# 定义文本块的命名前缀(与LLaMA一致)self.block_name_prefix = 'text_model.layers'self.pairs = {'q_proj': 'qkv', 'o_proj': 'out', 'up_proj': 'fc1'}# 保持与BaseModel接口兼容的其他方法(如需可补充视觉处理逻辑)
配置新的SmolVLM2
from .bloom import Bloom
from .chatglm import ChatGLM
from .deepseekv2 import DeepseekV2
from .deepseekv3 import DeepseekV3
from .falcon import Falcon
from .gemma2 import Gemma2
from .glm4v import GLM4V
from .internlm2 import InternLM2
from .internomni import InternOmni
from .internvl2 import InternVL2
from .llama import Llama
from .llava import Llava
from .minicpm import MiniCPM
from .minicpmv import MiniCPMV
from .mistral import Mistral
from .mixtral import Mixtral
from .mllama import Mllama
from .opt import Opt
from .phi import Phi
from .phi3 import Phi3
from .qwen import Qwen
from .qwen2 import Qwen2
from .qwen2audio import Qwen2Audio
from .qwen2moe import Qwen2Moe
from .qwen2vl import Qwen2VL
from .smollm import SmolLM
from .smolvlm2 import SmolVLM2
from .stablelm import StableLm
from .starcoder import Starcoder
from .vila import Vila
from .vit import Vit
量化配置文件
base:seed: &seed 42
model:type: SmolVLM2 #【SmolLM,SmolVLM2】path: /mnt/share/toky/LLMs/SmolVLM2-2.2B-Instruct/ #【/mnt/share/toky/LLMs/SmolVLM2-2.2B-Instruct/,/mnt/share/toky/LLMs/SmolLM-135M-Instruct/】tokenizer_mode: slowtorch_dtype: auto
calib:name: pilevaldownload: Falsepath: /mnt/share/toky/Datasets/LLMC/pileval/n_samples: 128bs: -1seq_len: 512preproc: pileval_awqseed: *seed
eval:eval_pos: [pretrain, transformed, fake_quant]name: wikitext2download: Falsepath: /mnt/share/toky/Datasets/LLMC/wikitext2/seq_len: 2048# For 7B / 13B model eval, bs can be set to "1", and inference_per_block can be set to "False".# For 70B model eval, bs can be set to "20", and inference_per_block can be set to "True".bs: 1inference_per_block: False
quant:vision:method: Awqweight:bit: 4symmetric: Truegranularity: per_groupgroup_size: 16special:trans: True# The options for "trans_version" include "v1" and "v2".# But their results don't differ significantly.trans_version: v2weight_clip: True# For 2-bit quantization, setting "clip_sym: False" will yield better results.clip_sym: Truelanguage:method: Awqweight:bit: 4symmetric: Truegranularity: per_groupgroup_size: 128special:trans: True# The options for "trans_version" include "v1" and "v2".# But their results don't differ significantly.trans_version: v2weight_clip: True# For 2-bit quantization, setting "clip_sym: False" will yield better results.clip_sym: True
save:save_trans: Falsesave_fake: Falsesave_vllm: Falsesave_path: /mnt/share/toky/Projects/LLMC_Test/llmc_quantized/SmolVLM2
修改了base_blockwise_quantization.py
import copy
import functools
import gc
import json
import os
import re
from collections import defaultdict
from functools import partialimport torch
import torch.distributed as dist
import torch.nn as nn
from loguru import loggerfrom llmc.utils.registry_factory import KV_REGISTRY, TOKEN_REDUCTION_REGISTRYfrom ..blockwise_optimization import BlockwiseOpt
from .attn_utils import _LLMC_ATTN_MAP_
from .auto_clip import AutoClipper
from .utils import is_fp8_supported_gpuif is_fp8_supported_gpu():from .kernel import weight_cast_to_bf16, weight_cast_to_fp8logger.info('import kernel successful.')
else:from .quant import weight_cast_to_bf16, weight_cast_to_fp8logger.info('import quant successful.')from .hadamard_utils import apply_exact_had_to_linear, get_hadK
from .module_utils import (_LLMC_LINEAR_TYPES_, _LLMC_LN_TYPES_,_REALQUANT_LINEAR_MAP_, _TRANSFORMERS_LINEAR_TYPES_,_TRANSFORMERS_LN_TYPES_, EffcientFakeQuantLinear,FakeQuantLinear, LlmcActFn, OriginFloatLinear,RotateLinear)
from .quant import FloatQuantizer, IntegerQuantizer, Weight48IntegerQuantizer
from .utils import check_do_quant, check_w_only, get_aquantizer, get_wquantizerclass BaseBlockwiseQuantization(BlockwiseOpt):def __init__(self, model, quant_config, input, padding_mask, config):super().__init__(model, quant_config, input, padding_mask, config)self.set_quant_config()def w_qdq(self, module, wquantizer):args = {'lowbound_factor': None, 'upbound_factor': None}if hasattr(module, 'buf_lowbound_factor'):args['lowbound_factor'] = module.buf_lowbound_factorif hasattr(module, 'buf_upbound_factor'):args['upbound_factor'] = module.buf_upbound_factorif module.weight.data.dtype == torch.float8_e4m3fn:tmp_weight \= weight_cast_to_bf16(module.weight,module.weight_scale_inv).to(torch.bfloat16)else:tmp_weight = module.weighttmp_weight = wquantizer.fake_quant_weight_dynamic(tmp_weight, args)if module.weight.data.dtype == torch.float8_e4m3fn:tmp_weight, module.weight_scale_inv.data = weight_cast_to_fp8(tmp_weight)return tmp_weightdef w_q(self, module, wquantizer):return wquantizer.real_quant_weight_dynamic(module.weight.data)def a_qdq(self, act, module, aquantizer, input_index=0):if self.act_static:args = {'scales': (getattr(module, f'buf_act_scales_{input_index}', None)),'zeros': (getattr(module, f'buf_act_zeros_{input_index}', None)),'qmax': (getattr(module, f'buf_act_qmax_{input_index}', None)),'qmin': (getattr(module, f'buf_act_qmin_{input_index}', None)),}return aquantizer.fake_quant_act_static(act, args)else:return aquantizer.fake_quant_act_dynamic(act)def get_replacement_params(self, mode='fake_quant', w_only=False, name=None):params_dict = {}if mode in ['fake_quant', 'fake_quant_wo_kv']:if not self.mix_bits:params_dict['a_qdq'] = (partial(self.a_qdq, aquantizer=self.aquantizer)if not w_onlyelse None)params_dict['w_qdq'] = partial(self.w_qdq, wquantizer=self.wquantizer)else:params_dict['mix_bits'] = Trueparams_dict['a_qdq'] = self.a_qdqparams_dict['w_qdq'] = self.w_qdqparams_dict['mix_bits_map'] = self.mix_bits_mapparams_dict['quantizer_mix_bits'] = self.quantizer_mix_bitsparams_dict['wquantizer_default'] = self.wquantizerparams_dict['aquantizer_default'] = self.aquantizerparams_dict['w_only_default'] = w_onlyelif mode in _REALQUANT_LINEAR_MAP_.keys():params_dict['w_q'] = partial(self.w_q, wquantizer=self.wquantizer)params_dict['quant_config'] = self.quant_configelif mode == 'online_rotate':had_K, K = get_hadK(self.intermediate_size if 'down_proj' in name else self.num_heads)params_dict = {'had_K': had_K,'K': K,'online_full_had': 'down_proj' in name,'online_partial_had': 'o_proj' in name,'had_dim': (None if 'down_proj' in name else self.hidden_size // self.num_heads),'fp32_had': self.fp32_had,}elif mode == 'quant_attn':params_dict = {'matmul_a1_qdq': partial(self.a_qdq, aquantizer=self.aquantizer, input_index=0),'matmul_a2_qdq': partial(self.a_qdq, aquantizer=self.aquantizer, input_index=1),'softmax_a_qdq': (partial(self.a_qdq, aquantizer=self.aquantizer)if self.quant_softmaxelse None),}elif mode == 'quant_act_fn':params_dict = {'a_qdq': partial(self.a_qdq, aquantizer=self.aquantizer)}return params_dictdef alloc_bits(self, mix_bits_settings):for i in range(len(mix_bits_settings)):mix_bits_setting = mix_bits_settings[f'setting_{i}']if mix_bits_setting['do_quant']:wquantizer_mix_bits = self.quant_module(**mix_bits_setting['weight'])if 'act' in mix_bits_setting:w_only_mix_bits = Falseaquantizer_mix_bits = self.quant_module(**mix_bits_setting['act'])else:w_only_mix_bits = Trueself.quantizer_mix_bits.append({'layer_name': mix_bits_setting['layer_name'],'do_quant': mix_bits_setting['do_quant'],'w_only_mix_bits': w_only_mix_bits,'wquantizer': wquantizer_mix_bits,'aquantizer': (aquantizer_mix_bits if not w_only_mix_bits else None),})else:self.quantizer_mix_bits.append({'layer_name': mix_bits_setting['layer_name'],'do_quant': mix_bits_setting['do_quant'],})for i in range(len(self.quantizer_mix_bits)):logger.info(f'quantizer_mix_bits {i} : {self.quantizer_mix_bits[i]}')layer_name = self.quantizer_mix_bits[i]['layer_name']for name in layer_name:n_layeridx = name.split('#')assert (len(n_layeridx) == 1 or len(n_layeridx) == 2), 'layer_name in mix_bits must be name#1-3-4 or name.'if len(n_layeridx) == 2:n = n_layeridx[0]layeridx = n_layeridx[1].split('-')layeridx = [int(idx) for idx in layeridx]else:n = n_layeridx[0]layeridx = 'all'if layeridx == 'all':for k in range(self.num_blocks):self.mix_bits_map[k][n] = ielse:for k in layeridx:self.mix_bits_map[k][n] = idef set_quant_config(self):self.mix_bits = 'mix_bits' in self.quant_configself.mix_bits_map = [{} for _ in range(self.num_blocks)]self.quantizer_mix_bits = []if 'ignored_layers' in self.config:self.mixed_precision = Trueself.ignored_block_ids = self.config.ignored_layers.get('block_ids', [])self.ignored_layer_names = self.config.ignored_layers.get('layer_names', [])self.ignored_speical_names = self.config.ignored_layers.get('speical_names', [])else:self.mixed_precision = Falseself.quant_out = self.quant_config.get('quant_out', False)self.tp = self.quant_config.get('tp', 1)self.quant_config['weight']['tp'] = self.tp# select quantizer# weightquant_type = self.quant_config['weight'].get('quant_type', 'int-quant')if quant_type == 'int-quant':if self.quant_config['weight']['bit'] == 48:self.weight_quant_module = Weight48IntegerQuantizerelse:self.weight_quant_module = IntegerQuantizerelif quant_type == 'float-quant':self.weight_quant_module = FloatQuantizerlogger.info(f'The used Weight Quant Module is {self.weight_quant_module}')self.wquantizer = self.weight_quant_module(**self.quant_config['weight'])# actif 'act' in self.quant_config:if self.quant_config['weight']['granularity'] == 'per_block':assert self.quant_config['act']['granularity'] == 'per_group'assert self.quant_config['act']['group_size'] \== self.quant_config['weight']['block_size']self.w_only = Falsequant_type = self.quant_config['act'].get('quant_type', 'int-quant')if quant_type == 'int-quant':if self.quant_config['act']['bit'] == 48:self.act_quant_module = Weight48IntegerQuantizerelse:self.act_quant_module = IntegerQuantizerelif quant_type == 'float-quant':self.act_quant_module = FloatQuantizerself.quant_config['act']['tp'] = self.tpself.aquantizer = self.act_quant_module(**self.quant_config['act'])self.act_static = self.quant_config['act'].get('static', False)if self.act_static:assert (self.quant_config['act']['granularity'] == 'per_tensor'), 'Only support per_tensor static quant'self.quant_attn = self.quant_config['act'].get('quant_attn', False)if self.quant_attn:assert self.config['model']['type'] in ['Vit', 'DeepseekV2']self.quant_softmax = self.quant_config['act'].get('quant_softmax', False)self.quant_act_fn = self.quant_config['act'].get('quant_act_fn', False)else:self.w_only = Trueself.aquantizer = Noneself.act_static = Falseself.quant_attn = Falseself.quant_softmax = Falseself.quant_act_fn = False# set mix-bits quant configif self.mix_bits:mix_bits_settings = self.quant_config['mix_bits']logger.info(f'mix_bits_settings number: {len(mix_bits_settings)}')logger.info(f'mix_bits_settings:\n'f'{json.dumps(mix_bits_settings, ensure_ascii=False, indent=4)}')self.alloc_bits(mix_bits_settings)logger.info(f'self.mix_bits_map:\n'f'{json.dumps(self.mix_bits_map, ensure_ascii=False, indent=4)}')# set kv cache quant configif 'kvcache' in self.quant_config:self.quant_config['kvcache']['static'] = self.act_statickv_special_cfg = self.quant_config['kvcache'].get('special', {})act_static_cfg = {}if self.act_static:act_static_cfg.update(self.config.calib.n_sample)act_static_cfg.update(self.config.calib.bs)kv_quant_type = self.quant_config['kvcache'].get('quant_type', 'int-quant')self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']](kv_quant_type, self.quant_config['kvcache'],self.model.model_config.text_config.num_hidden_layers, **kv_special_cfg, **act_static_cfg)self.quant_kvcache = Trueself.model.kvcache_buffer.append(self.kv_module)else:self.quant_kvcache = False# set special quant configspecial_config = self.quant_config.get('special', {})self.true_sequential = special_config.get('true_sequential', False)# set weight clip configself.weight_clip = special_config.get('weight_clip', False)if self.weight_clip or special_config.get('search_clip_init', False):self.save_clip = special_config.get('save_clip', False)if self.save_clip:self.clip_path = special_config['clip_path']self.clip_version = special_config.get('clip_version', 'v1')if self.clip_version == 'v2':assert self.wquantizer.calib_algo == 'learnable'clip_sym = special_config.get('clip_sym', self.wquantizer.sym)self.auto_clipper = AutoClipper(w_only=self.w_only,mix_bits_map=self.mix_bits_map,quantizer_mix_bits=self.quantizer_mix_bits,wquantizer=self.wquantizer,aquantizer=self.aquantizer,clip_version=self.clip_version,clip_sym=clip_sym,save_clip=self.save_clip,padding_mask=self.padding_mask,)# set transformation configself.save_scale = special_config.get('save_scale', False)if self.save_scale:self.scale_path = special_config['scale_path']self.act_scales = {}# set online-rotation configself.online_rotate = special_config.get('online_rotate', False)if self.online_rotate:assert (self.config['model']['type'] in ['Opt', 'Llama']), 'Please set online_rotate=False'self.fp32_had = special_config.get('fp32_had', False)self.hidden_size = self.model.model_config.text_config.hidden_sizeself.set_model_config()self.modality = self.quant_config.modalitylogger.info(f'self.quant_objects : {self.quant_config.modality}')# set token reduction configif 'token_reduction' in self.quant_config:token_reduction_cfg = self.quant_config['token_reduction']TOKEN_REDUCTION_REGISTRY[self.quant_config['token_reduction']['method']](token_reduction_cfg, self.model, self.blocks)self.do_gqa_trans = special_config.get('do_gqa_trans', False)logger.info(f'self.do_gqa_trans : {self.do_gqa_trans}')def set_model_config(self):self.hidden_size = self.model.model_config.text_config.hidden_sizeself.num_heads = self.model.model_config.text_config.num_attention_headsself.head_dim = self.hidden_size // self.num_headsif hasattr(self.model.model_config.text_config, 'intermediate_size'):self.intermediate_size = self.model.model_config.text_config.intermediate_sizeif hasattr(self.model.model_config.text_config, 'num_key_value_heads'):self.num_key_value_heads = self.model.model_config.text_config.num_key_value_headsself.num_key_value_groups = self.num_heads // self.num_key_value_headsif self.num_key_value_groups > 1:self.has_gqa = Trueelse:self.has_gqa = Falseelse:self.has_gqa = Falsedef replace_rotate_linears(self, block):for n, m in block.named_modules():if isinstance(m, nn.Linear) and ('down_proj' in n or 'o_proj' in n or 'fc2' in n or 'out_proj' in n):subset = {'layers': {n: m}}self.model.replace_module_subset(RotateLinear,block,subset,None,self.get_replacement_params(mode='online_rotate', w_only=self.w_only, name=n),)def replace_act_fn(self, block, extra_modules):act_fn_dict = self.model.get_act_fn_in_block(block)layers_dict = {'layers': act_fn_dict}self.model.replace_module_subset(LlmcActFn,block,layers_dict,self.block_idx,self.get_replacement_params(mode='quant_act_fn', w_only=self.w_only, name=None),)extra_modules.update(act_fn_dict)def replace_attention(self, block, extra_modules):attn_layers_dict = self.model.get_attn_in_block(block)layers_dict = {'layers': attn_layers_dict}attn_module = _LLMC_ATTN_MAP_[self.config['model']['type']]self.model.replace_module_subset(attn_module,block,layers_dict,self.block_idx,self.get_replacement_params(mode='quant_attn', w_only=self.w_only, name=None),)matmul_modules = self.model.get_matmul_in_block(block)softmax_modules = (self.model.get_softmax_in_block(block) if self.quant_softmax else {})extra_modules.update(matmul_modules)extra_modules.update(softmax_modules)@torch.no_grad()def collect_block_qparams(self, block):named_linears = self.model.get_block_linears(block)for n, m in named_linears.items():args = {}if hasattr(m, 'buf_lowbound_factor'):args['lowbound_factor'] = m.buf_lowbound_factorif hasattr(m, 'buf_upbound_factor'):args['upbound_factor'] = m.buf_upbound_factorif m.weight.data.dtype == torch.float8_e4m3fn:tmp_weight_data = weight_cast_to_bf16(m.weight.data,m.weight_scale_inv.data).to(torch.bfloat16)else:tmp_weight_data = m.weight.data(tensor,scales,zeros,max_int,min_int,) = self.wquantizer.get_tensor_qparams(tmp_weight_data, args=args)m.register_buffer('buf_scales', scales.detach())m.register_buffer('buf_zeros', zeros.detach())m.register_buffer('buf_qmax', torch.tensor(max_int).to(self.dev))m.register_buffer('buf_qmin', torch.tensor(min_int).to(self.dev))def block_forward(self, block, input_data=None):output = []if input_data is None:input_data = self.input['data']for i in range(len(input_data)):input_data[i] = input_data[i].to(device=next(block.parameters()).device)for k in self.input['kwargs'][i]:if torch.is_tensor(self.input['kwargs'][i][k]):self.input['kwargs'][i][k] = self.input['kwargs'][i][k].to(device=next(block.parameters()).device) # noqaif isinstance(self.input['kwargs'][i][k], tuple):self.input['kwargs'][i][k] = tuple(tmp.to(device=next(block.parameters()).device)for tmp in self.input['kwargs'][i][k]) # noqawith torch.no_grad():out = block(input_data[i], **self.input['kwargs'][i])if isinstance(out, tuple):out = out[0]output.append(out)return outputdef block_opt(self, block):if self.quant_kvcache:self.register_kv_cache(block)block = block.cuda()named_linears = self.model.get_block_linears(block)extra_modules = self.model.get_extra_modules(block)if self.quant_attn:self.replace_attention(block, extra_modules)if self.quant_act_fn:self.replace_act_fn(block, extra_modules)input_feat_modules = {k: v for d in [named_linears, extra_modules] for k, v in d.items()}logger.info(f'input_feat_modules: {input_feat_modules}')input_feat = defaultdict(list)handles = self.register_hooks(input_feat_modules, input_feat)self.block_init(block)self.run(block, input_feat, handles)block = block.cpu()del input_feat, blockgc.collect()torch.cuda.empty_cache()def register_hooks(self, input_feat_modules, input_feat):handles = []if not self.data_free:for name in input_feat_modules:handles.append(input_feat_modules[name].register_forward_hook(functools.partial(self.cache_input_hook, name=name, feat_dict=input_feat)))return handlesdef run(self, block, input_feat, handles):if not self.data_free:if self.quant_out:self.block_forward(block)else:self.input['data'] = self.block_forward(block)for h in handles:h.remove()torch.cuda.empty_cache()self.block_transform(block, input_feat, self.input['kwargs'])else:self.block_transform(block)if not self.data_free and self.quant_out:self.model.replace_module_block(FakeQuantLinear,block,self.block_idx,self.get_replacement_params(mode='fake_quant', w_only=self.w_only, name=None),)self.set_non_linear_mode('fake_quant', block, False)self.input['data'] = self.block_forward(block)torch.cuda.empty_cache()def block_transform(self, block, input_feat, block_kwargs):logger.info(f'Start transform the {self.block_idx}-th block')subsets = self.model.get_subsets_in_block(block)if self.act_static:self.register_non_linear_qparams(block, input_feat)self.set_non_linear_mode('fake_quant', block, False)for index, subset in enumerate(subsets):logger.info(f'subset: {subset}')layers_dict = subset['layers']input_name = subset['input'][0]inspect_has_kwargs = subset['has_kwargs']if inspect_has_kwargs:if 'sub_keys' in subset:subset_kwargs = [{k: block_kwargs[0][v] for k, v in subset['sub_keys'].items()}]else:subset_kwargs = block_kwargselse:subset_kwargs = {}self.subset_transform(subset,input_feat,subset_kwargs,)if self.act_static:input_tensors = copy.deepcopy(input_feat[input_name])self.register_act_qparams(layers_dict, input_tensors)del input_tensorsif self.true_sequential and index != len(subsets) - 1:next_subset = subsets[index + 1]input_feat_subset = self.rehook_next_subset(block, subset, next_subset)input_feat.update(input_feat_subset)self.set_non_linear_mode('fake_quant', block, True)logger.info(f'End transform the {self.block_idx}-th block')def rehook_next_subset(self, block, subset, next_subset):self.subset_init(next_subset)self.model.replace_module_subset(FakeQuantLinear,block,subset,self.block_idx,self.get_replacement_params(mode='fake_quant', w_only=self.w_only, name=None),)input_feat_subset = defaultdict(list)input_feat_modules = next_subset['layers']handles = self.register_hooks(input_feat_modules, input_feat_subset)self.block_forward(block)for h in handles:h.remove()return input_feat_subsetdef collect_layers_weights(self, layers, tensor_parallelize_style=None):weights = []for _m in layers:if _m.weight.data.dtype == torch.float8_e4m3fn:fp8_scale = _m.weight_scale_invtmp_weight = weight_cast_to_bf16(_m.weight, fp8_scale).to(torch.bfloat16)weights.append(tmp_weight)else:weights.append(_m.weight)return weights@torch.no_grad()def register_kv_cache(self, block):attn_layers_dict = self.model.get_attn_in_block(block)attn_layer = attn_layers_dict[list(attn_layers_dict.keys())[0]]setattr(attn_layer, 'kvcache', self.kv_module)attn_layer.register_forward_pre_hook(self.kv_cache_input_hook(attn_layer), with_kwargs=True)@torch.no_grad()def register_non_linear_qparams(self, block, input_feat):layer_types = [('quant_attn', self.model.get_matmul_in_block),('quant_softmax', self.model.get_softmax_in_block, 'quant_attn'),('quant_act_fn', self.model.get_act_fn_in_block),]for mode, layer_func, *dependency in layer_types:if getattr(self, mode, True) and all(getattr(self, dep, True) for dep in dependency):layers_dict = layer_func(block)for name, layer in layers_dict.items():input_tensors = copy.deepcopy(input_feat[name])self.register_act_qparams({name: layer}, input_tensors)del input_tensors@torch.no_grad()def register_act_qparams(self, layers_dict, act_tensors):scales_list, zeros_list, qmin_list, qmax_list = (self.aquantizer.get_batch_tensors_qparams(act_tensors))world_size = int(os.environ['WORLD_SIZE'])for i, (scales, zeros, qmin, qmax) in enumerate(zip(scales_list, zeros_list, qmin_list, qmax_list)):scales = scales.cuda()dist.all_reduce(scales, op=dist.ReduceOp.SUM)scales = scales / world_sizefor name, layer in layers_dict.items():if not isinstance(layer, tuple(_LLMC_LINEAR_TYPES_ + _TRANSFORMERS_LINEAR_TYPES_)):continuelayer.register_buffer(f'buf_act_scales_{i}', scales)layer.register_buffer(f'buf_act_zeros_{i}', zeros.cuda())layer.register_buffer(f'buf_act_qmin_{i}', qmin.cuda())layer.register_buffer(f'buf_act_qmax_{i}', qmax.cuda())@torch.no_grad()def repeat_gqa_scales(self, scales):scales = scales.view(1, self.num_key_value_heads, self.head_dim)scales = torch.repeat_interleave(scales, dim=1, repeats=self.num_key_value_groups)return scales@torch.no_grad()def apply_scale(self, scales, prev_op, layers):assert (len(prev_op) == 1), 'Only support single prev_op. If multi prev_ops, code need to be updated.'if isinstance(prev_op[0], tuple(_LLMC_LINEAR_TYPES_ + _TRANSFORMERS_LINEAR_TYPES_)):assert len(layers) == 1logger.info('apply scale between fc and fc')self.scale_fc_fc(prev_op[0], layers[0], scales)elif isinstance(prev_op[0], tuple(_LLMC_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_)):logger.info('apply scale between ln and fc')self.scale_ln_fcs(prev_op[0], layers, scales)else:raise NotImplementedError(f'prev_op {type(prev_op[0])} not supported yet!')@torch.no_grad()def apply_shift(self, shifts, prev_op, layers):if shifts is None:returnassert (len(prev_op) == 1), 'Only support single prev_op. If multi prev_ops, code need to be updated.'if isinstance(prev_op[0], tuple(_LLMC_LINEAR_TYPES_ + _TRANSFORMERS_LINEAR_TYPES_)):assert len(layers) == 1self.shift_fc_fc(prev_op[0], layers[0], shifts)elif isinstance(prev_op[0], tuple(_LLMC_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_)):self.shift_ln_fcs(prev_op[0], layers, shifts)else:raise NotImplementedError(f'prev_op {type(prev_op[0])} not supported yet!')@torch.no_grad()def scale_fc_fc(self, fc1, fc2, scales):scales = scales.to(fc1.weight.device)if fc1.out_features == fc2.in_features * 3:logger.info('fc1.out_features == fc2.in_features * 3')num_heads = self.model.get_num_attention_heads()fc1.weight.t_()org_shape = fc1.weight.shapefc1.weight.data = fc1.weight.data.reshape(org_shape[0] * num_heads, 3, -1)value = fc1.weight.data[:, 2, :].reshape(org_shape[0], -1)fc1.weight.data[:, 2, :] = value.div(scales.view(-1)).reshape(fc1.weight[:, 2, :].shape)fc1.weight.data = fc1.weight.data.reshape(org_shape).t_()if hasattr(fc1, 'bias') and fc1.bias is not None:fc1.bias.data = fc1.bias.data.reshape(num_heads, 3, -1)value = fc1.bias.data[:, 2, :].reshape(-1)fc1.bias.data[:, 2, :] = value.div(scales.view(-1)).reshape(fc1.bias[:, 2, :].shape)fc1.bias.data = fc1.bias.data.reshape(-1)elif fc1.out_features == fc2.in_features * 2:logger.info('fc1.out_features == fc2.in_features * 2')fc1.weight.data[fc1.weight.data.shape[0] // 2:].div_(scales.view(-1, 1))if hasattr(fc1, 'bias') and fc1.bias is not None:fc1.bias.data[fc1.bias.data.shape[0] // 2:].div_(scales.view(-1))elif fc1.out_features == fc2.in_features:logger.info('fc1.out_features == fc2.in_features')assert fc1.out_features == fc2.in_featuresif hasattr(fc1, 'bias') and fc1.bias is not None:fc1.bias.div_(scales.view(-1))if fc1.weight.data.dtype == torch.float8_e4m3fn:fp8_scale = fc1.weight_scale_invtmp_weight_data = weight_cast_to_bf16(fc1.weight.data, fp8_scale).to(torch.bfloat16)tmp_weight_data.div_(scales.view(-1, 1))fc1.weight.data, fc1.weight_scale_inv.data = weight_cast_to_fp8(tmp_weight_data)else:fc1.weight.div_(scales.view(-1, 1))elif self.has_gqa and self.do_gqa_trans:if hasattr(fc1, 'bias') and fc1.bias is not None:fc1.bias.div_(scales.view(-1))fc1.weight.div_(scales.view(-1, 1))if fc1.out_features != fc2.in_features:logger.info('GQA scale this fc-fc.')scales = self.repeat_gqa_scales(scales)else:logger.error(f'fc1.out_features: {fc1.out_features}')logger.error(f'fc2.in_features: {fc2.in_features}')raise Exception('Can not scale this fc-fc.')if fc2.weight.data.dtype == torch.float8_e4m3fn:fp8_scale = fc2.weight_scale_invtmp_weight_data = weight_cast_to_bf16(fc2.weight.data, fp8_scale).to(torch.bfloat16)tmp_weight_data.mul_(scales.view(1, -1))fc2.weight.data, fc2.weight_scale_inv.data = weight_cast_to_fp8(tmp_weight_data)else:fc2.weight.mul_(scales.view(1, -1))@torch.no_grad()def shift_fc_fc(self, fc1, fc2, shifts):if fc1.out_features == fc2.in_features * 3:num_heads = self.model.get_model_config().to_dict().get('n_head', None)if hasattr(fc1, 'bias') and fc1.bias is not None:fc1.bias.data = fc1.bias.data.reshape(num_heads, 3, -1)value = fc1.bias.data[:, 2, :].reshape(-1)fc1.bias.data[:, 2, :] = (value - shifts).reshape(fc1.bias[:, 2, :].shape)fc1.bias.data = fc1.bias.data.reshape(-1)else:assert fc1.out_features == fc2.in_featuresif hasattr(fc1, 'bias') and fc1.bias is not None:fc1.bias.sub_(shifts)if hasattr(fc2, 'bias') and fc2.bias is not None:fc2.bias.add_(fc2.weight @ shifts)else:if hasattr(self, 'use_shift') and self.use_shift:del fc2.biasfc2.register_buffer('bias', fc2.weight @ shifts)@torch.no_grad()def shift_ln_fcs(self, ln, fcs, shifts):if not isinstance(fcs, list):fcs = [fcs]if self.model.has_bias():ln.bias.sub_(shifts)for fc in fcs:if self.model.has_bias():fc.bias.add_(fc.weight @ shifts)else:if hasattr(self, 'use_shift') and self.use_shift:del fc.biasfc.register_buffer('bias', fc.weight @ shifts)for p in ln.parameters():assert torch.isnan(p).sum() == 0for fc in fcs:for p in fc.parameters():assert torch.isnan(p).sum() == 0@torch.no_grad()def scale_ln_fcs(self, ln, fcs, scales):if not isinstance(fcs, list):fcs = [fcs]scales = scales.to(ln.weight.device)ln.weight.div_(scales)if hasattr(ln, 'bias') and ln.bias is not None:ln.bias.div_(scales)for fc in fcs:if fc.weight.data.dtype == torch.float8_e4m3fn:fp8_scale = fc.weight_scale_inv.datatmp_weight_data = weight_cast_to_bf16(fc.weight.data, fp8_scale).to(torch.bfloat16)tmp_weight_data.mul_(scales.view(1, -1))fc.weight.data, fc.weight_scale_inv.data = weight_cast_to_fp8(tmp_weight_data)else:fc.weight.mul_(scales.view(1, -1))for p in ln.parameters():assert torch.isnan(p).sum() == 0for fc in fcs:for p in fc.parameters():assert torch.isnan(p).sum() == 0def rotate_pre_layers(self, pre_layers, Q):for layer in pre_layers:if layer.weight.data.dtype == torch.float8_e4m3fn:layer.weight.data \= weight_cast_to_bf16(layer.weight.data,layer.weight_scale_inv.data).to(torch.bfloat16)dtype = layer.weight.dtypelayer.weight.data = torch.matmul(layer.weight.data.double(), Q).to(dtype)if hasattr(layer, 'weight_scale_inv'):layer.weight.data, layer.weight_scale_inv.data \= weight_cast_to_fp8(layer.weight.data)torch.cuda.empty_cache()def rotate_post_layers(self, post_layers, Q, exact_had=False):for layer in post_layers:if layer.weight.data.dtype == torch.float8_e4m3fn:layer.weight.data \= weight_cast_to_bf16(layer.weight.data,layer.weight_scale_inv.data).to(torch.bfloat16)dtype = layer.weight.dtypelayer.weight.data = torch.matmul(Q.T, layer.weight.data.double()).to(dtype)if exact_had and self.online_rotate:apply_exact_had_to_linear(layer, had_dim=-1, output=False)if hasattr(layer, 'bias') and layer.bias is not None:b = layer.bias.data.to(torch.float64)layer.bias.data = torch.matmul(Q.T, b).to(dtype)if hasattr(layer, 'weight_scale_inv'):layer.weight.data, layer.weight_scale_inv.data \= weight_cast_to_fp8(layer.weight.data)torch.cuda.empty_cache()def rotate_embeddings(self, Q):embeddings = self.model.get_embed_layers()assert len(embeddings) == 1for layer in embeddings:dtype = layer.weight.data.dtypeW = layer.weight.data.to(device=self.dev, dtype=torch.float64)layer.weight.data = torch.matmul(W, Q).to(device='cpu', dtype=dtype)def rotate_head(self, Q):heads = self.model.get_head_layers()for layer in heads:dtype = layer.weight.data.dtypeW = layer.weight.data.to(device=self.dev, dtype=torch.float64)layer.weight.data = torch.matmul(W, Q).to(device='cpu', dtype=dtype)def fuse_ln_fcs(self, ln, fcs):for fc in fcs:if fc.weight.data.dtype == torch.float8_e4m3fn:fc.weight.data \= weight_cast_to_bf16(fc.weight.data,fc.weight_scale_inv.data).to(torch.bfloat16)fc_dtype = fc.weight.dtypeif hasattr(ln, 'bias') and ln.bias is not None:W = fc.weight.data.double().clone()fc.weight.data = (fc.weight.data.double() * ln.weight.double()).to(fc_dtype)if hasattr(ln, 'bias') and ln.bias is not None:if fc.bias is None:fc.bias = torch.nn.Parameter(torch.zeros(fc.out_features, dtype=torch.float64))fc.bias.data = fc.bias.data.double().to(device=W.device) + torch.matmul(W, ln.bias.double())fc.bias.data = fc.bias.data.to(fc_dtype)if hasattr(fc, 'weight_scale_inv'):fc.weight.data, fc.weight_scale_inv.data = weight_cast_to_fp8(fc.weight.data)torch.cuda.empty_cache()def remove_mean_from_embed(self):embeddings = self.model.get_embed_layers()for layer in embeddings:W = layer.weight.data.double()layer.weight.data = (W - W.mean(dim=-1, keepdim=True)).to(layer.weight.data.dtype)def bake_mean_into_fc(self, fc):fc_dtype = fc.weight.dtypeW_ = fc.weight.data.double()fc.weight.data = W_ - W_.mean(dim=-2, keepdim=True)fc.weight.data = fc.weight.data.to(fc_dtype)if hasattr(fc, 'bias') and fc.bias is not None:b_ = fc.bias.data.double()fc.bias.data = b_ - b_.mean()fc.bias.data = fc.bias.data.to(fc_dtype)@torch.no_grad()def scaling_input(self, x, scales, is_gqa):if is_gqa:scales_tmp = self.repeat_gqa_scales(scales)else:scales_tmp = scalesif hasattr(self, '_bs') and self._bs < x.shape[0]:x_tmp = torch.empty_like(x)for i, batch in enumerate(x):batch_scale = scales_tmp.view(1, -1)x_tmp[i] = batch / batch_scaleelse:x_tmp = x / scales_tmp.view(1, -1)return x_tmp@torch.no_grad()def update_input_feat(self, scale, input_feat, layers_dict, is_gqa):for layer_name in layers_dict:for i in range(len(input_feat[layer_name])):inp = input_feat[layer_name][i]scale = scale.to(inp.device)input_feat[layer_name][i] = self.scaling_input(inp, scale, is_gqa)@torch.no_grad()def set_non_linear_mode(self, quant_format, module, mode):assert mode in [True, False]if quant_format != 'fake_quant':returnfor name, m in module.named_modules():if 'kvcache' in name:continueif getattr(m, 'calib', None) is not None:m.calib = modedef set_no_quant_layer(self):if self.ignored_speical_names:assert hasattr(self.model, 'block_name_prefix'), \'block_name_prefix missing in model'ignored_block_ids = []for item in self.ignored_block_ids:match = re.match(r'(\d+)-(\d+)', str(item))if match:start, end = int(match.group(1)), int(match.group(2))ignored_block_ids.extend(range(start, end + 1))else:ignored_block_ids.append(int(item))for idx, block in enumerate(self.blocks):for n, m in block.named_modules():if idx in ignored_block_ids and n in self.ignored_layer_names:m.register_buffer('no_quant', torch.tensor(True))else:layer_name = f'{self.model.block_name_prefix}.{idx}.{n}'if layer_name in self.ignored_speical_names:m.register_buffer('no_quant', torch.tensor(True))@torch.no_grad()def deploy(self, quant_format, keep_device=False):logger.info(f'-- deploy_{quant_format}_model start --')logger.info(f'quant_config : {self.quant_config}')module_mapping = {'origin_float': OriginFloatLinear,'fake_quant': EffcientFakeQuantLinear,'fake_quant_wo_kv': EffcientFakeQuantLinear,}module_mapping.update(_REALQUANT_LINEAR_MAP_)if quant_format not in module_mapping:raise NotImplementedError(f"Quant format '{quant_format}' is not implemented.")if self.mixed_precision and 'quant' in quant_format:self.set_no_quant_layer()module = module_mapping[quant_format]if self.modality == 'vision':self.model.replace_vision_module_all(module,self.get_replacement_params(mode=quant_format, w_only=self.w_only),keep_device=keep_device,)if self.modality == 'language':self.model.replace_language_module_all(module,self.get_replacement_params(mode=quant_format, w_only=self.w_only),keep_device=keep_device,)self.set_non_linear_mode(quant_format, self.model.model, False)if self.quant_kvcache:if quant_format == 'origin_float':self.kv_module.use_org_kv = Trueelif quant_format == 'fake_quant_wo_kv':self.kv_module.use_org_kv = Trueelif quant_format == 'fake_quant':self.kv_module.use_org_kv = Falseif self.act_static:self.kv_module.calib = Falseif self.model.mm_model is not None:logger.info(f'Now, the mm_model is: {self.model.mm_model}')logger.info(f'-- deploy_{quant_format}_model done --')@torch.no_grad()def copy_tokenizer(self, path):self.model.tokenizer.save_pretrained(path)logger.info('copy tokenizer done --')@torch.no_grad()def contiguous_params(self):if self.model.mm_model is not None:for name, param in self.model.mm_model.named_parameters():if not param.is_contiguous():param.data = param.data.contiguous()for name, param in self.model.mm_model.named_buffers():if not param.is_contiguous():param.data = param.data.contiguous()else:for name, param in self.model.model.named_parameters():if not param.is_contiguous():param.data = param.data.contiguous()for name, param in self.model.model.named_buffers():if not param.is_contiguous():param.data = param.data.contiguous()@torch.no_grad()def save_model(self, path):if int(os.environ['RANK']) != 0:returnself.contiguous_params()if self.config.model.type in ['Llava', 'InternVL2', 'Mllama', 'Qwen2vl']:self.model.vlm_model.language_model = self.model.get_model()self.model.vlm_model.save_pretrained(path)logger.info('save model done --')self.copy_tokenizer(path)elif self.config.model.type in ['Qwen2Audio']:self.model.alm_model.language_model = self.model.get_model()self.model.alm_model.save_pretrained(path)logger.info('save model done --')self.copy_tokenizer(path)elif self.config.model.type in ['InternOmni']:self.model.avlm_model.language_model = self.model.get_model()self.model.avlm_model.save_pretrained(path)logger.info('save model done --')self.copy_tokenizer(path)else:self.model.get_model().save_pretrained(path)logger.info('save model done --')self.copy_tokenizer(path)