Keras/Tensorflow训练逻辑研究
Keras是什么,以及相关的基础知识,这里就不做详细介绍,请参考Keras学习站点http://keras-cn.readthedocs.io/en/latest/
Tensorflow作为backend时的训练逻辑梳理,主要是结合项目,研究了下源代码!
我们的项目是智能问答机器人,基于双向RNN(准确的说是GRU)网络,这里网络结构,就不做介绍,只研究其中的训练逻辑,我们的训练是基于fit_generator,即基于生成器模型,节省内存,有助效率提升。
什么是生成器以及生成器的工作原理,这里不表,属于python的基础范畴。
1. Keras的训练,是基于batch进行的,每一个batch训练过程,进行一次loss和acc的调整
1.1 .主要核心代码
A. /home/anaconda2/lib/python2.7/site-packages/keras/legacy/interfaces.py
1)里面的装饰器函数generate_legacy_interface里面。这里涉及到fit_generator这个最为核心的入口函数的执行过程。
2)python里面装饰器工作原理,非常类似java代码里面的AOP切面编程逻辑,即在正常的业务逻辑执行前,将before或者after或者两者都执行一下。
3)训练函数原型及重要参数解释
def fit_generator(self, generator, #生成器,一个yield的函数,迭代返回数据
steps_per_epoch, #一次训练周期(具体epoch是什么含义,要理解清楚)里面进行多少次batch
epochs=, #设置进行几次全数据集的训练,每一次全数据集训练过程被定义成一个epoch,其实这个是可以灵活应用的
verbose=, #一个开关,打开时,打印清晰的训练数据,即加载ProgbarLogger这个回调函数
callbacks=None, #设置业务需要的回调函数,我们的模型中添加了ModelCheckpoint这个回调函数
validation_data=None, #验证用的数据源设置,evaluate_generator函数要用到这个数据源,我们的项目里面,这里也是一个生成器
validation_steps=None, #设置验证多少次数据后取平均值作为此epoch训练后的效果,val_loss,val_acc的值受这个参数直接影响
class_weight=None, #此参数以及后续参数,我们的项目采用的都是默认值,可以参考官方文档了解细节
max_queue_size=,
workers=,
use_multiprocessing=False,
initial_epoch=)
B. /home/anaconda2/lib/python2.7/site-packages/keras/callbacks.py
1)这里重点有ModelCheckpoint这个回调函数,涉及到业务参数,其他回调都是keras框架默认行为。
2)callback这个类,其实是一个容器,具体表现为一个List,可以在git_generator运行时,基于该函数的入参,构建一个Callback的实例,即一个list里面装入业务需要的callback实例,这里默认会有BaseLogger以及History这个callback,然后会判断verbose为true时,会添加ProgbarLogger这个callback,除此之外,就是fit_generator函数入参callbacks传入的参数。一般都会传递ModelCheckpoint这个。
3)在git_generator这个基于生成器模式训练的过程中,每一个epoch结束(on_epoch_end)时,都要调用这个callback函数(ModelCheckpoint)进行模型数据写文件的操作
2. Keras训练时用到的几个重要回调函数(主要工作在on_batch_end里面)
回调函数是基于抽象类Callback实现的。下面是Callback的成员函数,便于理解。
def __init__(self):
self.validation_data = None def set_params(self, params):
self.params = params def set_model(self, model):
self.model = model def on_epoch_begin(self, epoch, logs=None):
pass def on_epoch_end(self, epoch, logs=None):
pass def on_batch_begin(self, batch, logs=None):
pass def on_batch_end(self, batch, logs=None):
pass def on_train_begin(self, logs=None):
pass def on_train_end(self, logs=None):
pass
A. keras.callbacks.BaseLogger
统计该batch里面训练的loss以及acc的值,计入totals,乘以batch_size后。
def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', )
self.seen += batch_size for k, v in logs.items():
if k in self.totals:
self.totals[k] += v * batch_size
else:
self.totals[k] = v * batch_size
在BaseLogger这个类的on_epoch_end函数里,执行对这个epoch训练数据的loss以及acc求平均值。
def on_epoch_end(self, epoch, logs=None):
if logs is not None:
for k in self.params['metrics']:
if k in self.totals:
# Make value available to next callbacks.
logs[k] = self.totals[k] / self.seen
B. keras.callbacks.ModelCheckpoint
在on_epoch_end时会保存模型数据进入文件
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epochs_since_last_save +=
if self.epochs_since_last_save >= self.period:
self.epochs_since_last_save =
filepath = self.filepath.format(epoch=epoch, **logs)
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
warnings.warn('Can save best model only with %s available, '
'skipping.' % (self.monitor), RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if self.verbose > :
print('Epoch %05d: %s improved from %0.5f to %0.5f,'
' saving model to %s'
% (epoch, self.monitor, self.best,
current, filepath))
self.best = current
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
else:
if self.verbose > :
print('Epoch %05d: %s did not improve' %
(epoch, self.monitor))
else:
if self.verbose > :
print('Epoch %05d: saving model to %s' % (epoch, filepath))
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
C.keras.callbacks.History
主要记录每一次epoch训练的结果,结果包含loss以及acc的值
D. keras.callbacks.ProgbarLogger
这个函数里面实现训练中间状态数据信息的输出,主要涉及进度相关信息。
3. 具体训练逻辑过程
A. 训练函数分析
a. model.fit_generator 训练入口函数(参考上面的函数原型定义), 我们项目中用tk_data_generator函数作为训练数据提供者(生成器)
1) callbacks.on_train_begin()
2) while epoch < epochs:
3) callbacks.on_epoch_begin(epoch)
4) while steps_done < steps_per_epoch:
5) generator_output = next(output_generator) #生成器next函数取输入数据进行训练,每次取一个batch大小的量
6) callbacks.on_batch_begin(batch_index, batch_logs)
7) outs = self.train_on_batch(x, y,sample_weight=sample_weight,class_weight=class_weight)
8) callbacks.on_batch_end(batch_index, batch_logs)
end of while steps_done < steps_per_epoch
self.evaluate_generator(...) #当一个epoch的最后一次batch执行完毕,执行一次训练效果的评估
9) callbacks.on_epoch_end(epoch, epoch_logs) #在这个执行过程中实现模型数据的保存操作
end of while epoch < epochs
10) callbacks.on_train_end()
b. 特别介绍下train_on_batch
train_on_batch (keras中的trainning.py)
|_self._standardize_user_data
|_self._make_train_function
|_self.train_function (tensorflow的函数)
|_updated = session.run(self.outputs + [self.updates_op], feed_dict=feed_dict,**self.session_kwargs)
B. 训练和验证的对比
a. 在每一个epoch的最后一个迭代(最后一次batch)时,要进行此轮epoch的校验(evaluate)
日志如下:
/ [==============================] - 12228s - loss: 0.5715 - acc: 0.6960 - val_loss: 0.5082 - val_acc: 0.7450 第一个141表示batch_index已经达到141,即steps_per_epoch参数规定的最后一步
第二个141表示steps_per_epoch,即一个epoch里面进行多少次batch处理
12228s 表示此batch处理结束所花费的时间
loss:此epoch里面的平均损失值
acc:此epoch里面的平均准确率
val_loss:此epoch训练完后进行的evaluate得到的损失值
val_acc:此epoch训练完后进行的evaluate得到的正确率
b. 验证逻辑,和训练逻辑差不多,只是将validation_steps指定次数的test的值进行取平均值,得到validation_steps次test的均值作为本epoch训练的最终效果
self.evaluate_generator(validation_data,validation_steps,max_queue_size=max_queue_size,workers=workers,use_multiprocessing=use_multiprocessing)
1) while steps_done < steps:
2) generator_output = next(output_generator)
3) outs = self.test_on_batch(x, y, sample_weight=sample_weight)
4)对上述while得到的每次outs进行 averages.append(np.average([out[i] for out in all_outs],weights=batch_sizes))
其中重点test_on_batch
test_on_batch(self, x, y, sample_weight=None)
|_self._standardize_user_data(x, y,sample_weight=sample_weight,check_batch_axis=True)
|_self._make_test_function()
|_self.test_function(ins)
|_updated = session.run(self.outputs + [self.updates_op],feed_dict=feed_dict,**self.session_kwargs)
c. train和test的重要区别,应该体现在下面的两个函数上
def _make_train_function(self):
if not hasattr(self, 'train_function'):
raise RuntimeError('You must compile your model before using it.')
if self.train_function is None:
inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
inputs += [K.learning_phase()] with K.name_scope('training'):
with K.name_scope(self.optimizer.__class__.__name__):
training_updates = self.optimizer.get_updates(
params=self._collected_trainable_weights,
loss=self.total_loss)
updates = self.updates + training_updates
# Gets loss and metrics. Updates weights at each call.
self.train_function = K.function(inputs,
[self.total_loss] + self.metrics_tensors,
updates=updates,
name='train_function',
**self._function_kwargs)
def _make_test_function(self):
if not hasattr(self, 'test_function'):
raise RuntimeError('You must compile your model before using it.')
if self.test_function is None:
inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
inputs += [K.learning_phase()]
# Return loss and metrics, no gradient updates.
# Does update the network states.
self.test_function = K.function(inputs,
[self.total_loss] + self.metrics_tensors,
updates=self.state_updates,
name='test_function',
**self._function_kwargs)
经过前面的代码逻辑梳理,可以看到不管是train的过程还是test的过程,最终底层都是调用Tensorflow的session.run方法进行loss和acc的获取,细心的观察,会发现两个session.run函数的入参其实有点不同。
结合上面train和test的私有函数中标注红色的注释,以及用K.function生成函数的入参中,可以看出train和test的差异。
总结:
0. 训练过程中,每次权重的更新都是在一个batch上进行一次,是基于batch量的数据为单位进行一次权重的更新
1. 基于生成器模型训练数据,可以提升效率,降低对物理服务器性能,尤其是内存的要求
2. 训练过程中,Callback函数执行了大量的工作,包括loss、acc值的记录,以及训练中间结果的日志反馈,最重要的是模型数据的输出,也是通过callback的方式实现(ModelCheckpoint)
3. 训练(train)和验证(evaluate/validate)的逻辑近乎一样,训练要更新权重,但是验证过程,仅仅更新网络状态,不涉及权重(loss以及acc参数)信息的更新
4. 代码梳理过程中,得出结论,Keras对python编程基本功底要求还是有点高的,采用了推导式编程习惯,生成器,装饰器,回调等编程思想,另外,对矩阵运算,例如numpy.dot以及numpy.multiply的数学逻辑都有一定要求,否则比较难看懂。
Keras/Tensorflow训练逻辑研究的更多相关文章
- keras&tensorflow+分布式训练︱实现简易视频内容问答框架
内容来源:Keras 之父讲解 Keras:几行代码就能在分布式环境训练模型 把 Keras API 直接整合入 TensorFlow 项目中,这样能与你的已有工作流无缝结合.至此,Keras 成为了 ...
- 【深度学习】keras + tensorflow 实现猫和狗图像分类
本文主要是使用[监督学习]实现一个图像分类器,目的是识别图片是猫还是狗. 从[数据预处理]到 [图片预测]实现一个完整的流程, 当然这个分类在 Kaggle 上已经有人用[迁移学习](VGG,Resn ...
- 人脸检测及识别python实现系列(5)——利用keras库训练人脸识别模型
人脸检测及识别python实现系列(5)——利用keras库训练人脸识别模型 经过前面稍显罗嗦的准备工作,现在,我们终于可以尝试训练我们自己的卷积神经网络模型了.CNN擅长图像处理,keras库的te ...
- Keras模型训练的断点续训、早停、效果可视化
训练:model.fit()函数 fit(x=None, y=None, batch_size=None, epochs=, verbose=, callbacks=None, validation_ ...
- VGG16等keras预训练权重文件的下载及本地存放
VGG16等keras预训练权重文件的下载: https://github.com/fchollet/deep-learning-models/releases/ .h5文件本地存放目录: Linux ...
- 人工智能不过尔尔,基于Python3深度学习库Keras/TensorFlow打造属于自己的聊天机器人(ChatRobot)
原文转载自「刘悦的技术博客」https://v3u.cn/a_id_178 聊天机器人(ChatRobot)的概念我们并不陌生,也许你曾经在百无聊赖之下和Siri打情骂俏过,亦或是闲暇之余与小爱同学谈 ...
- [AI开发]centOS7.5上基于keras/tensorflow深度学习环境搭建
这篇文章详细介绍在centOS7.5上搭建基于keras/tensorflow的深度学习环境,该环境可用于实际生产.本人现在非常熟练linux(Ubuntu/centOS/openSUSE).wind ...
- 目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练
将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练. import xml.etree.ElementTree as ET import numpy as ...
- 自己搞了20万张图片100个分类,tensorflow训练23万次后。。。。。。
自己搞了20万张图片100个分类,tensorflow训练23万次后...... 我自己把训练用的一张图片,弄乱之后做了一个预测 100个汉字,20多万张图片,tensorflow CNN训练23万次 ...
随机推荐
- js编写轮播图,广告弹框
1.轮播图 js编写轮播图,需要用到setInterval(计时器):先给一个div,里面放轮播图的图片,将轮播图的图片明明为相同样式的:如:banner1.jpg,banner2.jpg,banne ...
- Python socketserver ftp功能简单讲解
socketserver模块实现并发 为什么要讲socketserver?我们之前写的tcp协议的socket是不是一次只能和一个客户端通信,如果用socketserver可以实现和多个客户端通信.它 ...
- HDU 6143 17多校8 Killer Names(组合数学)
题目传送:Killer Names Problem Description > Galen Marek, codenamed Starkiller, was a male Human appre ...
- HDU 1251 统计难题(字典树 裸题 链表做法)
Problem Description Ignatius最近遇到一个难题,老师交给他很多单词(只有小写字母组成,不会有重复的单词出现),现在老师要他统计出以某个字符串为前缀的单词数量(单词本身也是自己 ...
- 1.selenium实战之从txt文档读取配置信息并执行登录
前置条件: 1.本机已搭建ECShop3.0网站 2.在脚本目录创建了user.txt文本如下: 目的:实现从txt中读取配置文件信息,本实战中,包含url地址.用户名.密码,然后进行ESChop的登 ...
- H5之localStorage,sessionStorage
在以前的时候也听说过一些h5缓存技术,具体也没有去使用过,就在前两三个礼拜我用了localStorage和sessionStorage这两个存储方式, 我使用这些存储技术,也是想减少访问服务器的请求, ...
- Linux sort命令详解
linux之sort用法 sort命令是帮我们依据不同的数据类型进行排序,其语法及常用参数格式: sort [-bcfMnrtk][源文件][-o 输出文件] 补充说明:sort可针对文本文件的内容, ...
- 实验吧—Web——WP之 简单的sql注入之2
直接打开解题连接: 既然是SQL注入,那么我们就要构造注入语句了,这个就要有耐心一个一个去尝试了 输入语句 1'and 1=1 # 和 1'and/**/1=1/**/#后 对比一下,发现是过滤掉了空 ...
- 快速挂载iso文件到虚拟机系统
在vm软件菜单栏那里选择vm,再选择弹出菜单最下面的设置,如图,找到实体机上的iso文件,保存. 这时候,在虚拟机ls /dev会发现有一个cdrom,这个就是我们的iso文件,不过我们还需要把它挂载 ...
- 文件访问控制列表facl
[root@bogon code]# getfacl a.c //获取文件a.c的文件访问控制列表 # file: a.c # owner: root # group: root user::rw- ...