Keras是什么,以及相关的基础知识,这里就不做详细介绍,请参考Keras学习站点http://keras-cn.readthedocs.io/en/latest/

Tensorflow作为backend时的训练逻辑梳理,主要是结合项目,研究了下源代码!

我们的项目是智能问答机器人,基于双向RNN(准确的说是GRU)网络,这里网络结构,就不做介绍,只研究其中的训练逻辑,我们的训练是基于fit_generator,即基于生成器模型,节省内存,有助效率提升。

什么是生成器以及生成器的工作原理,这里不表,属于python的基础范畴。

1. Keras的训练,是基于batch进行的,每一个batch训练过程,进行一次loss和acc的调整

1.1 .主要核心代码

A. /home/anaconda2/lib/python2.7/site-packages/keras/legacy/interfaces.py

1)里面的装饰器函数generate_legacy_interface里面。这里涉及到fit_generator这个最为核心的入口函数的执行过程。

2)python里面装饰器工作原理,非常类似java代码里面的AOP切面编程逻辑,即在正常的业务逻辑执行前,将before或者after或者两者都执行一下。

3)训练函数原型及重要参数解释

def fit_generator(self, generator,        #生成器,一个yield的函数,迭代返回数据
steps_per_epoch, #一次训练周期(具体epoch是什么含义,要理解清楚)里面进行多少次batch
epochs=, #设置进行几次全数据集的训练,每一次全数据集训练过程被定义成一个epoch,其实这个是可以灵活应用的
verbose=, #一个开关,打开时,打印清晰的训练数据,即加载ProgbarLogger这个回调函数
callbacks=None, #设置业务需要的回调函数,我们的模型中添加了ModelCheckpoint这个回调函数
validation_data=None, #验证用的数据源设置,evaluate_generator函数要用到这个数据源,我们的项目里面,这里也是一个生成器
validation_steps=None, #设置验证多少次数据后取平均值作为此epoch训练后的效果,val_loss,val_acc的值受这个参数直接影响
class_weight=None, #此参数以及后续参数,我们的项目采用的都是默认值,可以参考官方文档了解细节
max_queue_size=,
workers=,
use_multiprocessing=False,
initial_epoch=)

B. /home/anaconda2/lib/python2.7/site-packages/keras/callbacks.py

1)这里重点有ModelCheckpoint这个回调函数,涉及到业务参数,其他回调都是keras框架默认行为。

2)callback这个类,其实是一个容器,具体表现为一个List,可以在git_generator运行时,基于该函数的入参,构建一个Callback的实例,即一个list里面装入业务需要的callback实例,这里默认会有BaseLogger以及History这个callback,然后会判断verbose为true时,会添加ProgbarLogger这个callback,除此之外,就是fit_generator函数入参callbacks传入的参数。一般都会传递ModelCheckpoint这个。

3)在git_generator这个基于生成器模式训练的过程中,每一个epoch结束(on_epoch_end)时,都要调用这个callback函数(ModelCheckpoint)进行模型数据写文件的操作

2. Keras训练时用到的几个重要回调函数(主要工作在on_batch_end里面)

回调函数是基于抽象类Callback实现的。下面是Callback的成员函数,便于理解。

   def __init__(self):
self.validation_data = None def set_params(self, params):
self.params = params def set_model(self, model):
self.model = model def on_epoch_begin(self, epoch, logs=None):
pass def on_epoch_end(self, epoch, logs=None):
pass def on_batch_begin(self, batch, logs=None):
pass def on_batch_end(self, batch, logs=None):
pass def on_train_begin(self, logs=None):
pass def on_train_end(self, logs=None):
pass

A. keras.callbacks.BaseLogger

统计该batch里面训练的loss以及acc的值,计入totals,乘以batch_size后。

def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', )
self.seen += batch_size for k, v in logs.items():
if k in self.totals:
self.totals[k] += v * batch_size
else:
self.totals[k] = v * batch_size

在BaseLogger这个类的on_epoch_end函数里,执行对这个epoch训练数据的loss以及acc求平均值。

def on_epoch_end(self, epoch, logs=None):
if logs is not None:
for k in self.params['metrics']:
if k in self.totals:
# Make value available to next callbacks.
logs[k] = self.totals[k] / self.seen

B. keras.callbacks.ModelCheckpoint

在on_epoch_end时会保存模型数据进入文件

def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epochs_since_last_save +=
if self.epochs_since_last_save >= self.period:
self.epochs_since_last_save =
filepath = self.filepath.format(epoch=epoch, **logs)
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
warnings.warn('Can save best model only with %s available, '
'skipping.' % (self.monitor), RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if self.verbose > :
print('Epoch %05d: %s improved from %0.5f to %0.5f,'
' saving model to %s'
% (epoch, self.monitor, self.best,
current, filepath))
self.best = current
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
else:
if self.verbose > :
print('Epoch %05d: %s did not improve' %
(epoch, self.monitor))
else:
if self.verbose > :
print('Epoch %05d: saving model to %s' % (epoch, filepath))
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)

C.keras.callbacks.History

主要记录每一次epoch训练的结果,结果包含loss以及acc的值

D. keras.callbacks.ProgbarLogger

这个函数里面实现训练中间状态数据信息的输出,主要涉及进度相关信息。

3. 具体训练逻辑过程

A. 训练函数分析

a. model.fit_generator 训练入口函数(参考上面的函数原型定义), 我们项目中用tk_data_generator函数作为训练数据提供者(生成器)
1) callbacks.on_train_begin()
2) while epoch < epochs:
3)         callbacks.on_epoch_begin(epoch)
4)         while steps_done < steps_per_epoch:
5)             generator_output = next(output_generator)       #生成器next函数取输入数据进行训练,每次取一个batch大小的量
6)             callbacks.on_batch_begin(batch_index, batch_logs)
7)             outs = self.train_on_batch(x, y,sample_weight=sample_weight,class_weight=class_weight)
8)             callbacks.on_batch_end(batch_index, batch_logs)
            end of while steps_done < steps_per_epoch
            self.evaluate_generator(...)          #当一个epoch的最后一次batch执行完毕,执行一次训练效果的评估
9)      callbacks.on_epoch_end(epoch, epoch_logs)          #在这个执行过程中实现模型数据的保存操作
      end of while epoch < epochs
10) callbacks.on_train_end()

b. 特别介绍下train_on_batch
   train_on_batch (keras中的trainning.py)
        |_self._standardize_user_data
        |_self._make_train_function
        |_self.train_function (tensorflow的函数)
                        |_updated = session.run(self.outputs + [self.updates_op], feed_dict=feed_dict,**self.session_kwargs)

B训练和验证的对比

a. 在每一个epoch的最后一个迭代(最后一次batch)时,要进行此轮epoch的校验(evaluate)

日志如下:

/ [==============================] - 12228s - loss: 0.5715 - acc: 0.6960 - val_loss: 0.5082 - val_acc: 0.7450

第一个141表示batch_index已经达到141,即steps_per_epoch参数规定的最后一步
第二个141表示steps_per_epoch,即一个epoch里面进行多少次batch处理
12228s 表示此batch处理结束所花费的时间
loss:此epoch里面的平均损失值
acc:此epoch里面的平均准确率
val_loss:此epoch训练完后进行的evaluate得到的损失值
val_acc:此epoch训练完后进行的evaluate得到的正确率

b. 验证逻辑,和训练逻辑差不多,只是将validation_steps指定次数的test的值进行取平均值,得到validation_steps次test的均值作为本epoch训练的最终效果

self.evaluate_generator(validation_data,validation_steps,max_queue_size=max_queue_size,workers=workers,use_multiprocessing=use_multiprocessing)

1) while steps_done < steps:
2)           generator_output = next(output_generator)
3)         outs = self.test_on_batch(x, y, sample_weight=sample_weight)
4)对上述while得到的每次outs进行 averages.append(np.average([out[i] for out in all_outs],weights=batch_sizes))

其中重点test_on_batch

test_on_batch(self, x, y, sample_weight=None)
         |_self._standardize_user_data(x, y,sample_weight=sample_weight,check_batch_axis=True)
         |_self._make_test_function()
         |_self.test_function(ins)                    
                    |_updated = session.run(self.outputs + [self.updates_op],feed_dict=feed_dict,**self.session_kwargs)

c. train和test的重要区别,应该体现在下面的两个函数上

def _make_train_function(self):
if not hasattr(self, 'train_function'):
raise RuntimeError('You must compile your model before using it.')
if self.train_function is None:
inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
inputs += [K.learning_phase()] with K.name_scope('training'):
with K.name_scope(self.optimizer.__class__.__name__):
training_updates = self.optimizer.get_updates(
params=self._collected_trainable_weights,
loss=self.total_loss)
updates = self.updates + training_updates
# Gets loss and metrics. Updates weights at each call.
self.train_function = K.function(inputs,
[self.total_loss] + self.metrics_tensors,
updates=updates,
name='train_function',
**self._function_kwargs)
def _make_test_function(self):
if not hasattr(self, 'test_function'):
raise RuntimeError('You must compile your model before using it.')
if self.test_function is None:
inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
inputs += [K.learning_phase()]
# Return loss and metrics, no gradient updates.
# Does update the network states.

self.test_function = K.function(inputs,
[self.total_loss] + self.metrics_tensors,
updates=self.state_updates,
name='test_function',
**self._function_kwargs)

经过前面的代码逻辑梳理,可以看到不管是train的过程还是test的过程,最终底层都是调用Tensorflow的session.run方法进行loss和acc的获取,细心的观察,会发现两个session.run函数的入参其实有点不同。

结合上面train和test的私有函数中标注红色的注释,以及用K.function生成函数的入参中,可以看出train和test的差异。

总结:

0. 训练过程中,每次权重的更新都是在一个batch上进行一次,是基于batch量的数据为单位进行一次权重的更新

1. 基于生成器模型训练数据,可以提升效率,降低对物理服务器性能,尤其是内存的要求

2. 训练过程中,Callback函数执行了大量的工作,包括loss、acc值的记录,以及训练中间结果的日志反馈,最重要的是模型数据的输出,也是通过callback的方式实现(ModelCheckpoint)

3. 训练(train)和验证(evaluate/validate)的逻辑近乎一样,训练要更新权重,但是验证过程,仅仅更新网络状态,不涉及权重(loss以及acc参数)信息的更新

4. 代码梳理过程中,得出结论,Keras对python编程基本功底要求还是有点高的,采用了推导式编程习惯,生成器,装饰器,回调等编程思想,另外,对矩阵运算,例如numpy.dot以及numpy.multiply的数学逻辑都有一定要求,否则比较难看懂。

Keras/Tensorflow训练逻辑研究的更多相关文章

  1. keras&tensorflow+分布式训练︱实现简易视频内容问答框架

    内容来源:Keras 之父讲解 Keras:几行代码就能在分布式环境训练模型 把 Keras API 直接整合入 TensorFlow 项目中,这样能与你的已有工作流无缝结合.至此,Keras 成为了 ...

  2. 【深度学习】keras + tensorflow 实现猫和狗图像分类

    本文主要是使用[监督学习]实现一个图像分类器,目的是识别图片是猫还是狗. 从[数据预处理]到 [图片预测]实现一个完整的流程, 当然这个分类在 Kaggle 上已经有人用[迁移学习](VGG,Resn ...

  3. 人脸检测及识别python实现系列(5)——利用keras库训练人脸识别模型

    人脸检测及识别python实现系列(5)——利用keras库训练人脸识别模型 经过前面稍显罗嗦的准备工作,现在,我们终于可以尝试训练我们自己的卷积神经网络模型了.CNN擅长图像处理,keras库的te ...

  4. Keras模型训练的断点续训、早停、效果可视化

    训练:model.fit()函数 fit(x=None, y=None, batch_size=None, epochs=, verbose=, callbacks=None, validation_ ...

  5. VGG16等keras预训练权重文件的下载及本地存放

    VGG16等keras预训练权重文件的下载: https://github.com/fchollet/deep-learning-models/releases/ .h5文件本地存放目录: Linux ...

  6. 人工智能不过尔尔,基于Python3深度学习库Keras/TensorFlow打造属于自己的聊天机器人(ChatRobot)

    原文转载自「刘悦的技术博客」https://v3u.cn/a_id_178 聊天机器人(ChatRobot)的概念我们并不陌生,也许你曾经在百无聊赖之下和Siri打情骂俏过,亦或是闲暇之余与小爱同学谈 ...

  7. [AI开发]centOS7.5上基于keras/tensorflow深度学习环境搭建

    这篇文章详细介绍在centOS7.5上搭建基于keras/tensorflow的深度学习环境,该环境可用于实际生产.本人现在非常熟练linux(Ubuntu/centOS/openSUSE).wind ...

  8. 目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练

    将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练. import xml.etree.ElementTree as ET import numpy as ...

  9. 自己搞了20万张图片100个分类,tensorflow训练23万次后。。。。。。

    自己搞了20万张图片100个分类,tensorflow训练23万次后...... 我自己把训练用的一张图片,弄乱之后做了一个预测 100个汉字,20多万张图片,tensorflow CNN训练23万次 ...

随机推荐

  1. wx小程序获取用户位置信息

    wx小程序内置的接口只能获取用户的坐标,通过微信位置服务获取用户地址: http://lbs.qq.com/qqmap_wx_jssdk/index.html 注:需要在key的设置中打开webSer ...

  2. JavaScript条件语句-5--if语句的嵌套

    JavaScript条件语句 学习目标 1.掌握length属性的应用 2.掌握if语句的嵌套 length 语法:string.length 功能:获取string字符串的长度 返回值:number ...

  3. C高级第三次作业

    C高级第三次作业(1) 6-1 输出月份英文名 1.设计思路 (1)算法: 第一步:定义整型变量n,字符指针s,输入一个数赋给n. 第二步:调用函数getmonth将值赋给s. 第三步:在函数getm ...

  4. pycharm远程服务器进行调试

    背景是这样的:我有一台远程的服务器,以及一台本地的电脑:现在我想用远程的服务器上的python编译器来运行代码,怎么办?通用的做法是ssh服务器,vim代码,之后python运行文件,但是如果遇到调试 ...

  5. 由testcase数据之分析

    一.获取data来源 1.利用openpyxl从excel表格获取数据,相较于xlrd,openpyxl可以将表格里的样式也传递过来的优势 xlrd  -----------------     ht ...

  6. 20165313 预备作业3 Linux安装及学习

    虚拟机安装 刚开始我觉得既然有了教程,安装虚拟机应该是很简单的事情,然而由于电脑本身系统地地问题,导致我数次安装失败,后来咨询了老师并查阅了资料,最终才安装好. 其中最主要的问题就是电脑虚拟化的修改. ...

  7. Spring Cloud 微服务实战

    Eureka 服务治理 Maven dependency 与spring boot的版本的对应 Finchley兼容Spring Boot 2.0.x,不兼容Spring Boot 1.5.x Dal ...

  8. 学习笔记TF013:卷积、跨度、边界填充、卷积核

    卷积运算,两个输入张量(输入数据和卷积核)进行卷积,输出代表来自每个输入的信息张量.tf.nn.conv2d完成卷积运算.卷积核(kernel),权值.滤波器.卷积矩阵或模版,filter.权值训练习 ...

  9. Eclipse maven 错误修正方法:An error occurred while filtering resources

    最近打开Eclipse后发现项目报红叉,解决办法如下: 1.eclipse中删除该项目(注意:不要删除代码) 2.cmd,进入到项目目录下,执行命令:mvn eclipse:clean 3.重新导入项 ...

  10. python简单实现目录对比

    [root@localhost python]# cat dircmptest.py #!/usr/bin/python import filecmp path1="/root/python ...