Keras是什么,以及相关的基础知识,这里就不做详细介绍,请参考Keras学习站点http://keras-cn.readthedocs.io/en/latest/

Tensorflow作为backend时的训练逻辑梳理,主要是结合项目,研究了下源代码!

我们的项目是智能问答机器人,基于双向RNN(准确的说是GRU)网络,这里网络结构,就不做介绍,只研究其中的训练逻辑,我们的训练是基于fit_generator,即基于生成器模型,节省内存,有助效率提升。

什么是生成器以及生成器的工作原理,这里不表,属于python的基础范畴。

1. Keras的训练,是基于batch进行的,每一个batch训练过程,进行一次loss和acc的调整

1.1 .主要核心代码

A. /home/anaconda2/lib/python2.7/site-packages/keras/legacy/interfaces.py

1)里面的装饰器函数generate_legacy_interface里面。这里涉及到fit_generator这个最为核心的入口函数的执行过程。

2)python里面装饰器工作原理,非常类似java代码里面的AOP切面编程逻辑,即在正常的业务逻辑执行前,将before或者after或者两者都执行一下。

3)训练函数原型及重要参数解释

def fit_generator(self, generator,        #生成器,一个yield的函数,迭代返回数据
steps_per_epoch, #一次训练周期(具体epoch是什么含义,要理解清楚)里面进行多少次batch
epochs=, #设置进行几次全数据集的训练,每一次全数据集训练过程被定义成一个epoch,其实这个是可以灵活应用的
verbose=, #一个开关,打开时,打印清晰的训练数据,即加载ProgbarLogger这个回调函数
callbacks=None, #设置业务需要的回调函数,我们的模型中添加了ModelCheckpoint这个回调函数
validation_data=None, #验证用的数据源设置,evaluate_generator函数要用到这个数据源,我们的项目里面,这里也是一个生成器
validation_steps=None, #设置验证多少次数据后取平均值作为此epoch训练后的效果,val_loss,val_acc的值受这个参数直接影响
class_weight=None, #此参数以及后续参数,我们的项目采用的都是默认值,可以参考官方文档了解细节
max_queue_size=,
workers=,
use_multiprocessing=False,
initial_epoch=)

B. /home/anaconda2/lib/python2.7/site-packages/keras/callbacks.py

1)这里重点有ModelCheckpoint这个回调函数,涉及到业务参数,其他回调都是keras框架默认行为。

2)callback这个类,其实是一个容器,具体表现为一个List,可以在git_generator运行时,基于该函数的入参,构建一个Callback的实例,即一个list里面装入业务需要的callback实例,这里默认会有BaseLogger以及History这个callback,然后会判断verbose为true时,会添加ProgbarLogger这个callback,除此之外,就是fit_generator函数入参callbacks传入的参数。一般都会传递ModelCheckpoint这个。

3)在git_generator这个基于生成器模式训练的过程中,每一个epoch结束(on_epoch_end)时,都要调用这个callback函数(ModelCheckpoint)进行模型数据写文件的操作

2. Keras训练时用到的几个重要回调函数(主要工作在on_batch_end里面)

回调函数是基于抽象类Callback实现的。下面是Callback的成员函数,便于理解。

   def __init__(self):
self.validation_data = None def set_params(self, params):
self.params = params def set_model(self, model):
self.model = model def on_epoch_begin(self, epoch, logs=None):
pass def on_epoch_end(self, epoch, logs=None):
pass def on_batch_begin(self, batch, logs=None):
pass def on_batch_end(self, batch, logs=None):
pass def on_train_begin(self, logs=None):
pass def on_train_end(self, logs=None):
pass

A. keras.callbacks.BaseLogger

统计该batch里面训练的loss以及acc的值,计入totals,乘以batch_size后。

def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', )
self.seen += batch_size for k, v in logs.items():
if k in self.totals:
self.totals[k] += v * batch_size
else:
self.totals[k] = v * batch_size

在BaseLogger这个类的on_epoch_end函数里,执行对这个epoch训练数据的loss以及acc求平均值。

def on_epoch_end(self, epoch, logs=None):
if logs is not None:
for k in self.params['metrics']:
if k in self.totals:
# Make value available to next callbacks.
logs[k] = self.totals[k] / self.seen

B. keras.callbacks.ModelCheckpoint

在on_epoch_end时会保存模型数据进入文件

def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epochs_since_last_save +=
if self.epochs_since_last_save >= self.period:
self.epochs_since_last_save =
filepath = self.filepath.format(epoch=epoch, **logs)
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
warnings.warn('Can save best model only with %s available, '
'skipping.' % (self.monitor), RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if self.verbose > :
print('Epoch %05d: %s improved from %0.5f to %0.5f,'
' saving model to %s'
% (epoch, self.monitor, self.best,
current, filepath))
self.best = current
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
else:
if self.verbose > :
print('Epoch %05d: %s did not improve' %
(epoch, self.monitor))
else:
if self.verbose > :
print('Epoch %05d: saving model to %s' % (epoch, filepath))
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)

C.keras.callbacks.History

主要记录每一次epoch训练的结果,结果包含loss以及acc的值

D. keras.callbacks.ProgbarLogger

这个函数里面实现训练中间状态数据信息的输出,主要涉及进度相关信息。

3. 具体训练逻辑过程

A. 训练函数分析

a. model.fit_generator 训练入口函数(参考上面的函数原型定义), 我们项目中用tk_data_generator函数作为训练数据提供者(生成器)
1) callbacks.on_train_begin()
2) while epoch < epochs:
3)         callbacks.on_epoch_begin(epoch)
4)         while steps_done < steps_per_epoch:
5)             generator_output = next(output_generator)       #生成器next函数取输入数据进行训练,每次取一个batch大小的量
6)             callbacks.on_batch_begin(batch_index, batch_logs)
7)             outs = self.train_on_batch(x, y,sample_weight=sample_weight,class_weight=class_weight)
8)             callbacks.on_batch_end(batch_index, batch_logs)
            end of while steps_done < steps_per_epoch
            self.evaluate_generator(...)          #当一个epoch的最后一次batch执行完毕,执行一次训练效果的评估
9)      callbacks.on_epoch_end(epoch, epoch_logs)          #在这个执行过程中实现模型数据的保存操作
      end of while epoch < epochs
10) callbacks.on_train_end()

b. 特别介绍下train_on_batch
   train_on_batch (keras中的trainning.py)
        |_self._standardize_user_data
        |_self._make_train_function
        |_self.train_function (tensorflow的函数)
                        |_updated = session.run(self.outputs + [self.updates_op], feed_dict=feed_dict,**self.session_kwargs)

B训练和验证的对比

a. 在每一个epoch的最后一个迭代(最后一次batch)时,要进行此轮epoch的校验(evaluate)

日志如下:

/ [==============================] - 12228s - loss: 0.5715 - acc: 0.6960 - val_loss: 0.5082 - val_acc: 0.7450

第一个141表示batch_index已经达到141,即steps_per_epoch参数规定的最后一步
第二个141表示steps_per_epoch,即一个epoch里面进行多少次batch处理
12228s 表示此batch处理结束所花费的时间
loss:此epoch里面的平均损失值
acc:此epoch里面的平均准确率
val_loss:此epoch训练完后进行的evaluate得到的损失值
val_acc:此epoch训练完后进行的evaluate得到的正确率

b. 验证逻辑,和训练逻辑差不多,只是将validation_steps指定次数的test的值进行取平均值,得到validation_steps次test的均值作为本epoch训练的最终效果

self.evaluate_generator(validation_data,validation_steps,max_queue_size=max_queue_size,workers=workers,use_multiprocessing=use_multiprocessing)

1) while steps_done < steps:
2)           generator_output = next(output_generator)
3)         outs = self.test_on_batch(x, y, sample_weight=sample_weight)
4)对上述while得到的每次outs进行 averages.append(np.average([out[i] for out in all_outs],weights=batch_sizes))

其中重点test_on_batch

test_on_batch(self, x, y, sample_weight=None)
         |_self._standardize_user_data(x, y,sample_weight=sample_weight,check_batch_axis=True)
         |_self._make_test_function()
         |_self.test_function(ins)                    
                    |_updated = session.run(self.outputs + [self.updates_op],feed_dict=feed_dict,**self.session_kwargs)

c. train和test的重要区别,应该体现在下面的两个函数上

def _make_train_function(self):
if not hasattr(self, 'train_function'):
raise RuntimeError('You must compile your model before using it.')
if self.train_function is None:
inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
inputs += [K.learning_phase()] with K.name_scope('training'):
with K.name_scope(self.optimizer.__class__.__name__):
training_updates = self.optimizer.get_updates(
params=self._collected_trainable_weights,
loss=self.total_loss)
updates = self.updates + training_updates
# Gets loss and metrics. Updates weights at each call.
self.train_function = K.function(inputs,
[self.total_loss] + self.metrics_tensors,
updates=updates,
name='train_function',
**self._function_kwargs)
def _make_test_function(self):
if not hasattr(self, 'test_function'):
raise RuntimeError('You must compile your model before using it.')
if self.test_function is None:
inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
inputs += [K.learning_phase()]
# Return loss and metrics, no gradient updates.
# Does update the network states.

self.test_function = K.function(inputs,
[self.total_loss] + self.metrics_tensors,
updates=self.state_updates,
name='test_function',
**self._function_kwargs)

经过前面的代码逻辑梳理,可以看到不管是train的过程还是test的过程,最终底层都是调用Tensorflow的session.run方法进行loss和acc的获取,细心的观察,会发现两个session.run函数的入参其实有点不同。

结合上面train和test的私有函数中标注红色的注释,以及用K.function生成函数的入参中,可以看出train和test的差异。

总结:

0. 训练过程中,每次权重的更新都是在一个batch上进行一次,是基于batch量的数据为单位进行一次权重的更新

1. 基于生成器模型训练数据,可以提升效率,降低对物理服务器性能,尤其是内存的要求

2. 训练过程中,Callback函数执行了大量的工作,包括loss、acc值的记录,以及训练中间结果的日志反馈,最重要的是模型数据的输出,也是通过callback的方式实现(ModelCheckpoint)

3. 训练(train)和验证(evaluate/validate)的逻辑近乎一样,训练要更新权重,但是验证过程,仅仅更新网络状态,不涉及权重(loss以及acc参数)信息的更新

4. 代码梳理过程中,得出结论,Keras对python编程基本功底要求还是有点高的,采用了推导式编程习惯,生成器,装饰器,回调等编程思想,另外,对矩阵运算,例如numpy.dot以及numpy.multiply的数学逻辑都有一定要求,否则比较难看懂。

Keras/Tensorflow训练逻辑研究的更多相关文章

  1. keras&tensorflow+分布式训练︱实现简易视频内容问答框架

    内容来源:Keras 之父讲解 Keras:几行代码就能在分布式环境训练模型 把 Keras API 直接整合入 TensorFlow 项目中,这样能与你的已有工作流无缝结合.至此,Keras 成为了 ...

  2. 【深度学习】keras + tensorflow 实现猫和狗图像分类

    本文主要是使用[监督学习]实现一个图像分类器,目的是识别图片是猫还是狗. 从[数据预处理]到 [图片预测]实现一个完整的流程, 当然这个分类在 Kaggle 上已经有人用[迁移学习](VGG,Resn ...

  3. 人脸检测及识别python实现系列(5)——利用keras库训练人脸识别模型

    人脸检测及识别python实现系列(5)——利用keras库训练人脸识别模型 经过前面稍显罗嗦的准备工作,现在,我们终于可以尝试训练我们自己的卷积神经网络模型了.CNN擅长图像处理,keras库的te ...

  4. Keras模型训练的断点续训、早停、效果可视化

    训练:model.fit()函数 fit(x=None, y=None, batch_size=None, epochs=, verbose=, callbacks=None, validation_ ...

  5. VGG16等keras预训练权重文件的下载及本地存放

    VGG16等keras预训练权重文件的下载: https://github.com/fchollet/deep-learning-models/releases/ .h5文件本地存放目录: Linux ...

  6. 人工智能不过尔尔,基于Python3深度学习库Keras/TensorFlow打造属于自己的聊天机器人(ChatRobot)

    原文转载自「刘悦的技术博客」https://v3u.cn/a_id_178 聊天机器人(ChatRobot)的概念我们并不陌生,也许你曾经在百无聊赖之下和Siri打情骂俏过,亦或是闲暇之余与小爱同学谈 ...

  7. [AI开发]centOS7.5上基于keras/tensorflow深度学习环境搭建

    这篇文章详细介绍在centOS7.5上搭建基于keras/tensorflow的深度学习环境,该环境可用于实际生产.本人现在非常熟练linux(Ubuntu/centOS/openSUSE).wind ...

  8. 目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练

    将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练. import xml.etree.ElementTree as ET import numpy as ...

  9. 自己搞了20万张图片100个分类,tensorflow训练23万次后。。。。。。

    自己搞了20万张图片100个分类,tensorflow训练23万次后...... 我自己把训练用的一张图片,弄乱之后做了一个预测 100个汉字,20多万张图片,tensorflow CNN训练23万次 ...

随机推荐

  1. js编写轮播图,广告弹框

    1.轮播图 js编写轮播图,需要用到setInterval(计时器):先给一个div,里面放轮播图的图片,将轮播图的图片明明为相同样式的:如:banner1.jpg,banner2.jpg,banne ...

  2. Python socketserver ftp功能简单讲解

    socketserver模块实现并发 为什么要讲socketserver?我们之前写的tcp协议的socket是不是一次只能和一个客户端通信,如果用socketserver可以实现和多个客户端通信.它 ...

  3. HDU 6143 17多校8 Killer Names(组合数学)

    题目传送:Killer Names Problem Description > Galen Marek, codenamed Starkiller, was a male Human appre ...

  4. HDU 1251 统计难题(字典树 裸题 链表做法)

    Problem Description Ignatius最近遇到一个难题,老师交给他很多单词(只有小写字母组成,不会有重复的单词出现),现在老师要他统计出以某个字符串为前缀的单词数量(单词本身也是自己 ...

  5. 1.selenium实战之从txt文档读取配置信息并执行登录

    前置条件: 1.本机已搭建ECShop3.0网站 2.在脚本目录创建了user.txt文本如下: 目的:实现从txt中读取配置文件信息,本实战中,包含url地址.用户名.密码,然后进行ESChop的登 ...

  6. H5之localStorage,sessionStorage

    在以前的时候也听说过一些h5缓存技术,具体也没有去使用过,就在前两三个礼拜我用了localStorage和sessionStorage这两个存储方式, 我使用这些存储技术,也是想减少访问服务器的请求, ...

  7. Linux sort命令详解

    linux之sort用法 sort命令是帮我们依据不同的数据类型进行排序,其语法及常用参数格式: sort [-bcfMnrtk][源文件][-o 输出文件] 补充说明:sort可针对文本文件的内容, ...

  8. 实验吧—Web——WP之 简单的sql注入之2

    直接打开解题连接: 既然是SQL注入,那么我们就要构造注入语句了,这个就要有耐心一个一个去尝试了 输入语句 1'and 1=1 # 和 1'and/**/1=1/**/#后 对比一下,发现是过滤掉了空 ...

  9. 快速挂载iso文件到虚拟机系统

    在vm软件菜单栏那里选择vm,再选择弹出菜单最下面的设置,如图,找到实体机上的iso文件,保存. 这时候,在虚拟机ls /dev会发现有一个cdrom,这个就是我们的iso文件,不过我们还需要把它挂载 ...

  10. 文件访问控制列表facl

    [root@bogon code]# getfacl a.c //获取文件a.c的文件访问控制列表 # file: a.c # owner: root # group: root user::rw- ...