tf.data
Dataset作为新的API,比以上两种方法的速度都快,并且使用难度要远远低于使用Queues。tf.data中包含了两个用于TensorFLow程序的接口:Dataset和Iterator。
Dataset(数据集) API 在 TensorFlow 1.4版本中已经从tf.contrib.data迁移到了tf.data之中,增加了对于Python的生成器的支持,官方强烈建议使用Dataset API 为 TensorFlow模型创建输入管道,原因如下:
Dataset
Dataset表示一个元素的集合,可以看作函数式编程中的 lazy list, 元素是tensor tuple。创建Dataset的方式可以分为两种,分别是:
Source
Apply transformation
Source
这里 source 指的是从tf.Tensor对象创建Dataset,常见的方法又如下几种:
tf.data.Dataset.from_tensors((features, labels))
tf.data.Dataset.from_tensor_slices((features, labels))
tf.data.TextLineDataset(filenames)
tf.data.TFRecordDataset(filenames)
作用分别为:
1.从一个tensor tuple创建一个单元素的dataset;
2.从一个tensor tuple创建一个包含多个元素的dataset;
3.读取一个文件名列表,将每个文件中的每一行作为一个元素,构成一个dataset;
4.读取硬盘中的TFRecord格式文件,构造dataset。
Apply transformation
第二种方法就是通过转化已有的dataset来得到新的dataset,TensorFLow tf.data.Dataset支持很多中变换,在这里介绍常见的几种:
dataset.map(lambda x: tf.decode_jpeg(x))
dataset.repeat(NUM_EPOCHS)
dataset.batch(BATCH_SIZE)
以上三种方式分别表示了:使用map对dataset中的每个元素进行处理,这里的例子是对图片数据进行解码;将dataset重复一定数目的次数用于多个epoch的训练;将原来的dataset中的元素按照某个数量叠在一起,生成mini batch。
将以上代码组合起来,我们可以得到一个常用的代码片段:
# 从一个文件名列表读取 TFRecord 构成 dataset
dataset = TFRecordDataset(["file1.tfrecord", "file2.tfrecord"])
# 处理 string,将 string 转化为 tf.Tensor 对象
dataset = dataset.map(lambda record: tf.parse_single_example(record))
# buffer 大小设置为 10000,打乱 dataset
dataset = dataset.shuffle(10000)
# dataset 将被用来训练 100 个 epoch
dataset = dataset.repeat(100)
# 设置 batch size 为 128
dataset = dataset.batch(128)
Iterator
定义好了数据集以后可以通过Iterator接口来访问数据集中的tensor tuple,iterator保持了数据在数据集中的位置,提供了访问数据集中数据的方法。
可以通过调用 dataset 的 make iterator 方法来构建 iterator。
替换了place_holder,直接在原来开始的x,y处使用.get_next(),然后在sess.run时加个while true,在try里面放sess.run,exception 放OutofRangeError:
X, y = dataset.get_next() while True:
try:
sess.run(accuracy)
except tf.errors.OutOfRangeError:
break
API 支持以下四种 iterator,复杂程度递增:
- one-shot
- initializable
- reinitializable
- feedable
one-shot
one-shot iterator 谁最简单的一种 iterator,仅支持对整个数据集访问一遍,不需要显式的初始化。one-shot iterator 不支参数化。以下代码使用tf.data.Dataset.range生成数据集,作用与 python 中的 range 类似。
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next() for i in range(100):
value = sess.run(next_element)
assert i == value
initializable
Initializable iterator 要求在使用之前显式的通过调用iterator.initializer操作初始化,这使得在定义数据集时可以结合tf.placeholder传入参数,如:
max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next() # Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
value = sess.run(next_element)
assert i == value # Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):
value = sess.run(next_element)
assert i == value
reinitializable
reinitializable iterator 可以被不同的 dataset 对象初始化,比如对于训练集进行了shuffle的操作,对于验证集则没有处理,通常这种情况会使用两个具有相同结构的dataset对象,如:
# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50) # A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
training_dataset.output_shapes)
next_element = iterator.get_next() training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset) # 如果后面初始化的是这个,那么就将循环这个数据集 # Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
for _ in range(20):
# Initialize an iterator over the training dataset.
sess.run(training_init_op)
for _ in range(100):
sess.run(next_element) # Initialize an iterator over the validation dataset.
sess.run(validation_init_op) # 替换init_op,相当于替换数据集
for _ in range(50):
sess.run(next_element)
feedable
feedable iterator 可以通过和tf.placeholder结合在一起,同通过feed_dict机制来选择在每次调用tf.Session.run的时候选择哪种Iterator。它提供了与 reinitilizable iterator 类似的功能,并且在切换数据集的时候不需要在开始的时候初始化iterator,还是上面的例子,通过tf.data.Iterator.from_string_handle来定义一个 feedable iterator,达到切换数据集的目的:
# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50) # A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next() # You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator() # The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle()) # Loop forever, alternating between training and validation.
while True:
# Run 200 steps using the training dataset. Note that the training dataset is
# infinite, and we resume from where we left off in the previous `while` loop
# iteration.
for _ in range(200):
sess.run(next_element, feed_dict={handle: training_handle}) # Run one pass over the validation dataset.
sess.run(validation_iterator.initializer)
for _ in range(50):
sess.run(next_element, feed_dict={handle: validation_handle})
使用实例:
def get_encodes(x):
# x is `batch_size` of lines, each of which is a json object
samples = [json.loads(l) for l in x]
text = [s['fact'] for s in samples]
# get a client from available clients
bc_client = bc_clients.pop()
features = bc_client.encode(text)
# after use, put it back
bc_clients.append(bc_client)
labels = [0 for _ in text]
return features, labels data_node = (tf.data.TextLineDataset(train_fp).batch(batch_size)
.map(lambda x: tf.py_func(get_encodes, [x], [tf.float32, tf.int64], name='bert_client'), num_parallel_calls=num_parallel_calls)
.map(lambda x, y: {'feature': x, 'label': y})
.make_one_shot_iterator().get_next())
tf.data的更多相关文章
- python3 zip 与tf.data.Data.zip的用法
###python自带的zip函数 与 tf.data.Dataset.zip函数 功能用法相似 ''' zip([iterator1,iterator2,]) 将可迭代对象中对应的元素打包成一个元祖 ...
- Tensorflow2(二)tf.data输入模块
代码和其他资料在 github 一.tf.data模块 数据分割 import tensorflow as tf dataset = tf.data.Dataset.from_tensor_slice ...
- tf.data(二) —— 并行化 tf.data.Dataset 生成器
在处理大规模数据时,数据无法全部载入内存,我们通常用两个选项 使用tfrecords 使用 tf.data.Dataset.from_generator() tfrecords的并行化使用前文已经有过 ...
- tf.contrib.slim.data数据加载(1) reader
reader: 适用于原始数据数据形式的Tensorflow Reader 在库中parallel_reader.py是与reader相关的,它使用多个reader并行处理来提高速度,但文件中定义的类 ...
- TensorFlow走过的坑之---数据读取和tf中batch的使用方法
首先介绍数据读取问题,现在TensorFlow官方推荐的数据读取方法是使用tf.data.Dataset,具体的细节不在这里赘述,看官方文档更清楚,这里主要记录一下官方文档没有提到的坑,以示" ...
- tf更新tensor/自定义层
修改Tensor特定位置的值 如 stack overflow 中提到的方案. TensorFlow不让你直接单独改指定位置的值,但是留了个歪门儿,就是tf.scatter_update这个方法,它可 ...
- TF常用知识
命名空间及变量共享 # coding=utf-8 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt; ...
- Tensorflow1.4 高级接口使用(estimator, data, keras, layers)
TensorFlow 高级接口使用简介(estimator, keras, data, experiment) TensorFlow 1.4正式添加了keras和data作为其核心代码(从contri ...
- 深度学习原理与框架-CNN在文本分类的应用 1.tf.nn.embedding_lookup(根据索引数据从数据中取出数据) 2.saver.restore(加载sess参数)
1. tf.nn.embedding_lookup(W, X) W的维度为[len(vocabulary_list), 128], X的维度为[?, 8],组合后的维度为[?, 8, 128] 代码说 ...
随机推荐
- linux导出Excel The maximum column width for an individual cell is 255 characters
linux环境到处Excel报错: The maximum column width for an individual cell is 255 characters 解决方案: for (int i ...
- Codeforces777A Shell Game 2017-05-04 17:11 59人阅读 评论(0) 收藏
A. Shell Game time limit per test 0.5 seconds memory limit per test 256 megabytes input standard inp ...
- C++中的乱七八糟问题
1 在编写的c++程序中,如果是窗口,有时会一闪就消失了,如果不想让其消失,在程序结尾处添加: #include“iostream.h” system("pause"); 分析 ...
- 给ubuntu系统的root设置密码:
首先输入:sudo passwd root 然后输入当前用户的密码,例如xiaomeige这个用户的密码为xiaomeige 然后输入希望给root账户设置的密码,例如密码也为root
- CF1096D Easy Problem(DP)
题意:给出一个字符串,去掉第i位的花费为a[i],求使字符串中子串不含hard的最小代价. 题解:这题的思路还是比较套路的, dp[i][kd]两维,kd=0表示不含d的最小花费,1表示不含rd ...
- Linux内存子系统及常用调优参数
1>内存子系统 1>组件: slab allocator buddy system kswapd pdflush 2>虚拟化环境: PA:进程地址: HA:虚拟机地址: ...
- 用NPOI操作EXCEL--巧妙使用Excel Chart
在NPOI中,本身并不支持Chart等高级对象的创建,但通过l模板的方式可以巧妙地利用Excel强大的透视和图表功能,请看以下例子. 首先建立模板文件,定义两列以及指向此区域的名称“sales”: 创 ...
- HTML/HTML5
HTML/HTML5 一.文档加载顺序. 文档入下,数字编号为加载顺序. <html><!--1--> <head><!--2--> <link ...
- OpenLayers在地图上显示统计图,饼图线状图柱状图,修复统计图跳动的问题
环境介绍 Openlayers ol.js v5.3.0 Highcharts highcharts.js v7.0.1 jquery jquery-3.3.1.js v3.3.1 显示效果 地图放大 ...
- Hibernate一级缓存测试分析
Hibernate 一级缓存测试分析 Hibernate的一级缓存就是指Session缓存,此Session非http的session会话技术,可以理解为JDBC的Connection,连接会话,Se ...