最近实在是有点忙,没啥时间写博客了。趁着周末水一文,把最近用 huggingface transformers 训练文本分类模型时遇到的一个小问题说下。

背景

之前只闻 transformers 超厉害超好用,但是没有实际用过。之前涉及到 bert 类模型都是直接手写或是在别人的基础上修改。但这次由于某些原因,需要快速训练一个简单的文本分类模型。其实这种场景应该挺多的,例如简单的 POC 或是临时测试某些模型。

我的需求很简单:用我们自己的数据集,快速训练一个文本分类模型,验证想法。

我觉得如此简单的一个需求,应该有模板代码。但实际去搜的时候发现,官方文档什么时候变得这么多这么庞大了?还多了个 Trainer API?瞬间让我想起了 Pytorch Lightning 那个坑人的同名 API。但可能是时间原因,找了一圈没找到适用于自定义数据集的代码,都是用的官方、预定义的数据集。

所以弄完后,我决定简单写一个文章,来说下这原本应该极其容易解决的事情。

数据

假设我们数据的格式如下:

0 第一个句子
1 第二个句子
0 第三个句子

即每一行都是 label sentence 的格式,中间空格分隔。并且我们已将数据集分成了 train.txtval.txt

代码

加载数据集

首先使用 datasets 加载数据集:

from datasets import load_dataset
dataset = load_dataset('text', data_files={'train': 'data/train_20w.txt', 'test': 'data/val_2w.txt'})

加载后的 dataset 是一个 DatasetDict 对象:

DatasetDict({
train: Dataset({
features: ['text'],
num_rows: 3
})
test: Dataset({
features: ['text'],
num_rows: 3
})
})

类似 tf.data ,此后我们需要对其进行 map ,对每一个句子进行 tokenize、padding、batch、shuffle:

def tokenize_function(examples):
labels = []
texts = []
for example in examples['text']:
split = example.split(' ', maxsplit=1)
labels.append(int(split[0]))
texts.append(split[1])
tokenized = tokenizer(texts, padding='max_length', truncation=True, max_length=32)
tokenized['labels'] = labels
return tokenized tokenized_datasets = dataset.map(tokenize_function, batched=True)
train_dataset = tokenized_datasets["train"].shuffle(seed=42)
eval_dataset = tokenized_datasets["test"].shuffle(seed=42)

根据数据集格式不同,我们可以在 tokenize_function 中随意自定义处理过程,以得到 text 和 labels。注意 batch_sizemax_length 也是在此处指定。处理完我们便得到了可以输入给模型的训练集和测试集。

训练

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2, cache_dir='data/pretrained')
training_args = TrainingArguments('ckpts', per_device_train_batch_size=256, num_train_epochs=5)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset
)
trainer.train()

你可以根据情况修改训练 batchsize per_device_train_batch_size

完整代码

完整代码见 GitHub

END

使用 Transformers 在你自己的数据集上训练文本分类模型的更多相关文章

  1. (2) 用DPM(Deformable Part Model,voc-release4.01)算法在INRIA数据集上训练自己的人体检測模型

    步骤一,首先要使voc-release4.01目标检測部分的代码在windows系统下跑起来: 參考在window下执行DPM(deformable part models) -(检測demo部分) ...

  2. 基于深度学习和迁移学习的识花实践——利用 VGG16 的深度网络结构中的五轮卷积网络层和池化层,对每张图片得到一个 4096 维的特征向量,然后我们直接用这个特征向量替代原来的图片,再加若干层全连接的神经网络,对花朵数据集进行训练(属于模型迁移)

    基于深度学习和迁移学习的识花实践(转)   深度学习是人工智能领域近年来最火热的话题之一,但是对于个人来说,以往想要玩转深度学习除了要具备高超的编程技巧,还需要有海量的数据和强劲的硬件.不过 Tens ...

  3. [PocketFlow]解决TensorFLow在COCO数据集上训练挂起无输出的bug

    1. 引言 因项目要求,需要在PocketFlow中添加一套PeleeNet-SSD和COCO的API,具体为在datasets文件夹下添加coco_dataset.py, 在nets下添加pelee ...

  4. CaffeExample 在CIFAR-10数据集上训练与测试

    本文主要来自Caffe作者Yangqing Jia网站给出的examples. @article{jia2014caffe, Author = {Jia, Yangqing and Shelhamer ...

  5. 第三十二节,使用谷歌Object Detection API进行目标检测、训练新的模型(使用VOC 2012数据集)

    前面已经介绍了几种经典的目标检测算法,光学习理论不实践的效果并不大,这里我们使用谷歌的开源框架来实现目标检测.至于为什么不去自己实现呢?主要是因为自己实现比较麻烦,而且调参比较麻烦,我们直接利用别人的 ...

  6. NVIDIA GPUs上深度学习推荐模型的优化

    NVIDIA GPUs上深度学习推荐模型的优化 Optimizing the Deep Learning Recommendation Model on NVIDIA GPUs 推荐系统帮助人在成倍增 ...

  7. Microsoft Dynamics CRM 2011 当您在 大型数据集上执行 RetrieveMultiple 查询很慢的解决方法

    症状 当您在 Microsoft Dynamics CRM 2011 年大型数据集上执行 RetrieveMultiple 查询时,您会比较慢. 原因 发生此问题是因为大型数据集缓存 Retrieve ...

  8. 在Titanic数据集上应用AdaBoost元算法

    一.AdaBoost 元算法的基本原理 AdaBoost是adaptive boosting的缩写,就是自适应boosting.元算法是对于其他算法进行组合的一种方式. 而boosting是在从原始数 ...

  9. TersorflowTutorial_MNIST数据集上简单CNN实现

    MNIST数据集上简单CNN实现 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 Tensorflow机器学习实战指南 源代码请点击下方链接欢迎加星 Tesorflow实现基于MNI ...

  10. BP算法在minist数据集上的简单实现

    BP算法在minist上的简单实现 数据:http://yann.lecun.com/exdb/mnist/ 参考:blog,blog2,blog3,tensorflow 推导:http://www. ...

随机推荐

  1. Vulnhub 靶场 DARKHOLE: 2

    Vulnhub 靶场 DARKHOLE: 2 前期准备 前期准备: 靶机地址:https://www.vulnhub.com/entry/darkhole-2,740/ kali攻击机ip:192.1 ...

  2. spring Security 使用

    1.pom文件引入 <dependency> <groupId>org.springframework.boot</groupId> <artifactId& ...

  3. 1.CD冷却效果

    CD冷却效果.. 一.将需要用到的图片复制到 PS 中做去色处理,将图片保存为 PNG 格式.如下 二.将制作好的图片导入 Unity 中,做成图集 三.在虚拟按键上添加 UI - Image 制作 ...

  4. WC2023 游记

    不是很会写游记,随便写写吧. 一些附件 讲课资料合集(压缩后 \(\rm 31MB\))太大了,可以去 U 群下载. 由于后面很多乐子,我把相关内容打包成 zip 上传上来了. 乐子合集下载链接.(这 ...

  5. 7.项目结构的构建和提交到gitee

    创建微服务模块 以商城项目的产品模块为例 点击Next,然后倒入依赖的包,Spring Web 然后在选择一个微服务和微服务之间调用需要的包:OpenFeign 导入这两个微服务的组件就行,后面需要用 ...

  6. vulnhub:Victim01靶机

    kali:192.168.111.111 靶机:192.168.111.170 信息收集 端口扫描 nmap -A -v -sV -T5 -p- --script=http-enum 192.168. ...

  7. 多个el-table在使用v-if在同一页面切换渲染时相互影响的解决办法

    解决办法 给每个el-table设置一个唯一的key值,如: <el-table key='uniqueName' ></el-table> <el-table key= ...

  8. 爬取白鲸nft排名前25项目,持有nft大户地址数据。

    https://moby.gg/rankings?tab=Market SELECT address '钱包地址', COUNT (1) '持有nft项目数', SUM (balance) '持有nf ...

  9. K8s网络策略

    Network Policy(网络策略) 默认情况下,k8s集群网络是没有任何限制的,Pod可以和任何其他Pod通信,在某些场景下需要做网络控制,减少网络面的攻击,提高安全性,就会用到网络策略(Net ...

  10. aspx页面,Page_Load 无人进入,解决

    又一次copy放的错误,今天必须记录一下. 当你不需要走后台时候,ready 就有限制了. ready放的位置有问题.下面是错误示范. <script src="../ToExamin ...