概述

Github官方地址:GLM-4

网上已经有很多关于微调的文章,介绍各种方式下的使用,这里不会赘述。我个人比较关心的是微调时的loss计算逻辑,这点在很多的文章都不会有相关的描述,因为大多数人都是关心如何使用之类的应用层,而不是其具体的底层逻辑,当然咱也说不清太底层的计算。

可了解其它loss计算的文章:

再聊多轮对话微调训练格式与长序列训练

聊聊ChatGLM2与ChatGLM3微调多轮对话的设计逻辑及源码分析

聊聊大模型多轮对话的训练及优化

微调

微调格式:

[
{
"messages": [
{
"role": "system",
"content": "<system prompt text>",
"tools": [
{
"name": "<tool name>",
"args": {
"<arg name>": "<arg value>"
}
}
]
},
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
},
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
},
{
"role": "observation",
"content": "<observation prompt text>"
},
{
"role": "assistant",
"content": "<assistant response observation>"
},
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
}
]
}
]

微调源码地址:finetune.py

Loss计算代码:

def process_batch(
batch: Mapping[str, Sequence],
tokenizer: PreTrainedTokenizer,
max_input_length: int,
max_output_length: int,
) -> dict[str, list]:
batched_conv = batch['messages']
batched_input_ids = []
batched_labels = []
# batched_conv 是一个数组
# conv 是数组内的单个 message
for conv in batched_conv:
input_ids = [151331, 151333]
loss_masks = [False, False]
# conv 是数组内的单个 message
# message 是 单个role json对象
for message in conv:
message = process_message(message)
# 设置 mask 掩码,只有system,user,observation不参与mask计算,其余的角色参与计算
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
# 获取 input 文本的数字表示(ids)
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
# 计算整句的 mask
new_loss_masks = [loss_mask_val] * len(new_input_ids)
# 拼接message中的每段json
input_ids += new_input_ids
# 拼接message中每段json对应的mask
loss_masks += new_loss_masks
# 追加结尾的 token id
input_ids.append(tokenizer.eos_token_id)
loss_masks = [False, *loss_masks]
labels = []
for input_id, mask in zip(input_ids, loss_masks):
if mask:
# 添加到label,计算loss
labels.append(input_id)
else:
# -100 不处理,即ignore_index
labels.append(-100)
max_length = max_input_length + max_output_length + 1
# 截断
batched_input_ids.append(input_ids[:max_length])
batched_labels.append(labels[:max_length])
return {'input_ids': batched_input_ids, 'labels': batched_labels}

注释在代码中已经写明。process_batch方法用于将输入转换为ids,并计算mask(用于Loss计算)。而该方法的调用是在数据集的遍历处理中,即如下所示:

tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
data_manager = DataManager(data_dir, ft_config.data_config)
# 数据集拆分遍历
train_dataset = data_manager.get_dataset(
Split.TRAIN,
functools.partial(
process_batch,
tokenizer=tokenizer,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
print('train_dataset:', train_dataset)

Loss计算如下图所示:

总结

相比较于之前的ChatGLM版本,GLM4开源版本的多轮对话loss计算更恰当且效率也会更高;在其它的开源模型/微调框架中早已支持该种loss计算,如InternLM、XTuner、Firefly等。对于loss格式的类别,可参考XTuner的官方文档说明:dataset_format.md

更多大模型相关的文章,请上个人公众号查阅:

聊聊GLM-4-9B开源模型的微调loss计算的更多相关文章

  1. GoogLeNet模型的微调

    我从零开始训练了GoogLeNet模型. 但它没有给我带来希望的结果. 作为替代,我想对我的数据集中的GoogLeNet模型进行微调. 有谁知道我应该遵循什么步骤? 采纳答案: 假设你正在尝试做图像分 ...

  2. R语言广义线性模型(GLM)、全子集回归模型选择、检验分析全国风向气候数据|附代码数据

    全文链接:http://tecdat.cn/?p=30914 最近我们被客户要求撰写关于广义线性模型(GLM)的研究报告,包括一些图形和统计输出. 我们正和一位朋友讨论如何在R软件中用GLM模型处理全 ...

  3. 从开源模型、框架到自研,声网 Web 端虚拟背景算法正式发布

    根据研究发现,在平均 38 分钟的视频会议里面,大概会有 13 分钟左右的时间用于处理和干扰相关的事情.同时研究也表明在参加在线会议的时候,人们更加倾向于语音会议,其中一个关键原因就是大家不希望个人隐 ...

  4. [深度学习] 经典深度学习模型及其微调(Caffe)总结

    目录 经典模型 Caffe预训练模型 经典模型 LeNet https://blog.csdn.net/kaido0/article/details/53161684 AlexNet https:// ...

  5. LSF-SCNN:一种基于 CNN 的短文本表达模型及相似度计算的全新优化模型

    欢迎大家前往腾讯云社区,获取更多腾讯海量技术实践干货哦~ 本篇文章是我在读期间,对自然语言处理中的文本相似度问题研究取得的一点小成果.如果你对自然语言处理 (natural language proc ...

  6. [Pytorch]深度模型的显存计算以及优化

    原文链接:https://oldpan.me/archives/how-to-calculate-gpu-memory 前言 亲,显存炸了,你的显卡快冒烟了! torch.FatalError: cu ...

  7. 聊聊GIS中的坐标系|再版 详细定义、计算及高程系统

    本篇讲坐标系统的详细定义,有关坐标系的变换公式,以及简单说说高程坐标系统. 本文约6000字,阅读时间建议45分钟.硬内容比较多,如有疏漏错误请指出,建议有兴趣的朋友进一步阅读. 作者:博客园/B站/ ...

  8. css盒模型宽高混合计算calc

    例如: .element{ width:calc(expression); } 兼容性:在IE9+.FF4.0+.Chrome19+.Safari6+都得到了较好支持,但是在移动端的支持不是很好. 其 ...

  9. DL开源框架Caffe | 模型微调 (finetune)的场景、问题、技巧以及解决方案

    转自:http://blog.csdn.net/u010402786/article/details/70141261 前言 什么是模型的微调?   使用别人训练好的网络模型进行训练,前提是必须和别人 ...

  10. 聊聊同步、异步、阻塞、非阻塞以及IO模型

    前言 在使用Netty改造手写RPC框架的时候,需要给大家介绍一些相关的知识,这样很多东西大家就可以看明白了,手写RPC是一个支线任务,后续重点仍然是Kubernetes相关内容. 阻塞与非阻塞 同步 ...

随机推荐

  1. 阿里巴巴开源大规模稀疏模型训练/预测引擎DeepRec

    ​简介:经历6年时间,在各团队的努力下,阿里巴巴集团大规模稀疏模型训练/预测引擎DeepRec正式对外开源,助力开发者提升稀疏模型训练性能和效果. ​ 作者 | 烟秋 来源 | 阿里技术公众号 经历6 ...

  2. 每次都需要解释大量指令?使用 PolarDB-X 向量化引擎

    简介: 向量化引擎为PolarDB-X的表达式计算带来了显著的性能提升. 介绍 PolarDB-X是阿里巴巴自研的云原生分布式数据库,采用了计算-存储分离的架构,其中计算节点承担着大量的表达式计算任务 ...

  3. [Go] go build 和 go install 的区别

    $ go build 源文件及其包依赖 编译成二进制. install 不仅执行build过程 而且会把编译的二进制放到 $GOPATH/bin/,包放到 $GOPATH/pkg/ Link:http ...

  4. [K8s] Kubernetes 集群部署管理方式对比, kops, kubeadm, kubespray

    kops 是官方出的 Kubernetes Operations,生产级 K8s 的安装.升级和管理. 可以看做是适用于集群的 kubectl,kops 可帮助您从命令行创建,销毁,升级和维护生产级, ...

  5. [FAQ] Solidity 合约销毁 ?

    仅创建者可以销毁合约的示例: address public owner; // When deploy contract constructor() public { owner = msg.send ...

  6. WPF 推荐一个剪贴板内容查看工具

    本文来安利大家一个好用的 Windows 剪贴板的内容查看工具 这是在 GitHub 上完全免费开源的应用,由 walterlv 开发的应用,详细请看 https://github.com/walte ...

  7. dotnet 使用 IndentedTextWriter 辅助生成代码时生成带缩进的内容

    随着源代码生成的越来越多的应用,自然也遇到了越来越多开发上的坑,例如源代码的缩进是一个绕不过去的问题.如果源代码生成是人类可见的代码,我期望生成的代码最好是比较符合人类编写代码的规范.为了能让人类在阅 ...

  8. 8.prometheus监控--监控Mysql8.0

    一.环境搭建 docker-compose安装mysql mkdir /data/mysql -p cd /data/mysql cat > docker-compose.yaml <&l ...

  9. DE10-Lite输入/出高/低电平说明

    DE10-Lite输入/出高/低电平说明 DE10-Lite实验板上有一些设备可以输入/出高/低电平,说明如下: HEX 7-segment LED displays (active low)共阳极 ...

  10. 【爬虫+情感判定+Top10高频词+词云图】"乌克兰"油管热评python舆情分析

    目录 一.分析背景 二.整体思路 三.代码讲解 3.1 爬虫采集 3.2 情感判定 3.3 Top10高频词 3.4 词云图 四.得出结论 五.同步视频演示 六.附完整源码 一.分析背景 乌克兰局势这 ...