谷歌BERT预训练源码解析(三):训练过程
目录
前言
源码解析
主函数
自定义模型
遮蔽词预测
下一句预测
规范化数据集
前言
本部分介绍BERT训练过程,BERT模型训练过程是在自己的TPU上进行的,这部分我没做过研究所以不做深入探讨。BERT针对两个任务同时训练。1.下一句预测。2.遮蔽词识别
下面介绍BERT的预训练模型run_pretraining.py是怎么训练的。
源码解析
主函数
训练过程主要用了estimator调度器。这个调度器支持自定义训练过程,将训练集传入之后自动训练。详情见注释
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
if not FLAGS.do_train and not FLAGS.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
tf.gfile.MakeDirs(FLAGS.output_dir)
input_files = []
for input_pattern in FLAGS.input_file.split(","):
input_files.extend(tf.gfile.Glob(input_pattern))
tf.logging.info("*** Input Files ***")
for input_file in input_files:
tf.logging.info(" %s" % input_file)
tpu_cluster_resolver = None
if FLAGS.use_tpu and FLAGS.tpu_name:
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
run_config = tf.contrib.tpu.RunConfig( #训练参数
cluster=tpu_cluster_resolver,
master=FLAGS.master,
model_dir=FLAGS.output_dir,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
tpu_config=tf.contrib.tpu.TPUConfig(
iterations_per_loop=FLAGS.iterations_per_loop,
num_shards=FLAGS.num_tpu_cores,
per_host_input_for_training=is_per_host))
model_fn = model_fn_builder( #自定义模型,用于estimator训练
bert_config=bert_config,
init_checkpoint=FLAGS.init_checkpoint,
learning_rate=FLAGS.learning_rate,
num_train_steps=FLAGS.num_train_steps,
num_warmup_steps=FLAGS.num_warmup_steps,
use_tpu=FLAGS.use_tpu,
use_one_hot_embeddings=FLAGS.use_tpu)
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
estimator = tf.contrib.tpu.TPUEstimator( #创建TPUEstimator
use_tpu=FLAGS.use_tpu,
model_fn=model_fn,
config=run_config,
train_batch_size=FLAGS.train_batch_size,
eval_batch_size=FLAGS.eval_batch_size)
if FLAGS.do_train: #训练过程
tf.logging.info("***** Running training *****")
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
train_input_fn = input_fn_builder( #创建输入训练集
input_files=input_files,
max_seq_length=FLAGS.max_seq_length,
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
is_training=True)
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
if FLAGS.do_eval: #验证过程
tf.logging.info("***** Running evaluation *****")
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
eval_input_fn = input_fn_builder( #创建验证集
input_files=input_files,
max_seq_length=FLAGS.max_seq_length,
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
is_training=False)
result = estimator.evaluate(
input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)
output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
with tf.gfile.GFile(output_eval_file, "w") as writer:
tf.logging.info("***** Eval results *****")
for key in sorted(result.keys()):
tf.logging.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
自定义模型
首先获取数据内容,传入到上一篇定义的模型中。对下一句预测任务取出模型的[CLS]结果。对遮蔽词预测任务取出模型的最后结果。然后分别计算loss值,最后将loss值相加。详情见注释
def model_fn_builder(bert_config, init_checkpoint, learning_rate,
num_train_steps, num_warmup_steps, use_tpu,
use_one_hot_embeddings):
"""Returns `model_fn` closure for TPUEstimator."""
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
"""The `model_fn` for TPUEstimator."""
tf.logging.info("*** Features ***")
for name in sorted(features.keys()):
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
#获取数据内容
input_ids = features["input_ids"]
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]
masked_lm_positions = features["masked_lm_positions"]
masked_lm_ids = features["masked_lm_ids"]
masked_lm_weights = features["masked_lm_weights"]
next_sentence_labels = features["next_sentence_labels"]
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
传入到Bert模型中。
model = modeling.BertModel(
config=bert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=use_one_hot_embeddings)
#遮蔽预测的batch_loss,平均loss,预测概率矩阵
(masked_lm_loss,
masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
bert_config, model.get_sequence_output(), model.get_embedding_table(),
masked_lm_positions, masked_lm_ids, masked_lm_weights)
#下一句预测的batch_loss,平均loss,预测概率矩阵
(next_sentence_loss, next_sentence_example_loss,
next_sentence_log_probs) = get_next_sentence_output(
bert_config, model.get_pooled_output(), next_sentence_labels)
#loss相加
total_loss = masked_lm_loss + next_sentence_loss
#获取所有变量
tvars = tf.trainable_variables()
initialized_variable_names = {}
scaffold_fn = None
#如果有之前保存的模型
if init_checkpoint:
(assignment_map, initialized_variable_names
) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
if use_tpu:
def tpu_scaffold():
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
return tf.train.Scaffold()
scaffold_fn = tpu_scaffold
else:
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
tf.logging.info("**** Trainable Variables ****")
#如果有之前保存的模型
for var in tvars:
init_string = ""
if var.name in initialized_variable_names:
init_string = ", *INIT_FROM_CKPT*"
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
init_string)
output_spec = None
if mode == tf.estimator.ModeKeys.TRAIN:
train_op = optimization.create_optimizer( #自定义好的优化器
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
output_spec = tf.contrib.tpu.TPUEstimatorSpec( #Estimator要求返回一个EstimatorSpec对象
mode=mode,
loss=total_loss,
train_op=train_op,
scaffold_fn=scaffold_fn)
#验证过程
elif mode == tf.estimator.ModeKeys.EVAL:
def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
masked_lm_weights, next_sentence_example_loss,
next_sentence_log_probs, next_sentence_labels):
"""Computes the loss and accuracy of the model."""
masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
[-1, masked_lm_log_probs.shape[-1]]) #概率矩阵转成[batch_size*max_pred_pre_seq,vocab_size]
masked_lm_predictions = tf.argmax(
masked_lm_log_probs, axis=-1, output_type=tf.int32) #取最大值位置为输出
masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) #每句loss列表 [batch_size*max_pred_per_seq]
masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
masked_lm_accuracy = tf.metrics.accuracy( #计算准确率
labels=masked_lm_ids,
predictions=masked_lm_predictions,
weights=masked_lm_weights)
masked_lm_mean_loss = tf.metrics.mean( #计算平均loss
values=masked_lm_example_loss, weights=masked_lm_weights)
next_sentence_log_probs = tf.reshape(
next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]])
next_sentence_predictions = tf.argmax( #获取最大位置为输出
next_sentence_log_probs, axis=-1, output_type=tf.int32)
next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
next_sentence_accuracy = tf.metrics.accuracy( #计算准确率
labels=next_sentence_labels, predictions=next_sentence_predictions)
next_sentence_mean_loss = tf.metrics.mean( 计算平均loss
values=next_sentence_example_loss)
return {
"masked_lm_accuracy": masked_lm_accuracy,
"masked_lm_loss": masked_lm_mean_loss,
"next_sentence_accuracy": next_sentence_accuracy,
"next_sentence_loss": next_sentence_mean_loss,
}
eval_metrics = (metric_fn, [
masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
masked_lm_weights, next_sentence_example_loss,
next_sentence_log_probs, next_sentence_labels
])
output_spec = tf.contrib.tpu.TPUEstimatorSpec( #Estimator要求返回一个EstimatorSpec对象
mode=mode,
loss=total_loss,
eval_metrics=eval_metrics,
scaffold_fn=scaffold_fn)
else:
raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
return output_spec
return model_fn
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
遮蔽词预测
输入BERT模型的最后一层encoder,输出遮蔽词预测任务的loss和概率矩阵。详情见注释
def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
label_ids, label_weights):
#这里的input_tensor是模型中传回的最后一层结果 [batch_size,seq_length,hidden_size]。
#output_weights是词向量表 [vocab_size,embedding_size]
"""Get loss and log probs for the masked LM."""
#获取positions位置的所有encoder(即要预测的那些位置的encoder)
input_tensor = gather_indexes(input_tensor, positions) #[batch_size*max_pred_pre_seq,hidden_size]
with tf.variable_scope("cls/predictions"):
# We apply one more non-linear transformation before the output layer.
# This matrix is not used after pre-training.
with tf.variable_scope("transform"):
input_tensor = tf.layers.dense( #传入一个全连接层 输出shape [batch_size*max_pred_pre_seq,hidden_size]
input_tensor,
units=bert_config.hidden_size,
activation=modeling.get_activation(bert_config.hidden_act),
kernel_initializer=modeling.create_initializer(
bert_config.initializer_range))
input_tensor = modeling.layer_norm(input_tensor)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
output_bias = tf.get_variable(
"output_bias",
shape=[bert_config.vocab_size],
initializer=tf.zeros_initializer())
logits = tf.matmul(input_tensor, output_weights, transpose_b=True) #[batch_size*max_pred_pre_seq,vocab_size]
logits = tf.nn.bias_add(logits, output_bias) #加bias
log_probs = tf.nn.log_softmax(logits, axis=-1) #[batch_size*max_pred_pre_seq,vocab_size]
label_ids = tf.reshape(label_ids, [-1]) #[batch_size*max_pred_per_seq]
label_weights = tf.reshape(label_weights, [-1])
one_hot_labels = tf.one_hot( #[batch_size*max_pred_per_seq,vocab_size]
label_ids, depth=bert_config.vocab_size, dtype=tf.float32) #label id转one hot
# The `positions` tensor might be zero-padded (if the sequence is too
# short to have the maximum number of predictions). The `label_weights`
# tensor has a value of 1.0 for every real prediction and 0.0 for the
# padding predictions.
per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) #[batch_size*max_pred_per_seq]
numerator = tf.reduce_sum(label_weights * per_example_loss) #[1] 一个batch的loss
denominator = tf.reduce_sum(label_weights) + 1e-5
loss = numerator / denominator #平均loss
return (loss, per_example_loss, log_probs)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
下一句预测
输入BERT模型CLS的encoder,输出下一句预测任务的loss和概率矩阵,详情见注释
def get_next_sentence_output(bert_config, input_tensor, labels):
#input_tensor shape [batch_size,hidden_size]
"""Get loss and log probs for the next sentence prediction."""
# Simple binary classification. Note that 0 is "next sentence" and 1 is
# "random sentence". This weight matrix is not used after pre-training.
with tf.variable_scope("cls/seq_relationship"):
output_weights = tf.get_variable(
"output_weights",
shape=[2, bert_config.hidden_size],
initializer=modeling.create_initializer(bert_config.initializer_range))
output_bias = tf.get_variable(
"output_bias", shape=[2], initializer=tf.zeros_initializer()) #[batch_size,hidden_size]
logits = tf.matmul(input_tensor, output_weights, transpose_b=True) #[batch_size,2]
logits = tf.nn.bias_add(logits, output_bias) #[batch_size,2]
log_probs = tf.nn.log_softmax(logits, axis=-1)
labels = tf.reshape(labels, [-1])
one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) #[batch_size,2]
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) #[batch_size]
loss = tf.reduce_mean(per_example_loss) #[1]
return (loss, per_example_loss, log_probs)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
规范化数据集
Estimator要求模型的输入为特定格式(from_tensor_slices),所以要对数据进行类封装
def input_fn_builder(input_files,
max_seq_length,
max_predictions_per_seq,
is_training,
num_cpu_threads=4):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
def input_fn(params):
"""The actual input function."""
batch_size = params["batch_size"]
name_to_features = {
"input_ids":
tf.FixedLenFeature([max_seq_length], tf.int64),
"input_mask":
tf.FixedLenFeature([max_seq_length], tf.int64),
"segment_ids":
tf.FixedLenFeature([max_seq_length], tf.int64),
"masked_lm_positions":
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
"masked_lm_ids":
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
"masked_lm_weights":
tf.FixedLenFeature([max_predictions_per_seq], tf.float32),
"next_sentence_labels":
tf.FixedLenFeature([1], tf.int64),
}
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
if is_training:
d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
d = d.repeat() #重复
d = d.shuffle(buffer_size=len(input_files)) #打乱
# `cycle_length` is the number of parallel files that get read.
cycle_length = min(num_cpu_threads, len(input_files))
# `sloppy` mode means that the interleaving is not exact. This adds
# even more randomness to the training pipeline.
d = d.apply(
tf.contrib.data.parallel_interleave( #生成嵌套数据集,并且输出其元素隔行交错
tf.data.TFRecordDataset,
sloppy=is_training,
cycle_length=cycle_length))
d = d.shuffle(buffer_size=100)
else:
d = tf.data.TFRecordDataset(input_files)
# Since we evaluate for a fixed number of steps we don't want to encounter
# out-of-range exceptions.
d = d.repeat()
# We must `drop_remainder` on training because the TPU requires fixed
# size dimensions. For eval, we assume we are evaluating on the CPU or GPU
# and we *don't* want to drop the remainder, otherwise we wont cover
# every sample.
d = d.apply(
tf.contrib.data.map_and_batch( #结构转换
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
num_parallel_batches=num_cpu_threads,
drop_remainder=True))
return d
return input_fn
---------------------
作者:保持一份率性
来源:CSDN
原文:https://blog.csdn.net/weixin_39470744/article/details/84619903
版权声明:本文为博主原创文章,转载请附上博文链接!

谷歌BERT预训练源码解析(三):训练过程的更多相关文章
- [源码解析] 分布式训练Megatron (1) --- 论文 & 基础
[源码解析] 分布式训练Megatron (1) --- 论文 & 基础 目录 [源码解析] 分布式训练Megatron (1) --- 论文 & 基础 0x00 摘要 0x01 In ...
- Celery 源码解析三: Task 对象的实现
Task 的实现在 Celery 中你会发现有两处,一处位于 celery/app/task.py,这是第一个:第二个位于 celery/task/base.py 中,这是第二个.他们之间是有关系的, ...
- Mybatis源码解析(三) —— Mapper代理类的生成
Mybatis源码解析(三) -- Mapper代理类的生成 在本系列第一篇文章已经讲述过在Mybatis-Spring项目中,是通过 MapperFactoryBean 的 getObject( ...
- 谷歌BERT预训练源码解析(一):训练数据生成
目录预训练源码结构简介输入输出源码解析参数主函数创建训练实例下一句预测&实例生成随机遮蔽输出结果一览预训练源码结构简介关于BERT,简单来说,它是一个基于Transformer架构,结合遮蔽词 ...
- 谷歌BERT预训练源码解析(二):模型构建
目录前言源码解析模型配置参数BertModelword embeddingembedding_postprocessorTransformerself_attention模型应用前言BERT的模型主要 ...
- ReactiveCocoa源码解析(三) Signal代码的基本实现
上篇博客我们详细的聊了ReactiveSwift源码中的Bag容器,详情请参见<ReactiveSwift源码解析之Bag容器>.本篇博客我们就来聊一下信号量,也就是Signal的的几种状 ...
- ReactiveSwift源码解析(三) Signal代码的基本实现
上篇博客我们详细的聊了ReactiveSwift源码中的Bag容器,详情请参见<ReactiveSwift源码解析之Bag容器>.本篇博客我们就来聊一下信号量,也就是Signal的的几种状 ...
- React的React.createRef()/forwardRef()源码解析(三)
1.refs三种使用用法 1.字符串 1.1 dom节点上使用 获取真实的dom节点 //使用步骤: 1. <input ref="stringRef" /> 2. t ...
- 第三十四节,目标检测之谷歌Object Detection API源码解析
我们在第三十二节,使用谷歌Object Detection API进行目标检测.训练新的模型(使用VOC 2012数据集)那一节我们介绍了如何使用谷歌Object Detection API进行目标检 ...
随机推荐
- tomcat9 gzip
我认为apr模式比较屌所以 <Connector port=" protocol="org.apache.coyote.http11.Http11AprProtocol&qu ...
- springmvc 使用了登录拦截器之后静态资源还是会被拦截的处理办法
解决办法 在拦截器的配置里加上静态资源的处理 参考https://www.jb51.net/article/103704.htm
- Java IO:为什么InputStream只能读一次
http://zhangbo-peipei-163-com.iteye.com/blog/2021879 InputStream的接口规范就是这么设计的. /** * Reads the next b ...
- C++ UNREFERENCED_PARAMETER函数的作用
新建win32 application程序,会有这样一段代码 int APIENTRY wWinMain(_In_ HINSTANCE hInstance, _In_opt_ HINSTANCE hP ...
- C++ std::vector指定位置插入
使用vector,必须加上:#include <vector> 1.初始化vector,一般有这几种方式: std::vector<std::wstring> v1; //创建 ...
- SpringBoot web获取请求数据【转】
SpringBoot web获取请求数据 一个网站最基本的功能就是匹配请求,获取请求数据,处理请求(业务处理),请求响应,我们今天来看SpringBoot中怎么获取请求数据. 文章包含的内容如下: 获 ...
- Spring Boot:Boot2.0版本整合Neo4j
前面介绍了Boot 1.5版本集成Neo4j,Boot 2.0以上版本Neo4j变化较大. 场景还是电影人员关系 Boot 2.0主要变化 GraphRepository在Boot2.0下不支持了,调 ...
- SDUT-3378_数据结构实验之查找六:顺序查找
数据结构实验之查找六:顺序查找 Time Limit: 1000 ms Memory Limit: 65536 KiB Problem Description 在一个给定的无序序列里,查找与给定关键字 ...
- OSI七层模型,作用及其对应的协议
物理层(Physical Layer):利用传输介质为数据链路层提供物理连接,实现比特流的透明传输 数据链路层(Data Link Layer):负责建立和管理节点间的链路 网络层(Network L ...
- MySQL数据库操作语句(补充1)(cmd环境运行)
一.字符串类型 enum枚举类型 /* 也叫做枚举类型,类似于单选! 如果某个字段的值只能从某几个确定的值中进行选择,一般就使用enum类型, 在定义的时候需要将该字段所有可能的选项都罗列出来: */ ...