机器学习中数据读取是很重要的一个环节,TensorFlow也提供了很多实用的方法,为了避免以后时间久了又忘记,所以写下笔记以备日后查看。

最普通的正常情况

首先我们看看最普通的情况:

# 创建0-10的数据集,每个batch取个数。
dataset = tf.data.Dataset.range(10).batch(6)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next() with tf.Session() as sess:
for i in range(2):
value = sess.run(next_element)
print(value)

输出结果

[0 1 2 3 4 5]
[6 7 8 9]

由结果我们可以知道TensorFlow能很好地帮我们自动处理最后一个batch的数据。

datasets.batch(batch_size)与迭代次数的关系

但是如果上面for循环次数超过2会怎么样呢?也就是说如果 循环次数*批数量 > 数据集数量 会怎么样?我们试试看:

dataset = tf.data.Dataset.range(10).batch(6)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next() with tf.Session() as sess:
>>==for i in range(3):==<<
value = sess.run(next_element)
print(value)

输出结果

[0 1 2 3 4 5]
[6 7 8 9]
---------------------------------------------------------------------------
OutOfRangeError Traceback (most recent call last)
D:\Continuum\anaconda3\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)
1277 try: ...
...省略若干信息...
... OutOfRangeError (see above for traceback): End of sequence
[[Node: IteratorGetNext_64 = IteratorGetNext[output_shapes=[[?]], output_types=[DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator_28)]]

可以知道超过范围了,所以报错了。

datasets.repeat()

为了解决上述问题,repeat方法登场。还是直接看例子吧:

dataset = tf.data.Dataset.range(10).batch(6)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next() with tf.Session() as sess:
for i in range(4):
value = sess.run(next_element)
print(value)

输出结果

[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]

可以知道repeat其实就是将数据集重复了指定次数,上面代码将数据集重复了2次,所以这次即使for循环次数是4也依旧能正常读取数据,并且都能完整把数据读取出来。同理,如果把for循环次数设置为大于4,那么也还是会报错,这么一来,我每次还得算repeat的次数,岂不是很心累?所以更简便的办法就是对repeat方法不设置重复次数,效果见如下:

dataset = tf.data.Dataset.range(10).batch(6)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next() with tf.Session() as sess:
for i in range(6):
value = sess.run(next_element)
print(value)

输出结果:

[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]

此时无论for循环多少次都不怕啦~~

datasets.shuffle(buffer_size)

仔细看可以知道上面所有输出结果都是有序的,这在机器学习中用来训练模型是浪费资源且没有意义的,所以我们需要将数据打乱,这样每批次训练的时候所用到的数据集是不一样的,这样啊可以提高模型训练效果。

另外shuffle前需要设置buffer_size:

  • 不设置会报错,
  • buffer_size=1:不打乱顺序,既保持原序
  • buffer_size越大,打乱程度越大,演示效果见如下代码:
dataset = tf.data.Dataset.range(10).shuffle(2).batch(6)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next() with tf.Session() as sess:
for i in range(4):
value = sess.run(next_element)
print(value)

输出结果:

[1 0 2 4 3 5]
[7 8 9 6]
[1 2 3 4 0 6]
[7 8 9 5]

注意:shuffle的顺序很重要,一般建议是最开始执行shuffle操作,因为如果是先执行batch操作的话,那么此时就只是对batch进行shuffle,而batch里面的数据顺序依旧是有序的,那么随机程度会减弱。不信你看:

dataset = tf.data.Dataset.range(10).batch(6).shuffle(10)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next() with tf.Session() as sess:
for i in range(4):
value = sess.run(next_element)
print(value)

输出结果:

[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]

MARSGGBO♥原创







2018-8-5

Tensorflow datasets.shuffle repeat batch方法的更多相关文章

  1. TensorFlow高效读取数据的方法——TFRecord的学习

    关于TensorFlow读取数据,官网给出了三种方法: 供给数据(Feeding):在TensorFlow程序运行的每一步,让python代码来供给数据. 从文件读取数据:在TensorFlow图的起 ...

  2. 【tf.keras】tensorflow datasets,tfds

    一些最常用的数据集如 MNIST.Fashion MNIST.cifar10/100 在 tf.keras.datasets 中就能找到,但对于其它也常用的数据集如 SVHN.Caltech101,t ...

  3. TensorFlow模型保存和加载方法

    TensorFlow模型保存和加载方法 模型保存 import tensorflow as tf w1 = tf.Variable(tf.constant(2.0, shape=[1]), name= ...

  4. TensorFlow指定CPU和GPU方法

    TensorFlow指定CPU和GPU方法 TensorFlow 支持 CPU 和 GPU.它也支持分布式计算.可以在一个或多个计算机系统的多个设备上使用 TensorFlow. TensorFlow ...

  5. [TensorFlow] Introduction to TensorFlow Datasets and Estimators

    Datasets and Estimators are two key TensorFlow features you should use: Datasets: The best practice ...

  6. Tensorflow高效读取数据的方法

    最新上传的mcnn中有完整的数据读写示例,可以参考. 关于Tensorflow读取数据,官网给出了三种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码 ...

  7. TensorFlow加载图片的方法

    方法一:直接使用tensorflow提供的函数image = tf.gfile.FastGFile('PATH')来读取一副图片: import matplotlib.pyplot as plt; i ...

  8. tensorflow中的参数初始化方法

    1. 初始化为常量 tf中使用tf.constant_initializer(value)类生成一个初始值为常量value的tensor对象. constant_initializer类的构造函数定义 ...

  9. TensorFlow 常见错误与解决方法——长期不定时更新

    1. TypeError: Cannot interpret feed_dict key as Tensor: Can not convert a builtin_function_or_method ...

随机推荐

  1. BZOJ4259 残缺的字符串(FFT)

    两个串匹配时相匹配的位置位置差是相同的,那么翻转一个串就变成位置和相同,卷积的形式. 考虑如何使用卷积体现两个位置能否匹配.一个暴力的思路是每次只考虑一种字符,将其在一个串中设为1,并在另一个串中将不 ...

  2. hadoop MapReduce 入门

    原创播客,如需转载请注明出处.原文地址:http://www.cnblogs.com/crawl/p/7687120.html ------------------------------------ ...

  3. Goldbach`s Conjecture LightOJ - 1259 (素数打表 哥德巴赫猜想)

    题意: 就是哥德巴赫猜想...任意一个偶数 都可以分解成两个(就是一对啦)质数的加和 输入一个偶数求有几对.. 解析: 首先! 素数打表..因为 质数 + 质数 = 偶数 所以 偶数 - 质数 = 质 ...

  4. HGOI 20190303 题解

    /* 记一串数字真难. 5435 今天比赛又是hjcAK的一天. 今天开题顺序是312,在搞T1之前搞了T3 昨天某谷月赛真是毒瘤. 但是讲评的同学不错,起码T4看懂了... 构造最优状态然后DP的思 ...

  5. [luogu4201][bzoj1063]设计路线【树形DP】

    题目描述 Z国坐落于遥远而又神奇的东方半岛上,在小Z的统治时代公路成为这里主要的交通手段.Z国共有n座城市,一些城市之间由双向的公路所连接.非常神奇的是Z国的每个城市所处的经度都不相同,并且最多只和一 ...

  6. [hgoi#2019/2/24]玄学考试

    感想 对于这次考试,真的不想说什么了,太玄学了!!! t1输出比标准输出长,这是什么操作???难道要关文件???但是交到oj上又A掉了.这是什么操作. t2还好,没有出什么意外...但是要吐槽一下出题 ...

  7. 【转】 cJSON 源码解析

    关于cjson的介绍和使用方法就不在这里介绍了,详情请查看上一篇博客cjson使用方法. JSON的内存结构像广义表,可以认为是有层次的双向链表. cJSON程序中的细节点如下: 大量宏替换 大量静态 ...

  8. Java -- JDBC_DAO 设计模式

    DAO:Date Access Object 实现代码模块化,更加有利于代码的维护和升级. DAO 可以被子类继承或者直接使用. 访问数据信息的类,包含对数据的CRUD(create read upd ...

  9. 内置窗口 pyqt5

    1.使用Qt Designer设计三个窗口 注意:在主窗口中需要添加一个girdLayout 2.创建**.py from PyQt5.QtWidgets import QMainWindow, QA ...

  10. [python网络编程]使用scapy修改源IP发送请求

    Python爬虫视频教程零基础小白到scrapy爬虫高手-轻松入门 https://item.taobao.com/item.htm?spm=a1z38n.10677092.0.0.482434a6E ...