基于 P-Tuning v2 进行 ChatGLM2-6B 微调实践
微调类型简介
1. SFT监督微调:适用于在源任务中具有较高性能的模型进行微调,学习率较小。常见任务包括中文实体识别、语言模型训练、UIE模型微调。优点是可以快速适应目标任务,但缺点是可能需要较长的训练时间和大量数据。
2. LoRA微调:通过高阶矩阵秩的分解减少微调参数量,不改变预训练模型参数,新增参数。优点是减少了微调的参数量和成本,同时能达到与全模型微调相近的效果。
3. P-tuning v2微调:引入了prefix-tuning的思想,每一层都加入了prefix,并采用了多任务学习。解决了P-tuning v1中序列标注任务效果不佳和普遍性差的问题。其参数对象是各层的prefix。优点是适用于多任务学习,但在自然语言理解任务上表现可能不佳。
4. Freeze微调:主要用于大语言模型的微调,后几层网络提取语义特征,前几层提取文本表层特征。优点是参数高效,适用于提取特定层次的特征。
综上所述,各种微调方法适用于不同的场景和任务。SFT监督微调适用于快速适应目标任务,LoRA适用于减少参数量和成本,P-tuning v2适用于多任务学习,而Freeze适用于提取特定层次的特征。
1.下载glm2训练脚本
git clone https://github.com/THUDM/ChatGLM2-6B.git
2.然后使用 pip 安装依赖
pip install -r requirements.txt -i https://pypi.douban.com/simple/
运行行微调除 ChatGLM2-6B 的依赖之外,还需要安装以下依赖
pip install rouge_chinese nltk jieba datasets transformers[torch] -i https://pypi.douban.com/simple/
3.下载样例数据或者自己构建样例
{"content": "类型#裙_材质#网纱_颜色#粉红色_图案#线条_图案#刺绣_裙腰型#高腰_裙长#连衣裙_裙袖长#短袖_裙领型#圆领", "summary": "这款连衣裙,由上到下都透出女性魅力,经典圆领型,开口度恰好,露出修长的脖颈线条,很是优雅气质,短袖设计,这款对身材有很好的修饰作用,穿起来很女神;裙身粉红色花枝重工刺绣,让人一眼难忘!而且在这种网纱面料上做繁复图案的绣花,是很考验工艺的,对机器的要求会更高,更加凸显我们的高品质做工;"}
可以根据以上格式,构建自己的训练样本,我们可以用一些行业生产数据,如会话记录对模型进行训练,
官方示例数据下载:
https%3A//cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/%3Fdl%3D1
4.根据自己的环境修改训练脚本中对应的文件地址
PRE_SEQ_LEN=128 #序列的预设长度为128
LR=2e-2 #学习率为0.02
NUM_GPUS=4 #用几颗GPU进行训练
torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS main.py \
--do_train \
--train_file /export/data/train.json \ #设置训练数据文件的目录
--validation_file /export/data/validation.json \ #设置验证文件的目录
--preprocessing_num_workers 10 \
--prompt_column content \
--response_column summary \
--overwrite_cache \
--model_name_or_path /opt/tritonserver/python_backend/models/chatglm2-6b \ #模型目录
--output_dir /export/models/trained-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \ #训练后的模型目录
--overwrite_output_dir \
--max_source_length 64 \
--max_target_length 128 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 16 \
--predict_with_generate \
--max_steps 3000 \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate $LR \
--pre_seq_len $PRE_SEQ_LEN \
--quantization_bit 4
5.开始训练吧
sh train.sh
训练中
快要训练完成
6.训练完成
Training completed. Do not forget to share your model on huggingface.co/models =)
{'train_runtime': 4598.3849, 'train_samples_per_second': 41.754, 'train_steps_per_second': 0.652, 'train_loss': 0.1287700497706731, 'epoch': 2400.0}
100%|██████████| 3000/3000 [1:16:37<00:00, 1.53s/it]
***** train metrics *****
epoch = 2400.0
train_loss = 0.1288
train_runtime = 1:16:38.38
train_samples = 24
train_samples_per_second = 41.754
train_steps_per_second = 0.652
7.部署训练后的模型
在 P-tuning v2 训练时模型只保存 PrefixEncoder 部分的参数,所以在推理时需要同时加载原 ChatGLM-6B 模型以及 PrefixEncoder 的权重
model_path = "/opt/tritonserver/python_backend/models/chatglm2-6b"
model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join('/opt/train/trained-chatglm2-6b-pt-128-1e-4/checkpoint-3000', "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
8.过程中遇到的问题
8.1 微调后无法应答
PRE_SEQ_LEN=128
LR=2e-2
NUM_GPUS=1
torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS main.py \
--do_train \
--train_file train.json \
--validation_file dev.json \
--preprocessing_num_workers 10 \
--prompt_column content \
--response_column summary \
--overwrite_cache \
--model_name_or_path /opt/tritonserver/python_backend/models/chatglm2-6b \
--output_dir trained-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \
--overwrite_output_dir \
--max_source_length 64 \
--max_target_length 64 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--predict_with_generate \
--max_steps 3000 \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate $LR \
--pre_seq_len $PRE_SEQ_LEN \
使用官方脚本中的学习率设置 LR=2e-2 (0.02)
模型出现无法应答,灾难性遗忘,基本上原有的知识都遗忘了,无法应答普通提问 , 比如"你好.."
于是尝试使用 LR=1e-4 (0.0001) 进行训练
"1e-4" 表示 1 乘以 10 的 -4 次方,即等于 0.0001,"2e-2" 表示 2 乘以 10 的 -2 次方,即等于 0.02。
模型最终可以应答.
镜像问题:
https://github.com/THUDM/ChatGLM-6B/issues/1148
8.2 关于学习率:
我理解是,学习率大小像看书看的粗细,看的太粗就学的快(收敛快)但啥也学不到,
学习率是影响模型训练效果的重要参数。过大的学习率可能导致模型不稳定,过小的学习率则可能导致训练速度变慢。因此,需要反复试验,找到合适的学习率。
学习率(lr)表示每次更新权重参数的尺度(步长),ΔΘ=Θ0−(lr)(loss′)。
学习率与batch_size在权重更新中的关系
学习率(lr)直观可以看出lr越大,权重更新的跨度越大,模型参数调整变化越快。
batch_size对模型的影响,在于模型每次更新时,计算梯度是计算整个Batch的平均梯度,
即权重更新公式中的loss′=1batchsize(lossbatch)′, 整合就是 ΔΘ=Θ0−(lr)1batchsize(lossbatch)′ 。即lr与batch_size共同影响模型更新。
作者:京东科技 杨建
来源:京东云开发者社区 转发请注明来源
基于 P-Tuning v2 进行 ChatGLM2-6B 微调实践的更多相关文章
- 【快报】基于K2 BPM的新一代协同办公门户实践交流会
2014年2月28日,“基于BPM的新一代协同办公门户”用户实践交流活动在深圳金茂JW万豪酒店3楼Meet Room IV举办.本次会议由K2携手微软共同举办,邀请到的参会企业都是K2 的BPM老客户 ...
- 基于Sql Server 2008的分布式数据库的实践(五)
原文 基于Sql Server 2008的分布式数据库的实践(五) 程序设计 ------------------------------------------------------------- ...
- 基于Sql Server 2008的分布式数据库的实践(四)
原文 基于Sql Server 2008的分布式数据库的实践(四) 数据库设计 1.E-R图 2.数据库创建 Win 7 1 create database V3 Win 2003 1 create ...
- 基于Sql Server 2008的分布式数据库的实践(三)
原文 基于Sql Server 2008的分布式数据库的实践(三) 配置PHP 1.打开PHP配置文件,找到extension=php_mssql.dll,将前面的注释符号去掉 2.找到mssql.s ...
- 基于Sql Server 2008的分布式数据库的实践(二)
原文 基于Sql Server 2008的分布式数据库的实践(二) 从Win7连接Win2003的Sql Server 2008 1.新建链接服务器链接到Win2003的Sql Server 2008 ...
- 基于Sql Server 2008的分布式数据库的实践(一)
原文 基于Sql Server 2008的分布式数据库的实践(一) 配置Sql Server 2008(Win7) 1.打开SQL server2012,使用windows身份登录 2.登录后,右键选 ...
- 【公开课】【阿里在线技术峰会】魏鹏:基于Java容器的多应用部署技术实践
对于公开课,可能目前用不上这些,但是往往能在以后想解决方案的时候帮助到我.以下是阿里对公开课的整理 摘要: 在首届阿里巴巴在线峰会上,阿里巴巴中间件技术部专家魏鹏为大家带来了题为<基于Java容 ...
- 滴滴出行基于RocketMQ构建企业级消息队列服务的实践
小结: 1. https://mp.weixin.qq.com/s/v6NM3UgX-qTI7yO1QPCJrw 滴滴出行基于RocketMQ构建企业级消息队列服务的实践 原创: 江海挺 阿里巴巴中间 ...
- Python 基于Python从mysql表读取千万数据实践
基于Python 从mysql表读取千万数据实践 by:授客 QQ:1033553122 场景: 有以下两个表,两者都有一个表字段,名为waybill_no,我们需要从tl_waybill_b ...
- 苏宁基于Spark Streaming的实时日志分析系统实践 Spark Streaming 在数据平台日志解析功能的应用
https://mp.weixin.qq.com/s/KPTM02-ICt72_7ZdRZIHBA 苏宁基于Spark Streaming的实时日志分析系统实践 原创: AI+落地实践 AI前线 20 ...
随机推荐
- 屏蔽CSDN百度广告
最近在查询一些技术问题访问到CSDN时一直弹一些令人作恶的广告,说个特别的广告,脱发广告,特别有针对性程序员同胞们的共性问题,不过还是特别恶心,百度了一下,大家也特别反感,CSDN你真这么缺钱?废话不 ...
- LLaMA模型指令微调 字节跳动多模态视频大模型 Valley 论文详解
Valley: Video Assistant with Large Language model Enhanced abilitY 大家好,我是卷了又没卷,薛定谔的卷的AI算法工程师「陈城南」~ 担 ...
- 【Python】Locust持续优化:InfluxDB与Grafana实现数据持久化与可视化分析
前言 在进行性能测试时,我们需要对测试结果进行监控和分析,以便于及时发现问题并进行优化. Locust在内存中维护了一个时间序列数据结构,用于存储每个事件的统计信息. 这个数据结构允许我们在Chart ...
- CF1794C Scoring Subsequences题解
文中 \(a\) 为题目中给的 \(a\). 如果我们要求 \(a_1, a_2, a_3, \dots, a_m\) 的结果, 那么我们可以把 \(a\) 数组从后往前依次除以 \(i\),\(i\ ...
- .Net8的AOT引导程序BootStrap
前言 .Net8的本地预编机器码AOT,它几乎进行了100%的自举.微软为了摆脱C++的钳制,做了很多努力.也就是代码几乎是用C#重写,包括了虚拟机,GC,内存模型等等.而需要C++做的,也就仅仅是引 ...
- List子集合__小记
List集合的子实现类的特点: ArrayList: 底层数据结构是数组的形式,满足数组结构的特点:查询快,增删慢 从线程安全问题来看:线程不安全的,不同步,执行效率高 Vector: 底层数据结构是 ...
- linux内核编译体验篇(一)
文章目录 一. 准备环境 二. 获取内核源码 三. 交叉编译工具链的配置 1. 博友们常用安装方法链接 2. 公司常用的交叉工具链使用方法 四. 内核解压以及如何打补丁 五. 内核基本配置 1. 编译 ...
- http方式内网搭建CDH6.3.2与部分组件优化
Cloudera_Manager_6.3.2安装配置文档 1. 配置准备 Cloudera Manager (简称CM)用于管理CDH6集群,可进行节点安装.配置.服务配置等,提供Web窗口界面提高了 ...
- FPGA移位加三法
介绍 BCD码 BCD码的英文全称是Binary-Coded Decimal,简称BCD,按字面解释是二进制十进制代码,是一种二进制的数字编码形式. 常见的BCD码有8421BCD码,2421BCD ...
- python excel 07版本转换为03版本
需要安装pywin32模块 pip install pywin32 主程序: import win32com.client as win32 import os.path import glob cl ...