BERT-Pytorch版本代码pipline梳理
最近在做BERT的fine-tune工作,记录一下阅读项目https://github.com/weizhepei/BERT-NER时梳理的训练pipline,该项目基于Google的Transformers代码构建
前置知识
bert的DataLoader简介(真的很简介)
https://zhuanlan.zhihu.com/p/384469908
yield介绍
https://www.runoob.com/w3cnote/python-yield-used-analysis.html
这是一种提高代码复用性的方法
带yield的函数被称为 generator(生成器),调用next()方法可使其执行至函数内部的yield处中断并返回一个迭代值
Pipeline
训练部分
① 运行build_dataset_tags.py将原始数据集处理为txt文本保存(生成原始数据集文本)
②数据流
注:“XX.py--->”代表该过程由XX.py发起
1、train.py--->class DataLoader[data_loader.py]--->train_data(d)
通过data_loader.py中的load_data,再调用load_sentences_tags
load_sentences_tags返回一个字典d,包含:
使用tokenizer对原始句子的token
token对应的id
token对应的tag
句子的长度
2、train.py--->train_and_evaluate(train_data, val_data)--->2个generator--->evaluate()[evaluate.py]
此处生成的两个生成器分别用于在训练和测试时以迭代方式获取batch数据

3、train.py--->evaluate(generator)[evaluate.py]--->batch_data, batch_token_starts, batch_tags--->将batch输入model[在train.py处实例化]中--->loss、batch_output、batch_tags--->计算出F1值返回给train_and_evaluate()
在得到F1值后,根据设置的参数决定是否满足停止训练的条件
数据迭代器
data_loader.py--->data_iterator(train/val/test_data)
--->计算会产生的batch的数量(由train/val/test_data中记录的句子长度size和class DataLoader中人为设置的batch_size参数决定)--->提取train/val/test_data中的sentences、tags
# 计算batch数
if data['size'] % self.batch_size == 0:
BATCH_NUM = data['size']//self.batch_size
else:
BATCH_NUM = data['size']//self.batch_size + 1
# one pass over data
# 提取一个batch,由batch_size个sentences构成
for i in range(BATCH_NUM):
# fetch sentences and tags
if i * self.batch_size < data['size'] < (i+1) * self.batch_size:
sentences = [data['data'][idx] for idx in order[i*self.batch_size:]]
if not interMode:
tags = [data['tags'][idx] for idx in order[i*self.batch_size:]]
else:
sentences = [data['data'][idx] for idx in order[i*self.batch_size:(i+1)*self.batch_size]]
if not interMode:
tags = [data['tags'][idx] for idx in order[i*self.batch_size:(i+1)*self.batch_size]]
--->计算batch中最大的句子长度--->将数据转换为np矩阵(numpy array)
--->将数据拷贝到另一个np矩阵,使得所有数据的长度与最大句子长度保持一致(即完成了padding)
# prepare a numpy array with the data, initialising the data with pad_idx
# batch_data的形状为:最长句子长度X最长句子长度(batch_len X batch_len),元素全为0
batch_data = self.token_pad_idx * np.ones((batch_len, max_subwords_len))
batch_token_starts = []
# copy the data to the numpy array
for j in range(batch_len):
cur_subwords_len = len(sentences[j][0])
if cur_subwords_len <= max_subwords_len:
batch_data[j][:cur_subwords_len] = sentences[j][0]
else:
batch_data[j] = sentences[j][0][:max_subwords_len]
token_start_idx = sentences[j][-1]
token_starts = np.zeros(max_subwords_len)
token_starts[[idx for idx in token_start_idx if idx < max_subwords_len]] = 1
batch_token_starts.append(token_starts)
max_token_len = max(int(sum(token_starts)), max_token_len)
--->将所有索引格式的(我理解就是numpy array形式的)数据转换为torch LongTensors
--->返回batch_data, batch_token_starts, batch_tags(这就是用于直接输入模型的数据)
BERT-Pytorch版本代码pipline梳理的更多相关文章
- PyTorch常用代码段整理合集
PyTorch常用代码段整理合集 转自:知乎 作者:张皓 众所周知,程序猿在写代码时通常会在网上搜索大量资料,其中大部分是代码段.然而,这项工作常常令人心累身疲,耗费大量时间.所以,今天小编转载了知乎 ...
- Pytorch版本yolov3源码阅读
目录 Pytorch版本yolov3源码阅读 1. 阅读test.py 1.1 参数解读 1.2 data文件解析 1.3 cfg文件解析 1.4 根据cfg文件创建模块 1.5 YOLOLayer ...
- pytorch版本问题:AttributeError: 'module' object has no attribute '_rebuild_tensor_v2'
用pytorch加载训练好的模型的时候遇到了如下的问题: AttributeError: 'module' object has no attribute '_rebuild_tensor_v2' 到 ...
- PyTorch 常用代码段整理
基础配置 检查 PyTorch 版本 torch.__version__ # PyTorch version torch.version.cuda ...
- Caffe学习系列(二)Caffe代码结构梳理,及相关知识点归纳
前言: 通过检索论文.书籍.博客,继续学习Caffe,千里之行始于足下,继续努力.将自己学到的一些东西记录下来,方便日后的整理. 正文: 1.代码结构梳理 在终端下运行如下命令,可以查看caffe代码 ...
- OpenGL10-骨骼动画原理篇(3)-Shader版本代码已经上传
视频教程请关注 http://edu.csdn.net/lecturer/lecturer_detail?lecturer_id=440 接上一个例程OpenGL10-骨骼动画原理篇(2),对骨骼动画 ...
- (转)GitHub Desktop 拉取 GitHub上 Tag 版本代码
转自:GitHub Desktop 拉取 GitHub上 Tag 版本代码 一直在使用 GitHub Desktop 图形化 git 管理工具,统一项目框架版本时需要切换到ThinkPHP Tag 分 ...
- Gradle 如何打包 Spring Boot 如何不添加版本代码
在 Gradle 中如何在打包的 Jar 中不包含版本代码? 在 bootJar 中,使用下面的代码进行打包不包含版本代码. archiveFileName = "${archiveBase ...
- faster RCNN(keras版本)代码讲解(3)-训练流程详情
转载:https://blog.csdn.net/u011311291/article/details/81121519 https://blog.csdn.net/qq_34564612/artic ...
随机推荐
- GDAL重投影重采样像元配准对齐
研究通常会涉及到多源数据,需要进行基于像元的运算,在此之前需要对数据进行地理配准.空间配准.重采样等操作.那么当不同来源,不同分辨率的数据重采样为同一空间分辨率之后,各个像元不一一对应,有偏移该怎么办 ...
- java 编程基础 Class对象 反射 :数组操作java.lang.reflect.Array类
java.lang.reflect包下还提供了Array类 java.lang.reflect包下还提供了Array类,Array对象可以代表所有的数组.程序可以通过使 Array 来动态地创建数组, ...
- vue常用技巧-动态btn的封装
@1.要求: 1.点击某个按钮后激活active样式,其余按钮则为normal样式 2.要满足任意个数btn(btn个数不确定) @2.思路: 1.首先,btn个数不确定则意味着必须使用v-for循环 ...
- 【LeetCode】982. Triples with Bitwise AND Equal To Zero 解题报告(C++)
作者: 负雪明烛 id: fuxuemingzhu 个人博客: http://fuxuemingzhu.cn/ 目录 题目描述 题目大意 解题方法 日期 题目地址:https://leetcode.c ...
- 【LeetCode】963. Minimum Area Rectangle II 解题报告(Python)
作者: 负雪明烛 id: fuxuemingzhu 个人博客: http://fuxuemingzhu.cn/ 目录 题目描述 题目大意 解题方法 线段长+线段中心+字典 日期 题目地址:https: ...
- 【剑指Offer】二叉树的深度 解题报告(Python & C++)
作者: 负雪明烛 id: fuxuemingzhu 个人博客: http://fuxuemingzhu.cn/ 目录 题目描述 解题方法 日期 题目地址:https://www.nowcoder.co ...
- 1226 - One Unit Machine
1226 - One Unit Machine PDF (English) Statistics Forum Time Limit: 2 second(s) Memory Limit: 32 MB ...
- mybatis查询时使用基本数据类型接收报错-attempted to return null from a method with a primitive return type (int)
一.问题由来 自己在查看日志时发现日志中打印了一行错误信息为: 组装已经放养的宠物数据异常--->Mapper method 'applets.user.mapper.xxxMapper.xxx ...
- 【C++】关于new分配空间
1如果不使用new,则在函数结束时内存被回收,指针变成野指针 #include <iostream> using namespace std; struct Node { int val; ...
- 阿里云视觉智能开放平台的人脸1:N搜索的开源替代-Java版(文末赋开源地址)
一.人脸检测相关概念 人脸检测(Face Detection)是检测出图像中人脸所在位置的一项技术,是人脸智能分析应用的核心组成部分,也是最基础的部分.人脸检测方法现在多种多样,常用的技术或工具大 ...