pytorch入门 - 微调huggingface大模型
在自然语言处理(NLP)领域,预训练语言模型如BERT已经成为主流。HuggingFace提供的Transformers库让我们能够方便地使用这些强大的模型。
本文将详细介绍如何使用PyTorch微调HuggingFace上的BERT模型,包括原理讲解、代码实现和逐行解释。
1. 微调原理
1.1 什么是微调(Fine-tuning)
微调是指在预训练模型的基础上,针对特定任务进行少量训练的过程。
BERT等预训练模型已经在大规模语料上学习了通用的语言表示能力,通过微调,我们可以将这些知识迁移到特定任务上。
1.2 BERT模型结构
BERT模型主要由以下部分组成:
- 嵌入层(Embedding Layer)
- 多层Transformer编码器
- 池化层(Pooler)
在微调时,我们通常会在BERT的输出上添加一个任务特定的分类头(Classification Head)。
1.3 神经元数量计算
在我们的模型中,分类头是一个全连接层,其神经元数量计算如下:
输入维度:768 (BERT最后一层隐藏状态维度)
输出维度:2 (二分类任务)
参数数量 = (输入维度 × 输出维度) + 输出维度(偏置项)
= (768 × 2) + 2 = 1538
2. 代码实现
2.1 数据集处理 (finetuing_my_dataset.py)
from datasets import load_dataset, load_from_disk # 导入HuggingFace的数据集加载工具
from torch.utils.data import Dataset # 导入PyTorch的数据集基类
class MydataSet(Dataset): # 自定义数据集类,继承自PyTorch的Dataset
def __init__(self, split): # 初始化方法,split指定数据集划分
save_path = r".\cache\datasets\lansinuote\ChnSentiCorp\train" # 数据集路径
self.dataset = load_from_disk(save_path) # 从磁盘加载数据集
# 根据split参数选择数据集划分
if split == "train":
self.dataset = self.dataset["train"]
elif split == "test":
self.dataset = self.dataset["test"]
elif split == "validation":
self.dataset = self.dataset["validation"]
else:
raise ValueError("split must be one of 'train', 'test', or 'validation'")
def __len__(self): # 返回数据集大小
return len(self.dataset)
def __getitem__(self, idx): # 获取单个样本
return self.dataset[idx]["text"], self.dataset[idx]["label"] # 返回文本和标签
if __name__ == "__main__": # 测试代码
dataset = MydataSet(split="validation") # 创建验证集实例
for i in range(50): # 打印前50个样本
print(dataset[i])
print(dataset) # 打印数据集信息
print(dataset[0]) # 打印第一个样本
2.2 模型定义 (finetuing_net.py)
from transformers import BertModel # 导入BERT模型
import torch # 导入PyTorch
# 设置设备(GPU或CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载预训练BERT模型
cache_dir = "./cache/bertbasechinese" # 缓存目录
pretrained = BertModel.from_pretrained(
"bert-base-chinese", # 中文BERT模型
cache_dir=cache_dir
).to(device) # 移动到指定设备
class Model(torch.nn.Module): # 自定义模型类
def __init__(self):
super(Model, self).__init__() # 调用父类初始化
# 定义分类头: 768维输入, 2维输出(二分类)
self.fc = torch.nn.Linear(768, 2)
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
# 冻结BERT参数,不计算梯度
with torch.no_grad():
outputs = pretrained(
input_ids=input_ids, # 输入token IDs
attention_mask=attention_mask, # 注意力掩码
token_type_ids=token_type_ids, # 句子类型IDs
)
# 使用[CLS]标记的隐藏状态作为分类特征
cls_output = outputs.last_hidden_state[:, 0] # 形状(batch_size, 768)
logits = self.fc(cls_output) # 通过分类头
out = logits.softmax(dim=-1) # softmax归一化
return out
2.3 训练过程 (finetuing_train.py)
import torch
from finetuing_my_dataset import MydataSet
from torch.utils.data import DataLoader
from finetuing_net import Model
from transformers import BertTokenizer
from torch.optim import AdamW
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCH = 100 # 训练轮数
# 加载分词器
token = BertTokenizer.from_pretrained(
"bert-base-chinese",
cache_dir="./cache/tokenizer/bert-base-chinese"
)
def collate_fn(batch): # 数据批处理函数
sentes = [item[0] for item in batch] # 提取文本
labels = [item[1] for item in batch] # 提取标签
# 使用分词器处理文本
data = token.batch_encode_plus(
sentes,
truncation=True, # 截断过长的文本
max_length=350, # 最大长度350
padding=True, # 自动填充
return_tensors="pt", # 返回PyTorch张量
return_length=True, # 返回长度信息
)
# 提取编码后的数据
input_ids = data["input_ids"]
attention_mask = data["attention_mask"]
token_type_ids = data["token_type_ids"]
labels = torch.LongTensor(labels) # 转换标签为LongTensor
return input_ids, attention_mask, token_type_ids, labels
# 创建训练数据集和数据加载器
train_dataset = MydataSet(split="train")
train_dataloader = DataLoader(
train_dataset,
batch_size=32, # 批大小32
shuffle=True, # 打乱数据
drop_last=True, # 丢弃最后不完整的批次
collate_fn=collate_fn, # 使用自定义批处理函数
)
if __name__ == "__main__":
model = Model().to(device) # 初始化模型
optimizer = AdamW(model.parameters(), lr=1e-5) # 优化器
loss_func = torch.nn.CrossEntropyLoss() # 损失函数
model.train() # 设置为训练模式
for epoch in range(EPOCH): # 训练循环
for step, (input_ids, attention_mask, token_type_ids, labels) in enumerate(train_dataloader):
# 移动数据到设备
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
token_type_ids = token_type_ids.to(device)
labels = labels.to(device)
# 前向传播
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
loss = loss_func(outputs, labels) # 计算损失
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 每5步打印训练信息
if step % 5 == 0:
out = outputs.argmax(dim=1) # 预测类别
acc = (out == labels).sum().item() / len(labels) # 计算准确率
print(f"Epoch: {epoch + 1}/{EPOCH}, Step: {step + 1}/{len(train_dataloader)}, Loss: {loss.item():.4f}, Acc: {acc:.4f}")
# 保存模型
torch.save(model.state_dict(), f"./model/{epoch}finetuned_model_new.pth")
print(epoch, "参数保存成功")
2.4 测试过程 (finetuing_test.py)
import torch
from finetuing_my_dataset import MydataSet
from torch.utils.data import DataLoader
from finetuing_net import Model
from transformers import BertTokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载分词器
token = BertTokenizer.from_pretrained(
"bert-base-chinese",
cache_dir="./cache/tokenizer/bert-base-chinese"
)
def collate_fn(batch): # 与训练时相同的批处理函数
sentes = [item[0] for item in batch]
labels = [item[1] for item in batch]
data = token.batch_encode_plus(
sentes,
truncation=True,
max_length=350,
padding=True,
return_tensors="pt",
return_length=True,
)
input_ids = data["input_ids"]
attention_mask = data["attention_mask"]
token_type_ids = data["token_type_ids"]
labels = torch.LongTensor(labels)
return input_ids, attention_mask, token_type_ids, labels
# 创建测试数据集和数据加载器
train_dataset = MydataSet(split="test")
train_dataloader = DataLoader(
train_dataset,
batch_size=32,
shuffle=True,
drop_last=True,
collate_fn=collate_fn,
)
if __name__ == "__main__":
acc = 0 # 正确预测数
total = 0 # 总样本数
model = Model().to(device) # 初始化模型
model.load_state_dict(torch.load("./model/3finetuned_model.pth")) # 加载训练好的模型
model.eval() # 设置为评估模式
for step, (input_ids, attention_mask, token_type_ids, labels) in enumerate(train_dataloader):
# 移动数据到设备
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
token_type_ids = token_type_ids.to(device)
labels = labels.to(device)
# 前向传播(不计算梯度)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
out = outputs.argmax(dim=1) # 预测类别
acc += (out == labels).sum().item() # 累加正确预测数
total += len(labels) # 累加总样本数
print(acc / total) # 输出准确率
2.5 交互式预测 (finetuing_run.py)
import torch
from finetuing_net import Model
from transformers import BertTokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 类别名称
names = [
"负向评价", # 类别0
"正向评价", # 类别1
]
model = Model().to(device) # 初始化模型
# 加载分词器
token = BertTokenizer.from_pretrained(
"bert-base-chinese",
cache_dir="./cache/tokenizer/bert-base-chinese"
)
def collate_fn(data): # 单样本处理函数
sentes = []
sentes.append(data) # 将输入文本加入列表
# 使用分词器处理文本
data = token.batch_encode_plus(
sentes,
truncation=True,
padding="max_length",
max_length=350,
return_tensors="pt",
return_length=True,
)
input_ids = data["input_ids"]
attention_mask = data["attention_mask"]
token_type_ids = data["token_type_ids"]
return input_ids, attention_mask, token_type_ids
def test():
model.load_state_dict(torch.load("./model/2finetuned_model.pth")) # 加载训练好的模型
model.eval() # 设置为评估模式
while True: # 交互式循环
text = input("请输入文本:") # 获取用户输入
if text == "q": # 输入q退出
print("退出测试")
break
# 处理输入文本
input_ids, attention_mask, token_type_ids = collate_fn(text)
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
token_type_ids = token_type_ids.to(device)
# 预测(不计算梯度)
with torch.no_grad():
outputs = model(input_ids, attention_mask, token_type_ids)
out = outputs.argmax(dim=1) # 预测类别
print("模型预测", names[out], "\n") # 输出预测结果
if __name__ == "__main__":
test() # 启动测试
3. 关键点解析
3.1 数据处理流程
- 数据集加载:使用HuggingFace的
load_from_disk加载预处理好的数据集 - 文本编码:使用
BertTokenizer将文本转换为模型可接受的输入格式 - 批处理:
collate_fn函数负责将多个样本打包成一个批次
3.2 模型结构
- 预训练BERT:固定参数,仅作为特征提取器
- 分类头:可训练的全连接层,将BERT输出映射到任务特定的类别空间
3.3 训练策略
- 优化器选择:使用AdamW优化器,适合Transformer模型
- 学习率:较小的学习率(1e-5)避免破坏预训练学到的知识
- 评估指标:准确率和交叉熵损失
4. 总结
本文详细介绍了如何使用PyTorch微调HuggingFace上的BERT模型,包括:
- 数据集处理与加载
- 模型定义与微调策略
- 训练、测试和交互式预测的实现
- 关键代码的逐行解释
通过微调预训练模型,我们可以在相对较小的数据集上获得良好的性能,这是现代NLP应用中的常用技术。
pytorch入门 - 微调huggingface大模型的更多相关文章
- Pytorch入门下 —— 其他
本节内容参照小土堆的pytorch入门视频教程. 现有模型使用和修改 pytorch框架提供了很多现有模型,其中torchvision.models包中有很多关于视觉(图像)领域的模型,如下图: 下面 ...
- pytorch入门2.2构建回归模型初体验(开始训练)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- pytorch入门2.0构建回归模型初体验(数据生成)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- pytorch入门2.1构建回归模型初体验(模型构建)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- DeepSpeed Chat: 一键式RLHF训练,让你的类ChatGPT千亿大模型提速省钱15倍
DeepSpeed Chat: 一键式RLHF训练,让你的类ChatGPT千亿大模型提速省钱15倍 1. 概述 近日来,ChatGPT及类似模型引发了人工智能(AI)领域的一场风潮. 这场风潮对数字世 ...
- pytorch 入门指南
两类深度学习框架的优缺点 动态图(PyTorch) 计算图的进行与代码的运行时同时进行的. 静态图(Tensorflow <2.0) 自建命名体系 自建时序控制 难以介入 使用深度学习框架的优点 ...
- 总结笔记 | 深度学习之Pytorch入门教程
笔记作者:王博Kings 目录 一.整体学习的建议 1.1 如何成为Pytorch大神? 1.2 如何读Github代码? 1.3 代码能力太弱怎么办? 二.Pytorch与TensorFlow概述 ...
- Pytorch入门上 —— Dataset、Tensorboard、Transforms、Dataloader
本节内容参照小土堆的pytorch入门视频教程.学习时建议多读源码,通过源码中的注释可以快速弄清楚类或函数的作用以及输入输出类型. Dataset 借用Dataset可以快速访问深度学习需要的数据,例 ...
- 千亿参数开源大模型 BLOOM 背后的技术
假设你现在有了数据,也搞到了预算,一切就绪,准备开始训练一个大模型,一显身手了,"一朝看尽长安花"似乎近在眼前 -- 且慢!训练可不仅仅像这两个字的发音那么简单,看看 BLOOM ...
- 第一章:PyTorch 入门
第一章:PyTorch 入门 1.1 Pytorch 简介 1.1.1 PyTorch的由来 1.1.2 Torch是什么? 1.1.3 重新介绍 PyTorch 1.1.4 对比PyTorch和Te ...
随机推荐
- 大模型评测之幻觉检测hallucination_evaluation_model
大背景: 2025开年deepseek铺天盖地的新闻 参会代表已经表明,年度主线就是以AI为基础 Manus于3月初横空出世 国内各种模型竞赛的现状,只要是和科技沾边的公司不可能没有大模型,哪怕是里三 ...
- zstd压缩算法概述与基本使用
本文仅关注zstd的使用,并不关心其算法的具体实现 并没有尝试使用zstd的所有功能模式,但是会简单介绍每种模式的应用场景,用到的时候去查api吧 step 0:why zstd? zstd是face ...
- Golang 入门 : 转换
Go中数学运算和比较运算要求包含的值具有相同的类型.如果不是的话,则在尝试运行代码时会报错. 为变量分配新值也是如此.如果所赋值的类型与变量的声明类型不匹配,也会报错. 解决方法是使用转换,它允许你将 ...
- 人工智能-A*算法-最优路径搜索实验
上次学会了<A*算法-八数码问题>,初步了解了A*算法的原理,本次再用A*算法完成一个最优路径搜索实验. 一.实验内容1. 设计自己的启发式函数.2. 在网格地图中,设计部分障碍物.3. ...
- 办公自动化-批量更新tar包内文件
最近工作有点忙,学习的时间也少了,为了提高工作效率,有时候我们需要自己写一些提高办公处理效率给的工具或者脚本或者程序. 比如,我目前遇到的一个事项,需要更新很多个tar包文件,把tar包内的某个文件替 ...
- useSyncExternalStore 的应用
我们是袋鼠云数栈 UED 团队,致力于打造优秀的一站式数据中台产品.我们始终保持工匠精神,探索前端道路,为社区积累并传播经验价值. 本文作者:修能 学而不思则罔,思而不学则殆 . --- <论语 ...
- Destination host unreachable 一般解决办法
症状: 上网各类应用基本正常,但是在命令行下使用ping命令,无论任何地址,均反馈Destination host unreachable. 分析: 输入命令arp -a可以看到网关的MAC地址正常解 ...
- 扫盲ASM
在进行程序跟踪时,会出现汇编.由于ASM盲,所以添加不少烦恼.有烦恼得想办法解决.对,扫盲ASM. 这里是教材,感觉大白话很好理解(感谢 http://www.ruanyifeng.com/blog/ ...
- Devops工程师需要具备的10项技能
Facebook.Amazon和Microsoft等公司正在大量使用DevOps技术来确保软件的一致交付,DevOps的的工作机会和所需要的技能集也是越来越多. 在这里,我们将讨论Devops工程师需 ...
- python之“if __name__=="__main__"”的代表的意思和用法
创建下方脚本A def print_sum(a): print(a) print_sum(20) if __name__=="__main__": print("test ...