pytorch入门 - 修改huggingface大模型配置参数
介绍
Hugging Face的Transformers库提供了大量预训练模型,但有时我们需要修改这些模型的默认参数来适应特定任务。
本文将详细介绍如何修改BERT模型的最大序列长度(max_position_embeddings)参数,并解释相关原理和实现细节。
原理
BERT等Transformer模型对输入序列长度有固定限制,这主要由位置编码(position embeddings)决定。
原始BERT-base-chinese模型的max_position_embeddings为512,意味着它最多只能处理512个token的输入。当我们需要处理更长的文本时,必须修改这一参数。
修改过程涉及三个关键步骤:
- 调整模型配置中的max_position_embeddings值
- 替换位置嵌入层(position_embeddings)为新尺寸
- 初始化新位置嵌入层的权重(复制原有权重,其余随机初始化)
实现代码详解
下面我们逐行分析实现代码:
1. 数据集准备 (news_finetuing_data_set.py)
from datasets import load_dataset, load_from_disk
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, split):
# 指定CSV文件路径,支持train/test/validation三种分割
data_file = rf"cache\datasets\csv\THUCNewsText\{split}.csv"
self.dataset = load_dataset(
"csv",
data_files={split: data_file},
split=split if split in ["train", "test", "validation"] else "train",
)
def __len__(self):
return len(self.dataset) # 返回数据集样本数量
def __getitem__(self, idx):
return self.dataset[idx]["text"], self.dataset[idx]["label"] # 返回文本和标签
2. 模型修改 (news_finetuing_net.py)
from transformers import BertModel, BertConfig
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1. 加载预训练模型和配置
model = BertModel.from_pretrained(
"bert-base-chinese", cache_dir="./cache/bertbasechinese"
).to(device)
# 2. 修改max_position_embeddings配置
model.config.max_position_embeddings = 1500
# 3. 替换position_embeddings层
old_embeddings = model.embeddings.position_embeddings
new_embeddings = torch.nn.Embedding(1500, old_embeddings.embedding_dim)
# 拷贝原有权重
num = min(old_embeddings.weight.size(0), 1500)
new_embeddings.weight.data[:num, :] = old_embeddings.weight.data[:num, :]
model.embeddings.position_embeddings = new_embeddings
# 4. 冻结除position_embeddings外的所有参数
for name, param in pretrained.named_parameters():
if "embeddings.position_embeddings" in name:
param.requires_grad = True
else:
param.requires_grad = False
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.classifier = torch.nn.Linear(768, 10) # 添加分类头
def forward(self, input_ids, attention_mask, token_type_ids):
position_ids = (
torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device)
.unsqueeze(0)
.expand_as(input_ids)
)
outputs = pretrained(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
)
cls_output = outputs.last_hidden_state[:, 0] # 取[CLS] token的输出
out = self.classifier(cls_output)
return out
3. 训练过程 (news_finetuing_train.py)
import torch
from news_finetuing_data_set import MyDataset
from torch.utils.data import DataLoader
from news_finetuing_net import Model
from transformers import BertTokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCH = 100
# 加载分词器并设置最大长度
token = BertTokenizer.from_pretrained(
"bert-base-chinese",
cache_dir="./cache/tokenizer/bert-base-chinese",
)
token.model_max_length = 1500 # 设置分词器最大长度
def collate_fn(batch):
# 数据处理函数
sentes = [item[0] for item in batch]
labels = [item[1] for item in batch]
data = token.batch_encode_plus(
sentes,
truncation=True,
padding="max_length",
max_length=1500,
return_tensors="pt",
return_length=True,
)
# 返回模型需要的各种输入
return (
data["input_ids"],
data["attention_mask"],
data["token_type_ids"],
torch.LongTensor(labels),
)
# 创建数据集和DataLoader
train_dataset = MyDataset(split="train")
val_dateset = MyDataset(split="validation")
train_loader = DataLoader(
train_dataset,
batch_size=32,
shuffle=True,
drop_last=True,
collate_fn=collate_fn,
)
val_loader = DataLoader(
val_dateset,
batch_size=32,
shuffle=False,
drop_last=True,
collate_fn=collate_fn,
)
# 训练主循环
if __name__ == "__main__":
model = Model().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
loss_func = torch.nn.CrossEntropyLoss()
for epoch in range(EPOCH):
model.train()
for step, (input_ids, attention_mask, token_type_ids, labels) in enumerate(train_loader):
# 数据移动到设备
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
token_type_ids = token_type_ids.to(device)
labels = labels.to(device)
# 前向传播和反向传播
outputs = model(input_ids, attention_mask, token_type_ids)
loss = loss_func(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 打印训练信息
if step % 5 == 0:
out = outputs.argmax(dim=1)
acc = (out == labels).sum().item() / len(labels)
print(f"Epoch: {epoch + 1}/{EPOCH}, Step: {step + 1}/{len(train_loader)}, Loss: {loss.item():.4f}, Acc: {acc:.4f}")
# 保存模型
torch.save(model.state_dict(), f"./model/news_finetuning_epoch_{epoch}.pth")
print(f"epoch {epoch} 保存成功")
关键点解释
模型修改部分
model.config.max_position_embeddings = 1500
- 修改配置中的最大位置嵌入数- 创建新的位置嵌入层时,我们保留了原始嵌入维度(
embedding_dim
),只扩展了位置数量 - 权重初始化策略是复制原有512个位置的权重,剩余位置使用随机初始化
训练策略
- 我们冻结了除位置嵌入外的所有BERT参数,只训练位置嵌入和新添加的分类头
- 这种策略在长文本微调中很常见,可以防止过拟合
数据处理
- 分词器也需要设置
model_max_length
以匹配新的序列长度 collate_fn
函数确保所有输入都被填充/截断到1500的长度
总结
本文详细介绍了如何修改Hugging Face模型的max_position_embeddings参数,包括原理说明和完整代码实现。这种方法可以扩展到其他参数的修改,为定制化预训练模型提供了参考。关键点在于正确修改配置、替换相应层并合理初始化参数。
pytorch入门 - 修改huggingface大模型配置参数的更多相关文章
- 大页(huge pages) 三大系列 ---计算大页配置参数
使用以下shell 脚本来计算大页配置参数,确保使用脚本实例之前的数据已经开始, 如果数据库的版本号11g,确认是否使用自己主动的内存管理(AMM) +++++++++++++++++++++++++ ...
- Docker(十七)-修改Docker容器启动配置参数
有时候,我们创建容器时忘了添加参数 --restart=always ,当 Docker 重启时,容器未能自动启动, 现在要添加该参数怎么办呢,方法有二: 1.Docker 命令修改 docker c ...
- 修改Nginx与Apache配置参数解决http状态码:413上传文件大小限制问题
一.修改Nginx上传文件大小限制 我们使用ngnix做web server的时候,nginx对上传文件的大小有限制,默认是1M. 当超过大小的时候会报413(too large)错误.这个时候我们要 ...
- 修改Docker容器启动配置参数
有时候,我们创建容器时忘了添加参数 --restart=always ,当 Docker 重启时,容器未能自动启动, 现在要添加该参数怎么办呢,方法有二: 1.Docker 命令修改 docker c ...
- pytorch中修改后的模型如何加载预训练模型
问题描述 简单来说,比如你要加载一个vgg16模型,但是你自己需要的网络结构并不是原本的vgg16网络,可能你删掉某些层,可能你改掉某些层,这时你去加载预训练模型,就会报错,错误原因就是你的模型和原本 ...
- [FreeRTOS入门] 1.CubeMX中FreeRTOS配置参数及理解
1.有关优先级 1.1 Configuration --> FreeRTOS MAX_PRIORITIES 设置任务优先级的数量:配置应用程序有效的优先级数目.任何数量的任务都可以共享一个优先级 ...
- Pytorch入门下 —— 其他
本节内容参照小土堆的pytorch入门视频教程. 现有模型使用和修改 pytorch框架提供了很多现有模型,其中torchvision.models包中有很多关于视觉(图像)领域的模型,如下图: 下面 ...
- 千亿参数开源大模型 BLOOM 背后的技术
假设你现在有了数据,也搞到了预算,一切就绪,准备开始训练一个大模型,一显身手了,"一朝看尽长安花"似乎近在眼前 -- 且慢!训练可不仅仅像这两个字的发音那么简单,看看 BLOOM ...
- pytorch入门2.2构建回归模型初体验(开始训练)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- 对MySQL性能影响较大的五类配置参数
以下主要是对MySQL 性能影响关系紧密的五大配置参数的介绍. 一. 连接 连接通常来自Web 服务器,下面列出了一些与连接有关的参数,以及该如何设置它们. (一). ...
随机推荐
- laravel-echo-server 启动报错 [ioredis] Unhandled error event: ReplyError: NOAUTH Authentication required.
可以在 .env 文件加上以下配置 LARAVEL_ECHO_SERVER_REDIS_HOST= LARAVEL_ECHO_SERVER_REDIS_PASSWORD= LARAVEL_ECHO_S ...
- celery 启动显示警告信息“...whether broker connection retries are made during startup in Celery 6.0 and above...”
博客地址:https://www.cnblogs.com/zylyehuo/ # celery作为一个单独项目运行,在settings文件中设置 broker_connection_retry_on_ ...
- 阿里云ECS服务器Ubuntu下安装docker-ce技巧
官方文档 先来份Ubuntu 下安装 docker 的官方文档 -> Get Docker CE for Ubuntu 官方文档的安装方式是最靠谱的,但是对于国内的小伙伴来说墙是硬伤... 国内 ...
- 【ESP32】移植 Arduino 库到 idf 项目中
今天咱们要聊的内容非常简单,所以先扯点别的.上一篇水文中,老周没能将 TinyUSB 的源码编译进 Arduino 中,心有两百万个不甘,于是清明节的时候再试了一次,居然成功了,已经在 esp32 开 ...
- Codeforces Round 944 (Div. 4)
知识点模块 1. ai xor aj<=4 意味着两个数字的二进制位,只能有后两位的二进制位不同,因为如果第三位二进制位不同,就会出现异或的结果大于4 2.要有化曲为直的思想 学会把曲线上的坐标 ...
- 如何开发 MCP 服务?保姆级教程!
最近这段时间有个 AI 相关的概念特别火,叫 MCP,全称模型上下文协议(Model Context Protocol).这是由 Anthropic 推出的一项开放标准,目标是为大型语言模型和 AI ...
- JAVA基础之多线程三期--线程安全问题
一.线程安全问题就是指:多个线程并发访问同一个资源而发生安全性的问题, 线程安全问题都是由全局变量及静态变量引起的. 若每个线程中对全局变量.静态变量只有读操作,而无写 操作,一般来说,这个全局变量是 ...
- nginx配置代理指向Redis
stream { upstream redis { server 127.0.0.1:6379 max_fails=3 fail_timeout=30s; #*redis-addres*替换为真实地址 ...
- 2025年4月TIOBE指数
4 月头条:编程语言 Kotlin.Ruby 和 Swift 直到最近在 TIOBE 指数排名中都一直稳居前 20 的稳定位置.但如今它们似乎失去了发展动力,且很可能会逐渐过时.Kotlin 和 Sw ...
- K8s新手系列之Pod的重启策略
概述 K8s中Pod的重启策略具有确保服务连续性.保证任务完整性.提升资源利用效率.便于故障排查的作用 Pod的重启策略可以根据restartPolicy字段定义. 重启策略适用于pod对象中的所有容 ...