https://keras.io/zh/layers/core/


keras使用稀疏输入进行训练

2018.06.14 12:55:46字数 902阅读 760

稀疏矩阵

稀疏矩阵是指矩阵中数值为0的元素数目远远多于非0元素的数目,在实际中遇到的大矩阵基本都是稀疏的。如果使用普通的ndarray存储稀疏矩阵,会有很大的内存浪费。在python中我们可以使用scipy中的sparse模块存储这些矩阵,但是在用keras搭建神经网络使用这些矩阵作为神经网络的输入时,则需要做一些处理才能使用sparse格式的数据。


方法一、使用keras函数式API中的参数实现

keras的Sequential顺序模型是不支持稀疏输入的,如果非要用Sequential模型,可以参考方法二。在使用函数式API模型时,Input层初始化时有一个sparse参数,用来指明要创建的占位符是否是稀疏的,如图:

 
Input的参数,可以用sparse来指明是否是稀疏的输入数据

在使用时也很直接,一个参数就可以搞定:

ipt_layer = Input((shape, ), sparse=True)

网络的定义过程和常规方法没有什么区别,后边compile、fit等操作也都没有变化。不过目前这么用有一个问题,就是指定的batch_size不生效,不管设置多大的batch_size,训练的时候都是按照batch_size为1来进行,可能是人家觉得都用稀疏数据了,数据肯定大到可怕,用大一些batch会引入内存问题吧。如果要使用指定的batch_size来训练稀疏数据,或者需要调整batch_size,可以参考方法二。

方法二、使用生成器方法实现

还有一种方法可以实现,是使用生成器的方法,最早看到这个方法是在stackoverflow上,参考链接

这种方法是利用生成器配合keras模型的fit_generator来实现,核心代码如下:

# batch_generator
def batch_generator(x, y, batch_size):
number_of_batches = x.shape[0]//batch_size
counter = 0
shuffle_index = np.arange(x.shape[0])
np.random.shuffle(shuffle_index)
x = x[shuffle_index, :]
y = y[shuffle_index, :]
while 1:
index_batch = shuffle_index[batch_size*counter: batch_size*(counter+1)]
x_batch = x[index_batch, :].todense()
y_batch = y[index_batch, :].todense()
counter += 1
yield(np.array(x_batch), np.array(y_batch))
if counter >= number_of_batches:
np.random.shuffle(shuffle_index)
counter = 0 # fit时要先根据batch_size和样本总量计算一下总共的steps_per_epoch
train_steps = x.shape[0]//batch_size
# 在fit时使用fit_generator
model.fit_generator(generator=batch_generator(x, y, batch_size), steps_per_epoch=train_steps......)

除了生成器函数,这里需要注意的是在fit之前先要计算每个epoch需要训练多少个step。

在用这个方法进行训练的时候,对于validation数据,有几种场景区分:

  • 如果比较大,也可以使用这个生成器,直接将fit_generator的validation_data这个参数设置为生成器并且使用对应的验证数据即可;
  • 如果数据不大,可以选择把所有的validation数据都todense转为常规的ndarray;
  • 另外如果在训练中使用tensorboard,并且histogram_freq参数设置不为0,那么验证数据就不能使用生成器来生成了,必须转为ndarray才可以。

方法总结

时间就是金钱,在多数场景下,推荐使用方法一,节省生命。但如果对于需要调整batch_size或者铁了头要用Sequential模型的,方法二是比较好的选择,鉴于方法二对于tensorboard不是很友好,所以建议在使用方法二的时候不要在验证集上也使用生成器。

对于稀疏的输入,上边的方法应该可以解决大部分问题了,不过有一些输出也是稀疏的情况,虽然训练过程跟着batch_size走,不会有什么影响,但在需要大规模predict的时候,比如要对几千万上亿条数据进行预测,目前还没有很好的办法能够直接输出稀疏格式存储的数据。


Keras:的更多相关文章

  1. keras:InternalError: Failed to create session

    如题,keras出现以上错误,解决办法: 找到占用gpu的进程: nvidia-smi -q 杀死这些进程即可: xxxxx

  2. [机器学习] keras:MNIST手写数字体识别(DeepLearning 的 HelloWord程序)

    深度学习界的Hello Word程序:MNIST手写数字体识别 learn from(仍然是李宏毅老师<机器学习>课程):http://speech.ee.ntu.edu.tw/~tlka ...

  3. 深度学习:Keras入门(一)之基础篇

    1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深度学习框架. Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结 ...

  4. 深度学习框架: Keras官方中文版文档正式发布

    今年 1 月 12 日,Keras 作者 François Chollet‏ 在推特上表示因为中文读者的广泛关注,他已经在 GitHub 上展开了一个 Keras 中文文档项目.而昨日,Françoi ...

  5. 深度学习:Keras入门(一)之基础篇【转】

    本文转载自:http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorfl ...

  6. 深度学习:Keras入门(一)之基础篇(转)

    转自http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深 ...

  7. 【TensorFlow 3】mnist数据集:与Keras对比

    在TF1.8之后Keras被当作为一个内置API:tf.keras. 并且之前的下载语句会报错. mnist = input_data.read_data_sets('MNIST_data',one_ ...

  8. 安装Keras

    在cmd窗口运行代码: pip install keras -U --pre 安装Keras: 进入Python环境,运行import keras,检验是否成功安装.

  9. 学习笔记TF054:TFLearn、Keras

    元框架(metaframework). TFLearn.模块化深度学习框架,更高级API,快速实验,完全透明兼容. TFLearn实现AlexNet.https://github.com/tflear ...

随机推荐

  1. LINUX上安装JDK+tomcat+mysql操作笔记

    1.环境准备: 1-1.centos 64位(本人的虚拟机安装此系统),安装步骤和网络配置已经在前两篇记录. 1-2.JDK 版本1.8 1-3.tomcat压缩包 1-4.CRT远程连接工具(可用其 ...

  2. layUI学习第二日:非模块化方法使用layUI

    layUI采用非模块化方式(即所有模块一次性加载),操作示例代码如下(如果问怎么创建项目和工具,参考layUI学习第一日的步骤): 运行的结果如下: 运行的显示不会太持久,过几秒就会消失,具体封装的代 ...

  3. Vue 知识点个人总结

    Vue 脚手架 脚手架 3 的版本 ---- webpack 4 cnpm install -g @vue/cli-----全局安装组件 vue create myapp-----命令行创建项目 或者 ...

  4. CSP2019 树上的数 题解

    题面 这是一道典型的部分分启发正解的题. 所以我们先来看两个部分分. Part 1 菊花图 这应该是除了暴力以外最好想的一档部分分了. 如上图(节点上的数字已省略),如果我们依次删去边(2)(1)(3 ...

  5. Batch Normalization、Layer Normalization、Instance Normalization、Group Normalization、Switchable Normalization比较

    深度神经网络难训练一个重要的原因就是深度神经网络涉及很多层的叠加,每一层的参数变化都会导致下一层输入数据分布的变化,随着层数的增加,高层输入数据分布变化会非常剧烈,这就使得高层需要不断适应低层的参数更 ...

  6. yii2关联表

    asArray()这个方法很好用,返回数组是1版本想要的形式,这种方式有种tp框架的感觉

  7. 如何用 Python 给照片换色

    最近遇到了一个需求,就是对图片进行色彩风格转换,让一个物体可以以各种不同的色彩来呈现. 比如一个红色的苹果,我想把它转化成绿色,这可怎么办呢?本来想的解决方案是先识别边界,然后对边界内区域进行色彩替换 ...

  8. ROS下多雷达融合算法

    有些小车车身比较长,如果是一个激光雷达,顾前不顾后,有比较大的视野盲区,这对小车导航定位避障来说都是一个问题,比如AGV小车, 所有想在小车前后各加一个雷达,那问题是ROS的建图或者定位导航都只是支持 ...

  9. 打印对象(__str__()和__repr__())

    当打印一个类的实例时,返回的字符串是对象的地址信息,如<__main__.Student object at 0x109afb310>,很不好看 可通过在类内定义__str__(),这样打 ...

  10. pytest框架之pytest-html报告生成

    一.关于安装 pytest-html属于pytest的一个插件,使用它需要先安装 pip install pytest-html pytest可以生成多种样式的结果: 生成JunitXML格式的测试报 ...