以往的TensorFLow模型数据的导入方法可以分为两个主要方法,一种是使用feed_dict另外一种是使用TensorFlow中的Queues。前者使用起来比较灵活,可以利用Python处理各种输入数据,劣势也比较明显,就是程序运行效率较低;后面一种方法的效率较高,但是使用起来较为复杂,灵活性较差。

Dataset作为新的API,比以上两种方法的速度都快,并且使用难度要远远低于使用Queues。tf.data中包含了两个用于TensorFLow程序的接口:Dataset和Iterator。

Dataset(数据集) API 在 TensorFlow 1.4版本中已经从tf.contrib.data迁移到了tf.data之中,增加了对于Python的生成器的支持,官方强烈建议使用Dataset API 为 TensorFlow模型创建输入管道,原因如下:

Dataset

Dataset表示一个元素的集合,可以看作函数式编程中的 lazy list, 元素是tensor tuple。创建Dataset的方式可以分为两种,分别是:

Source

Apply transformation
Source
这里 source 指的是从tf.Tensor对象创建Dataset,常见的方法又如下几种:

tf.data.Dataset.from_tensors((features, labels))
tf.data.Dataset.from_tensor_slices((features, labels))
tf.data.TextLineDataset(filenames)
tf.data.TFRecordDataset(filenames)

作用分别为:

  1.从一个tensor tuple创建一个单元素的dataset;

  2.从一个tensor tuple创建一个包含多个元素的dataset;

  3.读取一个文件名列表,将每个文件中的每一行作为一个元素,构成一个dataset;

  4.读取硬盘中的TFRecord格式文件,构造dataset。

Apply transformation

第二种方法就是通过转化已有的dataset来得到新的dataset,TensorFLow tf.data.Dataset支持很多中变换,在这里介绍常见的几种:

dataset.map(lambda x: tf.decode_jpeg(x))
dataset.repeat(NUM_EPOCHS)
dataset.batch(BATCH_SIZE)

以上三种方式分别表示了:使用map对dataset中的每个元素进行处理,这里的例子是对图片数据进行解码;将dataset重复一定数目的次数用于多个epoch的训练;将原来的dataset中的元素按照某个数量叠在一起,生成mini batch。

将以上代码组合起来,我们可以得到一个常用的代码片段:

# 从一个文件名列表读取 TFRecord 构成 dataset
dataset = TFRecordDataset(["file1.tfrecord", "file2.tfrecord"])
# 处理 string,将 string 转化为 tf.Tensor 对象
dataset = dataset.map(lambda record: tf.parse_single_example(record))
# buffer 大小设置为 10000,打乱 dataset
dataset = dataset.shuffle(10000)
# dataset 将被用来训练 100 个 epoch
dataset = dataset.repeat(100)
# 设置 batch size 为 128
dataset = dataset.batch(128)

Iterator

定义好了数据集以后可以通过Iterator接口来访问数据集中的tensor tuple,iterator保持了数据在数据集中的位置,提供了访问数据集中数据的方法。

可以通过调用 dataset 的 make iterator 方法来构建 iterator。

替换了place_holder,直接在原来开始的x,y处使用.get_next(),然后在sess.run时加个while true,在try里面放sess.run,exception 放OutofRangeError:

X, y = dataset.get_next()

while True:
try:
sess.run(accuracy)
except tf.errors.OutOfRangeError:
break

API 支持以下四种 iterator,复杂程度递增:

  • one-shot
  • initializable
  • reinitializable
  • feedable

one-shot

one-shot iterator 谁最简单的一种 iterator,仅支持对整个数据集访问一遍,不需要显式的初始化。one-shot iterator 不支参数化。以下代码使用tf.data.Dataset.range生成数据集,作用与 python 中的 range 类似。

dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next() for i in range(100):
value = sess.run(next_element)
assert i == value

initializable

Initializable iterator 要求在使用之前显式的通过调用iterator.initializer操作初始化,这使得在定义数据集时可以结合tf.placeholder传入参数,如:

max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next() # Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
value = sess.run(next_element)
assert i == value # Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):
value = sess.run(next_element)
assert i == value

reinitializable

reinitializable iterator 可以被不同的 dataset 对象初始化,比如对于训练集进行了shuffle的操作,对于验证集则没有处理,通常这种情况会使用两个具有相同结构的dataset对象,如:

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50) # A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
training_dataset.output_shapes)

next_element = iterator.get_next() training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset) # 如果后面初始化的是这个,那么就将循环这个数据集 # Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
for _ in range(20):
# Initialize an iterator over the training dataset.
sess.run(training_init_op)
for _ in range(100):
sess.run(next_element) # Initialize an iterator over the validation dataset.
sess.run(validation_init_op) # 替换init_op,相当于替换数据集
for _ in range(50):
sess.run(next_element)

feedable

feedable iterator 可以通过和tf.placeholder结合在一起,同通过feed_dict机制来选择在每次调用tf.Session.run的时候选择哪种Iterator。它提供了与 reinitilizable iterator 类似的功能,并且在切换数据集的时候不需要在开始的时候初始化iterator,还是上面的例子,通过tf.data.Iterator.from_string_handle来定义一个 feedable iterator,达到切换数据集的目的:

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50) # A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()
# You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator() # The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle()) # Loop forever, alternating between training and validation.
while True:
# Run 200 steps using the training dataset. Note that the training dataset is
# infinite, and we resume from where we left off in the previous `while` loop
# iteration.
for _ in range(200):
sess.run(next_element, feed_dict={handle: training_handle}) # Run one pass over the validation dataset.
sess.run(validation_iterator.initializer)
for _ in range(50):
sess.run(next_element, feed_dict={handle: validation_handle})

使用实例:

def get_encodes(x):
# x is `batch_size` of lines, each of which is a json object
samples = [json.loads(l) for l in x]
text = [s['fact'] for s in samples]
# get a client from available clients
bc_client = bc_clients.pop()
features = bc_client.encode(text)
# after use, put it back
bc_clients.append(bc_client)
labels = [0 for _ in text]
return features, labels data_node = (tf.data.TextLineDataset(train_fp).batch(batch_size)
.map(lambda x: tf.py_func(get_encodes, [x], [tf.float32, tf.int64], name='bert_client'), num_parallel_calls=num_parallel_calls)
.map(lambda x, y: {'feature': x, 'label': y})
.make_one_shot_iterator().get_next())

tf.data的更多相关文章

  1. python3 zip 与tf.data.Data.zip的用法

    ###python自带的zip函数 与 tf.data.Dataset.zip函数 功能用法相似 ''' zip([iterator1,iterator2,]) 将可迭代对象中对应的元素打包成一个元祖 ...

  2. Tensorflow2(二)tf.data输入模块

    代码和其他资料在 github 一.tf.data模块 数据分割 import tensorflow as tf dataset = tf.data.Dataset.from_tensor_slice ...

  3. tf.data(二) —— 并行化 tf.data.Dataset 生成器

    在处理大规模数据时,数据无法全部载入内存,我们通常用两个选项 使用tfrecords 使用 tf.data.Dataset.from_generator() tfrecords的并行化使用前文已经有过 ...

  4. tf.contrib.slim.data数据加载(1) reader

    reader: 适用于原始数据数据形式的Tensorflow Reader 在库中parallel_reader.py是与reader相关的,它使用多个reader并行处理来提高速度,但文件中定义的类 ...

  5. TensorFlow走过的坑之---数据读取和tf中batch的使用方法

    首先介绍数据读取问题,现在TensorFlow官方推荐的数据读取方法是使用tf.data.Dataset,具体的细节不在这里赘述,看官方文档更清楚,这里主要记录一下官方文档没有提到的坑,以示" ...

  6. tf更新tensor/自定义层

    修改Tensor特定位置的值 如 stack overflow 中提到的方案. TensorFlow不让你直接单独改指定位置的值,但是留了个歪门儿,就是tf.scatter_update这个方法,它可 ...

  7. TF常用知识

    命名空间及变量共享 # coding=utf-8 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt; ...

  8. Tensorflow1.4 高级接口使用(estimator, data, keras, layers)

    TensorFlow 高级接口使用简介(estimator, keras, data, experiment) TensorFlow 1.4正式添加了keras和data作为其核心代码(从contri ...

  9. 深度学习原理与框架-CNN在文本分类的应用 1.tf.nn.embedding_lookup(根据索引数据从数据中取出数据) 2.saver.restore(加载sess参数)

    1. tf.nn.embedding_lookup(W, X) W的维度为[len(vocabulary_list), 128], X的维度为[?, 8],组合后的维度为[?, 8, 128] 代码说 ...

随机推荐

  1. Spark-1.2.2部署

    1.安装Scala 1.1解压和安装 在Scala官网http://www.scala-lang.org/download/下载Scala安装包,然后解压.(注:JDK的版本最好是1.7及以上,否则S ...

  2. node nodemailer

    需求:通过nodejs发送邮件 一般都是用nodemailer这个模块.目前有0.7和1.0两个版本,网上的很多教程和代码,大都是按照0.7写的,自己做的时候需要注意看README.md 支持的ser ...

  3. delphi自带的SHA1算法

    delphi自带的SHA1算法 uses IdHashSHA, IdGlobal; function SHA1(Input: String): String; begin with TIdHashSH ...

  4. Sql Server 2008 压缩数据库日志文件

    第一步:将数据库设置为简单模式 选中数据库点右键->属性: 第二步:收缩数日志文件 1, 2,   第三步:将恢复模式改回为完整模式     如果你觉得用UI界面麻烦,那你就用SQL语句吧   ...

  5. Mongodb 与 SQL 语句对照表

    In addition to the charts that follow, you might want to consider the Frequently Asked Questions sec ...

  6. mysql免安装版 安装配置 (转)

    1. 下载MySQL Community Server 5.6.13 2. 解压MySQL压缩包     将以下载的MySQL压缩包解压到自定义目录下,我的解压目录是:     "D:\Pr ...

  7. Python - 更改pip源至国内镜像

    永久使用 [windows] 在用户名目录下创建一个目录 C:\Users\xxx\pip [linux] ~/.pip/pip.conf 新建pip.ini [global] index-url = ...

  8. 爬虫浅谈一:一个简单c#爬虫程序

    这篇文章只是简单展示一个基于HTTP请求如何抓取数据的文章,如觉得简单的朋友,后续我们再慢慢深入研究探讨. 图1: 如图1,我们工作过程中,无论平台网站还是企业官网,总少不了新闻展示.如某天产品经理跟 ...

  9. Xamarin.Forms第三方XAML预览工具-LiveXAML简单体验

    截至目前,Xamarin官方的Xaml Previewer工具仍然处于测试阶段,使用中也发现了各种不便,例如各种莫名其妙的渲染失败,或者提示需要编译项目才能渲染等等,复杂项目基本不可用, 完全没有体现 ...

  10. Python:Selenium+Webdriver安装

    本人小白一枚,今天在使用selenium+webdriver的时候遇到了一个小问题: WebDriverException: 'chromedriver' executable needs to be ...