keras fit_generator 并行
虽然已经走在 torch boy 的路上了, 还是把碰到的这个坑给记录一下
- 数据量较小时,我们可直接把整个数据集 load 到内存里,用 model.fit() 来拟合模型。
- 当数据集过大比如几十个 G 时,内存撑不下,需要用 model.fit_generator 的方式来拟合。
model.fit_generator 一般参数的配置参考官方文档就好,其中 generator, workers, use_multiprocessing 的使用有一些坑存在。
workers=0, use_multiprocessing=False
此时 generator 用一个普通的 generator去提供数据即可,类似官方提供的这种
def generate_arrays_from_file(path):
while True:
with open(path) as f:
for line in f:
# create numpy arrays of input data
# and labels, from each line in the file
x1, x2, y = process_line(line)
yield ({'input_1': x1, 'input_2': x2}, {'output': y})
model.fit_generator(generate_arrays_from_file('/my_file.txt'),
steps_per_epoch=10000, epochs=10)
workers>0, use_multiprocessing=True
这时依然用一个 generator function 来做 generator在拟合的时候便会报错如下:
PicklingError: Can't pickle <function generator_queue.<locals>.data_generator_task at
且当 use_multiprocessing=True 时,如果你使用的是 generator function, 代码会把你的数据copy几份分给不同的worker去处理,但我们希望的是把一份数据平均分拆成几份给多个worker去处理。
怎么解决上面两个问题? keras.utils.Sequence 可以做到
很简单,继承 keras.utils.Sequence 这个类,重写自己的 len(), getitem 即可。
class SequenceData(Sequence):
def __init__(self, filePaths, batch_size):
self.filePaths = filePaths[:100].copy()
self.batch_size = batch_size
self.Y = self.getY()
def __len__(self):
return len(self.Y) // self.batch_size
def __getitem__(self, index):
batch_X = np.zeros((self.batch_size,) + IMG_DIMS, dtype='float32')
batch_Y_ = self.Y[index*self.batch_size: (index+1)*self.batch_size].copy()
batch_Y_.reset_index(drop=True, inplace=True)
assert batch_Y_.shape[0] == self.batch_size
for index, rows in batch_Y_.iterrows():
try:
img = _load_img(rows['path'])
batch_X[index, :, :, :] = img.copy()
batch_Y_.loc[index, 'valid'] = 1
except:
batch_Y_.loc[index, 'valid'] = 0
traceback.print_exc()
batch_Y = to_categorical(batch_Y_['label'], classes_num)
return batch_X, batch_Y
def __iter__(self):
for item in (self[i] for i in range(len(self))):
yield item
def getY(self):
Y = pd.DataFrame(self.filePaths, columns=['path'])
Y['class'] = Y['path'].apply(lambda x: path2class(x))
Y['label'] = Y['class'].apply(lambda x: class2label[x])
Y = Y.sample(frac=1).reset_index(drop=True)
return Y
效果比较
- 样本量:1000张图片
- 模型: MobileNetV2
- epochs: 5
- CPU: 4核,3.4GHz
- GPU: None
可能数据量过小,并行的效果不是太明显。
| 数据读取方式 | workers | use_multiprocessing | 耗时/s |
|---|---|---|---|
| 内存读取 | 0 | True | 1797 |
| keras.utils.Sequence | 0 | False | 1475 |
| keras.utils.Sequence | 4 | True |
参考:
- https://zhuanlan.zhihu.com/p/32679425
- https://github.com/keras-team/keras/blob/master/keras/utils/data_utils.py#L305
keras fit_generator 并行的更多相关文章
- keras 入门整理 如何shuffle,如何使用fit_generator 整理合集
keras入门参考网址: 中文文档教你快速建立model keras不同的模块-基本结构的简介-类似xmind整理 Keras的基本使用(1)--创建,编译,训练模型 Keras学习笔记(完结) ke ...
- (转)The AlphaGo Replication Wiki
The AlphaGo Replication Wiki 摘自:https://github.com/Rochester-NRT/RocAlphaGo/wiki/01.-Home Contents : ...
- 『计算机视觉』Mask-RCNN_训练网络其三:训练Model
Github地址:Mask_RCNN 『计算机视觉』Mask-RCNN_论文学习 『计算机视觉』Mask-RCNN_项目文档翻译 『计算机视觉』Mask-RCNN_推断网络其一:总览 『计算机视觉』M ...
- [Tensorflow] 使用 Mask_RCNN 完成目标检测与实例分割,同时输出每个区域的 Feature Map
Mask_RCNN-2.0 网页链接:https://github.com/matterport/Mask_RCNN/releases/tag/v2.0 Mask_RCNN-master(matter ...
- keras系列︱利用fit_generator最小化显存占用比率/数据Batch化
本文主要参考两篇文献: 1.<深度学习theano/tensorflow多显卡多人使用问题集> 2.基于双向LSTM和迁移学习的seq2seq核心实体识别 运行机器学习算法时,很多人一开始 ...
- keras 学习笔记(一) ——— model.fit & model.fit_generator
from keras.preprocessing.image import load_img, img_to_array a = load_img('1.jpg') b = img_to_array( ...
- [TensorFlow 2] [Keras] fit()、fit_generator() 和 train_on_batch() 分析与应用
前言 是的,除了水报错文,我也来写点其他的.本文主要介绍Keras中以下三个函数的用法: fit()fit_generator()train_on_batch()当然,与上述三个函数相似的evalua ...
- keras训练函数fit和fit_generator对比,图像生成器ImageDataGenerator数据增强
1. [深度学习] Keras 如何使用fit和fit_generator https://blog.csdn.net/zwqjoy/article/details/88356094 ps:解决样本数 ...
- Keras函数——mode.fit_generator()
1 model.fit_generator(self,generator, steps_per_epoch, epochs=1, verbose=1, callbacks=None, validati ...
随机推荐
- 【Linux】将ens33修改为eth0 网卡方法
1.编辑 grub 配置文件 vim /etc/sysconfig/grub # 其实是/etc/default/grub的软连接 # 为GRUB_CMDLINE_LINUX变量增加2个参数,添加的内 ...
- npm i 报错 'match' of undefined 错误以及删除node_modules失败
简单粗暴的解决办法就是一个字'删', 1.先把node_modules给删了 手动删除的话,window系统经常会有部分删不了,说需要个权限什么的,直接用rimraf 就能解决 先安装npm inst ...
- powershell中的cmdlet命令
Add-Computer 向域或工作组中添加计算机. Add-Content 向指定的项中添加内容,如向文件中添加字词. Add-History 向会话历史记录追加条目. Add-Member 向 W ...
- 给dtcms增加模板自动生成功能
作为dtcms的使用者你是不是像我一样,也在不停的修改模板之后要点击生成模板浪费了很多开发模板的时间? 那就跟我一起给dtcms增加一个开发者模式,当模板修改完成之后,直接刷新页面就能看到效果,而不再 ...
- 【故障公告】K8s CofigMap 挂载问题引发网站故障
今天凌晨我们用阿里云服务器自建的 kubernetes 集群出现突发异常情况,博客站点(blog-web)与博客 web api(blog-api)的 pod 无法正常启动(CrashLoopBack ...
- WIFI 国家码和信道划分
前言 网上百度了很多资料,都没有找到国家码对应支持哪些信道的资料,无奈只能qiang到谷歌,分享给大家完整的WIFI 国家码和信道划分. 安卓WIFI国家码的影响 android中设置wifi国家码的 ...
- Visual Studio中自定义代码段!
Visual Studio中自定义代码段! 第一步:在编辑器中进行快捷键的输入[ctrl + shift + p] 或者 点击 查看 第一个选项就是!请看下图 第二步:选择你要配置代码段的语言, 这里 ...
- postgres多知识点综合案例
使用到的知识点: 1.使用with临时存储sql语句,格式[with as xxx(), as xxx2() ]以减少代码: 2.使用round()取小数点后几位: 3.使用to_char()将时间格 ...
- 【题解】 CF767E Change-free
洛谷链接 这个题翻译忘了输入,我看的英语原文...... 首先,这是一道贪心题 我的大致方法:pair+堆优 题目分析: 从第一天开始,到最后一天,每天可以选择找钱或者不找钱. 如果不找钱,则零钱数m ...
- python3中zip对象的使用
zip(*iterables) zip可以将多个可迭代对象组合成一个迭代器对象,通过迭代取值,可以得到n个长度为m的元组.其中n为长度最短可迭代对象的元素个数,m为可迭代对象的个数.并且每个元组的第i ...