展示如何将数据输入到计算图中

Dataset可以看作是相同类型“元素”的有序列表,在实际使用时,单个元素可以是向量、字符串、图片甚至是tuple或dict。

数据集对象实例化:

dataset=tf.data.Dataset.from_tensor_slice(<data>)

迭代器对象实例化:

iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()

读取结束异常:如果一个dataset中的元素被读取完毕,再尝试sess.run(one_element)的话,会抛出tf.errors.OutOfRangeError异常,这个行为与使用队列方式读取数据是一致的。

高维数据集的使用

tf.data.Dataset.from_tensor_slices真正作用是切分传入Tensor的第一个维度,生成相应的dataset,即第一维表明数据集中数据的数量,之后切分batch等操作均以第一维为基础。

dataset=tf.data.Dataset.from_tensor_slices(np.random.uniform((5,2)))
iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()
with tf.Session(config=config) as sess:
try:
while True:
print(sess.run(one_element))
except tf.errors.OutOfRangeError as e:
print('end~')

输出:

[0.1,0.2]
[0.3,0.2]
[0.1,0.6]
[0.4,0.3]
[0.5,0.2]

tuple组合数据

dataset=tf.data.Dataset.from_tensor_slices((np.array([1.,2.,3.,4.,5.]),
np.random.uniform(size=(5,2))))
iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()
with tf.Session() as sess:
try:
while True:
print(sess.run(one_element))
except tf.errors.OutOfRangeError:
print('end~')

输出:

(1.,array(0.1,0.3))
(2.,array(0.2,0.4))
...

数据集处理方法

Dataset支持一类特殊操作:Transformation。一个Dataset通过Transformation变成一个新的Dataset。常用的Transformation

  • map
  • batch
  • shuffle
  • repeat

其中,

  • map和python中的map一致,接受一个函数,Dataset中的每个元素都会作为这个函数的输入,并将函数返回值作为新的Dataset

    dataset=dataset.map(lambda x:x+1)

    注意:map函数可以使用num_parallel_calls参数并行化

  • batch就是将多个元素组成batch。

    dataset=tf.data.Dataset.from_tensor_slices(
    {
    'a':np.array([1.,2.,3.,4.,5.]),
    'b':np.random.uniform(size=(5,2))
    })
    ###
    dataset=dataset.batch(2) # batch_size=2
    ###
    iterator=dataset.make_one_shot_iterator()
    one_element=iterator.get_next()
    with tf.Session() as sess:
    try:
    while True:
    print(one_element)
    except tf.errors.OutOfRangeError:
    print('end~')

    输出:

    {'a':array([1.,2.]),'b':array([[1.,2.],[3.,4.]])}
    {'a':array([3.,4.]),'b':array([[5.,6.],[7.,8.]])}
  • shuffle的功能是打乱dataset中的元素,它有个参数buffer_size,表示打乱时使用的buffer的大小,不应设置过小,推荐值1000.

    dataset=tf.data.Dataset.from_tensor_slices(
    {
    'a':np.array([1.,2.,3.,4.,5.]),
    'b':np.random.uniform(size=(5,2))
    })
    ###
    dataset=dataset.shuffle(buffer_size=5)
    ###
    iterator=dataset.make_one_shot_iterator()
    one_element=iterator.get_next()
    with tf.Session() as sess:
    try:
    while True:
    print(one_element)
    except tf.errors.OutOfRangeError:
    print('end~')
  • repeat的功能就是将整个序列重复多次,主要用来处理机器学习中的epoch。假设原先的数据是一个epoch,使用repeat(2)可以使之变成2个epoch.

    dataset=tf.data.Dataset.from_tensor_slices({
    'a':np.array([1.,2.,3.,4.,5.]),
    'b':np.random.uniform(size=(5,2))
    })
    ###
    dataset=dataset.repeat(2) # 2epoch
    ###
    # iterator, one_element...

    注意:如果直接调用repeat()函数的话,生成的序列会无限重复下去,没有结果,因此不会抛出tf.errors.OutOfRangeError异常。

模拟读入磁盘图片及其Label示例

def _parse_function(filename,label):  # 接受单个元素,转换为目标
img_string=tf.read_file(filename)
img_decoded=tf.image.decode_images(img_string)
img_resized=tf.image.resize_images(image_decoded,[28,28])
return image_resized,label filenames=tf.constant(['data/img1.jpg','data/img2.jpg',...])
labels=tf.constant([1,3,...])
dataset=tf.data.Dataset.from_tensor_slices((filenames,labels))
dataset=dataset.map(_parse_function) # num_parallel_calls 并行
dataset=dataset.shuffle(buffer_size=1000).batch_size(32).repeat(10)

更多Dataset创建方法

  • tf.data.TextLineDataset():函数输入一个文件列表,输出一个Dataset。dataset中的每一个元素对应文件中的一行,可以使用该方法读入csv文件。
  • tf.data.FixedLengthRecordDataset():函数输入一个文件列表和record_bytes参数,dataset中每一个元素是文件中固定字节数record_bytes的内容,可用来读取二进制保存的文件,如CIFAR10。
  • tf.data.TFRecordDataset():读取TFRecord文件,dataset中每一个元素是一个TFExample。

更多Iterator创建方法

最简单的创建Iterator方法是通过dataset.make_one_shot_iterator()创建一个iterator。

除了这种iterator之外,还有更复杂的Iterator:

  • initializable iterator
  • reinitializable iterator
  • feedable iterator

其中,initializable iterator方法要在使用前通过sess.run()进行初始化,initializable iterator还可用于读入较大数组。在使用tf.data.Dataset.from_tensor_slices(array)时,实际上发生的事情是将array作为一个tf.constants保存到了计算图中,当array很大时,会导致计算图变得很大,给传输保存带来不便,这时可以使用一个placeholder取代这里的array,并使用initializable iterator,只在需要时将array传进去,这样即可避免将大数组保存在图里。

features_placeholder=tf.placeholder(<features.dtype>,<features.shape>)
labels_placeholder=tf.placeholder(<labels.dtype>,<labels.shape>)
dataset=tf.data.Dataset.from_tensor_slices((features_placeholder,labels_placeholder))
iterator=dataset.make_initializable_iterator()
next_element=iterator.get_next()
sess.run(iterator.initializer,feed_dict={features_placeholder:features,labels_placeholder:labels})

Tensorflow内部读取机制

对于文件名队列,使用tf.train.string_input_producer()函数,tf.train.string_input_producer()还有两个重要参数,num_epochesshuffle

内存队列不需要我们建立,只需要使用reader对象从文件名队列中读取数据即可,使用tf.train.start_queue_runners()函数启动队列,填充两个队列的数据。

with tf.Session() as sess:
filenames=['A.jpg','B.jpg','C.jpg']
filename_queue=tf.train.string_input_producer(filenames,shuffle=True,num_epoch=5)
reader=tf.WholeFileReader()
key,value=reader.read(filename_queue)
# tf.train.string_input_producer()定义了一个epoch变量,需要对其进行初始化
tf.local_variables_initializer().run()
threads=tf.train.start_queue_runners(sess=sess)
i=0
while True:
i+=1
image_data=sess.run(value)
with open('reader/test_%d.jpg'%i,'wb') as f:
f.write(image_data)

Tensorflow数据读取机制的更多相关文章

  1. 十图详解tensorflow数据读取机制(附代码)转知乎

    十图详解tensorflow数据读取机制(附代码) - 何之源的文章 - 知乎 https://zhuanlan.zhihu.com/p/27238630

  2. tensorflow 1.0 学习:十图详解tensorflow数据读取机制

    本文转自:https://zhuanlan.zhihu.com/p/27238630 在学习tensorflow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找 ...

  3. 十图详解tensorflow数据读取机制

    在学习tensorflow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找不到什么合适的学习材料.今天这篇文章就以图片的形式,用最简单的语言,为大家详细解释一下 ...

  4. 十图详解TensorFlow数据读取机制(附代码)

    在学习TensorFlow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找不到什么合适的学习材料.今天这篇文章就以图片的形式,用最简单的语言,为大家详细解释一下 ...

  5. 【转载】 十图详解tensorflow数据读取机制(附代码)

    原文地址: https://zhuanlan.zhihu.com/p/27238630 何之源 ​ 深度学习(Deep Learning) 话题的优秀回答者       --------------- ...

  6. tensorflow数据读取机制tf.train.slice_input_producer 和 tf.train.batch 函数

    tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数据读入到一个内存队列中,另一个线程 ...

  7. TensorFlow数据读取

    TensorFlow高效读取数据的方法 TF Boys (TensorFlow Boys ) 养成记(二): TensorFlow 数据读取 Tensorflow从文件读取数据 极客学院-数据读取 十 ...

  8. TensorFlow数据读取方式:Dataset API

    英文详细版参考:https://www.cnblogs.com/jins-note/p/10243716.html Dataset API是TensorFlow 1.3版本中引入的一个新的模块,主要服 ...

  9. 详解Tensorflow数据读取有三种方式(next_batch)

    转自:https://blog.csdn.net/lujiandong1/article/details/53376802 Tensorflow数据读取有三种方式: Preloaded data: 预 ...

随机推荐

  1. MKNetworkKit的断点续传SIDownloader下载

    comefrom:http://cache.baiducontent.com/c?m=9f65cb4a8c8507ed4fece763105392230e54f73d6f8b9042238fce098 ...

  2. 开源库Fab-Transformation简单使用解析

    转载请注明出处王亟亟的大牛之路 相似于IPhone的悬浮按钮的操作,仅仅只是是固定的,当然经过自己的改动也能够动.这边仅仅是给伸手党一个福祉,外加加上一些自己的理解.让大家能够拿来就用.看了就懂,废话 ...

  3. SCM文章9类:外部中断示例程序

    JP3遇见P0口,JP5遇见P3口,P1接受该发光二极管,什么时候P1所有的都是高时,,全亮度发光二极管.因为外部中断0和1用同样的方法.这里只是外部中断0计划. #include<reg51. ...

  4. 在 Oracle 中新建 SDE 用户

    --1.创建用户(SDE)和密码(SDE) CREATE USER SDE IDENTIFIED BY SDE --2.创建表空间(SDE) CREATE TABLESPACE SDE DATAFIL ...

  5. 在项目中使用CLR规划

    1.创建自己的项目 2.对"解..."→参加→目→C#→数据库→SQL Server项目,例如以下图所看到的: 3.选择操作数据库 4.创建存储过程 5.代码(详见:CLR存储过程 ...

  6. java开发环境配置(windows下JDK7+tomcat7)

    參考原文:http://www.cnblogs.com/goto/archive/2012/11/16/2772683.html http://www.cnblogs.com/feilong35407 ...

  7. In partitioned databases, trading some consistency for availability can lead to dramatic improvements in scalability.

    In partitioned databases, trading some consistency for availability can lead to dramatic improvement ...

  8. 将memo转化为JPG输出,使用Memo1.PaintTo(Bitmap.Canvas)

    unit unit1; interface uses  Windows, Messages, SysUtils, Graphics, Controls, Forms, StdCtrls,  Class ...

  9. 初探js

    第一章   1.JS的位置 1-1.行间 1-2.内嵌 1-3.外联 2.JS的标签位置 页面中的代码在一般情况下会按从上到下的顺序,从左往右的顺序执行. 因此当JS放在了元素上面的时候,就不能正常执 ...

  10. Android中SQLite数据库操作(1)——使用SQL语句操作SQLite数据库

    下面是最原始的方法,用SQL语句操作数据库.后面的"Android中SQLite数据库操作(2)--SQLiteOpenHelper类"将介绍一种常用的android封装操作SQL ...