Tensorflow 中(批量)读取数据的案列分析及TFRecord文件的打包与读取
内容概要:
单一数据读取方式:
第一种:slice_input_producer()
# 返回值可以直接通过 Session.run([images, labels])查看,且第一个参数必须放在列表中,如[...][images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True)
第二种:string_input_producer()
# 需要定义文件读取器,然后通过读取器中的 read()方法来获取数据(返回值类型 key,value),再通过 Session.run(value)查看file_queue = tf.train.string_input_producer(filename, num_epochs=None, shuffle=True) reader = tf.WholeFileReader() # 定义文件读取器key, value = reader.read(file_queue) # key:文件名;value:文件中的内容
!!!num_epochs=None,不指定迭代次数,这样文件队列中元素个数也不限定(None*数据集大小)。
!!!如果它不是None,则此函数创建本地计数器 epochs,需要使用local_variables_initializer()初始化局部变量
!!!以上两种方法都可以生成文件名队列。
(随机)批量数据读取方式:
batchsize=2 # 每次读取的样本数量tf.train.batch(tensors, batch_size=batchsize)tf.train.shuffle_batch(tensors, batch_size=batchsize, capacity=batchsize*10, min_after_dequeue=batchsize*5) # capacity > min_after_dequeue
!!!以上所有读取数据的方法,在Session.run()之前必须开启文件队列线程 tf.train.start_queue_runners()
TFRecord文件的打包与读取
一、单一数据读取方式
第一种:slice_input_producer()
def slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None)
案例1:
import tensorflow as tf images = ['image1.jpg', 'image2.jpg', 'image3.jpg', 'image4.jpg']
labels = [1, 2, 3, 4] # [images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True) # 当num_epochs=2时,此时文件队列中只有 2*4=8个样本,所有在取第9个样本时会出错
# [images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=2, shuffle=True) data = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True)
print(type(data)) # <class 'list'> with tf.Session() as sess:
# sess.run(tf.local_variables_initializer())
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator() # 线程的协调器
threads = tf.train.start_queue_runners(sess, coord) # 开始在图表中收集队列运行器 for i in range(10):
print(sess.run(data)) coord.request_stop()
coord.join(threads) """
运行结果:
[b'image2.jpg', 2]
[b'image1.jpg', 1]
[b'image3.jpg', 3]
[b'image4.jpg', 4]
[b'image2.jpg', 2]
[b'image1.jpg', 1]
[b'image3.jpg', 3]
[b'image4.jpg', 4]
[b'image2.jpg', 2]
[b'image3.jpg', 3]
"""
!!!slice_input_producer() 中的第一个参数需要放在一个列表中,列表中的每个元素可以是 List 或 Tensor,如 [images,labels],
!!!num_epochs设置
第二种:string_input_producer()
def string_input_producer(string_tensor, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None, cancel_op=None)
文件读取器
不同类型的文件对应不同的文件读取器,我们称为 reader对象;
该对象的 read 方法自动读取文件,并创建数据队列,输出key/文件名,value/文件内容;
reader = tf.TextLineReader() ### 一行一行读取,适用于所有文本文件 reader = tf.TFRecordReader() ### A Reader that outputs the records from a TFRecords file
reader = tf.WholeFileReader() ### 一次读取整个文件,适用图片
案例2:读取csv文件
iimport tensorflow as tf filename = ['data/A.csv', 'data/B.csv', 'data/C.csv'] file_queue = tf.train.string_input_producer(filename, shuffle=True, num_epochs=2) # 生成文件名队列
reader = tf.WholeFileReader() # 定义文件读取器(一次读取整个文件)
# reader = tf.TextLineReader() # 定义文件读取器(一行一行的读)
key, value = reader.read(file_queue) # key:文件名;value:文件中的内容
print(type(file_queue)) init = [tf.global_variables_initializer(), tf.local_variables_initializer()]
with tf.Session() as sess:
sess.run(init)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
for i in range(6):
print(sess.run([key, value]))
break
except tf.errors.OutOfRangeError:
print('read done')
finally:
coord.request_stop()
coord.join(threads) """
reader = tf.WholeFileReader() # 定义文件读取器(一次读取整个文件)
运行结果:
[b'data/C.csv', b'7.jpg,7\n8.jpg,8\n9.jpg,9\n']
[b'data/B.csv', b'4.jpg,4\n5.jpg,5\n6.jpg,6\n']
[b'data/A.csv', b'1.jpg,1\n2.jpg,2\n3.jpg,3\n']
[b'data/A.csv', b'1.jpg,1\n2.jpg,2\n3.jpg,3\n']
[b'data/B.csv', b'4.jpg,4\n5.jpg,5\n6.jpg,6\n']
[b'data/C.csv', b'7.jpg,7\n8.jpg,8\n9.jpg,9\n']
"""
"""
reader = tf.TextLineReader() # 定义文件读取器(一行一行的读)
运行结果:
[b'data/B.csv:1', b'4.jpg,4']
[b'data/B.csv:2', b'5.jpg,5']
[b'data/B.csv:3', b'6.jpg,6']
[b'data/C.csv:1', b'7.jpg,7']
[b'data/C.csv:2', b'8.jpg,8']
[b'data/C.csv:3', b'9.jpg,9']
"""
案例3:读取图片(每次读取全部图片内容,不是一行一行)
import tensorflow as tf filename = ['1.jpg', '2.jpg']
filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=1)
reader = tf.WholeFileReader() # 文件读取器
key, value = reader.read(filename_queue) # 读取文件 key:文件名;value:图片数据,bytes with tf.Session() as sess:
tf.local_variables_initializer().run()
coord = tf.train.Coordinator() # 线程的协调器
threads = tf.train.start_queue_runners(sess, coord) for i in range(filename.__len__()):
image_data = sess.run(value)
with open('img_%d.jpg' % i, 'wb') as f:
f.write(image_data)
coord.request_stop()
coord.join(threads)
二、(随机)批量数据读取方式:
功能:shuffle_batch() 和 batch() 这两个API都是从文件队列中批量获取数据,使用方式类似;
案例4:slice_input_producer() 与 batch()
import tensorflow as tf
import numpy as np images = np.arange(20).reshape([10, 2])
label = np.asarray(range(0, 10))
images = tf.cast(images, tf.float32) # 可以注释掉,不影响运行结果
label = tf.cast(label, tf.int32) # 可以注释掉,不影响运行结果 batchsize = 6 # 每次获取元素的数量
input_queue = tf.train.slice_input_producer([images, label], num_epochs=None, shuffle=False)
image_batch, label_batch = tf.train.batch(input_queue, batch_size=batchsize) # 随机获取 batchsize个元素,其中,capacity:队列容量,这个参数一定要比 min_after_dequeue 大
# image_batch, label_batch = tf.train.shuffle_batch(input_queue, batch_size=batchsize, capacity=64, min_after_dequeue=10) with tf.Session() as sess:
coord = tf.train.Coordinator() # 线程的协调器
threads = tf.train.start_queue_runners(sess, coord) # 开始在图表中收集队列运行器
for cnt in range(2):
print("第{}次获取数据,每次batch={}...".format(cnt+1, batchsize))
image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
print(image_batch_v, label_batch_v, label_batch_v.__len__()) coord.request_stop()
coord.join(threads) """
运行结果:
第1次获取数据,每次batch=6...
[[ 0. 1.]
[ 2. 3.]
[ 4. 5.]
[ 6. 7.]
[ 8. 9.]
[10. 11.]] [0 1 2 3 4 5] 6
第2次获取数据,每次batch=6...
[[12. 13.]
[14. 15.]
[16. 17.]
[18. 19.]
[ 0. 1.]
[ 2. 3.]] [6 7 8 9 0 1] 6
"""
案例5:从本地批量的读取图片 --- string_input_producer() 与 batch()

import tensorflow as tf
import glob
import cv2 as cv def read_imgs(filename, picture_format, input_image_shape, batch_size=1):
"""
从本地批量的读取图片
:param filename: 图片路径(包括图片的文件名),[]
:param picture_format: 图片的格式,如 bmp,jpg,png等; string
:param input_image_shape: 输入图像的大小; (h,w,c)或[]
:param batch_size: 每次从文件队列中加载图片的数量; int
:return: batch_size张图片数据, Tensor
"""
global new_img
# 创建文件队列
file_queue = tf.train.string_input_producer(filename, num_epochs=1, shuffle=True)
# 创建文件读取器
reader = tf.WholeFileReader()
# 读取文件队列中的文件
_, img_bytes = reader.read(file_queue)
# print(img_bytes) # Tensor("ReaderReadV2_19:1", shape=(), dtype=string)
# 对图片进行解码
if picture_format == ".bmp":
new_img = tf.image.decode_bmp(img_bytes, channels=1)
elif picture_format == ".jpg":
new_img = tf.image.decode_jpeg(img_bytes, channels=3)
else:
pass
# 重新设置图片的大小
# new_img = tf.image.resize_images(new_img, input_image_shape)
new_img = tf.reshape(new_img, input_image_shape)
# 设置图片的数据类型
new_img = tf.image.convert_image_dtype(new_img, tf.uint8) # return new_img
return tf.train.batch([new_img], batch_size) def main():
image_path = glob.glob(r'F:\demo\FaceRecognition\人脸库\ORL\*.bmp')
image_batch = read_imgs(image_path, ".bmp", (112, 92, 1), 5)
print(type(image_batch))
# image_path = glob.glob(r'.\*.jpg')
# image_batch = read_imgs(image_path, ".jpg", (313, 500, 3), 1) sess = tf.Session()
sess.run(tf.local_variables_initializer())
tf.train.start_queue_runners(sess=sess) image_batch = sess.run(image_batch)
print(type(image_batch)) # <class 'numpy.ndarray'> for i in range(image_batch.__len__()):
cv.imshow("win_"+str(i), image_batch[i])
cv.waitKey()
cv.destroyAllWindows() def start():
image_path = glob.glob(r'F:\demo\FaceRecognition\人脸库\ORL\*.bmp')
image_batch = read_imgs(image_path, ".bmp", (112, 92, 1), 5)
print(type(image_batch)) # <class 'tensorflow.python.framework.ops.Tensor'> with tf.Session() as sess:
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator() # 线程的协调器
threads = tf.train.start_queue_runners(sess, coord) # 开始在图表中收集队列运行器
image_batch = sess.run(image_batch)
print(type(image_batch)) # <class 'numpy.ndarray'> for i in range(image_batch.__len__()):
cv.imshow("win_"+str(i), image_batch[i])
cv.waitKey()
cv.destroyAllWindows() # 若使用 with 方式打开 Session,且没加如下2行语句,则会出错
# ERROR:tensorflow:Exception in QueueRunner: Enqueue operation was cancelled;
# 原因:文件队列线程还处于工作状态(队列中还有图片数据),而加载完batch_size张图片会话就会自动关闭,同时关闭文件队列线程
coord.request_stop()
coord.join(threads) if __name__ == "__main__":
# main()
start()
从本地批量的读取图片案例
案列6:TFRecord文件打包与读取

def write_TFRecord(filename, data, labels, is_shuffler=True):
"""
将数据打包成TFRecord格式
:param filename: 打包后路径名,默认在工程目录下创建该文件;String
:param data: 需要打包的文件路径名;list
:param labels: 对应文件的标签;list
:param is_shuffler:是否随机初始化打包后的数据,默认:True;Bool
:return: None
"""
im_data = list(data)
im_labels = list(labels) index = [i for i in range(im_data.__len__())]
if is_shuffler:
np.random.shuffle(index) # 创建写入器,然后使用该对象写入样本example
writer = tf.python_io.TFRecordWriter(filename)
for i in range(im_data.__len__()):
im_d = im_data[index[i]] # im_d:存放着第index[i]张图片的路径信息
im_l = im_labels[index[i]] # im_l:存放着对应图片的标签信息 # # 获取当前的图片数据 方式一:
# data = cv2.imread(im_d)
# # 创建样本
# ex = tf.train.Example(
# features=tf.train.Features(
# feature={
# "image": tf.train.Feature(
# bytes_list=tf.train.BytesList(
# value=[data.tobytes()])), # 需要打包成bytes类型
# "label": tf.train.Feature(
# int64_list=tf.train.Int64List(
# value=[im_l])),
# }
# )
# )
# 获取当前的图片数据 方式二:相对于方式一,打包文件占用空间小了一半多
data = tf.gfile.FastGFile(im_d, "rb").read()
ex = tf.train.Example(
features=tf.train.Features(
feature={
"image": tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[data])), # 此时的data已经是bytes类型
"label": tf.train.Feature(
int64_list=tf.train.Int64List(
value=[im_l])),
}
)
) # 写入将序列化之后的样本
writer.write(ex.SerializeToString())
# 关闭写入器
writer.close()
TFRecord文件打包案列

import tensorflow as tf
import cv2 def read_TFRecord(file_list, batch_size=10):
"""
读取TFRecord文件
:param file_list: 存放TFRecord的文件名,List
:param batch_size: 每次读取图片的数量
:return: 解析后图片及对应的标签
"""
file_queue = tf.train.string_input_producer(file_list, num_epochs=None, shuffle=True)
reader = tf.TFRecordReader()
_, ex = reader.read(file_queue)
batch = tf.train.shuffle_batch([ex], batch_size, capacity=batch_size * 10, min_after_dequeue=batch_size * 5) feature = {
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)
}
example = tf.parse_example(batch, features=feature) images = tf.decode_raw(example['image'], tf.uint8)
images = tf.reshape(images, [-1, 32, 32, 3]) return images, example['label'] def main():
# filelist = ['data/train.tfrecord']
filelist = ['data/test.tfrecord']
images, labels = read_TFRecord(filelist, 2)
with tf.Session() as sess:
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord) try:
while not coord.should_stop():
for i in range(1):
image_bth, _ = sess.run([images, labels])
print(_) cv2.imshow("image_0", image_bth[0])
cv2.imshow("image_1", image_bth[1])
break
except tf.errors.OutOfRangeError:
print('read done')
finally:
coord.request_stop()
coord.join(threads)
cv2.waitKey(0)
cv2.destroyAllWindows() if __name__ == "__main__":
main()
TFReord文件的读取案列
Tensorflow 中(批量)读取数据的案列分析及TFRecord文件的打包与读取的更多相关文章
- SQL Server中批量替换数据
SQL Server数据库中批量替换数据的方法 SQL Server数据库操作中,我们可能会根据某写需要去批量替换数据,那么如何批量修改替换数据呢?本文我们就介绍这一部分内容,接下来就让我们一起来了解 ...
- C# 批量插入表SQLSERVER SqlBulkCopy往数据库中批量插入数据
#region 帮助实例:SQL 批量插入数据 多种方法 /// <summary> /// SqlBulkCopy往数据库中批量插入数据 /// </summary> /// ...
- 向mysql中批量插入数据的性能分析
MYSQL批量插入数据库实现语句性能分析 假定我们的表结构如下 代码如下 CREATE TABLE example (example_id INT NOT NULL,name VARCHAR( 5 ...
- .Net中批量添加数据的几种实现方法比较
在.Net中经常会遇到批量添加数据,如将Excel中的数据导入数据库,直接在DataGridView控件中添加数据再保存到数据库等等. 方法一:一条一条循环添加 通常我们的第一反应是采用for或for ...
- SharePoint自动化系列——通过PowerShell在SharePoint中批量做数据
转载请注明出自天外归云的博客园:http://www.cnblogs.com/LanTianYou/ PowerShell是基于.NET的一门脚本语言,对于SharePoint一些日常操作支持的很好. ...
- Hibernate 中批量处理数据
一.批量处理操作 批量处理数据是指在一个事务场景中处理大量数据.在应用程序中难以避免进行批量操作,Hibernate提供了以下方式进行批量处理数据: (1)使用HQL进行批量操作 数据库层面 ...
- C#读txt文件并写入二维数组中(txt数据行,列未知)
using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.I ...
- SQLSERVER数据库中批量导入数据的几种方法
第一:使用Select Into 语句 如果企业数据库都是采用SQL Server数据库的话,则可以利用select into语句实现数据的导入. select into语句的作用是把数据从另外一个数 ...
- MyBatis向数据库中批量插入数据
Foreach标签 foreach: collection:指定要遍历的集合; 表示传入过来的参数的数据类型.该参数为必选.要做 foreach 的对象,作为入参时,List 对象默认用 list 代 ...
随机推荐
- 又抓了一个导致频繁GC的鬼--数组动态扩容
概述 本周有个同事过来咨询一个比较诡异的gc问题,大概现象是,系统一直在做cms gc,但是老生代一直不降下去,但是执行一次jmap -histo:live之后,也就是主动触发一次full gc之后, ...
- Docker scratch 无法正常运行golang二进制程序的问题
使用Docker构建容器能够极大的降低运维成本,提高部署效率,同时非常方便对服务的平行扩展.然而在构建容器镜像过程中的,存在着一个难以避免的问题,就是如果使用常见的发行版本作为程序运行的基础环境,那么 ...
- C#中方法的静态和非静态
在代码中,给方法加上static就成为了一个静态的方法,然而静态方法是隶属于类的,由类名点出来! 不给方法加static就是一个非静态方法,非静态的方法,是隶属于对象的,需要把类实例化之后,用对象名去 ...
- spring-kafka之KafkaListener注解深入解读
简介 Kafka目前主要作为一个分布式的发布订阅式的消息系统使用,也是目前最流行的消息队列系统之一.因此,也越来越多的框架对kafka做了集成,比如本文将要说到的spring-kafka. Kafka ...
- String类练习
1.模拟一个trim方法,去除字符串两端的空格 2.将一个字符串进行反转.将字符串中指定部分进行反转 3.获取一个字符串在另一个字符串中出现的次数 4.获取两个字符串中最大相同子串 5.对字符串中字符 ...
- Black Hat Python之#2:TCP代理
在本科做毕设的时候就接触到TCP代理这东西,当时需要使用代理来对发送和收到的数据做修改,同时使用代理也让我对HTTP协议有了更深的了解. TCP Proxy用到的一个主要的东西就是socket.pro ...
- Spring Boot笔记(七) springboot 集成 JavaMail 实现邮箱认证
个人博客网:https://wushaopei.github.io/ (你想要这里多有) 一.JavaMail 1.什么是JavaMail? JavaMail,顾名思义,提供给开发者处理 电子邮 ...
- Java实现 LeetCode 556 下一个更大元素 III(数组的翻转)
556. 下一个更大元素 III 给定一个32位正整数 n,你需要找到最小的32位整数,其与 n 中存在的位数完全相同,并且其值大于n.如果不存在这样的32位整数,则返回-1. 示例 1: 输入: 1 ...
- Java实现蓝桥杯突击战
突击战 你有n个部下,每个部下需要完成一项任务.第i个部下需要你花Bi分钟交待任务,然后他会立刻独立地. 无间断地执行Ji分钟后完成任务.你需要选择交待任务的顺序, 使得所有任务尽早执行完毕(即最后一 ...
- Java实现 LeetCode 133 克隆图
133. 克隆图 给你无向 连通 图中一个节点的引用,请你返回该图的 深拷贝(克隆). 图中的每个节点都包含它的值 val(int) 和其邻居的列表(list[Node]). class Node { ...