bert,albert的快速训练和预测
随着预训练模型越来越成熟,预训练模型也会更多的在业务中使用,本文提供了bert和albert的快速训练和部署,实际上目前的预训练模型在用起来时都大致相同。
基于不久前发布的中文数据集chineseGLUE,将所有任务分成四大类:文本分类,句子对判断,实体识别,阅读理解。同类可以共享代码,除上面四个任务之外,还加了一个learning to rank ,基于pair wise的方式的任务,代码见:https://github.com/jiangxinyang227/bert-for-task。
具体使用见readme
模型定义在每个项目下的model.py文件中,直接调用bert和albert的源码modeling.py将预训练模型引入,将预训练模型作为encoder部分,也可以只作为embedding层,再自己定义encoder部分,总之可以非常方便的接入下游任务网络层,尤其是当你只想使用预训练模型作为embedding层时,我们需要自己些encoder部分。
bert_config = modeling.BertConfig.from_json_file(self.__bert_config_path)
model = modeling.BertModel(config=bert_config,
is_training=self.__is_training,
input_ids=self.input_ids,
input_mask=self.input_masks,
token_type_ids=self.segment_ids,
use_one_hot_embeddings=False)
output_layer = model.get_pooled_output()
hidden_size = output_layer.shape[-1].value
if self.__is_training:
# I.e., 0.1 dropout
output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
with tf.name_scope("output"):
output_weights = tf.get_variable(
"output_weights", [self.__num_classes, hidden_size],
initializer=tf.truncated_normal_initializer(stddev=0.02))
output_bias = tf.get_variable(
"output_bias", [self.__num_classes], initializer=tf.zeros_initializer())
logits = tf.matmul(output_layer, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
self.predictions = tf.argmax(logits, axis=-1, name="predictions")
在训练时加载预训练的参数值来初始化预训练模型的变量,具体在trainer.py文件中
tvars = tf.trainable_variables()
(assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(
tvars, self.__bert_checkpoint_path)
print("init bert model params")
tf.train.init_from_checkpoint(self.__bert_checkpoint_path, assignment_map)
print("init bert model params done")
sess.run(tf.variables_initializer(tf.global_variables()))
在预测时可以直接实例化predict.py文件中的Predictor类就会加载checkpoint模型文件,调用类中的predict方法就可以进行预测,在不需要考虑模型代码加密,模型优化等情况下,可以直接线上部署。
import json
from predict import Predictor
with open("config/tnews_config.json", "r") as fr:
config = json.load(fr)
predictor = Predictor(config)
text = "歼20座舱盖上的两条“花纹”是什么?"
res = predictor.predict(text)
print(res)
bert,albert的快速训练和预测的更多相关文章
- ResNet网络的训练和预测
ResNet网络的训练和预测 简介 Introduction 图像分类与CNN 图像分类 是指将图像信息中所反映的不同特征,把不同类别的目标区分开来的图像处理方法,是计算机视觉中其他任务,比如目标检测 ...
- YOLO2 (3) 快速训练自己的目标
1快速训练自己的目标 在 YOLO2 (2) 测试自己的数据 中记录了完整的训练自己数据的过程. 训练时目标只有一类 car. 如果已经执行过第一次训练,改过一次配置文件,之后仍然训练同样的目标还是只 ...
- 机器学习使用sklearn进行模型训练、预测和评价
cross_val_score(model_name, x_samples, y_labels, cv=k) 作用:验证某个模型在某个训练集上的稳定性,输出k个预测精度. K折交叉验证(k-fold) ...
- 初识Sklearn-IrisData训练与预测
笔记:机器学习入门---鸢尾花分类 Sklearn 本身就有很多数据库,可以用来练习. 以 Iris 的数据为例,这种花有四个属性,花瓣的长宽,茎的长宽,根据这些属性把花分为三类:山鸢尾花Setosa ...
- 【HEVC帧间预测论文】P1.1 基于运动特征的HEVC快速帧间预测算法
基于运动特征的 HEVC 快速帧间预测算法/Fast Inter-Frame Prediction Algorithm for HEVC Based on Motion Features <HE ...
- Spark技术在京东智能供应链预测的应用——按照业务进行划分,然后利用scikit learn进行单机训练并预测
3.3 Spark在预测核心层的应用 我们使用Spark SQL和Spark RDD相结合的方式来编写程序,对于一般的数据处理,我们使用Spark的方式与其他无异,但是对于模型训练.预测这些需要调用算 ...
- Tensorflow训练和预测中的BN层的坑
以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了.在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在<实战Google ...
- fcn训练及预测tgs数据集
一.背景 kaggle上有这样一个题目,关于盐份预测的语义分割题目.TGS Salt Identification Challenge | Kaggle https://www.kaggle.com ...
- siftflow-fcn32s训练及预测
一.说明 SIFT Flow 是一个标注的语义分割的数据集,有两个label,一个是语义分类(33类),另一个是场景标签(3类). Semantic and geometric segmentatio ...
随机推荐
- aiohttp_spider
aiohttp_spider_def: import asyncio import re import aiohttp import aiomysql from pyquery import PyQu ...
- selenium环境搭建:
环境搭建 基于python3和selenium3做自动化测试,俗话说:工欲善其事必先利其器:没有金刚钻就不揽那瓷器活,磨刀不误砍柴工,因此你必须会搭建基本的开发环境,掌握python基本的语法和一个I ...
- 2019 SDN上机第二次作业
2019 SDN上机第二次作业 1.利用mininet创建如下拓扑,要求拓扑支持OpenFlow 1.3协议,主机名.交换机名以及端口对应正确,请给出拓扑Mininet执行结果,展示端口连接情况 1. ...
- goroutine使用
Goroutine是建立在线程之上的轻量级的抽象.它允许我们以非常低的代价在同一个地址空间中并行地执行多个函数或者方法.相比于线程,它的创建和销毁的代价要小很多,并且它的调度是独立于线程的.在gola ...
- oracle存储过程中循环游标,变量的引用
创建出错时使用: show errors查看具体的错误提示 一. 存储过程中的一个循环及变量引用示例: create or replace procedure my_proiscursor cur i ...
- PHP输出函数
1.print()输出 header('Content-Type:text/html;charset=utf-8'); print ("最近想学习PHP,大家推荐哪个学校好点?/n" ...
- bzoj4520 K远点对
题目链接 思路 这个"\(K\)远"点对一直理解成了距离第\(K\)大的点对\(233\). 要求第\(K\)远,那么我们只要想办法求出来最远的\(K\)个点对就可以了. 用一个大 ...
- python数据分析教程大全
第一篇:Anaconda安装和使用 第二篇:Jupyter norebook使用 第三篇:pandas教程 第四篇:numpy教程 第五篇:Matplotlib教程 第六篇:实战项目 期待吗?(微笑脸 ...
- webrtc笔记(2): 1对1实时视频/语音通讯原理概述
开始正文之前,先思考1个问题:2个处于不同网络环境的(具备摄像头/麦克风多媒体设备的)浏览器,要实现点对点的实时视频/语音通讯,难点在哪? 至少得先搞定下面2个问题: 1.彼此要了解对方支持的媒体格式 ...
- 【08月20日】A股滚动市净率PB历史新低排名
2010年01月01日 到 2019年08月20日 之间,滚动市净率历史新低排名. 上市三年以上的公司,2019年08月20日市净率在30以下的公司. 来源:A股滚动市净率(PB)历史新低排名. 1 ...