使用 Transformers 在你自己的数据集上训练文本分类模型
最近实在是有点忙,没啥时间写博客了。趁着周末水一文,把最近用 huggingface transformers 训练文本分类模型时遇到的一个小问题说下。
背景
之前只闻 transformers 超厉害超好用,但是没有实际用过。之前涉及到 bert 类模型都是直接手写或是在别人的基础上修改。但这次由于某些原因,需要快速训练一个简单的文本分类模型。其实这种场景应该挺多的,例如简单的 POC 或是临时测试某些模型。
我的需求很简单:用我们自己的数据集,快速训练一个文本分类模型,验证想法。
我觉得如此简单的一个需求,应该有模板代码。但实际去搜的时候发现,官方文档什么时候变得这么多这么庞大了?还多了个 Trainer API?瞬间让我想起了 Pytorch Lightning 那个坑人的同名 API。但可能是时间原因,找了一圈没找到适用于自定义数据集的代码,都是用的官方、预定义的数据集。
所以弄完后,我决定简单写一个文章,来说下这原本应该极其容易解决的事情。
数据
假设我们数据的格式如下:
0 第一个句子
1 第二个句子
0 第三个句子
即每一行都是 label sentence 的格式,中间空格分隔。并且我们已将数据集分成了 train.txt 和 val.txt 。
代码
加载数据集
首先使用 datasets 加载数据集:
from datasets import load_dataset
dataset = load_dataset('text', data_files={'train': 'data/train_20w.txt', 'test': 'data/val_2w.txt'})
加载后的 dataset 是一个 DatasetDict 对象:
DatasetDict({
train: Dataset({
features: ['text'],
num_rows: 3
})
test: Dataset({
features: ['text'],
num_rows: 3
})
})
类似 tf.data ,此后我们需要对其进行 map ,对每一个句子进行 tokenize、padding、batch、shuffle:
def tokenize_function(examples):
labels = []
texts = []
for example in examples['text']:
split = example.split(' ', maxsplit=1)
labels.append(int(split[0]))
texts.append(split[1])
tokenized = tokenizer(texts, padding='max_length', truncation=True, max_length=32)
tokenized['labels'] = labels
return tokenized
tokenized_datasets = dataset.map(tokenize_function, batched=True)
train_dataset = tokenized_datasets["train"].shuffle(seed=42)
eval_dataset = tokenized_datasets["test"].shuffle(seed=42)
根据数据集格式不同,我们可以在 tokenize_function 中随意自定义处理过程,以得到 text 和 labels。注意 batch_size 和 max_length 也是在此处指定。处理完我们便得到了可以输入给模型的训练集和测试集。
训练
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2, cache_dir='data/pretrained')
training_args = TrainingArguments('ckpts', per_device_train_batch_size=256, num_train_epochs=5)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset
)
trainer.train()
你可以根据情况修改训练 batchsize per_device_train_batch_size 。
完整代码
完整代码见 GitHub。
END
使用 Transformers 在你自己的数据集上训练文本分类模型的更多相关文章
- (2) 用DPM(Deformable Part Model,voc-release4.01)算法在INRIA数据集上训练自己的人体检測模型
步骤一,首先要使voc-release4.01目标检測部分的代码在windows系统下跑起来: 參考在window下执行DPM(deformable part models) -(检測demo部分) ...
- 基于深度学习和迁移学习的识花实践——利用 VGG16 的深度网络结构中的五轮卷积网络层和池化层,对每张图片得到一个 4096 维的特征向量,然后我们直接用这个特征向量替代原来的图片,再加若干层全连接的神经网络,对花朵数据集进行训练(属于模型迁移)
基于深度学习和迁移学习的识花实践(转) 深度学习是人工智能领域近年来最火热的话题之一,但是对于个人来说,以往想要玩转深度学习除了要具备高超的编程技巧,还需要有海量的数据和强劲的硬件.不过 Tens ...
- [PocketFlow]解决TensorFLow在COCO数据集上训练挂起无输出的bug
1. 引言 因项目要求,需要在PocketFlow中添加一套PeleeNet-SSD和COCO的API,具体为在datasets文件夹下添加coco_dataset.py, 在nets下添加pelee ...
- CaffeExample 在CIFAR-10数据集上训练与测试
本文主要来自Caffe作者Yangqing Jia网站给出的examples. @article{jia2014caffe, Author = {Jia, Yangqing and Shelhamer ...
- 第三十二节,使用谷歌Object Detection API进行目标检测、训练新的模型(使用VOC 2012数据集)
前面已经介绍了几种经典的目标检测算法,光学习理论不实践的效果并不大,这里我们使用谷歌的开源框架来实现目标检测.至于为什么不去自己实现呢?主要是因为自己实现比较麻烦,而且调参比较麻烦,我们直接利用别人的 ...
- NVIDIA GPUs上深度学习推荐模型的优化
NVIDIA GPUs上深度学习推荐模型的优化 Optimizing the Deep Learning Recommendation Model on NVIDIA GPUs 推荐系统帮助人在成倍增 ...
- Microsoft Dynamics CRM 2011 当您在 大型数据集上执行 RetrieveMultiple 查询很慢的解决方法
症状 当您在 Microsoft Dynamics CRM 2011 年大型数据集上执行 RetrieveMultiple 查询时,您会比较慢. 原因 发生此问题是因为大型数据集缓存 Retrieve ...
- 在Titanic数据集上应用AdaBoost元算法
一.AdaBoost 元算法的基本原理 AdaBoost是adaptive boosting的缩写,就是自适应boosting.元算法是对于其他算法进行组合的一种方式. 而boosting是在从原始数 ...
- TersorflowTutorial_MNIST数据集上简单CNN实现
MNIST数据集上简单CNN实现 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 Tensorflow机器学习实战指南 源代码请点击下方链接欢迎加星 Tesorflow实现基于MNI ...
- BP算法在minist数据集上的简单实现
BP算法在minist上的简单实现 数据:http://yann.lecun.com/exdb/mnist/ 参考:blog,blog2,blog3,tensorflow 推导:http://www. ...
随机推荐
- SVNKit操作SVN
系统集成SVN(SVNKit操作SVN) 网址:https://svnkit.com/documentation.html 文档:https://svnkit.com/javadoc/index.ht ...
- Centos安装后出现please make your choice from '1' to eter the license information spoke | 'q' to quit |'c' to continue |'r' to refresh
这是要求用户阅读或者接收协议: 解决方法:输入"1",按Enter键 阅读许可协议,输入"2",按Enter键 接受许可协议,输入"q" ...
- surfaceview+mediaplayer
public class VideosurfaceView extends SurfaceView implements SurfaceHolder.Callback, MediaPlayer.OnP ...
- spring Security 使用
1.pom文件引入 <dependency> <groupId>org.springframework.boot</groupId> <artifactId& ...
- PC端 图片宽度是百分比,动态设置图片高度为 6:9
我们知道图片宽度可以设置 百分比,但是高度要给一个固定值 不然不生效,并且产品要求图片显示必须是9:6,这开始确实难倒我了 后面想了一下用js 获取图片宽度 动态的计算高度就行了,超简单 se ...
- 1.win10安装centos虚拟机并设置允许远程
一.下载并安装 打开如下连接,下载VMware和CentOS7镜像安装好虚拟机 http://t.zoukankan.com/onlymate-p-9837651.html这个链接的镜像是7.0的,我 ...
- Rsync已过时?替代文件同步软件了解一下
随着企业结构分散化的不断扩大,企业内部和企业间的信息互动更加频繁.越来越多的企业要求内部各种业务数据在服务器.数据中心甚至云上能够有实时的同步留存.所以,企业需要文件同步软件,通过在两个或更多设备之间 ...
- Android开发数据库Sqlite
创建数据库 首先我们要了解这个类:SQLiteOpenHelper: 1.写一个类继承SQLiteOpenHelper 2.实现里面的方法,创建构造方法 参数解释: /** @param: conte ...
- SAP GGB0 校验
需求,针对财务凭证分配号的要求 在满足条件下进行必填校验 在需要的位置 创建确认 创建步骤,一般通过点击就可以形成需要的前提逻辑,也可以通过 设置->专门方式 来进行自定义编写. 如果前提条件是 ...
- Web Dynpro for ABAP(15):Print
3.20 Print WDA调用浏览器打印界面 1.创建Print按钮,绑定事件PRINT; 2.实现ONACTIONPRINT事件: method ONACTIONPRINT. DATA:l_api ...