Keras是什么,以及相关的基础知识,这里就不做详细介绍,请参考Keras学习站点http://keras-cn.readthedocs.io/en/latest/

Tensorflow作为backend时的训练逻辑梳理,主要是结合项目,研究了下源代码!

我们的项目是智能问答机器人,基于双向RNN(准确的说是GRU)网络,这里网络结构,就不做介绍,只研究其中的训练逻辑,我们的训练是基于fit_generator,即基于生成器模型,节省内存,有助效率提升。

什么是生成器以及生成器的工作原理,这里不表,属于python的基础范畴。

1. Keras的训练,是基于batch进行的,每一个batch训练过程,进行一次loss和acc的调整

1.1 .主要核心代码

A. /home/anaconda2/lib/python2.7/site-packages/keras/legacy/interfaces.py

1)里面的装饰器函数generate_legacy_interface里面。这里涉及到fit_generator这个最为核心的入口函数的执行过程。

2)python里面装饰器工作原理,非常类似java代码里面的AOP切面编程逻辑,即在正常的业务逻辑执行前,将before或者after或者两者都执行一下。

3)训练函数原型及重要参数解释

def fit_generator(self, generator,        #生成器,一个yield的函数,迭代返回数据
steps_per_epoch, #一次训练周期(具体epoch是什么含义,要理解清楚)里面进行多少次batch
epochs=, #设置进行几次全数据集的训练,每一次全数据集训练过程被定义成一个epoch,其实这个是可以灵活应用的
verbose=, #一个开关,打开时,打印清晰的训练数据,即加载ProgbarLogger这个回调函数
callbacks=None, #设置业务需要的回调函数,我们的模型中添加了ModelCheckpoint这个回调函数
validation_data=None, #验证用的数据源设置,evaluate_generator函数要用到这个数据源,我们的项目里面,这里也是一个生成器
validation_steps=None, #设置验证多少次数据后取平均值作为此epoch训练后的效果,val_loss,val_acc的值受这个参数直接影响
class_weight=None, #此参数以及后续参数,我们的项目采用的都是默认值,可以参考官方文档了解细节
max_queue_size=,
workers=,
use_multiprocessing=False,
initial_epoch=)

B. /home/anaconda2/lib/python2.7/site-packages/keras/callbacks.py

1)这里重点有ModelCheckpoint这个回调函数,涉及到业务参数,其他回调都是keras框架默认行为。

2)callback这个类,其实是一个容器,具体表现为一个List,可以在git_generator运行时,基于该函数的入参,构建一个Callback的实例,即一个list里面装入业务需要的callback实例,这里默认会有BaseLogger以及History这个callback,然后会判断verbose为true时,会添加ProgbarLogger这个callback,除此之外,就是fit_generator函数入参callbacks传入的参数。一般都会传递ModelCheckpoint这个。

3)在git_generator这个基于生成器模式训练的过程中,每一个epoch结束(on_epoch_end)时,都要调用这个callback函数(ModelCheckpoint)进行模型数据写文件的操作

2. Keras训练时用到的几个重要回调函数(主要工作在on_batch_end里面)

回调函数是基于抽象类Callback实现的。下面是Callback的成员函数,便于理解。

   def __init__(self):
self.validation_data = None def set_params(self, params):
self.params = params def set_model(self, model):
self.model = model def on_epoch_begin(self, epoch, logs=None):
pass def on_epoch_end(self, epoch, logs=None):
pass def on_batch_begin(self, batch, logs=None):
pass def on_batch_end(self, batch, logs=None):
pass def on_train_begin(self, logs=None):
pass def on_train_end(self, logs=None):
pass

A. keras.callbacks.BaseLogger

统计该batch里面训练的loss以及acc的值,计入totals,乘以batch_size后。

def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', )
self.seen += batch_size for k, v in logs.items():
if k in self.totals:
self.totals[k] += v * batch_size
else:
self.totals[k] = v * batch_size

在BaseLogger这个类的on_epoch_end函数里,执行对这个epoch训练数据的loss以及acc求平均值。

def on_epoch_end(self, epoch, logs=None):
if logs is not None:
for k in self.params['metrics']:
if k in self.totals:
# Make value available to next callbacks.
logs[k] = self.totals[k] / self.seen

B. keras.callbacks.ModelCheckpoint

在on_epoch_end时会保存模型数据进入文件

def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epochs_since_last_save +=
if self.epochs_since_last_save >= self.period:
self.epochs_since_last_save =
filepath = self.filepath.format(epoch=epoch, **logs)
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
warnings.warn('Can save best model only with %s available, '
'skipping.' % (self.monitor), RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if self.verbose > :
print('Epoch %05d: %s improved from %0.5f to %0.5f,'
' saving model to %s'
% (epoch, self.monitor, self.best,
current, filepath))
self.best = current
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
else:
if self.verbose > :
print('Epoch %05d: %s did not improve' %
(epoch, self.monitor))
else:
if self.verbose > :
print('Epoch %05d: saving model to %s' % (epoch, filepath))
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)

C.keras.callbacks.History

主要记录每一次epoch训练的结果,结果包含loss以及acc的值

D. keras.callbacks.ProgbarLogger

这个函数里面实现训练中间状态数据信息的输出,主要涉及进度相关信息。

3. 具体训练逻辑过程

A. 训练函数分析

a. model.fit_generator 训练入口函数(参考上面的函数原型定义), 我们项目中用tk_data_generator函数作为训练数据提供者(生成器)
1) callbacks.on_train_begin()
2) while epoch < epochs:
3)         callbacks.on_epoch_begin(epoch)
4)         while steps_done < steps_per_epoch:
5)             generator_output = next(output_generator)       #生成器next函数取输入数据进行训练,每次取一个batch大小的量
6)             callbacks.on_batch_begin(batch_index, batch_logs)
7)             outs = self.train_on_batch(x, y,sample_weight=sample_weight,class_weight=class_weight)
8)             callbacks.on_batch_end(batch_index, batch_logs)
            end of while steps_done < steps_per_epoch
            self.evaluate_generator(...)          #当一个epoch的最后一次batch执行完毕,执行一次训练效果的评估
9)      callbacks.on_epoch_end(epoch, epoch_logs)          #在这个执行过程中实现模型数据的保存操作
      end of while epoch < epochs
10) callbacks.on_train_end()

b. 特别介绍下train_on_batch
   train_on_batch (keras中的trainning.py)
        |_self._standardize_user_data
        |_self._make_train_function
        |_self.train_function (tensorflow的函数)
                        |_updated = session.run(self.outputs + [self.updates_op], feed_dict=feed_dict,**self.session_kwargs)

B训练和验证的对比

a. 在每一个epoch的最后一个迭代(最后一次batch)时,要进行此轮epoch的校验(evaluate)

日志如下:

/ [==============================] - 12228s - loss: 0.5715 - acc: 0.6960 - val_loss: 0.5082 - val_acc: 0.7450

第一个141表示batch_index已经达到141,即steps_per_epoch参数规定的最后一步
第二个141表示steps_per_epoch,即一个epoch里面进行多少次batch处理
12228s 表示此batch处理结束所花费的时间
loss:此epoch里面的平均损失值
acc:此epoch里面的平均准确率
val_loss:此epoch训练完后进行的evaluate得到的损失值
val_acc:此epoch训练完后进行的evaluate得到的正确率

b. 验证逻辑,和训练逻辑差不多,只是将validation_steps指定次数的test的值进行取平均值,得到validation_steps次test的均值作为本epoch训练的最终效果

self.evaluate_generator(validation_data,validation_steps,max_queue_size=max_queue_size,workers=workers,use_multiprocessing=use_multiprocessing)

1) while steps_done < steps:
2)           generator_output = next(output_generator)
3)         outs = self.test_on_batch(x, y, sample_weight=sample_weight)
4)对上述while得到的每次outs进行 averages.append(np.average([out[i] for out in all_outs],weights=batch_sizes))

其中重点test_on_batch

test_on_batch(self, x, y, sample_weight=None)
         |_self._standardize_user_data(x, y,sample_weight=sample_weight,check_batch_axis=True)
         |_self._make_test_function()
         |_self.test_function(ins)                    
                    |_updated = session.run(self.outputs + [self.updates_op],feed_dict=feed_dict,**self.session_kwargs)

c. train和test的重要区别,应该体现在下面的两个函数上

def _make_train_function(self):
if not hasattr(self, 'train_function'):
raise RuntimeError('You must compile your model before using it.')
if self.train_function is None:
inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
inputs += [K.learning_phase()] with K.name_scope('training'):
with K.name_scope(self.optimizer.__class__.__name__):
training_updates = self.optimizer.get_updates(
params=self._collected_trainable_weights,
loss=self.total_loss)
updates = self.updates + training_updates
# Gets loss and metrics. Updates weights at each call.
self.train_function = K.function(inputs,
[self.total_loss] + self.metrics_tensors,
updates=updates,
name='train_function',
**self._function_kwargs)
def _make_test_function(self):
if not hasattr(self, 'test_function'):
raise RuntimeError('You must compile your model before using it.')
if self.test_function is None:
inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
inputs += [K.learning_phase()]
# Return loss and metrics, no gradient updates.
# Does update the network states.

self.test_function = K.function(inputs,
[self.total_loss] + self.metrics_tensors,
updates=self.state_updates,
name='test_function',
**self._function_kwargs)

经过前面的代码逻辑梳理,可以看到不管是train的过程还是test的过程,最终底层都是调用Tensorflow的session.run方法进行loss和acc的获取,细心的观察,会发现两个session.run函数的入参其实有点不同。

结合上面train和test的私有函数中标注红色的注释,以及用K.function生成函数的入参中,可以看出train和test的差异。

总结:

0. 训练过程中,每次权重的更新都是在一个batch上进行一次,是基于batch量的数据为单位进行一次权重的更新

1. 基于生成器模型训练数据,可以提升效率,降低对物理服务器性能,尤其是内存的要求

2. 训练过程中,Callback函数执行了大量的工作,包括loss、acc值的记录,以及训练中间结果的日志反馈,最重要的是模型数据的输出,也是通过callback的方式实现(ModelCheckpoint)

3. 训练(train)和验证(evaluate/validate)的逻辑近乎一样,训练要更新权重,但是验证过程,仅仅更新网络状态,不涉及权重(loss以及acc参数)信息的更新

4. 代码梳理过程中,得出结论,Keras对python编程基本功底要求还是有点高的,采用了推导式编程习惯,生成器,装饰器,回调等编程思想,另外,对矩阵运算,例如numpy.dot以及numpy.multiply的数学逻辑都有一定要求,否则比较难看懂。

Keras/Tensorflow训练逻辑研究的更多相关文章

  1. keras&tensorflow+分布式训练︱实现简易视频内容问答框架

    内容来源:Keras 之父讲解 Keras:几行代码就能在分布式环境训练模型 把 Keras API 直接整合入 TensorFlow 项目中,这样能与你的已有工作流无缝结合.至此,Keras 成为了 ...

  2. 【深度学习】keras + tensorflow 实现猫和狗图像分类

    本文主要是使用[监督学习]实现一个图像分类器,目的是识别图片是猫还是狗. 从[数据预处理]到 [图片预测]实现一个完整的流程, 当然这个分类在 Kaggle 上已经有人用[迁移学习](VGG,Resn ...

  3. 人脸检测及识别python实现系列(5)——利用keras库训练人脸识别模型

    人脸检测及识别python实现系列(5)——利用keras库训练人脸识别模型 经过前面稍显罗嗦的准备工作,现在,我们终于可以尝试训练我们自己的卷积神经网络模型了.CNN擅长图像处理,keras库的te ...

  4. Keras模型训练的断点续训、早停、效果可视化

    训练:model.fit()函数 fit(x=None, y=None, batch_size=None, epochs=, verbose=, callbacks=None, validation_ ...

  5. VGG16等keras预训练权重文件的下载及本地存放

    VGG16等keras预训练权重文件的下载: https://github.com/fchollet/deep-learning-models/releases/ .h5文件本地存放目录: Linux ...

  6. 人工智能不过尔尔,基于Python3深度学习库Keras/TensorFlow打造属于自己的聊天机器人(ChatRobot)

    原文转载自「刘悦的技术博客」https://v3u.cn/a_id_178 聊天机器人(ChatRobot)的概念我们并不陌生,也许你曾经在百无聊赖之下和Siri打情骂俏过,亦或是闲暇之余与小爱同学谈 ...

  7. [AI开发]centOS7.5上基于keras/tensorflow深度学习环境搭建

    这篇文章详细介绍在centOS7.5上搭建基于keras/tensorflow的深度学习环境,该环境可用于实际生产.本人现在非常熟练linux(Ubuntu/centOS/openSUSE).wind ...

  8. 目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练

    将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练. import xml.etree.ElementTree as ET import numpy as ...

  9. 自己搞了20万张图片100个分类,tensorflow训练23万次后。。。。。。

    自己搞了20万张图片100个分类,tensorflow训练23万次后...... 我自己把训练用的一张图片,弄乱之后做了一个预测 100个汉字,20多万张图片,tensorflow CNN训练23万次 ...

随机推荐

  1. [转]JavaScript与元素间的抛物线轨迹运动

    在张鑫旭的博客看到这个抛物线的小动画,觉得很感兴趣,转载一下方便研究~ 原文地址:http://www.zhangxinxu.com/wordpress/?p=3855 在页面上添加元素的位移动画,除 ...

  2. 2.4 CSS定位

    前言 大部分人在使用selenium定位元素时,用的是xpath定位,因为xpath基本能解决定位的需求.css定位往往被忽略掉了,其实css定位也有它的价值,css定位更快,语法更简洁.这一篇css ...

  3. 学习quartz

    https://www.w3cschool.cn/quartz_doc/quartz_doc-1xbu2clr.html

  4. fixed不能罩住下面的内容

    fix的优先级并不是最高的,所以要设置z-index,比它下面的元素高就能遮住了

  5. 20155208徐子涵 2016-2017-2 《Java程序设计》第5周学习总结

    20155208徐子涵 2016-2017-2 <Java程序设计>第5周学习总结 教材学习内容总结 第八章 异常处理 8.1 语法与继承结构 Java中所有错误都会被打包为对象,运用tr ...

  6. BZOJ 2002:Bounce 弹飞绵羊(分块)

    2002: [Hnoi2010]Bounce 弹飞绵羊 Time Limit: 10 Sec  Memory Limit: 259 MB Submit: 14944  Solved: 7598 [Su ...

  7. log4j.properties与db.properties

    log4j.properties与db.properties db.driver=com.mysql.jdbc.Driver db.url=jdbc:mysql:///mybatis?useUnico ...

  8. (0-1)CSS 标签语法的属性

    CSS text-decoration 属性 display display 属性规定元素应该生成的框的类型

  9. vue的指令绑定、事件、冒泡

    a标签的属性绑定: v-once:就是第一次渲染什么就是什么,不会随着其他改变而改变,简言之就是绑定他不让他的值改变 防止跨站脚本攻击 如果你觉得安全的话,可以不要让变量的值显示成字符串 解决方法是: ...

  10. Atom编辑神器

    最近喜欢上了Atom编辑神器,安装就不说了,重点讲配置. 一:软件配置 1.先将欢迎界面去掉,每次打开Atom的时候都会出现,实在是很烦人. 就在欢迎界面里面有个复选框,去掉选中就可以了. 2.让At ...