欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 新闻 > 焦点 > 基于LoRA和AdaLoRA微调Qwen1.5-0.5B-Chat

基于LoRA和AdaLoRA微调Qwen1.5-0.5B-Chat

2025/1/17 15:57:35 来源:https://blog.csdn.net/weixin_44402973/article/details/140823647  浏览:    关键词:基于LoRA和AdaLoRA微调Qwen1.5-0.5B-Chat

 本文只开放基于LoRA和AdaLoRA微调代码,具体技术可以自行学习。

Qwen1.5-0.5B-Chat权重路径:https://huggingface.co/Qwen/Qwen1.5-0.5B

数据集路径:https://github.com/DB-lost/self-llm/blob/master/dataset/huanhuan.json

1. 知识点

LoRA, AdaLoRA技术

具体技术可以去看论文

Python关键包版本【我使用python版本是:3.10.14】

torch  2.2.2

transformers  4.39.3

peft                 0.9.0

accelerate       0.29.3

2. 项目目录

data 存放训练数据

models/Qwen1.5-0.5B-Chat 存放 Qwen1.5-0.5B-Chat权重

output: 存放训练过程保存的模型权重

inference.py 推理文件

train_adalora.py AdaLoRA 微调代码

train_lora.py LoRA 微调代码

merge.py LoRA权重和Qwen1.5-0.5B-Chat权重合并脚本

其他文件忽略

3. LoRA微调代码

train_adalora.py 具体代码:

# coding:utf-8
"""LoRA Finetune Qwen1.5-0.5B-Chat"""from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, TrainingArguments, Trainer
from torch.utils.data import Dataset
import torch
from peft import LoraConfig, TaskType, get_peft_model
from typing import Dict
import transformers
import json
from transformers.trainer_pt_utils import LabelSmootherIGNORE_TOKEN_ID = LabelSmoother.ignore_indexmax_len = 512
data_json = json.load(open("./data/huanhuan.json", 'r', encoding='utf-8'))
train_json = []
lazy_preprocess = True
gradient_checkpointing = Truedef print_model_allarguments_name_dtype(model):for n, v in model.named_parameters():if v.requires_grad:print(f"trainable model arguments:{n}--{v.dtype}--{v.shape}")else:print(f"not trainable model arguments:{n}--{v.dtype}--{v.shape}")config = AutoConfig.from_pretrained("./models/Qwen1.5-0.5B-Chat",trust_remote_code=True)# kv cache 在推理的时候才用,训练时候不用
config.use_cache = Falsetokenizer = AutoTokenizer.from_pretrained("./models/Qwen1.5-0.5B-Chat",model_max_length=max_len,padding_side="right",use_fast=False
)model = AutoModelForCausalLM.from_pretrained("./models/Qwen1.5-0.5B-Chat",torch_dtype=torch.bfloat16,device_map="auto",config=config,low_cpu_mem_usage=True
)print("Original Model: ")
print_model_allarguments_name_dtype(model)config = LoraConfig(task_type=TaskType.CAUSAL_LM,target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],r=64, # Lora 秩lora_alpha=16, # Lora alaph,具体作用参见 Lora 原理lora_dropout=0.05, # Dropout 比例bias='none'
)
model = get_peft_model(model, config)
print_model_allarguments_name_dtype(model)
model.print_trainable_parameters()# 不保存激活值
if gradient_checkpointing:model.enable_input_require_grads()def preprocess(sources,tokenizer: transformers.PreTrainedTokenizer,max_len: int,system_message: str = "You are a helpful assistant."
) -> Dict:roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}im_start = tokenizer('<|im_start|>').input_ids[0]im_end = tokenizer('<|im_end|>').input_ids[0]nl_tokens = tokenizer('\n').input_ids_system = tokenizer('system').input_ids + nl_token

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com