transformers 之Trainer对应的数据加载
基础信息说明
- 本文以Seq2SeqTrainer作为实例,来讨论其模型训练时的数据加载方式
- 预训练模型:opus-mt-en-zh
- 数据集:本地数据集
- 任务:en-zh 机器翻译
数据加载
Trainer的数据加载方式主要分为两种:基于torch.utils.data.Dataset的方式加载 和 基于huggingface自带的Datasets的方式(下文用huggingface / Datasets表示)加载。以下是一些需要注意的点:(1)Seq2SeqTrainer()的train_dataset和eval_dataset参数的所传实参应为字典类型;(2)该字典实参的keys应当覆盖模型运行所需要的数据参数(本文需要包括的有:'input_ids', 'attention_mask', 'labels');(3)使用huggingface / Datasets方法加载时,传给train_dataset和eval_dataset的字典实参中,多余的key(未在模型运行所需输入参数列表中)及其相关数据数,将会在训练之前被剔除。
torch.utils.data.Dataset
重载Dataset类(dataset.py)
# -*- coding: utf-8 -*-
from torch.utils.data import Dataset
class CDNDataset(Dataset):
def __init__(self, samples):
super(CDNDataset, self).__init__()
self.samples = samples
def __getitem__(self, ite):
res = {k_: v_[ite]for k_, v_ in self.samples.items()}
return res
def __len__(self):
return len(self.samples['labels'])
加载引用(main.py 后文代码同属于本文件)
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
from dataset import CDNDataset
读取数据
# 读取训练集
with open('raw_data/txt_en.txt', 'r', encoding='utf-8') as fr_en, open('raw_data/txt_zh.txt', 'r', encoding='utf-8') as fr_zh:
train_data = tokenizer([str_.strip() for str_ in fr_en.readlines()], max_length=128, padding=True,truncation=True)
# 将tokenized的中文序列对应的input_ids作为输入数据的标签
train_data['labels'] = tokenizer([str_.strip() for str_ in fr_zh.readlines()], max_length=128,
padding=True,truncation=True)["input_ids"]
fr_en.close()
fr_zh.close()
train_data = CDNDataset(train_data)
# 读取验证集
with open('raw_data/test_txt_en.txt', 'r', encoding='utf-8') as fr_en, open('raw_data/test_txt_zh.txt', 'r', encoding='utf-8') as fr_zh:
dev_data = tokenizer([str_.strip() for str_ in fr_en.readlines()], max_length=128, padding=True,truncation=True)
dev_data['labels'] = tokenizer([str_.strip() for str_ in fr_zh.readlines()], max_length=128,
padding=True,truncation=True)["input_ids"]
fr_en.close()
fr_zh.close()
dev_data = CDNDataset(dev_data)
huggingface / Datasets
修改main.py中数据集读取部分的代码
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
"""利用load_dataset()来读取数据:
- 该方法支持.txt、.csv、.json等文件格式
- 返回结果是一个字典类型
- 读取.txt文件时,若不指定名称,这key为"text", 且会返回文本中的样本数(段落数)
- 在读取.json文件时,若所有样本放在一个josn文件中,则返回的样本数为1(无法优雅地调用train_test_split()进行数据集分割),名称为默认名或者最层字典所 对应的keys;
- 将每个json文件仅存放一个样本,并把这些文件放在某一目录,可使利用load_dataset()正确计算出样本数。但该目录下每个.json文件命名风格要一致(例如:txt1.json、txt2.json、、、),文件名差异较大的话,系统会只读取某一类命名格式相近的文件中的数据。
"""
books = load_dataset("raw_data", data_dir='test_en', name='translation')
books = books["train"].train_test_split(test_size=0.15)
source_lang = "en"
target_lang = "zh"
prefix = "translate English to Chinese: " # 其实我也还没搞懂为啥要加这样一个前缀
def preprocess_function(examples):
inputs = [prefix + example[source_lang] for example in examples["translation"]]
targets = [example[target_lang] for example in examples["translation"]]
model_inputs = tokenizer(inputs, max_length=128, truncation=True)
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=128, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
tokenized_books = books.map(preprocess_function, batched=True)
模型及参数加载
tokenizer = AutoTokenizer.from_pretrained("opus-mt-en-zh")
model = AutoModelForSeq2SeqLM.from_pretrained("opus-mt-en-zh")
#使用huggingface/Datasets方式加载数据时,可以用DataCollator达到批处理的效果
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model) #用torch.utils.data.Dataset方式加载时,不需要
training_args = Seq2SeqTrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
weight_decay=0.01,
save_total_limit=3,
num_train_epochs=2,
fp16=True,
)
模型训练
本文以Seq2SeqTrainer作为实例来进行介绍。
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_data,
eval_dataset=dev_data,
tokenizer=tokenizer,
data_collator=data_collator, #用torch.utils.data.Dataset方式加载时,此参数不需要
)
补充说明
Seq2SeqTrainer()中的train_dataset和eval_dataset参数仅支持torch.utils.data.Dataset和huggingface / Datasets类型的传入实参。
torch.utils.data.Dataset类型的实参传入Seq2SeqTrainer()后,会在后序过程直接调用torch.utils.data.DataLoader,与常规pytorch操作相同
huggingface / Datasets类型的实参传入Seq2SeqTrainer()后,在后序过程会,先剔除多余的键及其值。至于torch.utils.data.Dataset类型的实参中若包含多余的键及其值,程序会不会报错暂没有测试过。获取模型所需输入参数列表的程序如下:
def _set_signature_columns_if_needed(self):
if self._signature_columns is None:
# Inspect model forward signature to keep only the arguments it accepts.
signature = inspect.signature(self.model.forward)
self._signature_columns = list(signature.parameters.keys())
# Labels may be named label or label_ids, the default data collator handles that.
self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
即使传递给train_dataset和eval_dataset的数据时字典类型,也存在一种做法使得基于torch.utils.data.Dataset加载数据的方式会报异常
重载Dataset类(dataset.py)
# -*- coding: utf-8 -*-
from torch.utils.data import Dataset class CDNDataset(Dataset):
def __init__(self, samples):
super(CDNDataset, self).__init__()
self.samples = samples def __getitem__(self, ite):
return self.samples[ite] def __len__(self):
return len(self.samples)
main.py中数据读取部分
train_data = []
with open('raw_data/txt_en.txt', 'r', encoding='utf-8') as fr_en,
open('raw_data/txt_zh.txt', 'r', encoding='utf-8') as fr_zh:
for en_, zh_ in zip(fr_en, fr_zh):
data = tokenizer(en_.strip(), max_length=128, padding=True, truncation=True, return_tensors='pt')
data["labels"] = tokenizer(zh_.strip(), max_length=128, padding=True,
truncation=True, return_tensors='pt')['input_ids']
train_data.append(data)
fr_en.close()
fr_zh.close() train_data = CDNDataset(train_data)
dev_data = []
with open('raw_data/test_txt_en.txt', 'r', encoding='utf-8') as fr_en,
open('raw_data/test_txt_zh.txt', 'r', encoding='utf-8') as fr_zh:
for en_, zh_ in zip(fr_en, fr_zh):
data = tokenizer(en_.strip(), max_length=128, padding=True, truncation=True, return_tensors='pt')
data["labels"] = tokenizer(zh_.strip(), max_length=128, padding=True,
truncation=True, return_tensors='pt')['input_ids']
dev_data.append(data)
fr_en.close()
fr_zh.close()
dev_data = CDNDataset(dev_data)
报错信息:
"Unable to create tensor, you should probably activate truncation and/or padding "
ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length.
说明:
现将前一个基于torch.utils.data.Dataset加载数据的方式的案例叫作method1,当前抛出异常的案例叫作method2,两者相比:
- dataset.py中__getitem__()的返回类型都是字典,每次也都是返回一个样本
- 在main.py中:
- method1将所有序列样本存入一个list中,然后对该list进行了一次tokenize,最后在CDNDataset类的__getitem__()中根据索引ite组合成一个样本的字典格式,并返回
- method2中是先对每个像本序列分别作tokenize,再将各个样本tokenize后得到的字典存入一个list,最后在CDNDataset类的__getitem__()中根据索引ite返回各个样本对应的字典
- 所报错误信息说没有进行padding 和 truncation, 但事实上我做了,故而不知道是啥问题,望各位大佬不吝赐教。谢过!
transformers 之Trainer对应的数据加载的更多相关文章
- ScrollView嵌套ListView,GridView数据加载不全问题的解决
我们大家都知道ListView,GridView加载数据项,如果数据项过多时,就会显示滚动条.ScrollView组件里面只能包含一个组件,当ScrollView里面嵌套listView,GridVi ...
- python多种格式数据加载、处理与存储
多种格式数据加载.处理与存储 实际的场景中,我们会在不同的地方遇到各种不同的数据格式(比如大家熟悉的csv与txt,比如网页HTML格式,比如XML格式),我们来一起看看python如何和这些格式的数 ...
- flask+sqlite3+echarts3+ajax 异步数据加载
结构: /www | |-- /static |....|-- jquery-3.1.1.js |....|-- echarts.js(echarts3是单文件!!) | |-- /templates ...
- Entity Framework关联查询以及数据加载(延迟加载,预加载)
数据加载分为延迟加载和预加载 EF的关联实体加载有三种方式:Lazy Loading,Eager Loading,Explicit Loading,其中Lazy Loading和Explicit Lo ...
- JQuery插件:遮罩+数据加载中。。。(特点:遮你想遮,罩你想罩)
在很多项目中都会涉及到数据加载.数据加载有时可能会是2-3秒,为了给一个友好的提示,一般都会给一个[数据加载中...]的提示.今天就做了一个这样的提示框. 先去jQuery官网看看怎么写jQuery插 ...
- 如何评估ETL的数据加载时间
简述如何评估大型ETL数据加载时间. 答:评估一个大型的ETL的数据加载时间是一件很复杂的事情.数据加载分为两类,一类是初次加载,另一类是增量加载. 在数据仓库正式投入使用时,需要进行一次初次加载,而 ...
- 浅谈Entity Framework中的数据加载方式
如果你还没有接触过或者根本不了解什么是Entity Framework,那么请看这里http://www.entityframeworktutorial.net/EntityFramework-Arc ...
- 实现虚拟模式的动态数据加载Windows窗体DataGridView控件 .net 4.5 (一)
实现虚拟模式的即时数据加载Windows窗体DataGridView控件 .net 4.5 原文地址 :http://msdn.microsoft.com/en-us/library/ms171624 ...
- Android Volley和Gson实现网络数据加载
Android Volley和Gson实现网络数据加载 先看接口 1 升级接口 http://s.meibeike.com/mcloud/ota/cloudService POST请求 参数列表如下 ...
- Echarts通过Ajax实现动态数据加载
Echarts(3.x版)官网实例的数据都是静态的,实际使用中往往会要求从服务器端取数据进行动态显示,官网教程里给出的异步数据加载很粗略,下面就以官网最简单的实例为例子,详细演示如下过程:1.客户端通 ...
随机推荐
- 【深入浅出 Yarn 架构与实现】4-4 RM 管理 Application
在 YARN 中,Application 是指应用程序,它可能启动多个运行实例,每个运行实例由 -个 ApplicationMaster 与一组该 ApplicationMaster 启动的任务组成, ...
- Django框架之drf:5、反序列化器校验部分源码分析、断言、drf之请求与响应、视图组件介绍及两个视图基类、代码部分实战
Django框架之drf 目录 Django框架之drf 一.反序列化类校验部分源码解析 二.断言 三.drf之请求 1.Request能够解析的前端传入编码格式 2.Request类中的属性和方法 ...
- drf-api接口、测试工具postman
1.web应用模式 """ django是一个web框架,专门用来写web项目,之前学的bbs项目,图书管理系统,用的是前后端混合开发. ""&quo ...
- 【动画笔记】数据结构-AVL树的插入操作
本笔记前置知识: 二叉搜索(排序)树及其插入操作. 本文主要围绕AVL树的平衡因子.纸上做题思路.失衡类型(LL/RR/LR/RL).失衡调整方法.插入后回溯这几部分知识点展开. 注: 本笔记中的平衡 ...
- C-09\编译预处理
一.预处理 C语言在对源程序进行正常编译之前,会先对一些特殊的预处理命令作解释,产生一个新的源程序,该过程称为编译预处理 为了区分预处理命令和一般的C语句,所有预处理命令行都以"#" ...
- Vue27 scoped样式
https://www.jianshu.com/p/d80383251fc5 1 简介 当我们在组件中写样式,vue最后会把所有样式合在一起,如果样式名称重复的话就会有问题 style标签上加scop ...
- 从实现到原理,聊聊Java中的SPI动态扩展
原创:微信公众号 码农参上,欢迎分享,转载请保留出处. 八股文背多了,相信大家都听说过一个词,SPI扩展. 有的面试官就很喜欢问这个问题,SpringBoot的自动装配是如何实现的? 基本上,你一说是 ...
- ROS入门:小海龟实验
1.初试小海龟 1.roscore 2.rosrun turtlesim turtlesim_node 3.rosrun turtlesim turtle_teleop_key 2.发布话题控制小海龟 ...
- 【TS】object类型
object是一个对象,在ts中定义对象类型的语法为:let 变量名 :object = { } 在object类型中,对象内部定义的值是不受类型约束的,只要是一个object类型即可,例如: let ...
- Diffusers库的初识及使用
diffusers库的目标是: 将扩散模型(diffusion models)集中到一个单一且长期维护的项目中 以公众可访问的方式复现高影响力的机器学习系统,如DALLE.Imagen等 让开发人员可 ...