前言

是的,除了水报错文,我也来写点其他的。本文主要介绍Keras中以下三个函数的用法:

fit()
fit_generator()
train_on_batch()
当然,与上述三个函数相似的evaluate、predict、test_on_batch、predict_on_batch、evaluate_generator和predict_generator等就不详细说了,举一反三嘛。

环境

本文的代码是在以下环境下进行测试的:

Windows 10
Python 3.6
TensorFlow 2.0 Alpha
异同

大家用Keras也就图个简单快捷,但是在享受简单快捷的时候,也常常需要些定制化需求,除了model.fit(),有时候model.fit_generator()和model.train_on_batch()也很重要。那么,这三个函数有什么异同呢?Adrian Rosebrock [1] 有如下总结:

当你使用.fit()函数时,意味着如下两个假设:

训练数据可以 完整地 放入到内存(RAM)里
数据已经不需要再进行任何处理了
这两个原因解释的非常好,之前我运行程序的时候,由于数据集太大(实际中的数据集显然不会都像 TensorFlow 官方教程里经常使用的 MNIST 数据集那样小),一次性加载训练数据到fit()函数里根本行不通:

history = model.fit(train_data, train_label) // Bomb!!!
1
于是我想,能不能先加载一个batch训练,然后再加载一个batch,如此往复。于是我就注意到了fit_generator()函数。什么时候该使用fit_generator函数呢?Adrian Rosebrock 的总结道:

内存不足以一次性加载整个训练数据的时候
需要一些数据预处理(例如旋转和平移图片、增加噪音、扩大数据集等操作)
在生成batch的时候需要更多的处理
对于我自己来说,除了数据集太大的缘故之外,我需要在生成batch的时候,对输入数据进行padding,所以fit_generator()就派上了用场。下面介绍如何使用这三种函数。

fit()函数

fit()函数其实没什么好说的,大家在看TensorFlow教程的时候已经见识过了。此外插一句话,tf.data.Dataset对不规则的序列数据真是不友好。

import tensorflow as tf

model = tf.keras.models.Sequential([
... // 你的模型
])

model.fit(train_x, // 训练输入
train_y, // 训练标签
epochs=5 // 训练5轮
)
1
2
3
4
5
6
7
8
9
10
fit_generator()函数

fit_generator()函数就比较重要了,也是本文讨论的重点。fit_generator()与fit()的主要区别就在一个generator上。之前,我们把整个训练数据都输入到fit()里,我们也不需要考虑batch的细节;现在,我们使用一个generator,每次生成一个batch送给fit_generator()训练。

def generator(x, y, b_size):
... // 处理函数

model.fit_generator(generator(train_x, train_y, batch_size),
step_per_epochs=np.ceil(len(train_x)/batch_size),
epochs=5
)
1
2
3
4
5
6
7
从上述代码中,我们发现有两处不同:

一个我们自定义的generator()函数,作为fit_generator()函数的第一个参数;
fit_generator()函数的step_per_epochs参数
自定义的generator()函数

该函数即是我们数据的生成器,在训练的时候,fit_generator()函数会不断地执行generator()函数,获取一个个的batch。

def generator(x, y, b_size):
"""Generates batch and batch and batch then feed into models.
Args:
x: input data;
y: input labels;
b_size: batch_size.
Yield:
(batch_x, batch_label): batched x and y.
"""
while 1: // 死循环
idx = ...
batch_x = ...
batch_y = ...
... // 任何你想要对这个`batch`中的数据执行的操作
yield (batch_x, batch_y)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
需要注意的是,不要使用return或者exit。

step_per_epochs参数

由于generator()函数的循环没有终止条件,fit_generator也不知道一个epoch什么时候结束,所以我们需要手动指定step_per_epochs参数,一般的数值即为len(y)//batch_size。如果数据集大小不能整除batch_size,而且你打算使用最后一个batch的数据(该batch比batch_size要小),此时使用np.ceil(len(y)/batch_size)。

keras.utils.Sequence类(2019年6月10日更新)

除了写generator()函数,我们还可以利用keras.utils.Sequence类来生成batch。先扔代码:

class Generator(keras.utils.Sequence):
def __init__(self, x, y, b_size):
self.x, self.y = x, y
self.batch_size = b_size

def __len__(self):
return math.ceil(len(self.y)/self.batch_size

def __getitem__(self, idx):
b_x = self.x[idx*self.batch_size:(idx+1)*self.batch_size]
b_y = self.y[idx*self.batch_size:(idx+1)*self.batch_size]
... // 对`batch`的其余操作
return np.array(b_x), np.array(b_y)

def on_epoch_end(self):
"""执行完一个`epoch`之后,还可以做一些其他的事情!"""
...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
我们首先定义__init__函数,读取训练集数据,然后定义__len__函数,返回一个epoch中需要执行的step数(此时在fit_generator()函数中就不需要指定steps_per_epoch参数了),最后定义__getitem__函数,返回一个batch的数据。代码如下:

train_generator = Generator(train_x, train_y, batch_size)
val_generator = Generator(val_x, val_y, batch_size)

model.fit_generator(generator=train_generator,
epochs=3197747,
validation_data=val_generator
)
1
2
3
4
5
6
7
根据官方 [2] 的说法,使用Sequence类可以保证在多进程的情况下,每个epoch中的样本只会被训练一次。总之,使用keras.utils.Sequence也是很方便的啦!

train_on_batch()函数

train_on_batch()函数接受一个batch的输入和标签,然后开始反向传播,更新参数等。大部分情况下你都不需要用到train_on_batch()函数,除非你有着充足的理由去定制化你的模型的训练流程。

结语

本文到此结束啦!也不知道讲清楚没有,如果有疑问或者有错误,还请读者不吝赐教啦!

Reference

A. Rosebrock. (December 24, 2018). How to use Keras fit and fit_generator (a hands-on tutorial). Retrieved from https://www.pyimagesearch.com/2018/12/24/how-to-use-keras-fit-and-fit_generator-a-hands-on-tutorial/
tf.keras.utils.Sequence. (July 10, 2019). Retrieved from https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/utils/Sequence
---------------------

[TensorFlow 2] [Keras] fit()、fit_generator() 和 train_on_batch() 分析与应用的更多相关文章

  1. Anaconda安装tensorflow和keras(gpu版,超详细)

    本人配置:window10+GTX 1650+tensorflow-gpu 1.14+keras-gpu 2.2.5+python 3.6,亲测可行 一.Anaconda安装 直接到清华镜像网站下载( ...

  2. 『计算机视觉』Mask-RCNN_推断网络其二:基于ReNet101的FPN共享网络暨TensorFlow和Keras交互简介

    零.参考资料 有关FPN的介绍见『计算机视觉』FPN特征金字塔网络. 网络构架部分代码见Mask_RCNN/mrcnn/model.py中class MaskRCNN的build方法的"in ...

  3. 深度学习基础系列(五)| 深入理解交叉熵函数及其在tensorflow和keras中的实现

    在统计学中,损失函数是一种衡量损失和错误(这种损失与“错误地”估计有关,如费用或者设备的损失)程度的函数.假设某样本的实际输出为a,而预计的输出为y,则y与a之间存在偏差,深度学习的目的即是通过不断地 ...

  4. windows安装TensorFlow和Keras遇到的问题及其解决方法

    安装TensorFlow在Windows上,真是让我心力交瘁,想死的心都有了,在Windows上做开发真的让人发狂. 首先说一下我的经历,本来也就是起初,网上说python3.7不支持TensorFl ...

  5. Ubuntu18.04 安装TensorFlow 和 Keras

    TensorFlow和Keras是当前两款主流的深度学习框架,Keras被采纳为TensorFlow的高级API,平时做深度学习任务,可以使用Keras作为深度学习框架,并用TensorFlow作为后 ...

  6. Anaconda 安装 tensorflow 和 keras

    说明:此操作是在 Anaconda Prompt 窗口完成的 CPU版 tensorflow 的安装. 1.用 conda 创建虚拟环境 tensorflow python=3.6 conda cre ...

  7. 成功解决 AttributeError: module 'tensorflow.python.keras.backend' has no attribute 'get_graph'

    在导入keras包时出现这个问题,是因为安装的tensorflow版本和keras版本不匹配,只需卸载keras,重新安装自己tensorflow对应的版本就OK了.可以在这个网址查看tensorfl ...

  8. tensorflow和keras的安装

    1 卸载tensorflow方法,在终端输入:  把protobuf删除了才能卸载干净. sudo pip uninstall protobuf sudo pip uninstall tensorfl ...

  9. win10+anaconda安装tensorflow和keras遇到的坑小结

    win10下利用anaconda安装tensorflow和keras的教程都大同小异(针对CPU版本,我的gpu是1050TI的MAX-Q,不知为啥一直没安装成功),下面简单说下步骤. 一 Anaco ...

随机推荐

  1. python 微服务方案

    介绍 使用python做web开发面临的一个最大的问题就是性能,在解决C10K问题上显的有点吃力.有些异步框架Tornado.Twisted.Gevent 等就是为了解决性能问题.这些框架在性能上有些 ...

  2. poj1182食物链(三类并查集)

    动物王国中有三类动物A,B,C,这三类动物的食物链构成了有趣的环形.A吃B, B吃C,C吃A. 现有N个动物,以1-N编号.每个动物都是A,B,C中的一种,但是我们并不知道它到底是哪一种. 有人用两种 ...

  3. .Net core 2.0 利用Attribute获取MVC Action来生成菜单

    最近在学习.net core的同时将老师的MVC5项目中的模块搬过来用,其中有一块就是利用Attribute来生成菜单. 一·首先定义Action实体 /// <summary> /// ...

  4. 洛谷 P1339 [USACO09OCT]热浪Heat Wave(dijkstra)

    题目链接 https://www.luogu.org/problemnew/show/P1339 最短路 解题思路 dijkstra直接过 注意: 双向边 memset ma数组要在读入之前 AC代码 ...

  5. 数据导出 写入到excle文件

    import com.google.common.collect.Lists; import com.google.common.collect.Maps; import org.apache.poi ...

  6. Sunday 字符串匹配算法(C++实现)

    简介: Sunday算法是Daniel M.Sunday于1990年提出的一种字符串模式匹配算法.其核心思想是:在匹配过程中,模式串并不被要求一定要按从左向右进行比较还是从右向左进行比较,它在发现不匹 ...

  7. gulp run 报错 gulp[3192]: src\node_contextify.cc:628: Assertion `args[1]->IsString()' failed.

    由于把node升级到了10以上的版本 执行gulp rjs打包文件报错,错误如下: gulp[3192]: src\node_contextify.cc:628: Assertion `args[1] ...

  8. JVM(18)之 Class文件

    开发十年,就只剩下这套架构体系了! >>>   关于类加载机制的相关知识在前面的博文中暂时先讲那么多.中间留下了很多问题,从本篇博文开始,我们来一一解决.    从我们最陌生而又最熟 ...

  9. ELK + Filebeat日志分析系统安装

    之前搭建过elk,用于分析日志,无奈服务器资源不足,开了多个Logstash之后发现占用内存过高,于是现在改为Filebeat做日志收集,记录一下搭建过程和遇到问题的解决方案. 第一步 , 安装jdk ...

  10. JS小案例--全选和全不选列表功能的实现

    纯js代码实现列表全选和全不选的功能 <!DOCTYPE html> <html> <head lang="en"> <meta char ...