欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 健康 > 美食 > 通义千问模型微调——swift框架

通义千问模型微调——swift框架

2025/3/28 7:51:17 来源:https://blog.csdn.net/W_extend/article/details/146375749  浏览:    关键词:通义千问模型微调——swift框架

1.创建环境

服务器CUDA Version: 12.2

conda create -n lora_qwen python=3.10 -y 
conda activate lora_qwen 
conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.1 -c pytorch -c nvidia -y

1.1环境搭建

本文使用swift进行微调,所以先下载swift,以及一些必要的packages

git clone https://github.com/modelscope/ms-swift.git
pip install transformers==4.49.0 
pip install pyav qwen_vl_utils 
pip install numpy==1.22.4 
pip install modelscope

1.2模型下载

使用modelscope下载指定模型,其中:

--model表示模型名称,可在modelscope官网找到

--local_dir代表模型下载地址

运行下面的命令,模型会下载到:./Qwen/Qwen2.5-VL-7B-Instruct目录下

modelscope download --model Qwen/Qwen2.5-VL-7B-Instruct --local_dir ./

下面脚本用于和模型进行对话,可以简单测试一下模型是否能够使用

CUDA_VISIBLE_DEVICES=1 swift infer --model_type qwen2_5_vl --ckpt_dir ./Qwen/Qwen2.5-VL-7B-Instruct

1.3数据集准备

下方是数据集格式,保存类型为.jsonl

[{"query": "OCR一下<image>","response": "朵拉童衣","images": ["datasets/lora_qwen/train/billboard_00001_010_朵拉童衣.jpg"]},{"query": "OCR一下<image>","response": "童衣雜貨舖","images": ["datasets/lora_qwen/train/billboard_00002_010_童衣雜貨舖.jpg"]},...
]

2.微调

2.1采用LoRA进行微调

对文件夹中之前下载的ms-swift-main/examples/train/multimodal/ocr.sh进行修改

# 20GB
CUDA_VISIBLE_DEVICES=0,1 \
MAX_PIXELS=1003520 \
swift sft \--model ./Qwen/Qwen2.5-VL-7B-Instruct \--model_type qwen2_5_vl \--dataset ./datatsets/train.jsonl \--val_dataset ./datatsets/val.jsonl \--train_type lora \--torch_dtype bfloat16 \--num_train_epochs 100 \--per_device_train_batch_size 1 \--per_device_eval_batch_size 1 \--learning_rate 1e-4 \--lora_rank 64 \--lora_alpha 64 \--target_modules all-linear \--freeze_vit true \--gradient_accumulation_steps 16 \--eval_steps 50 \--save_steps 50 \--save_total_limit 10 \--logging_steps 5 \--max_length 2048 \--output_dir output \--warmup_ratio 0.05 \--dataloader_num_workers 4
  • 常用参数解释:

--model:原模型的权重地址--dataset:训练集的数据地址--val_dataset:验证集的数据地址--train_type:全参数训练(full) 或 LoRA微调训练(lora)--num_train_epochs:总共要训练的轮数--per_device_train_batch_size:训练阶段batchsize大小,根据显存大小来设置--per_device_eval_batch_size:验证阶段batchsize大小,根据显存大小来设置--learning_rate:学习率,一般设为0.0001或0.00001--target_modules:需要做微调的目标模块,all-linear表示所有的线形层,也就是Attention和FeedForward层--freeze_vit:一般设为true,不微调视觉编码器,只微调LLM部分

2.2使用Transformer进行推理

import os
import re
import torch
from PIL import Imagefrom datasets import Dataset
from modelscope import AutoTokenizer
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
from transformers import (AutoProcessor,Qwen2_5_VLForConditionalGeneration,Trainer, TrainingArguments,Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq,
)
from qwen_vl_utils import process_vision_inforewrite_print = print
def print(save_txt, *arg, **kwargs):rewrite_print(*arg, **kwargs)rewrite_print(*arg, **kwargs, file=open(save_txt, "a+", encoding="utf-8"))def process_func(model, img_path, input_content):messages = [{"role": "user","content": [{"type": "image", "image": img_path},{"type": "text", "text": input_content},],}]text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)image_inputs, video_inputs = process_vision_info(messages)inputs = processor(text=[text],images=image_inputs,videos=video_inputs,padding=True,return_tensors="pt",)generated_ids = model.generate(**inputs, max_new_tokens=512)generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)print(save_txt_path, img_path)print(save_txt_path, output_text[0])print(save_txt_path, '\n')def get_lora_model(model_path, lora_model_path):model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype=torch.bfloat16)model.enable_input_require_grads()config = LoraConfig(task_type=TaskType.CAUSAL_LM,target_modules="model\..*layers\.\d+\.(self_attn\.(q_proj|k_proj|v_proj|o_proj)|mlp\.(gate_proj|up_proj|down_proj))",inference_mode=True,r=64,lora_alpha=64,lora_dropout=0.05,bias="none",)peft_model = PeftModel.from_pretrained(model, model_id=lora_model_path, config=config)return peft_modelif __name__ == '__main__':save_txt_path = 'log.txt'model_path = "./Qwen2.5-VL-7B-Instruct"lora_model_path = "./output/v2-20250228-202446/checkpoint-900"lora_model = get_lora_model(model_path, lora_model_path)processor = AutoProcessor.from_pretrained(model_path)img_path = "图片路径"prompt = "OCR一下"process_func(lora_model, img_path, prompt)

3.实验参数情况

模型微调显存:30G左右(主要看数据集,图片越大,prompt,answer越多,占用显存越多);

模型微调后推理:20G左右;

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词