Keras/Tensorflow训练逻辑研究
Keras是什么,以及相关的基础知识,这里就不做详细介绍,请参考Keras学习站点http://keras-cn.readthedocs.io/en/latest/
Tensorflow作为backend时的训练逻辑梳理,主要是结合项目,研究了下源代码!
我们的项目是智能问答机器人,基于双向RNN(准确的说是GRU)网络,这里网络结构,就不做介绍,只研究其中的训练逻辑,我们的训练是基于fit_generator,即基于生成器模型,节省内存,有助效率提升。
什么是生成器以及生成器的工作原理,这里不表,属于python的基础范畴。
1. Keras的训练,是基于batch进行的,每一个batch训练过程,进行一次loss和acc的调整
1.1 .主要核心代码
A. /home/anaconda2/lib/python2.7/site-packages/keras/legacy/interfaces.py
1)里面的装饰器函数generate_legacy_interface里面。这里涉及到fit_generator这个最为核心的入口函数的执行过程。
2)python里面装饰器工作原理,非常类似java代码里面的AOP切面编程逻辑,即在正常的业务逻辑执行前,将before或者after或者两者都执行一下。
3)训练函数原型及重要参数解释
def fit_generator(self, generator, #生成器,一个yield的函数,迭代返回数据
steps_per_epoch, #一次训练周期(具体epoch是什么含义,要理解清楚)里面进行多少次batch
epochs=, #设置进行几次全数据集的训练,每一次全数据集训练过程被定义成一个epoch,其实这个是可以灵活应用的
verbose=, #一个开关,打开时,打印清晰的训练数据,即加载ProgbarLogger这个回调函数
callbacks=None, #设置业务需要的回调函数,我们的模型中添加了ModelCheckpoint这个回调函数
validation_data=None, #验证用的数据源设置,evaluate_generator函数要用到这个数据源,我们的项目里面,这里也是一个生成器
validation_steps=None, #设置验证多少次数据后取平均值作为此epoch训练后的效果,val_loss,val_acc的值受这个参数直接影响
class_weight=None, #此参数以及后续参数,我们的项目采用的都是默认值,可以参考官方文档了解细节
max_queue_size=,
workers=,
use_multiprocessing=False,
initial_epoch=)
B. /home/anaconda2/lib/python2.7/site-packages/keras/callbacks.py
1)这里重点有ModelCheckpoint这个回调函数,涉及到业务参数,其他回调都是keras框架默认行为。
2)callback这个类,其实是一个容器,具体表现为一个List,可以在git_generator运行时,基于该函数的入参,构建一个Callback的实例,即一个list里面装入业务需要的callback实例,这里默认会有BaseLogger以及History这个callback,然后会判断verbose为true时,会添加ProgbarLogger这个callback,除此之外,就是fit_generator函数入参callbacks传入的参数。一般都会传递ModelCheckpoint这个。
3)在git_generator这个基于生成器模式训练的过程中,每一个epoch结束(on_epoch_end)时,都要调用这个callback函数(ModelCheckpoint)进行模型数据写文件的操作
2. Keras训练时用到的几个重要回调函数(主要工作在on_batch_end里面)
回调函数是基于抽象类Callback实现的。下面是Callback的成员函数,便于理解。
def __init__(self):
self.validation_data = None def set_params(self, params):
self.params = params def set_model(self, model):
self.model = model def on_epoch_begin(self, epoch, logs=None):
pass def on_epoch_end(self, epoch, logs=None):
pass def on_batch_begin(self, batch, logs=None):
pass def on_batch_end(self, batch, logs=None):
pass def on_train_begin(self, logs=None):
pass def on_train_end(self, logs=None):
pass
A. keras.callbacks.BaseLogger
统计该batch里面训练的loss以及acc的值,计入totals,乘以batch_size后。
def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', )
self.seen += batch_size for k, v in logs.items():
if k in self.totals:
self.totals[k] += v * batch_size
else:
self.totals[k] = v * batch_size
在BaseLogger这个类的on_epoch_end函数里,执行对这个epoch训练数据的loss以及acc求平均值。
def on_epoch_end(self, epoch, logs=None):
if logs is not None:
for k in self.params['metrics']:
if k in self.totals:
# Make value available to next callbacks.
logs[k] = self.totals[k] / self.seen
B. keras.callbacks.ModelCheckpoint
在on_epoch_end时会保存模型数据进入文件
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epochs_since_last_save +=
if self.epochs_since_last_save >= self.period:
self.epochs_since_last_save =
filepath = self.filepath.format(epoch=epoch, **logs)
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
warnings.warn('Can save best model only with %s available, '
'skipping.' % (self.monitor), RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if self.verbose > :
print('Epoch %05d: %s improved from %0.5f to %0.5f,'
' saving model to %s'
% (epoch, self.monitor, self.best,
current, filepath))
self.best = current
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
else:
if self.verbose > :
print('Epoch %05d: %s did not improve' %
(epoch, self.monitor))
else:
if self.verbose > :
print('Epoch %05d: saving model to %s' % (epoch, filepath))
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
C.keras.callbacks.History
主要记录每一次epoch训练的结果,结果包含loss以及acc的值
D. keras.callbacks.ProgbarLogger
这个函数里面实现训练中间状态数据信息的输出,主要涉及进度相关信息。
3. 具体训练逻辑过程
A. 训练函数分析
a. model.fit_generator 训练入口函数(参考上面的函数原型定义), 我们项目中用tk_data_generator函数作为训练数据提供者(生成器)
1) callbacks.on_train_begin()
2) while epoch < epochs:
3) callbacks.on_epoch_begin(epoch)
4) while steps_done < steps_per_epoch:
5) generator_output = next(output_generator) #生成器next函数取输入数据进行训练,每次取一个batch大小的量
6) callbacks.on_batch_begin(batch_index, batch_logs)
7) outs = self.train_on_batch(x, y,sample_weight=sample_weight,class_weight=class_weight)
8) callbacks.on_batch_end(batch_index, batch_logs)
end of while steps_done < steps_per_epoch
self.evaluate_generator(...) #当一个epoch的最后一次batch执行完毕,执行一次训练效果的评估
9) callbacks.on_epoch_end(epoch, epoch_logs) #在这个执行过程中实现模型数据的保存操作
end of while epoch < epochs
10) callbacks.on_train_end()
b. 特别介绍下train_on_batch
train_on_batch (keras中的trainning.py)
|_self._standardize_user_data
|_self._make_train_function
|_self.train_function (tensorflow的函数)
|_updated = session.run(self.outputs + [self.updates_op], feed_dict=feed_dict,**self.session_kwargs)
B. 训练和验证的对比
a. 在每一个epoch的最后一个迭代(最后一次batch)时,要进行此轮epoch的校验(evaluate)
日志如下:
/ [==============================] - 12228s - loss: 0.5715 - acc: 0.6960 - val_loss: 0.5082 - val_acc: 0.7450 第一个141表示batch_index已经达到141,即steps_per_epoch参数规定的最后一步
第二个141表示steps_per_epoch,即一个epoch里面进行多少次batch处理
12228s 表示此batch处理结束所花费的时间
loss:此epoch里面的平均损失值
acc:此epoch里面的平均准确率
val_loss:此epoch训练完后进行的evaluate得到的损失值
val_acc:此epoch训练完后进行的evaluate得到的正确率
b. 验证逻辑,和训练逻辑差不多,只是将validation_steps指定次数的test的值进行取平均值,得到validation_steps次test的均值作为本epoch训练的最终效果
self.evaluate_generator(validation_data,validation_steps,max_queue_size=max_queue_size,workers=workers,use_multiprocessing=use_multiprocessing)
1) while steps_done < steps:
2) generator_output = next(output_generator)
3) outs = self.test_on_batch(x, y, sample_weight=sample_weight)
4)对上述while得到的每次outs进行 averages.append(np.average([out[i] for out in all_outs],weights=batch_sizes))
其中重点test_on_batch
test_on_batch(self, x, y, sample_weight=None)
|_self._standardize_user_data(x, y,sample_weight=sample_weight,check_batch_axis=True)
|_self._make_test_function()
|_self.test_function(ins)
|_updated = session.run(self.outputs + [self.updates_op],feed_dict=feed_dict,**self.session_kwargs)
c. train和test的重要区别,应该体现在下面的两个函数上
def _make_train_function(self):
if not hasattr(self, 'train_function'):
raise RuntimeError('You must compile your model before using it.')
if self.train_function is None:
inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
inputs += [K.learning_phase()] with K.name_scope('training'):
with K.name_scope(self.optimizer.__class__.__name__):
training_updates = self.optimizer.get_updates(
params=self._collected_trainable_weights,
loss=self.total_loss)
updates = self.updates + training_updates
# Gets loss and metrics. Updates weights at each call.
self.train_function = K.function(inputs,
[self.total_loss] + self.metrics_tensors,
updates=updates,
name='train_function',
**self._function_kwargs)
def _make_test_function(self):
if not hasattr(self, 'test_function'):
raise RuntimeError('You must compile your model before using it.')
if self.test_function is None:
inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
inputs += [K.learning_phase()]
# Return loss and metrics, no gradient updates.
# Does update the network states.
self.test_function = K.function(inputs,
[self.total_loss] + self.metrics_tensors,
updates=self.state_updates,
name='test_function',
**self._function_kwargs)
经过前面的代码逻辑梳理,可以看到不管是train的过程还是test的过程,最终底层都是调用Tensorflow的session.run方法进行loss和acc的获取,细心的观察,会发现两个session.run函数的入参其实有点不同。
结合上面train和test的私有函数中标注红色的注释,以及用K.function生成函数的入参中,可以看出train和test的差异。
总结:
0. 训练过程中,每次权重的更新都是在一个batch上进行一次,是基于batch量的数据为单位进行一次权重的更新
1. 基于生成器模型训练数据,可以提升效率,降低对物理服务器性能,尤其是内存的要求
2. 训练过程中,Callback函数执行了大量的工作,包括loss、acc值的记录,以及训练中间结果的日志反馈,最重要的是模型数据的输出,也是通过callback的方式实现(ModelCheckpoint)
3. 训练(train)和验证(evaluate/validate)的逻辑近乎一样,训练要更新权重,但是验证过程,仅仅更新网络状态,不涉及权重(loss以及acc参数)信息的更新
4. 代码梳理过程中,得出结论,Keras对python编程基本功底要求还是有点高的,采用了推导式编程习惯,生成器,装饰器,回调等编程思想,另外,对矩阵运算,例如numpy.dot以及numpy.multiply的数学逻辑都有一定要求,否则比较难看懂。
Keras/Tensorflow训练逻辑研究的更多相关文章
- keras&tensorflow+分布式训练︱实现简易视频内容问答框架
内容来源:Keras 之父讲解 Keras:几行代码就能在分布式环境训练模型 把 Keras API 直接整合入 TensorFlow 项目中,这样能与你的已有工作流无缝结合.至此,Keras 成为了 ...
- 【深度学习】keras + tensorflow 实现猫和狗图像分类
本文主要是使用[监督学习]实现一个图像分类器,目的是识别图片是猫还是狗. 从[数据预处理]到 [图片预测]实现一个完整的流程, 当然这个分类在 Kaggle 上已经有人用[迁移学习](VGG,Resn ...
- 人脸检测及识别python实现系列(5)——利用keras库训练人脸识别模型
人脸检测及识别python实现系列(5)——利用keras库训练人脸识别模型 经过前面稍显罗嗦的准备工作,现在,我们终于可以尝试训练我们自己的卷积神经网络模型了.CNN擅长图像处理,keras库的te ...
- Keras模型训练的断点续训、早停、效果可视化
训练:model.fit()函数 fit(x=None, y=None, batch_size=None, epochs=, verbose=, callbacks=None, validation_ ...
- VGG16等keras预训练权重文件的下载及本地存放
VGG16等keras预训练权重文件的下载: https://github.com/fchollet/deep-learning-models/releases/ .h5文件本地存放目录: Linux ...
- 人工智能不过尔尔,基于Python3深度学习库Keras/TensorFlow打造属于自己的聊天机器人(ChatRobot)
原文转载自「刘悦的技术博客」https://v3u.cn/a_id_178 聊天机器人(ChatRobot)的概念我们并不陌生,也许你曾经在百无聊赖之下和Siri打情骂俏过,亦或是闲暇之余与小爱同学谈 ...
- [AI开发]centOS7.5上基于keras/tensorflow深度学习环境搭建
这篇文章详细介绍在centOS7.5上搭建基于keras/tensorflow的深度学习环境,该环境可用于实际生产.本人现在非常熟练linux(Ubuntu/centOS/openSUSE).wind ...
- 目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练
将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练. import xml.etree.ElementTree as ET import numpy as ...
- 自己搞了20万张图片100个分类,tensorflow训练23万次后。。。。。。
自己搞了20万张图片100个分类,tensorflow训练23万次后...... 我自己把训练用的一张图片,弄乱之后做了一个预测 100个汉字,20多万张图片,tensorflow CNN训练23万次 ...
随机推荐
- [转]JavaScript与元素间的抛物线轨迹运动
在张鑫旭的博客看到这个抛物线的小动画,觉得很感兴趣,转载一下方便研究~ 原文地址:http://www.zhangxinxu.com/wordpress/?p=3855 在页面上添加元素的位移动画,除 ...
- 2.4 CSS定位
前言 大部分人在使用selenium定位元素时,用的是xpath定位,因为xpath基本能解决定位的需求.css定位往往被忽略掉了,其实css定位也有它的价值,css定位更快,语法更简洁.这一篇css ...
- 学习quartz
https://www.w3cschool.cn/quartz_doc/quartz_doc-1xbu2clr.html
- fixed不能罩住下面的内容
fix的优先级并不是最高的,所以要设置z-index,比它下面的元素高就能遮住了
- 20155208徐子涵 2016-2017-2 《Java程序设计》第5周学习总结
20155208徐子涵 2016-2017-2 <Java程序设计>第5周学习总结 教材学习内容总结 第八章 异常处理 8.1 语法与继承结构 Java中所有错误都会被打包为对象,运用tr ...
- BZOJ 2002:Bounce 弹飞绵羊(分块)
2002: [Hnoi2010]Bounce 弹飞绵羊 Time Limit: 10 Sec Memory Limit: 259 MB Submit: 14944 Solved: 7598 [Su ...
- log4j.properties与db.properties
log4j.properties与db.properties db.driver=com.mysql.jdbc.Driver db.url=jdbc:mysql:///mybatis?useUnico ...
- (0-1)CSS 标签语法的属性
CSS text-decoration 属性 display display 属性规定元素应该生成的框的类型
- vue的指令绑定、事件、冒泡
a标签的属性绑定: v-once:就是第一次渲染什么就是什么,不会随着其他改变而改变,简言之就是绑定他不让他的值改变 防止跨站脚本攻击 如果你觉得安全的话,可以不要让变量的值显示成字符串 解决方法是: ...
- Atom编辑神器
最近喜欢上了Atom编辑神器,安装就不说了,重点讲配置. 一:软件配置 1.先将欢迎界面去掉,每次打开Atom的时候都会出现,实在是很烦人. 就在欢迎界面里面有个复选框,去掉选中就可以了. 2.让At ...