Tensorflow学习-数据读取
Tensorflow数据读取方式主要包括以下三种
- Preloaded data:预加载数据
- Feeding: 通过Python代码读取或者产生数据,然后给后端
- Reading from file: 通过TensorFlow队列机制,从文件中直接读取数据
前两种方法比较基础而且容易理解,在Tensorflow入门教程、书本中经常可以见到,这里不再进行介绍。
在介绍Tensorflow第三种读取数据方法之前,介绍以下有关队列相关知识
Queue(队列)
队列是用来存放数据的,并且tensorflow中的Queue中已经实现了同步机制,所以我们可以放心的往里面添加数据还有读取数据。如果Queue中的数据满了,那么en_queue(队列添加元素)操作将会阻塞,如果Queue是空的,那么dequeue(队列抛出元素)操作就会阻塞.在常用环境中,一般是有多个en_queue线程同时像Queue中放数据,有一个dequeue操作从Queue中取数据。
Coordinator(协调管理器)
Coordinator主要是用来帮助管理多个线程,协调多线程之间的配合
# Thread body: loop until the coordinator indicates a stop was requested.
# If some condition becomes true, ask the coordinator to stop.
#将coord传入到线程中,来帮助它们同时停止工作
def MyLoop(coord):
while not coord.should_stop():
...do something...
if ...some condition...:
coord.request_stop()
# Main thread: create a coordinator.
coord = tf.train.Coordinator()
# Create 10 threads that run 'MyLoop()'
threads = [threading.Thread(target=MyLoop, args=(coord,)) for i in xrange(10)]
# Start the threads and wait for all of them to stop.
for t in threads:
t.start()
coord.join(threads)
QueueRunner()
QueueRunner可以创建多个线程对队列(queue)进行插入(enqueue)操作,它是一个op,这些线程可以通过上述的Coordinator协调器来协调工作。
在深度学习中样本数据集有多种存储编码形式,以经典数据集Cifar-10为例,公开共下载的数据有三种存储方式:Bin(二进制)、Python以及Matlab版本。此外,我们常用的还有csv(天池竞赛、百度竞赛等)比较常见或txt等,当然对图片存储最为直观的还是可视化展示的TIF、PNG、JPG等。Tensorflow官方推荐使用他自己的一种文件格式叫TFRecord,具体实现及应用会在以后详细介绍。
从上图中可知,Tensorflow数据读取过程主要包括两个队列(FIFO),一个叫做文件队列,主要用作对输入样本文件的管理(可以想象,所有的训练数据一般不会存储在一个文件内,该部分主要完成对数据文件的管理);另一个叫做数据队列,如果对应的数据是图像可以认为该队列中的每一项都是存储在内存中的解码后的一系列图像像素值。
下面,我们分别新建3个csv文件->A.csv;B.csv;C.csv,每个文件下分别用X_i, y_i代表训练样本的数据及标注信息。


#-*- coding:gbk -*-
import tensorflow as tf
# 队列1:生成一个先入先出队列和一个QueueRunner,生成文件名队列
filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False, num_epochs=2)
# 定义Reader
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# 定义Decoder
example, label = tf.decode_csv(value, record_defaults=[['string'], ['string']])
with tf.Session() as sess:
coord = tf.train.Coordinator() #创建一个协调器,管理线程
threads = tf.train.start_queue_runners(coord=coord) #启动QueueRunner, 此时文件名队列已经进队。
for i in range(12):
e_val, l_val = sess.run([example, label])
print(e_val, l_val)
coord.request_stop()
coord.join(threads)
程序中,首先根据文件列表,通过tf.train.string_input_producer(filenames, shuffle=False)函数建立了一个对应的文件管理队列,其中shuffle=False表 示不对文件顺序进行打乱(True表示打乱,每次输出顺序将不再一致)。此外,还可通过设置第三个参数num_epochs来控制文件数据多少。
运行结果如下:

上段程序中,主要完成以下几方面工作:
- 针对文件名列表,建立对应的文件队列
- 使用reader读取对应文件数据集
- 解码数据集,得到样本example和标注label
感兴趣的读者可以打开tf.train.string_input_producer(...)函数,可以看到如下代码
"""
@compatibility(eager)
Input pipelines based on Queues are not supported when eager execution is
enabled. Please use the `tf.data` API to ingest data under eager execution.
@end_compatibility
"""
if context.in_eager_mode():
raise RuntimeError(
"Input pipelines based on Queues are not supported when eager execution"
" is enabled. Please use tf.data to ingest data into your model"
" instead.")
with ops.name_scope(name, "input_producer", [input_tensor]):
input_tensor = ops.convert_to_tensor(input_tensor, name="input_tensor")
element_shape = input_tensor.shape[1:].merge_with(element_shape)
if not element_shape.is_fully_defined():
raise ValueError("Either `input_tensor` must have a fully defined shape "
"or `element_shape` must be specified")
if shuffle:
input_tensor = random_ops.random_shuffle(input_tensor, seed=seed)
input_tensor = limit_epochs(input_tensor, num_epochs)
q = data_flow_ops.FIFOQueue(capacity=capacity,
dtypes=[input_tensor.dtype.base_dtype],
shapes=[element_shape],
shared_name=shared_name, name=name)
enq = q.enqueue_many([input_tensor])
queue_runner.add_queue_runner(
queue_runner.QueueRunner(
q, [enq], cancel_op=cancel_op))
if summary_name is not None:
summary.scalar(summary_name,
math_ops.to_float(q.size()) * (1. / capacity))
return q
可以看到该段代码主要完成以下工作:
- 创建队列Queue
- 创建线程enqueue_many
- 添加QueueRunner到collection中
- 返回队列Queue
数据解析
# 定义Reader
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# 定义Decoder
example, label = tf.decode_csv(value, record_defaults=[['string'], ['string']])
这里,我们通过定义一个reader来读取每个数据文件内容,也可图中也展示了TensorFlow支持定义多个reader并且读取文件队列中文件内容,从而提供数据读取效率。然后,采用一个decoder_csv函数对读取的原始CSV文件内容进行解码,平时我们也可根据自己数据存储格式选择不同数据解码方式。在这里需要指出的是,上述程序中并没有用到图中展示的第二个数据队列,这是为什么呢。
实际上做深度学习or机器学习训练过程中,为了保证训练过程的高效性通常不采用单个样本数据给训练模型,而是采用一组N个数据(称作mini-batch),并把每组样本个数N成为batch-size。现在假设我们每组需要喂给模型N个数据,需通过N次循环读入内存,然后再通过GPU进行前向or返向传播运算,这就意味着GPU每次运算都需要一段时间等待CPU读取数据,从而大大降低了训练效率。而第二个队列(数据队列)就是为了解决这个问题提出来的,代码实现即为:tf.train.batch()和 tf.train.shuffle_batch,这两个函数的主要区别在于是否需要将列表中数据进行随机打乱。
#-*- coding:gbk -*-
import tensorflow as tf
# 生成一个先入先出队列和一个QueueRunner,生成文件名队列
filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False, num_epochs=3)
# 定义Reader
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# 定义Decoder
example, label = tf.decode_csv(value, record_defaults=[['string'], ['string']])
#example_batch, label_batch = tf.train.shuffle_batch([example,label], batch_size=16, capacity=200, min_after_dequeue=100, num_threads=2)
example_batch, label_batch = tf.train.batch([example,label], batch_size=8, capacity=200, num_threads=2)
#example_list = [tf.decode_csv(value, record_defaults=[['string'], ['string']])
# for _ in range(2)] # Reader设置为2
### 使用tf.train.batch_join(),可以使用多个reader,并行读取数据。每个Reader使用一个线程。
#example_batch, label_batch = tf.train.batch_join(
# example_list, batch_size=5)
init_local_op = tf.initialize_local_variables()
with tf.Session() as sess:
sess.run(init_local_op)
coord = tf.train.Coordinator() #创建一个协调器,管理线程
threads = tf.train.start_queue_runners(coord=coord) #启动QueueRunner, 此时文件名队列已经进队。
for i in range(5):
# Retrieve a single instance:
e_val, l_val = sess.run([example_batch, label_batch])
print(e_val, l_val)
coord.request_stop()
coord.join(threads)
使用tf.train.batch()函数,每次根据自己定义大小会返回一组训练数据,从而避免了往内存中循环读取数据的麻烦,提高了效率。并且还可以通过设置reader个数,实现多线程高效地往数据队列(或叫内存队列)中填充数据,直到文件队列读完所有文件(或文件数据不足一个batch size)。
tf.train.batch()程序运行结果如下

注:tf.train.batch([example,label], batch_size=8, capacity=200, num_threads=2)参数中,capacity表示队列大小,每次读出数据后队尾会按顺序依次补充。num_treads=2表示两个线程(据说在一个reader下可达到最优),batch_size=8表示每次返回8组训练数据,即batch size大小。tf.train.shuffle_batch()比tf.train.bathc()多一个min_after_dequeue参数,意思是在每次抛出一个batch后,剩余数据样本不少于多少个。
Tensorflow学习-数据读取的更多相关文章
- AI学习---数据读取&神经网络
AI学习---数据读取&神经网络 fa
- Tensorflow学习教程------读取数据、建立网络、训练模型,小巧而完整的代码示例
紧接上篇Tensorflow学习教程------tfrecords数据格式生成与读取,本篇将数据读取.建立网络以及模型训练整理成一个小样例,完整代码如下. #coding:utf-8 import t ...
- tensorflow之数据读取探究(2)
tensorflow之tfrecord数据读取 Tensorflow关于TFRecord格式文件的处理.模型的训练的架构为: 1.获取文件列表.创建文件队列:http://blog.csdn.net/ ...
- tensorflow之数据读取探究(1)
Tensorflow中之前主要用的数据读取方式主要有: 建立placeholder,然后使用feed_dict将数据feed进placeholder进行使用.使用这种方法十分灵活,可以一下子将所有数据 ...
- 关于Tensorflow 的数据读取环节
Tensorflow读取数据的一般方式有下面3种: preloaded直接创建变量:在tensorflow定义图的过程中,创建常量或变量来存储数据 feed:在运行程序时,通过feed_dict传入数 ...
- 机器学习: TensorFlow 的数据读取与TFRecords 格式
最近学习tensorflow,发现其读取数据的方式看起来有些不同,所以又重新系统地看了一下文档,总得来说,tensorflow 有三种主流的数据读取方式: 1) 传送 (feeding): Pytho ...
- tensorflow学习--数据加载
文章主要来自Tensorflow官方文档,同时加入了自己的理解以及部分代码 数据读取 TensorFlow程序读取数据一共有3种方法: 供给数据(Feeding): 在TensorFlow程序运行的每 ...
- 『TensorFlow』数据读取类_data.Dataset
一.资料 参考原文: TensorFlow全新的数据读取方式:Dataset API入门教程 API接口简介: TensorFlow的数据集 二.背景 注意,在TensorFlow 1.3中,Data ...
- TensorFlow的数据读取机制
一.tensorflow读取机制图解 首先需要思考的一个问题是,什么是数据读取?以图像数据为例,读取的过程可以用下图来表示 假设我们的硬盘中有一个图片数据集0001.jpg,0002.jpg,0003 ...
随机推荐
- VS Code python初体验笔记
之前一直都是使用Notepad++来编写Python代码,后来想起来之前查资料的时候有个VS Code可以编写一些的脚本语言(js,node.js)甚至是高级编程语言(C#,PHP,JAVA,Pyth ...
- Java小技巧输出26个英文字母
相信有的童鞋写到过与字母有关的小东西,是否有写过全部的字母呢?26个这么多字母,一个个打会疯掉.所有咱们可以用一个小技巧使用for循环帮我们把26个字母自动搞出来,大家来瞅一眼把! 使用Java遍历2 ...
- eclipse下载指南
官网下载地址 下载https://www.eclipse.org/downloads/ 官网https://www.eclipse.org/ 最新版本 Eclipse OXYGEN Eclipse O ...
- kafka 集群
在kafka中,我们能创建多种类型的集群,一般如下: 单节点——单个broker集群 单节点——多broker集群 多结点——多broker集群 kafka集群中主要有五个组件: Topic:主题主要 ...
- python编码详解--转自(Alex的博客)
原文地址:http://www.cnblogs.com/alex3714/articles/7550940.html 编码回顾 在备编码相关的课件时,在知乎上看到一段关于Python编码的回答 这哥们 ...
- [译文]Domain Driven Design Reference(二)—— 让模型起作用
本书是Eric Evans对他自己写的<领域驱动设计-软件核心复杂性应对之道>的一本字典式的参考书,可用于快速查找<领域驱动设计>中的诸多概念及其简明解释. 其它本系列其它文章 ...
- JAVAEE——BOS物流项目13:Quartz入门案例、核心概念、cron 表达式的格式(了解)
1.quartz入门案例 本入门案例基于spring和quartz整合完成. 第一步:创建maven工程,导入spring和quartz相关依赖 第二步:创建任务类 第三步:在spring配置文件中配 ...
- Django模型层之字段查询参数及聚合函数
该系列教程系个人原创,并完整发布在个人官网刘江的博客和教程 所有转载本文者,需在顶部显著位置注明原作者及www.liujiangblog.com官网地址. 字段查询是指如何指定SQL WHERE子句的 ...
- iframe 背景透明设置--兼容IE
iframe标签添加: allowTransparency="true"属性. 子文件body背景设置透明: background-color: transparent;.
- java面试总结
一.java的集合框架 HashMap.HashTable.CurrentHashMap的底层数据结构与区别? CurrentHashMap与HashTable是如何保证线程安全的? ArrayLis ...