TensorFlow Data模块
模块作用
tf.data api用于创建训练前导入数据和数据处理的pipeline,使得处理大规模数据,不同数据格式和复杂数据处理变的容易。
基本抽象
提供了两种基本抽象:Dataset和Iterator
Dataset
表示元素序列集合,每个元素包含一个或者多个Tensor对象,每个元素是一个样本。有两种方式可以创建Dataset。
- 从源数据创建,比如:Dataset.from_tensor_slices()
- 通过数据处理转换创建,比如 Dataset.map()/batch()
Iterator
用于从Dataset获取数据给训练作为输入,作为输入管道和训练代码的接口。最简单的例子是“one-shot iterator”,和一个dataset绑定并返回它的元素一次。通过调整初始化参数,可以实现一些复杂场景的需求,比如循环迭代训练集的元素。
基本机制
基本使用的流程如下:
1. 创建一个dataset
2. 做一些数据处理(map, batch)
3. 创建一个iterator提供给训练使用
数据结构
每个元素都有相同的数据结构,每个元素包含一个或者多个tensor,可以元组表示或者嵌套表示。每个tensor包含类型tf.DType和维度形状tf.TensorShape。每个tensor可以有名称,通过collections.namedtuple或者字典实现。
Dataset的数据处理接口能够支持任意数据结构。
创建Iterator
有了dataset以后,下一步就是创建一个iterator提供访问数据元素的接口,现在tf.data api支持4中iterator来实现不同程度的复杂场景。
- one-shot,
- initializable,
- reinitializable
- feedable.
one-shot
最简单,最常用,就是一次遍历所有元素,目前(2019/08)也是estimator唯一支持的iterator。
initializable
需要显示运行初始化操作,可以通过tf.placeholder参数化dataset。
reinitializalbe
可以从多个dataset多次初始化,当然需要相同数据结构。每次切换数据集,使用前需要初始化,
feedable
可以从多个dataset多次初始化,初始化完毕后,可以通过tf.placeholder随时切换数据集。
获取iterator的值
通过session.run(iterator.get_next())完成,若没有数据了则报异常:tf.errors.OutOfRangeError。
iterator.get_next()返回的是tf.Tensor对象,需要在session中执行才会有值。
保存iterator状态
save and restore the current state of the iterator (and, effectively, the whole input pipeline). 能将整个输入pipeline保存并恢复(原理是什么?)
saveable = tf.contrib.data.make_saveable_from_iterator(iterator)
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable)
saver = tf.train.Saver()
读取输入数据
直接读取numpy的数据得到array,调用Dataset.from_tensor_slices
读取TFRecord data: dataset = tf.data.TFRecordDataset(filenames)
通过文件的方式,可以支持不能够全部导入内存中的场景。tf.data.TFRcordDataset将读取数据流作为整个输入流的一部分。
同时,文件名可以通过tf.placeholder传入。
这里所有的dataset的返回,也都是tensorflow中的operation,需要在session中计算得到值。
解析tf.Example
推荐情形下,输入需要从TFRecord文件格式读取TF Example的protocol buffer messages,每个TF Example包含一个或者多个features,输入管道需要将其转换为tensor。
典型列子
# Transforms a scalar string `example_proto` into a pair of a scalar string and# a scalar integer, representing an image and its label, respectively.
def _parse_function(example_proto):
features = {"image": tf.FixedLenFeature((), tf.string, default_value=""), "label": tf.FixedLenFeature((), tf.int64, default_value=0)}
parsed_features = tf.parse_single_example(example_proto, features) return parsed_features["image"], parsed_features["label"]
# Creates a dataset that reads all of the examples from two files, and extracts
# the image and label features.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)
对于图像,可以通过tf.image.decode_jpeg等方式提取和resize。
任意的python逻辑处理数据情形
在文本当中,比如有时候需要调用其他python库,比如jieba,这时候输入处理需要通过t'f.py_func()opeation完成。
Batching dataset elements
简单batch
调用Dataset.batch()实现,所有元素包含同样的数据结构。
带Padding的batch
当所有元素包含不同长度的数据结构时,可以通过padding使得单个batch数据长度一致。序列模型情形下。制定padding的shape为可变长度,比如(None,)来完成,会自动padding到单个batch的最大长度。当然,可以重写padding的值或者padding的长度。
训练流程
多批次遍历
tf.data提供了两种遍历数据的方式
想要遍历数据多次,可以使用Dataset.repeat(n)操作完成。遍历n次对应多少次epochs。若不提供参数,则将无限循环。同时,将不会给出一个批次结束的信号。
若想在每个批次结束后做处理,则需要在外围加循环,然后通过重复初始化迭代器完成,捕捉tf.errors.OutOfRangeError异常。
随机shuffle
dataset.shuffle()维持一个缓存区,随机取下一个数据。
高阶API
tf.train.MonitoredTrainingSession 接口简化了分布式Tensorflow的很多方面。MonitoredTrainingSession 用tf.error.OutOfRangeError来获取训练是否结束,所以推荐采用one-shot 迭代器。
用dataset作为estimator的输入函数时,直接将dataset返回,estimator会自动创建迭代器并初始化。
TensorFlow Data模块的更多相关文章
- jQuery1.9.1源码分析--数据缓存Data模块
jQuery1.9.1源码分析--数据缓存Data模块 阅读目录 jQuery API中Data的基本使用方法介绍 jQuery.acceptData(elem)源码分析 jQuery.data(el ...
- 读Zepto源码之Data模块
Zepto 的 Data 模块用来获取 DOM 节点中的 data-* 属性的数据,和储存跟 DOM 相关的数据. 读 Zepto 源码系列文章已经放到了github上,欢迎star: reading ...
- TensorFlow saved_model 模块
最近在学tensorflow serving 模块,一直对接口不了解,后面看到这个文章就豁然开朗了, 主要的困难在于 tf.saved_model.builder.SavedModelBuilde ...
- tensorflow data's save and load
note: if you'll load data,the data shape should be similar with saved data's shape. -- 中式英语,天下无敌 ...
- 服务器安装tensorflow导入模块报错Illegal instruction (core dumped)
在ubuntu上安装tensorflow后导入模块显示Illegal instruction (core dumped) 服务器的版本是Ubuntu 16.04.5 降低版本,成功导入模块 pip3 ...
- tensorflow mnist模块详解
tensorflow的官方文档是以mnist数据集为例子开始的.文档本身没有介绍tensorflow.contrib.learn.python.learn.datasets.mnist模块.要想用te ...
- 使用tensorflow.data.Dataset构造batch数据集(具体用法在下一篇博客介绍)
import tensorflow as tf import numpy as np def _parse_function(x): num_list = np.arange(10) return n ...
- Python pandas.io.data 模块迁移
这段时间用pandas做数据分析, import pandas.io.data as web 然后得到下面的错误提示 "The pandas.io.data module is moved ...
- data模块
这个模块原本应该存放Excel文件,提供utils目录下的config模块调用: 这里公司内部无法使用Excel读取数据,顾使用了ddt,其实里面就是.xslx文件
随机推荐
- IDEA永久使用
IDEA永久使用 一.在https://www.cnblogs.com/zyx110/p/10799387.html中下载下面图片中箭头所指的部分 下载完成后双击打开,除了以下图片提示内容,一路下一步 ...
- JAVA接口的继承与集合
复习 20190701 接口补充 一. java是单继承多实现 单继承: 一个类只能有一个父类 public class D extends D1 { } 2. 多实现 一个类可以同时实现多个接口 当 ...
- NOI 2011 兔农 题解
事先声明,本博客代码主要模仿accepoc,且仅针对一般如本博主一样的蒟蒻. 这道题不得不说数据良心,给了75分的水分,但剩下25分真心很难得到,因此我们就来讲一讲这剩下的25分. 首先,有数据可知他 ...
- 基于C#的机器学习--微基准测试和激活功能
本章我们将学习以下内容: l 什么是微基准测试 l 如何将它应用到代码中 l 什么是激活函数 l 如何绘制和基准测试激活函数 每个开发人员都需要有一个好的基准测试工具.质量基准无处不在;你们每 ...
- Apache struts2 Freemarker标签远程命令执行_CVE-2017-12611(S2-053)漏洞复现
Apache struts2 Freemarker标签远程命令执行_CVE-2017-12611(S2-053)漏洞复现 一.漏洞描述 Struts2在使用Freemarker模块引擎的时候,同时允许 ...
- 【题解】埃及分数-C++
Description 在古埃及,人们使用单位分数的和(形如1/a的, a是自然数)表示一切有理数. 如:2/3=1/2+1/6,但不允许2/3=1/3+1/3,因为加数中有相同的. 对于一个分数a/ ...
- SQLite的一些体会
SQLite遵循sql语法,所以如果接触过数据库,使用它进行增删改查几乎没障碍.在.net中,它与Mysql.sql server的类也相似,比如连接类名字是SQLiteConnection,不过它S ...
- 图像识别sift+bow+svm
本文概述 利用SIFT特征进行简单的花朵识别 SIFT算法的特点有: SIFT特征是图像的局部特征,其对旋转.尺度缩放.亮度变化保持不变性,对视角变化.仿射变换.噪声也保持一定程度的稳定性: SIFT ...
- 新手小白快速登天日记之安装python以及环境变量和pycharm解释器较为详细教程
Python解释器安装及环境变量配置 Python官方网站:www.python.org 首先打开这个网址,找到downloads选择,因为我是Windows所以下载windows的,因电脑而异 然后 ...
- typedef int a[10];怎么解释?
typedef int a[10]; a b[10]; 为什么分配400个字节的空间? int a[10];为什么分配了40个字节的空间? 问题:应该怎么解释typedef的这种行为呢?而如果换成是# ...