tf.data API可以建立复杂的输入管道。它可以从分布式文件系统中汇总数据,对每个图像数据施加随机扰动,随机选择图像组成一个批次训练。一个文本模型的管道可能涉及提取原始文本数据的符号,使用查询表将它们转换成嵌入标识,将不同长度的数据组成一个批次。tf.data API让处理大规模数据、不同格式数据和进行复杂变换更容易。

tf.data API引入了两个抽象机制。

(1)tf.data.Dataset 表示一个元素序列,每个元素包含一个或多个Tensor对象。比如,一个图像管道中,一个元素可能是单个训练样例,由一对tensor组成,包括图像数据和标签。有两种不同的方式来生成一个dataset:

1)生成一个源(source)(举例:Dataset.from_tensor_slices()),从一个或多个tf.Tensor对象中构造一个dataset。

2)应用一个变换(举例:Dataset.batch()),从一个或多个tf.data.Dataset对象中构造一个dataset。

(2)tf.data.Iterator提供了主要方法来从dataset中提取元素。通过Iterator.get_next()产生Dataset中下一个要执行的元素,这是输入管道和模型之间的一个接口。最简单的迭代器是"one-shot iterator", 这个迭代器和一个特定的Dataset联系,并只从中迭代一次。对于更多复杂的使用情况,Iterator.initializer操作允许重新初始化和参数化一个迭代器使用不同datasets,比如,在同样的程序中,迭代训练数据和验证数据多次。

1.基本机制

这节描述生成不同Dataset和Iterator对象的基础知识,和如何从中提取数据。

为了开始一个输入管道,首先需要定义一个源(source)。比如,从内存中的一些tensors中构造一个Dataset,可以使用tf.data.Dataset.from_tensors()或tf.data.Dataset.from_tensor_slices()。另外,如果输入数据是以推荐的TFRecord格式存储在硬盘中,可以构造tf.data.TFRecordDataset.

一旦有了Dataset对象,可以通过调用tf.data.Dataset的链方法将其变换成新的Dataset。比如,可以应用逐元素的变换如Dataset.map()(应用一个函数到每个元素),和多元素变换如Dataset.batch()。请参考tf.data.Dataset中完整的转换列表。

从Dataset中消耗值的最常用方式是,构建一个iterator对象,提供每次提供dataset中一个元素的获取(比如,调用Dataset.make_one_shot_iterator())。一个tf.data.Iterator提供两种操作:Iterator.initializer,用来初始化迭代器状态,Iterator.get_next(),返回表示下一个元素的tf.Tensor对象。取决于使用情况,可以使用不同类型的iterator,不同类型会在下面介绍。

Dataset结构

一个dataset包含了许多具有相同结构的元素,一个元素包含了一个或多个tf.Tensor对象,称为components。每个component有一个tf.DType代表元素类型,和一个tf.TensorShape代表每个元素的静态形状。Dataset.output_types和Dataset.output_shapes属性允许你检查dataset每个元素中每个component的类型和形状。这些属性的嵌套结构映射到每个元素的结构,可能是单个tensor,一个tensor元组,或一个嵌套的tensor元组。举例:

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
print(dataset1.output_types) # ==> "tf.float32"
print(dataset1.output_shapes) # ==> "(10,)" dataset2 = tf.data.Dataset.from_tensor_slices(
(tf.random_uniform([4]),
tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)))
print(dataset2.output_types) # ==> "(tf.float32, tf.int32)"
print(dataset2.output_shapes) # ==> "((), (100,))" dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
print(dataset3.output_types) # ==> (tf.float32, (tf.float32, tf.int32))
print(dataset3.output_shapes) # ==> "(10, ((), (100,)))"

通常给出一个元素的每个component的名字会更方便,如果他们表示训练样本的不同特征。除了元组(tuples)外,可以使用collections.namedtuple或一个字典映射字符串到tensors,来表示Dataset的一个单个元素。

dataset = tf.data.Dataset.from_tensor_slices(
{"a": tf.random_uniform([4]),
"b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})
print(dataset.output_types) # ==> "{'a': tf.float32, 'b': tf.int32}"
print(dataset.output_shapes) # ==> "{'a': (), 'b': (100,)}"

Dataset变换支持任何结构的datasets,当使用Dataset.map(), Dataset.flat_map(), 和Dataset.filter()变换时,这些变换对每个元素应用一个函数,元素结构决定了函数参数。

dataset1 = dataset1.map(lambda x: ...)

dataset2 = dataset2.flat_map(lambda x, y: ...)

# Note: Argument destructuring is not available in Python 3.
dataset3 = dataset3.filter(lambda x, (y, z): ...)

生成一个iterator

一旦建立了Dataset来表示你的输入数据,下一步是生活从呢个一个Iterator来获得数据集中的元素。tf.data API支持下列iterators, 复杂度依次递增。

  • one-shot
  • initializable
  • reinitializable
  • feedable

一个one-shot iterator是最简单形式的iterator, 只支持dataset中的一次迭代,不需要显示初始化。one-shot iterators处理几乎所有基于队列支持的输入管道情况,但是它们不支持参数。使用Dataset.range()例子:

dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next() for i in range(100):
value = sess.run(next_element)
assert i == value

一个initializable iterator要求在使用之前运行一个显式的iterator.initializer操作。为了交换方便,它可以参数化定义dataset,使用一个或多个tf.placeholder() tensors 在初始化迭代器时被喂数据。

max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next() # Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
value = sess.run(next_element)
assert i == value # Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):
value = sess.run(next_element)
assert i == value

一个reinitializable iterator可以使用多个不同的Dataset对象初始化。举例,你可能有一个训练输入管道,使用随机扰动来提升输入图像的泛化能力,和一个验证输入管道评估未修改数据的预测。这些管道通常会使用不同的Dataset对象,并且有相同的结构。

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50) # A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
training_dataset.output_shapes)
next_element = iterator.get_next() training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset) # Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
for _ in range(20):
# Initialize an iterator over the training dataset.
sess.run(training_init_op)
for _ in range(100):
sess.run(next_element) # Initialize an iterator over the validation dataset.
sess.run(validation_init_op)
for _ in range(50):
sess.run(next_element)

一个feedable iterator可以结合tf.placeholder使用,来选择每次调用tf.Session.run时使用什么Iterator, 通过熟悉的feed_dict机制。它提供了与reinitializable iterator同样的功能,但当你在迭代器间切换时不要求从数据集的开始初始化迭代器。举例,使用上例中同样的训练和测试样例,可以使用tf.data.Iterator.from_string_handle来定义一个feedable iterator,这允许你在两个数据集间切换。

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50) # A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next() # You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator() # The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle()) # Loop forever, alternating between training and validation.
while True:
# Run 200 steps using the training dataset. Note that the training dataset is
# infinite, and we resume from where we left off in the previous `while` loop
# iteration.
for _ in range(200):
sess.run(next_element, feed_dict={handle: training_handle}) # Run one pass over the validation dataset.
sess.run(validation_iterator.initializer)
for _ in range(50):
sess.run(next_element, feed_dict={handle: validation_handle})

从iterator中消耗值

Iterator.get_next()方法返回一个或多个tf.Tensor对象,对应迭代器的下一个元素。

dataset = tf.data.Dataset.range(5)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next() # Typically `result` will be the output of a model, or an optimizer's
# training operation.
result = tf.add(next_element, next_element) sess.run(iterator.initializer)
print(sess.run(result)) # ==> "0"
print(sess.run(result)) # ==> "2"
print(sess.run(result)) # ==> "4"
print(sess.run(result)) # ==> "6"
print(sess.run(result)) # ==> "8"
try:
sess.run(result)
except tf.errors.OutOfRangeError:
print("End of dataset") # ==> "End of dataset"

保存迭代器状态

tf.contrib.data.make_saveable_from_iterator函数生成一个SaveableObject,从一个迭代器中,这可以用来保存或还原迭代器的当前状态。这样生成的一个保存对象可以被加入到tf.train.Saver变量列表或tf.GraphKeys.SAVEABLE_OBJECTS collection中,以与tf.Variable相同的形式保存或还原。

# Create saveable object from iterator.
saveable = tf.contrib.data.make_saveable_from_iterator(iterator) # Save the iterator state by adding it to the saveable objects collection.
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable)
saver = tf.train.Saver() with tf.Session() as sess: if should_checkpoint:
saver.save(path_to_checkpoint) # Restore the iterator state.
with tf.Session() as sess:
saver.restore(sess, path_to_checkpoint)

2.读输入数据

如果所有的输入数据都适合内存,生成Dataset最简单的方式是将它们转换成tf.Tensor对象,使用Dataset.from_tensor_slices()

# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
features = data["features"]
labels = data["labels"] # Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0] dataset = tf.data.Dataset.from_tensor_slices((features, labels))

上面的方式比较占据内存。作为替代,可以使用tf.placeholder()定义Dataset,当初始化Iterator时喂Numpy数组。

# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
features = data["features"]
labels = data["labels"] # Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0] features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape) dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
# [Other transformations on `dataset`...]
dataset = ...
iterator = dataset.make_initializable_iterator() sess.run(iterator.initializer, feed_dict={features_placeholder: features,
labels_placeholder: labels})

消费TFRecord数据

tf.data API支持一系列文件格式,这样可以处理不适应内存的大型数据集。TFRecord文件格式是单个面向记录的二进制格式,许多tensorflow应用使用它作为训练数据。tf.data.TFRecordDataset类可以将一个或多个TFRecord文件作为内容输入管道。

# Creates a dataset that reads all of the examples from two files.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)

filenames参数既可以是string, strings列表,或tf.Tensor of strings. 当有两个文件集合用于训练和验证目的时,可以使用tf.placeholder(tf.string)表示filenames,从合适的filenames中初始化迭代器。

filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...) # Parse the record into tensors.
dataset = dataset.repeat() # Repeat the input indefinitely.
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator() # You can feed the initializer with the appropriate filenames for the current
# phase of execution, e.g. training vs. validation. # Initialize `iterator` with training data.
training_filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
sess.run(iterator.initializer, feed_dict={filenames: training_filenames}) # Initialize `iterator` with validation data.
validation_filenames = ["/var/data/validation1.tfrecord", ...]
sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})

消费文本数据

filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.TextLineDataset(filenames)
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]

dataset = tf.data.Dataset.from_tensor_slices(filenames)

# Use `Dataset.flat_map()` to transform each file as a separate nested dataset,
# and then concatenate their contents sequentially into a single "flat" dataset.
# * Skip the first line (header row).
# * Filter out lines beginning with "#" (comments).
dataset = dataset.flat_map(
lambda filename: (
tf.data.TextLineDataset(filename)
.skip(1)
.filter(lambda line: tf.not_equal(tf.substr(line, 0, 1), "#"))))

3.使用Dataset.map()预处理数据

Dataset.map(f)变换生成一个新的数据集,通过对输入数据集的每个元素应用函数f。map()函数通常应用在列表结构。

解析tf.Example协议缓冲消息

许多输入管道提取tf.train.Example协议缓冲消息,从TFRecord格式文件中。每个tf.train.Example记录包含一个或多个“features”,输入管道通常将这些特征转换为tensors.

# Transforms a scalar string `example_proto` into a pair of a scalar string and
# a scalar integer, representing an image and its label, respectively.
def _parse_function(example_proto):
features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
"label": tf.FixedLenFeature((), tf.int64, default_value=0)}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features["image"], parsed_features["label"] # Creates a dataset that reads all of the examples from two files, and extracts
# the image and label features.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)

解码图像数据和resizing it

# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def _parse_function(filename, label):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string)
image_resized = tf.image.resize_images(image_decoded, [28, 28])
return image_resized, label # A vector of filenames.
filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...]) # `labels[i]` is the label for the image in `filenames[i].
labels = tf.constant([0, 37, ...]) dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)

使用tf.py_func()应用任意Python logic

某些时候,使用额外的Python库解析输入数据时,是有用的。这时,在Dataset.map()变换中调用tf.py_func()操作。

import cv2

# Use a custom OpenCV function to read the image, instead of the standard
# TensorFlow `tf.read_file()` operation.
def _read_py_function(filename, label):
image_decoded = cv2.imread(filename.decode(), cv2.IMREAD_GRAYSCALE)
return image_decoded, label # Use standard TensorFlow operations to resize the image to a fixed shape.
def _resize_function(image_decoded, label):
image_decoded.set_shape([None, None, None])
image_resized = tf.image.resize_images(image_decoded, [28, 28])
return image_resized, label filenames = ["/var/data/image1.jpg", "/var/data/image2.jpg", ...]
labels = [0, 37, 29, 1, ...] dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(
lambda filename, label: tuple(tf.py_func(
_read_py_function, [filename, label], [tf.uint8, label.dtype])))
dataset = dataset.map(_resize_function)

4. Batching dataset elements

简单的批处理

最简单的批处理形式是将数据集的n个 连续元素堆叠成单个元素。Dataset.batch()变换做这件事,和tf.stack()有相同的限制,对每个component i, 所有元素必须有相同的shape。

inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4) iterator = batched_dataset.make_one_shot_iterator()
next_element = iterator.get_next() print(sess.run(next_element)) # ==> ([0, 1, 2, 3], [ 0, -1, -2, -3])
print(sess.run(next_element)) # ==> ([4, 5, 6, 7], [-4, -5, -6, -7])
print(sess.run(next_element)) # ==> ([8, 9, 10, 11], [-8, -9, -10, -11])

Batching tensors with padding

为了处理许多模型(比如序列模型)的输入数据有不同size的情况,Dataset.padded_batch()变换可以将不同形状的tensors指定一个或多个维度padding,来进行批处理。

dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=[None]) iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next() print(sess.run(next_element)) # ==> [[0, 0, 0], [1, 0, 0], [2, 2, 0], [3, 3, 3]]
print(sess.run(next_element)) # ==> [[4, 4, 4, 4, 0, 0, 0],
# [5, 5, 5, 5, 5, 0, 0],
# [6, 6, 6, 6, 6, 6, 0],
# [7, 7, 7, 7, 7, 7, 7]]

5. 训练工作流

处理多个epochs

最简单的处理方式是使用Dataset.repeat()变换。

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.repeat(10)
dataset = dataset.batch(32)

随机打乱输入数据

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat()

使用高级APIs

tensorflow Importing Data的更多相关文章

  1. Importing data in R 1

    目录 Importing data in R 学习笔记1 flat files:CSV txt文件 packages:readr read_csv() read_tsv read_delim() da ...

  2. 扩增子分析QIIME2-2数据导入Importing data

    # 激活工作环境 source activate qiime2-2017.8 # 建立工作目录 mkdir -p qiime2-importing-tutorial cd qiime2-importi ...

  3. How to use Data Iterator in TensorFlow

    How to use Data Iterator in TensorFlow one_shot_iterator initializable iterator reinitializable iter ...

  4. csharp:asp.net Importing or Exporting Data from Worksheets using aspose cell

    using System; using System.Data; using System.Configuration; using System.Collections; using System. ...

  5. Tutorial: Importing and analyzing data from a Web Page using Power BI Desktop

    In this tutorial, you will learn how to import a table of data from a Web page and create a report t ...

  6. TensorFlow数据读取方式:Dataset API

    英文详细版参考:https://www.cnblogs.com/jins-note/p/10243716.html Dataset API是TensorFlow 1.3版本中引入的一个新的模块,主要服 ...

  7. TensorFlow Ops

    TensorFlow Ops 1. Fun with TensorBoard In TensorFlow, you collectively call constants, variables, op ...

  8. TF Boys (TensorFlow Boys ) 养成记(六)

    圣诞节玩的有点嗨,差点忘记更新.祝大家昨天圣诞节快乐,再过几天元旦节快乐. 来继续学习,在/home/your_name/TensorFlow/cifar10/ 下新建文件夹cifar10_train ...

  9. TF Boys (TensorFlow Boys ) 养成记(四)

    前面基本上把 TensorFlow 的在图像处理上的基础知识介绍完了,下面我们就用 TensorFlow 来搭建一个分类 cifar10 的神经网络. 首先准备数据: cifar10 的数据集共有 6 ...

随机推荐

  1. html单选框(性别选择)

    在写单选框时,如何实现只能同时只能选择一个radio. 将name设置为一样的数值:代码如下: <input class="myforms-3-2" type="r ...

  2. LeetCode14.最长公共前缀 JavaScript

    编写一个函数来查找字符串数组中的最长公共前缀. 如果不存在公共前缀,返回空字符串 "". 示例 1: 输入: ["flower","flow" ...

  3. vue 集成百度富文本编辑器

    <template> <div> <textarea style="display:none" id="editor_content&quo ...

  4. Ionic的项目结构(angluar js)

    Hybird HTML5 App(移动应用开发)之3.Ionic的项目结构 前面使用命令ionic start myapp下载了默认的Ionic应用程序,下面我们打开应用程序项目,来分析一下Ionic ...

  5. windows 开启 nginx 监听80 端口 以及 禁用 http 服务后,无法重启 HTTP 服务,提示 系统错误 123,文件目录、卷标出错

    1. 正常情况直接运行  start nginx.exe 不能开启成功,因为 80 端口被占用.提示: bind() to 0.0.0.0:80 failed (10013: An attempt w ...

  6. 剑指offer—二维数组中的查找

    题目描述 在一个二维数组中(每个一维数组的长度相同),每一行都按照从左到右递增的顺序排序,每一列都按照从上到下递增的顺序排序.请完成一个函数,输入这样的一个二维数组和一个整数,判断数组中是否含有该整数 ...

  7. 二、html篇

    1.<br/> 有时css实现换行比较麻烦,可以使用该标签进行换行. 2.<strong></strong>  <ins></ins>  & ...

  8. 几种常用的git命令

    1.合并代码出现冲突,用git status 查看冲突所在的文件 2. clone 指定分支分支的文件夹 git clone -b **** ***; 3.git merge 和 git rebase ...

  9. 搞笑入群二维码在线生成源码 php图片合成并添加文字水印

    在凤凰网看到一篇文章:微信群二维码也能“整人”,99%的好友会中招!感觉挺好玩,所以自己也想做一个! 冷静分析

  10. 微信小程序使用相机

    <view class="page-body"> <view class="page-body-wrapper"> <camera ...