虽然已经走在 torch boy 的路上了, 还是把碰到的这个坑给记录一下

  • 数据量较小时,我们可直接把整个数据集 load 到内存里,用 model.fit() 来拟合模型。
  • 当数据集过大比如几十个 G 时,内存撑不下,需要用 model.fit_generator 的方式来拟合。

model.fit_generator 一般参数的配置参考官方文档就好,其中 generator, workers, use_multiprocessing 的使用有一些坑存在。

workers=0, use_multiprocessing=False

此时 generator 用一个普通的 generator去提供数据即可,类似官方提供的这种

def generate_arrays_from_file(path):
while True:
with open(path) as f:
for line in f:
# create numpy arrays of input data
# and labels, from each line in the file
x1, x2, y = process_line(line)
yield ({'input_1': x1, 'input_2': x2}, {'output': y}) model.fit_generator(generate_arrays_from_file('/my_file.txt'),
steps_per_epoch=10000, epochs=10)

workers>0, use_multiprocessing=True

这时依然用一个 generator function 来做 generator在拟合的时候便会报错如下:

PicklingError: Can't pickle <function generator_queue.<locals>.data_generator_task at

且当 use_multiprocessing=True 时,如果你使用的是 generator function, 代码会把你的数据copy几份分给不同的worker去处理,但我们希望的是把一份数据平均分拆成几份给多个worker去处理。

怎么解决上面两个问题? keras.utils.Sequence 可以做到

很简单,继承 keras.utils.Sequence 这个类,重写自己的 len(), getitem 即可。

class SequenceData(Sequence):
def __init__(self, filePaths, batch_size):
self.filePaths = filePaths[:100].copy()
self.batch_size = batch_size
self.Y = self.getY() def __len__(self):
return len(self.Y) // self.batch_size def __getitem__(self, index):
batch_X = np.zeros((self.batch_size,) + IMG_DIMS, dtype='float32')
batch_Y_ = self.Y[index*self.batch_size: (index+1)*self.batch_size].copy()
batch_Y_.reset_index(drop=True, inplace=True)
assert batch_Y_.shape[0] == self.batch_size for index, rows in batch_Y_.iterrows():
try:
img = _load_img(rows['path'])
batch_X[index, :, :, :] = img.copy()
batch_Y_.loc[index, 'valid'] = 1
except:
batch_Y_.loc[index, 'valid'] = 0
traceback.print_exc()
batch_Y = to_categorical(batch_Y_['label'], classes_num)
return batch_X, batch_Y def __iter__(self):
for item in (self[i] for i in range(len(self))):
yield item def getY(self):
Y = pd.DataFrame(self.filePaths, columns=['path'])
Y['class'] = Y['path'].apply(lambda x: path2class(x))
Y['label'] = Y['class'].apply(lambda x: class2label[x])
Y = Y.sample(frac=1).reset_index(drop=True)
return Y

效果比较

  • 样本量:1000张图片
  • 模型: MobileNetV2
  • epochs: 5
  • CPU: 4核,3.4GHz
  • GPU: None

可能数据量过小,并行的效果不是太明显。

数据读取方式 workers use_multiprocessing 耗时/s
内存读取 0 True 1797
keras.utils.Sequence 0 False 1475
keras.utils.Sequence 4 True

参考:

keras fit_generator 并行的更多相关文章

  1. keras 入门整理 如何shuffle,如何使用fit_generator 整理合集

    keras入门参考网址: 中文文档教你快速建立model keras不同的模块-基本结构的简介-类似xmind整理 Keras的基本使用(1)--创建,编译,训练模型 Keras学习笔记(完结) ke ...

  2. (转)The AlphaGo Replication Wiki

    The AlphaGo Replication Wiki 摘自:https://github.com/Rochester-NRT/RocAlphaGo/wiki/01.-Home Contents : ...

  3. 『计算机视觉』Mask-RCNN_训练网络其三:训练Model

    Github地址:Mask_RCNN 『计算机视觉』Mask-RCNN_论文学习 『计算机视觉』Mask-RCNN_项目文档翻译 『计算机视觉』Mask-RCNN_推断网络其一:总览 『计算机视觉』M ...

  4. [Tensorflow] 使用 Mask_RCNN 完成目标检测与实例分割,同时输出每个区域的 Feature Map

    Mask_RCNN-2.0 网页链接:https://github.com/matterport/Mask_RCNN/releases/tag/v2.0 Mask_RCNN-master(matter ...

  5. keras系列︱利用fit_generator最小化显存占用比率/数据Batch化

    本文主要参考两篇文献: 1.<深度学习theano/tensorflow多显卡多人使用问题集> 2.基于双向LSTM和迁移学习的seq2seq核心实体识别 运行机器学习算法时,很多人一开始 ...

  6. keras 学习笔记(一) ——— model.fit & model.fit_generator

    from keras.preprocessing.image import load_img, img_to_array a = load_img('1.jpg') b = img_to_array( ...

  7. [TensorFlow 2] [Keras] fit()、fit_generator() 和 train_on_batch() 分析与应用

    前言 是的,除了水报错文,我也来写点其他的.本文主要介绍Keras中以下三个函数的用法: fit()fit_generator()train_on_batch()当然,与上述三个函数相似的evalua ...

  8. keras训练函数fit和fit_generator对比,图像生成器ImageDataGenerator数据增强

    1. [深度学习] Keras 如何使用fit和fit_generator https://blog.csdn.net/zwqjoy/article/details/88356094 ps:解决样本数 ...

  9. Keras函数——mode.fit_generator()

    1 model.fit_generator(self,generator, steps_per_epoch, epochs=1, verbose=1, callbacks=None, validati ...

随机推荐

  1. Cloudera Manager添加主机节点

    为了监控方便,想把研发环境中的主机节点都纳入Cloudera Manager的管理中,这样在遇到问题时可方便的查看主机的硬件资源情况. 添加主机节点有多种方式,由于我是离线工作,所以选择rpm包的方式 ...

  2. SQL语句中 ` 的作用

    SQL语句中 ` 的作用 做攻防世界WEB区 supersqli 题目,在构建SQL语句时,遇到SQL语句中有 ` 时可以解析,没有则不能. 查阅资料得知,` 通常用来说明其中的内容是数据库名.表名. ...

  3. AmoebaNet:经费在燃烧,谷歌提出基于aging evolution的神经网络搜索 | AAAI 2019

    论文提出aging evolution,一个锦标赛选择的变种来优化进化算法,在NASNet搜索空间上,对比强化学习和随机搜索,该算法足够简洁,而且能够更快地搜索到更高质量的模型,论文搜索出的Amoeb ...

  4. C++ STL 栈和队列

    栈和队列 头文件 #include<queue> // 队列 #include<stack> //栈 定义方式 //参数就是数据类型 stack<int> s; q ...

  5. USB限流芯片,4.8A最大,过压关闭6V

    PW1503,PW1502是超低RDS(ON)开关,具有可编程的电流限制,以保护电源源于过电流和短路保护.它具有超温保护以及反向闭锁功能. PW1503,PW1502采用薄型(1毫米)5针薄型SOT2 ...

  6. MySQL库和表的操作

    MySQL库和表的操作 库操作 创建库 1.1 语法 CREATE DATABASE 数据库名 charset utf8; 1.2 数据库命名规则 可以由字母.数字.下划线.@.#.$ 区分大小写 唯 ...

  7. 基于Python的接口自动化-unittest测试框架和ddt数据驱动

    引言 在编写接口自动化用例时,我们一般针对一个接口建立一个.py文件,一条接口测试用例封装为一个函数(方法),但是在批量执行的过程中,如果其中一条出错,后面的用例就无法执行,还有在运行大量的接口测试用 ...

  8. 小白也能看懂的ACID与隔离级别

    前言 现如今JAVA开发工程师的数量越来越多,但大多数工程师平时做的工作都是简单的CRUD,当你一直处于这种舒适的环境中不追求进步的时候,如果哪一天你突然想要改变环境,换个工作,去与面试官当面聊技术的 ...

  9. Serverless对研发效能的变革和创新 云托管和Serverless应用差异

    https://mp.weixin.qq.com/s/J4RXtKanh3IMr4fY7t0nyQ Serverless对研发效能的变革和创新 杨皓然(不瞋) 阿里巴巴中间件 2020-10-23

  10. 【算法】数位 dp

    时隔多日,我终于再次开始写博客了!! 上午听了数位 dp,感觉没听懂,于是在网上进行一番愉 ♂ 快 ♀ 的学习后,写篇博来加深一下印象~~ 前置的没用的知识 数位 不同计数单位,按照一定顺序排列,它们 ...