本文主要参考两篇文献:

1、《深度学习theano/tensorflow多显卡多人使用问题集》

2、基于双向LSTM和迁移学习的seq2seq核心实体识别

运行机器学习算法时,很多人一开始都会有意无意将数据集默认直接装进显卡显存中,如果处理大型数据集(例如图片尺寸很大)或是网络很深且隐藏层很宽,也可能造成显存不足。

这个情况随着工作的深入会经常碰到,解决方法其实很多人知道,就是分块装入。以keras为例,默认情况下用fit方法载数据,就是全部载入。换用fit_generator方法就会以自己手写的方法用yield逐块装入。这里稍微深入讲一下fit_generator方法。

.

— fit_generator源码

def fit_generator(self, generator, samples_per_epoch, nb_epoch,
                      verbose=1, callbacks=[],
                      validation_data=None, nb_val_samples=None,
                      class_weight=None, max_q_size=10, **kwargs):

.

— generator 该怎么写?

其中generator参数传入的是一个方法,validation_data参数既可以传入一个方法也可以直接传入验证数据集,通常我们都可以传入方法。这个方法需要我们自己手写,伪代码如下:

def generate_batch_data_random(x, y, batch_size):
    """逐步提取batch数据到显存,降低对显存的占用"""
    ylen = len(y)
    loopcount = ylen // batch_size
    while (True):
        i = randint(0,loopcount)
        yield x[i * batch_size:(i + 1) * batch_size], y[i * batch_size:(i + 1) * batch_size]

.

为什么推荐在自己写的方法中用随机呢?

  • 因为fit方法默认shuffle参数也是True,fit_generator需要我们自己随机打乱数据。
  • 另外,在方法中需要用while写成死循环,因为每个epoch不会重新调用方法,这个是新手通常会碰到的问题。

当然,如果原始数据已经随机打乱过,那么可以不在这里做随机处理。否则还是建议加上随机取数逻辑(如果数据集比较大则可以保证基本乱序输出)。深度学习中随机打乱数据是非常重要的,具体参见《深度学习Deep Learning》一书的8.1.3节:《Batch and Minibatch Algorithm》。(2017年5月25日补充说明)

.

调用示例:

model.fit_generator(self.generate_batch_data_random(x_train, y_train, batch_size),
    samples_per_epoch=len(y_train)//batch_size*batch_size,
    nb_epoch=epoch,
    validation_data=self.generate_valid_data(x_valid, y_valid,batch_size),
    nb_val_samples=(len(y_valid)//batch_size*batch_size),
    verbose=verbose,
    callbacks=[early_stopping])

这样就可以将对显存的占用压低了,配合第一部分的方法可以方便同时执行多程序。


.

来看看一个《基于双向LSTM和迁移学习的seq2seq核心实体识别》实战案例:

'''
gen_matrix实现从分词后的list来输出训练样本
gen_target实现将输出序列转换为one hot形式的目标
超过maxlen则截断,不足补0
'''
gen_matrix = lambda z: np.vstack((word2vec[z[:maxlen]], np.zeros((maxlen-len(z[:maxlen]), word_size))))
gen_target = lambda z: np_utils.to_categorical(np.array(z[:maxlen] + [0]*(maxlen-len(z[:maxlen]))), 5)

#从节省内存的角度,通过生成器的方式来训练
def data_generator(data, targets, batch_size):
    idx = np.arange(len(data))
    np.random.shuffle(idx)
    batches = [idx[range(batch_size*i, min(len(data), batch_size*(i+1)))] for i in range(len(data)/batch_size+1)]
    while True:
        for i in batches:
            xx, yy = np.array(map(gen_matrix, data[i])), np.array(map(gen_target, targets[i]))
            yield (xx, yy)

batch_size = 1024
history = model.fit_generator(data_generator(d['words'], d['label'], batch_size), samples_per_epoch=len(d), nb_epoch=200)
model.save_weights('words_seq2seq_final_1.model')

延伸一:edwardlib/observations 规范数据导入、数据Batch化

def generator(array, batch_size):
  """Generate batch with respect to array's first axis."""
  start = 0  # pointer to where we are in iteration
  while True:
    stop = start + batch_size
    diff = stop - array.shape[0]
    if diff <= 0:
      batch = array[start:stop]
      start += batch_size
    else:
      batch = np.concatenate((array[start:], array[:diff]))
      start = diff
    yield batch

To use it, simply write

from observations import cifar10
(x_train, y_train), (x_test, y_test) = cifar10("~/data")
x_train_data = generator(x_train, 256)

keras系列︱利用fit_generator最小化显存占用比率/数据Batch化的更多相关文章

  1. Linux显存占用无进程清理方法(附批量清理命令)

    在跑TensorFlow.pytorch之类的需要CUDA的程序时,强行Kill掉进程后发现显存仍然占用,这时候可以使用如下命令查看到top或者ps中看不到的进程,之后再kill掉: fuser -v ...

  2. 深度学习中GPU和显存分析

    刚入门深度学习时,没有显存的概念,后来在实验中才渐渐建立了这个意识. 下面这篇文章很好的对GPU和显存总结了一番,于是我转载了过来. 作者:陈云 链接:https://zhuanlan.zhihu. ...

  3. keras系列︱keras是如何指定显卡且限制显存用量

    keras在使用GPU的时候有个特点,就是默认全部占满显存. 若单核GPU也无所谓,若是服务器GPU较多,性能较好,全部占满就太浪费了. 于是乎有以下三种情况: - 1.指定GPU - 2.使用固定显 ...

  4. 我的Keras使用总结(5)——Keras指定显卡且限制显存用量,常见函数的用法及其习题练习

    Keras 是一个高层神经网络API,Keras是由纯Python编写而成并基于TensorFlow,Theano以及CNTK后端.Keras为支持快速实验而生,能够将我们的idea迅速转换为结果.好 ...

  5. keras系列︱Sequential与Model模型、keras基本结构功能(一)

    引自:http://blog.csdn.net/sinat_26917383/article/details/72857454 中文文档:http://keras-cn.readthedocs.io/ ...

  6. [Pytorch]深度模型的显存计算以及优化

    原文链接:https://oldpan.me/archives/how-to-calculate-gpu-memory 前言 亲,显存炸了,你的显卡快冒烟了! torch.FatalError: cu ...

  7. MegEngine亚线性显存优化

    MegEngine亚线性显存优化 MegEngine经过工程扩展和优化,发展出一套行之有效的加强版亚线性显存优化技术,既可在计算存储资源受限的条件下,轻松训练更深的模型,又可使用更大batch siz ...

  8. 自制操作系统Antz(3)——进入保护模式 (中) 直接操作显存

    Antz系统更新地址: https://www.cnblogs.com/LexMoon/category/1262287.html Linux内核源码分析地址:https://www.cnblogs. ...

  9. 解决GPU显存未释放问题

    前言 今早我想用多块GPU测试模型,于是就用了PyTorch里的torch.nn.parallel.DistributedDataParallel来支持用多块GPU的同时使用(下面简称其为Dist). ...

随机推荐

  1. Vim 基本設置 – 使用Vim-plug管理插件 (3)【转】

    本文转载自:https://staryoru.github.io/vim-plugin-manager/ Vim中有很多非常好用的插件(plugin),對於這些插件的安裝.更新與移除等等,使用一個插件 ...

  2. TensorFlow和深度学习入门教程(TensorFlow and deep learning without a PhD)【转】

    本文转载自:https://blog.csdn.net/xummgg/article/details/69214366 前言 上月导师在组会上交我们用tensorflow写深度学习和卷积神经网络,并把 ...

  3. [BZOJ2599]Race

    Description 给一棵树,每条边有权.求一条简单路径,权值和等于K,且边的数量最小.N <= 200000, K <= 1000000 Input 第一行 两个整数 n, k 第二 ...

  4. git gc内存错误的解决方案

    Auto packing the repository for optimum performance. You may alsorun "git gc" manually. Se ...

  5. lucene 简介和实践 分享

    之前项目做了搜索的改造,使用lucene,公司内做了相关的技术分享,故先整理下ppt内容,后面会再把项目中的具体做法进行介绍 lucene 简介和实践  分享 搜索改造项目

  6. vue.js的一些事件绑定和表单数据双向绑定

    知识点: v-on:相当于: 例如:v-on:click==@click ,menthods事件绑定 v-on修饰符可以指定键盘事件 v-model进行表单数据的双向绑定 <template&g ...

  7. 不一样的入门:看C# Hello World的17种写法

    摘要:本文针对不同阶段.不同程度的C#学习者,介绍了C# Hello World的17种不同写法,希望会对大家有所帮助.(C# Hello World写法入门.C# Hello World写法进阶.C ...

  8. How to implement multiple constructor with different parameters in Scala

    Using scala is just another road, and it just like we fall in love again, but there is some pain you ...

  9. wampserver安装及安装中可能遇到的问题

    首先wampserver是windows apache Mysql PHP 集成开发环境,即在windows下的apache.php和mysql的服务器.因此wampserver是一个服务器端应用程序 ...

  10. 嵌入式--arm

    两年前的东西了,整理一下,说不定以后就会用到了. arm对于s3c2440的这个arm的驱动的整理. 其中包括:adc,beeper 蜂鸣器,key 按键,rtc ,timer定时器,UART等的驱动 ...