展示如何将数据输入到计算图中

Dataset可以看作是相同类型“元素”的有序列表,在实际使用时,单个元素可以是向量、字符串、图片甚至是tuple或dict。

数据集对象实例化:

dataset=tf.data.Dataset.from_tensor_slice(<data>)

迭代器对象实例化:

iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()

读取结束异常:如果一个dataset中的元素被读取完毕,再尝试sess.run(one_element)的话,会抛出tf.errors.OutOfRangeError异常,这个行为与使用队列方式读取数据是一致的。

高维数据集的使用

tf.data.Dataset.from_tensor_slices真正作用是切分传入Tensor的第一个维度,生成相应的dataset,即第一维表明数据集中数据的数量,之后切分batch等操作均以第一维为基础。

dataset=tf.data.Dataset.from_tensor_slices(np.random.uniform((5,2)))
iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()
with tf.Session(config=config) as sess:
try:
while True:
print(sess.run(one_element))
except tf.errors.OutOfRangeError as e:
print('end~')

输出:

[0.1,0.2]
[0.3,0.2]
[0.1,0.6]
[0.4,0.3]
[0.5,0.2]

tuple组合数据

dataset=tf.data.Dataset.from_tensor_slices((np.array([1.,2.,3.,4.,5.]),
np.random.uniform(size=(5,2))))
iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()
with tf.Session() as sess:
try:
while True:
print(sess.run(one_element))
except tf.errors.OutOfRangeError:
print('end~')

输出:

(1.,array(0.1,0.3))
(2.,array(0.2,0.4))
...

数据集处理方法

Dataset支持一类特殊操作:Transformation。一个Dataset通过Transformation变成一个新的Dataset。常用的Transformation

  • map
  • batch
  • shuffle
  • repeat

其中,

  • map和python中的map一致,接受一个函数,Dataset中的每个元素都会作为这个函数的输入,并将函数返回值作为新的Dataset

    dataset=dataset.map(lambda x:x+1)

    注意:map函数可以使用num_parallel_calls参数并行化

  • batch就是将多个元素组成batch。

    dataset=tf.data.Dataset.from_tensor_slices(
    {
    'a':np.array([1.,2.,3.,4.,5.]),
    'b':np.random.uniform(size=(5,2))
    })
    ###
    dataset=dataset.batch(2) # batch_size=2
    ###
    iterator=dataset.make_one_shot_iterator()
    one_element=iterator.get_next()
    with tf.Session() as sess:
    try:
    while True:
    print(one_element)
    except tf.errors.OutOfRangeError:
    print('end~')

    输出:

    {'a':array([1.,2.]),'b':array([[1.,2.],[3.,4.]])}
    {'a':array([3.,4.]),'b':array([[5.,6.],[7.,8.]])}
  • shuffle的功能是打乱dataset中的元素,它有个参数buffer_size,表示打乱时使用的buffer的大小,不应设置过小,推荐值1000.

    dataset=tf.data.Dataset.from_tensor_slices(
    {
    'a':np.array([1.,2.,3.,4.,5.]),
    'b':np.random.uniform(size=(5,2))
    })
    ###
    dataset=dataset.shuffle(buffer_size=5)
    ###
    iterator=dataset.make_one_shot_iterator()
    one_element=iterator.get_next()
    with tf.Session() as sess:
    try:
    while True:
    print(one_element)
    except tf.errors.OutOfRangeError:
    print('end~')
  • repeat的功能就是将整个序列重复多次,主要用来处理机器学习中的epoch。假设原先的数据是一个epoch,使用repeat(2)可以使之变成2个epoch.

    dataset=tf.data.Dataset.from_tensor_slices({
    'a':np.array([1.,2.,3.,4.,5.]),
    'b':np.random.uniform(size=(5,2))
    })
    ###
    dataset=dataset.repeat(2) # 2epoch
    ###
    # iterator, one_element...

    注意:如果直接调用repeat()函数的话,生成的序列会无限重复下去,没有结果,因此不会抛出tf.errors.OutOfRangeError异常。

模拟读入磁盘图片及其Label示例

def _parse_function(filename,label):  # 接受单个元素,转换为目标
img_string=tf.read_file(filename)
img_decoded=tf.image.decode_images(img_string)
img_resized=tf.image.resize_images(image_decoded,[28,28])
return image_resized,label filenames=tf.constant(['data/img1.jpg','data/img2.jpg',...])
labels=tf.constant([1,3,...])
dataset=tf.data.Dataset.from_tensor_slices((filenames,labels))
dataset=dataset.map(_parse_function) # num_parallel_calls 并行
dataset=dataset.shuffle(buffer_size=1000).batch_size(32).repeat(10)

更多Dataset创建方法

  • tf.data.TextLineDataset():函数输入一个文件列表,输出一个Dataset。dataset中的每一个元素对应文件中的一行,可以使用该方法读入csv文件。
  • tf.data.FixedLengthRecordDataset():函数输入一个文件列表和record_bytes参数,dataset中每一个元素是文件中固定字节数record_bytes的内容,可用来读取二进制保存的文件,如CIFAR10。
  • tf.data.TFRecordDataset():读取TFRecord文件,dataset中每一个元素是一个TFExample。

更多Iterator创建方法

最简单的创建Iterator方法是通过dataset.make_one_shot_iterator()创建一个iterator。

除了这种iterator之外,还有更复杂的Iterator:

  • initializable iterator
  • reinitializable iterator
  • feedable iterator

其中,initializable iterator方法要在使用前通过sess.run()进行初始化,initializable iterator还可用于读入较大数组。在使用tf.data.Dataset.from_tensor_slices(array)时,实际上发生的事情是将array作为一个tf.constants保存到了计算图中,当array很大时,会导致计算图变得很大,给传输保存带来不便,这时可以使用一个placeholder取代这里的array,并使用initializable iterator,只在需要时将array传进去,这样即可避免将大数组保存在图里。

features_placeholder=tf.placeholder(<features.dtype>,<features.shape>)
labels_placeholder=tf.placeholder(<labels.dtype>,<labels.shape>)
dataset=tf.data.Dataset.from_tensor_slices((features_placeholder,labels_placeholder))
iterator=dataset.make_initializable_iterator()
next_element=iterator.get_next()
sess.run(iterator.initializer,feed_dict={features_placeholder:features,labels_placeholder:labels})

Tensorflow内部读取机制

对于文件名队列,使用tf.train.string_input_producer()函数,tf.train.string_input_producer()还有两个重要参数,num_epochesshuffle

内存队列不需要我们建立,只需要使用reader对象从文件名队列中读取数据即可,使用tf.train.start_queue_runners()函数启动队列,填充两个队列的数据。

with tf.Session() as sess:
filenames=['A.jpg','B.jpg','C.jpg']
filename_queue=tf.train.string_input_producer(filenames,shuffle=True,num_epoch=5)
reader=tf.WholeFileReader()
key,value=reader.read(filename_queue)
# tf.train.string_input_producer()定义了一个epoch变量,需要对其进行初始化
tf.local_variables_initializer().run()
threads=tf.train.start_queue_runners(sess=sess)
i=0
while True:
i+=1
image_data=sess.run(value)
with open('reader/test_%d.jpg'%i,'wb') as f:
f.write(image_data)

Tensorflow数据读取机制的更多相关文章

  1. 十图详解tensorflow数据读取机制(附代码)转知乎

    十图详解tensorflow数据读取机制(附代码) - 何之源的文章 - 知乎 https://zhuanlan.zhihu.com/p/27238630

  2. tensorflow 1.0 学习:十图详解tensorflow数据读取机制

    本文转自:https://zhuanlan.zhihu.com/p/27238630 在学习tensorflow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找 ...

  3. 十图详解tensorflow数据读取机制

    在学习tensorflow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找不到什么合适的学习材料.今天这篇文章就以图片的形式,用最简单的语言,为大家详细解释一下 ...

  4. 十图详解TensorFlow数据读取机制(附代码)

    在学习TensorFlow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找不到什么合适的学习材料.今天这篇文章就以图片的形式,用最简单的语言,为大家详细解释一下 ...

  5. 【转载】 十图详解tensorflow数据读取机制(附代码)

    原文地址: https://zhuanlan.zhihu.com/p/27238630 何之源 ​ 深度学习(Deep Learning) 话题的优秀回答者       --------------- ...

  6. tensorflow数据读取机制tf.train.slice_input_producer 和 tf.train.batch 函数

    tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数据读入到一个内存队列中,另一个线程 ...

  7. TensorFlow数据读取

    TensorFlow高效读取数据的方法 TF Boys (TensorFlow Boys ) 养成记(二): TensorFlow 数据读取 Tensorflow从文件读取数据 极客学院-数据读取 十 ...

  8. TensorFlow数据读取方式:Dataset API

    英文详细版参考:https://www.cnblogs.com/jins-note/p/10243716.html Dataset API是TensorFlow 1.3版本中引入的一个新的模块,主要服 ...

  9. 详解Tensorflow数据读取有三种方式(next_batch)

    转自:https://blog.csdn.net/lujiandong1/article/details/53376802 Tensorflow数据读取有三种方式: Preloaded data: 预 ...

随机推荐

  1. web.xml 中的listener、 filter、servlet 加载顺序及其详解(转)

    在项目中总会遇到一些关于加载的优先级问题,近期也同样遇到过类似的,所以自己查找资料总结了下,下面有些是转载其他人的,毕竟人家写的不错,自己也就不重复造轮子了,只是略加点了自己的修饰. 首先可以肯定的是 ...

  2. 在视图上创建ListCtrl的做法

    作者:朱金灿 来源:http://blog.csdn.net/clever101 今天介绍下如何在一个视图上动态创建一个ListCtrl. 1.新建一个MFC的单文档工程,这里暂定名字为ListDem ...

  3. gradle命令学习

    概述 命令学习比较枯燥,全部是例子~ gradle版本 假设你的本地gradle已经安装配置完成.没有安装配置的,可以参考 gradle安装 C:\Users\yueling.DANGDANG> ...

  4. Android最新组件RecyclerView,替代ListView

    转载请注明出处:http://blog.csdn.net/allen315410/article/details/40379159 万众瞩目的android最新5.0版本号不久前已经正式公布了,对于我 ...

  5. Cordova热更新和App升级 - 简书

    原文:Cordova热更新和App升级 - 简书 公司的cordova项目前段时间增加了热更新功能,自己第一次做的时候在网上查找了很多资料,有的资料写的并不全面遇到了很多坑.因此总结一些在开发过程中遇 ...

  6. ServletContextListener接口用法

    ServletContextListener接口用于tomcat启动时自动加载函数,方法如下: 一.需加载的类必须实现ServletContextListener接口. 二.该接口中有两个方法必须实现 ...

  7. 64 位系统 vs2013 配置 OpenCV-3.1.0

    参考:64 位系统 vs2013 配置 opencv3.0 1. 环境准备 进入官网 http://opencv.org/,下载最新版本的 opencv(以本文 opencv-3.1.0 为例,.ex ...

  8. 面试问题:Vista与XP的Session 0与Session X的区别

    面试问题:Vista与XP的Session 0与Session X的区别 在XXXXX的一次面试中,笔试问题的题目曾提到Session 0.Session 1在Vista和Xp中的区别?现在把答案发上 ...

  9. angular中通过$location获取路径(参数)的写法

    以下获取与修改的 URL 以  ( http://172.16.0.88:8100/#/homePage?id=10&a=100  ) 为例 [一]获取 (不修改URL) //1.获取当前完整 ...

  10. 楼塔当天领袖acm心理(作为励志使用)

    楼主个人博客:吉尔博客 假期空闲的时候使用.这些年来GCJ.ACM,TopCoder 的一个号码的一重要的比赛的参与 回顾.GCJ2006 的回顾,今天时间上更早一些吧,我如今还清晰记得3 年 前.我 ...