在使用kears训练model的时候,一般会将所有的训练数据加载到内存中,然后喂给网络,但当内存有限,且数据量过大时,此方法则不再可用。此博客,将介绍如何在多核(多线程)上实时的生成数据,并立即的送入到模型当中训练。 本篇文章由圆柱模板博主发布。

   先看一下还未改进的版本:

   

import numpy as np
from keras.models import Sequential
#载入全部的数据!!
X, y = np.load('some_training_set_with_labels.npy')
#设计模型
model = Sequential()
[...] #网络结构
model.compile()
# 在数据集上进行模型训练
model.fit(x=X, y=y)

  下面的结构将改变一次性载入全部数据的情况。接下来将介绍如何一步一步的构造数据生成器,此数据生成器也可应用在你自己的项目当中;复制下来,并根据自己的需求填充空白处。

在构建之前先定义统一几个变量,并介绍几个小tips,对我们处理大的数据量很重要。 
ID type为string,代表数据集中的某个样本。 
调整以下结构,编译处理样本和他们的label:

1.新建一个词典名叫 partition :

partition[‘train’] 为训练集的ID,type为list
partition[‘validation’] 为验证集的ID,type为list

  2.新建一个词典名叫 * labels * ,根据ID可找到数据集中的样本,同样可通过labels[ID]找到样本标签。 
举个例子: 
假设训练集包含三个样本,ID分别为id-1,id-2和id-3,相应的label分别为0,1,2。验证集包含样本ID id-4,标签为 1。此时两个词典partition和 labels分别如下:

partition
{'train': ['id-1', 'id-2', 'id-3'], 'validation': ['id-4']}

  

labels
{'id-1': 0, 'id-2': 1, 'id-3': 2, 'id-4': 1}

  data/ 中为数据集文件。

数据生成器(data generator)

接下来将介绍如何构建数据生成器 DataGenerator ,DataGenerator将实时的对训练模型feed数据。 
接下来,将先初始化类。我们使此类继承自keras.utils.Sequence,这样我们可以使用多线程。

def __init__(self, list_IDs, labels, batch_size=32,
dim=(32,32,32), n_channels=1,
n_classes=10, shuffle=True):
'Initialization'
self.dim = dim
self.batch_size = batch_size
self.labels = labels
self.list_IDs = list_IDs
self.n_channels = n_channels
self.n_classes = n_classes
self.shuffle = shuffle
self.on_epoch_end()

  我们给了一些与数据相关的参数 dim,channels,classes,batch size ;方法 on_epoch_end 在一个epoch开始时或者结束时触发,shuffle决定是否在数据生成时要对数据进行打乱。

def on_epoch_end(self):
'Updates indexes after each epoch'
self.indexes = np.arange(len(self.list_IDs))
if self.shuffle == True:
np.random.shuffle(self.indexes)

  另一个数据生成核心的方法__data_generation 是生成批数据。

def __data_generation(self, list_IDs_temp):
'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
# Initialization
X = np.empty((self.batch_size, *self.dim, self.n_channels))
y = np.empty((self.batch_size), dtype=int) # Generate data
for i, ID in enumerate(list_IDs_temp):
# Store sample
X[i,] = np.load('data/' + ID + '.npy') # Store class
y[i] = self.labels[ID] return X, keras.utils.to_categorical(y, num_classes=self.n_classes)

  在数据生成期间,代码读取包含各个样本ID的代码ID.py.因为我们的代码是可以应用多线程的,所以可以采用更为复杂的操作,不用担心数据生成成为总体效率的瓶颈。 
另外,我们使用Keras的方法keras.utils.to_categorical对label进行2值化 
(比如,对6分类而言,第三个label则相应的变成 to [0 0 1 0 0 0]) 。

def __len__(self):
'Denotes the number of batches per epoch'
return int(np.floor(len(self.list_IDs) / self.batch_size))

  现在,当相应的index的batch被选到,则生成器执行_getitem_方法来生成它。

def __getitem__(self, index):
'Generate one batch of data'
# Generate indexes of the batch
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size] # Find list of IDs
list_IDs_temp = [self.list_IDs[k] for k in indexes] # Generate data
X, y = self.__data_generation(list_IDs_temp) return X, y

  

Keras神经网络data generators解决数据内存的更多相关文章

  1. mysql查询语句出现sending data耗时解决

    在执行一个简单的sql查询,表中数据量为14万 sql语句为:SELECT id,titile,published_at from spider_36kr_record where is_analyz ...

  2. 压缩Sqlite数据文件大小,解决数据删除后占用空间不变的问题

    最近有一网站使用Sqlite数据库作为数据临时性的缓存,对多片区进行划分 Sqlite数据库文件,每天大概新增近1万的数据量,起初效率有明显的提高,但历经一个多月后数据库文件从几K也上升到了近160M ...

  3. [转]怎样解决Myeclipse内存溢出?

    在用myeclipes10 开发 遇到了 内存溢出问题,百度了很久,这篇比较完善. 总结起来三个方面去检查 1)myeclipes的配置:myeclipes 10 的安装路径下 的myeclipse. ...

  4. JAVA 大数据内存耗用测试

    JAVA 大数据内存耗用测试import java.lang.management.ManagementFactory;import java.lang.management.MemoryMXBean ...

  5. 解决Windows内存问题的两个小工具RamMap和VMMap(这个更牛更好)

    来源:http://www.cr173.com/html/13006_1.html .net程序内存监测分配工具(CLR Profiler for .NET Framework 4)官方安装版 类型: ...

  6. Spark性能调优之解决数据倾斜

    Spark性能调优之解决数据倾斜 数据倾斜七种解决方案 shuffle的过程最容易引起数据倾斜 1.使用Hive ETL预处理数据    • 方案适用场景:如果导致数据倾斜的是Hive表.如果该Hiv ...

  7. SAS DATA步读取数据

    上面一节讲了SAS的基本概念,以及语法结构,这次主要讲解SAS DATA步读取数据.    1 ·列表输入    2 ·按列输入    3 ·格式化输入  使用DATA步读取数据的基本形式如下: DA ...

  8. [MapReduce_add_3] MapReduce 通过分区解决数据倾斜

    0. 说明 数据倾斜及解决方法的介绍与代码实现 1. 介绍 [1.1 数据倾斜的含义] 大量数据发送到同一个节点进行处理,造成此节点繁忙甚至瘫痪,而其他节点资源空闲 [1.2 解决数据倾斜的方式] 重 ...

  9. 解决Windows内存问题的两个小工具RamMap和VMMap

    解决Windows内存问题需要对操作系统的深入理解,同时对于如何运用Windows调试器或性能监控器要有工作认知.如果你正试着得到细节,诸如内核堆栈大小或硬盘内存消耗,你会需要调试器命令和内核数据架构 ...

随机推荐

  1. 学习数据结构Day2

    之前学习过了数组的静态实现方法,同时将数组的所有有可能实现的方法都统一实现了一遍,之后支持了泛型的相关 概念,接下来就是如何对数组进行扩容的操作也就是实现动态数组. private void resi ...

  2. java8新特性五-Stream

    继上次学习过Java8中的非常重要的Lambda表达式之后,接下来就要学习另一个也比较重要的知识啦,也就如标题所示:Stream,而它的学习是完全依赖于之前学习的Lambda表达式. Java 8 A ...

  3. PHP 获取星期

    <?php function getWeek($time = 0) { $week_array=array('日', '一', '二', '三', '四', '五', '六'); //先定义一个 ...

  4. Google深度学习开源框架TenseorFlow安装

    Google近期发布了TensorFlow,考录到Google出品,必属精品,估计这玩意会火,不过火钳刘明已经来不及了 今天才想着安装来试试 TensorFlow官网:https://www.tens ...

  5. 使用SnowFlake算法生成唯一ID

    转自:https://segmentfault.com/a/1190000007769660 考虑过的方法有 直接用时间戳,或者以此衍生的一系列方法 Mysql自带的uuid 以上两种方法都可以查到就 ...

  6. LeetCode 103. 二叉树的锯齿形层次遍历(Binary Tree Zigzag Level Order Traversal)

    103. 二叉树的锯齿形层次遍历 103. Binary Tree Zigzag Level Order Traversal 题目描述 给定一个二叉树,返回其节点值的锯齿形层次遍历.(即先从左往右,再 ...

  7. ERP解析外围系统json数据格式

    外围系统调用ERP的WebService接口,将数据以json格式传到ERP,ERP解析json 1.创建java source jsp,提供java方法解析json数据 create or repl ...

  8. [HAOI2008]硬币购物-题解

    传送门 解答 根据容斥原理 \[ \left|\bigcap_{i=1}^n \overline{S_i}\right| = |U| - \left|\bigcup_{i=1}^n S_i\right ...

  9. Java date日期类型,结束日期减去开始日期求两者时间差,精确到秒

    /** * @Author: * @Description: * @Date: 2019/4/10 19:01 * @Modified By: */ @Slf4j public class DateU ...

  10. Jmeter_数据库

    1.准备一个有测试数据表的mysql数据库 2.在测试计划面板点击“浏览..." 按钮,将你的JDBC驱动添加进来.         需要安装插件   mysql-connector-jav ...