TFRecord 的使用
什么是 TFRecord
PS:这段内容摘自 http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/reading_data.html
一种保存记录的方法可以允许你讲任意的数据转换为TensorFlow所支持的格式, 这种方法可以使TensorFlow的数据集更容易与网络应用架构相匹配。这种建议的方法就是使用TFRecords文件,TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。你可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocolbuffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriterclass写入到TFRecords文件。tensorflow/g3doc/how_tos/reading_data/convert_to_records.py就是这样的一个例子。
从TFRecords文件中读取数据,
可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个parse_single_example操作可以将Example协议内存块(protocolbuffer)解析为张量。
MNIST的例子就使用了convert_to_records 所构建的数据。
请参看tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py,
代码
adjust_pic.py
单纯的转换图片大小
- # -*- coding: utf-8 -*-
- import tensorflow as tf
- def resize(img_data, width, high, method=0):
- return tf.image.resize_images(img_data,[width, high], method)
pic2tfrecords.py
将图片保存成TFRecord
- # -*- coding: utf-8 -*-
- # 将图片保存成 TFRecord
- import os.path
- import matplotlib.image as mpimg
- import tensorflow as tf
- import adjust_pic as ap
- from PIL import Image
- SAVE_PATH = 'data/dataset.tfrecords'
- def _int64_feature(value):
- return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
- def _bytes_feature(value):
- return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
- def load_data(datafile, width, high, method=0, save=False):
- train_list = open(datafile,'r')
- # 准备一个 writer 用来写 TFRecord 文件
- writer = tf.python_io.TFRecordWriter(SAVE_PATH)
- with tf.Session() as sess:
- for line in train_list:
- # 获得图片的路径和类型
- tmp = line.strip().split(' ')
- img_path = tmp[0]
- label = int(tmp[1])
- # 读取图片
- image = tf.gfile.FastGFile(img_path, 'r').read()
- # 解码图片(如果是 png 格式就使用 decode_png)
- image = tf.image.decode_jpeg(image)
- # 转换数据类型
- # 因为为了将图片数据能够保存到 TFRecord 结构体中,所以需要将其图片矩阵转换成 string,所以为了在使用时能够转换回来,这里确定下数据格式为 tf.float32
- image = tf.image.convert_image_dtype(image, dtype=tf.float32)
- # 既然都将图片保存成 TFRecord 了,那就先把图片转换成希望的大小吧
- image = ap.resize(image, width, high)
- # 执行 op: image
- image = sess.run(image)
- # 将其图片矩阵转换成 string
- image_raw = image.tostring()
- # 将数据整理成 TFRecord 需要的数据结构
- example = tf.train.Example(features=tf.train.Features(feature={
- 'image_raw': _bytes_feature(image_raw),
- 'label': _int64_feature(label),
- }))
- # 写 TFRecord
- writer.write(example.SerializeToString())
- writer.close()
- load_data('train_list.txt_bak', 224, 224)
tfrecords2data.py
从TFRecord中读取并保存成图片
- # -*- coding: utf-8 -*-
- # 从 TFRecord 中读取并保存图片
- import tensorflow as tf
- import numpy as np
- SAVE_PATH = 'data/dataset.tfrecords'
- def load_data(width, high):
- reader = tf.TFRecordReader()
- filename_queue = tf.train.string_input_producer([SAVE_PATH])
- # 从 TFRecord 读取内容并保存到 serialized_example 中
- _, serialized_example = reader.read(filename_queue)
- # 读取 serialized_example 的格式
- features = tf.parse_single_example(
- serialized_example,
- features={
- 'image_raw': tf.FixedLenFeature([], tf.string),
- 'label': tf.FixedLenFeature([], tf.int64),
- })
- # 解析从 serialized_example 读取到的内容
- images = tf.decode_raw(features['image_raw'], tf.uint8)
- labels = tf.cast(features['label'], tf.int64)
- with tf.Session() as sess:
- # 启动多线程
- coord = tf.train.Coordinator()
- threads = tf.train.start_queue_runners(sess=sess, coord=coord)
- # 因为我这里只有 2 张图片,所以下面循环 2 次
- for i in range(2):
- # 获取一张图片和其对应的类型
- label, image = sess.run([labels, images])
- # 这里特别说明下:
- # 因为要想把图片保存成 TFRecord,那就必须先将图片矩阵转换成 string,即:
- # pic2tfrecords.py 中 image_raw = image.tostring() 这行
- # 所以这里需要执行下面这行将 string 转换回来,否则会无法 reshape 成图片矩阵,请看下面的小例子:
- # a = np.array([[1, 2], [3, 4]], dtype=np.int64) # 2*2 的矩阵
- # b = a.tostring()
- # # 下面这行的输出是 32,即: 2*2 之后还要再乘 8
- # # 如果 tostring 之后的长度是 2*2=4 的话,那可以将 b 直接 reshape([2, 2]),但现在的长度是 2*2*8 = 32,所以无法直接 reshape
- # # 同理如果你的图片是 500*500*3 的话,那 tostring() 之后的长度是 500*500*3 后再乘上一个数
- # print len(b)
- #
- # 但在网上有很多提供的代码里都没有下面这一行,你们那真的能 reshape ?
- image = np.fromstring(image, dtype=np.float32)
- # reshape 成图片矩阵
- image = tf.reshape(image, [224, 224, 3])
- # 因为要保存图片,所以将其转换成 uint8
- image = tf.image.convert_image_dtype(image, dtype=tf.uint8)
- # 按照 jpeg 格式编码
- image = tf.image.encode_jpeg(image)
- # 保存图片
- with tf.gfile.GFile('pic_%d.jpg' % label, 'wb') as f:
- f.write(sess.run(image))
- load_data(224, 224)
train_list.txt_bak 中的内容如下:
image_1093.jpg 13
image_0805.jpg 10
TFRecord 的使用的更多相关文章
- Tensorflow 处理libsvm格式数据生成TFRecord (parse libsvm data to TFRecord)
#写libsvm格式 数据 write libsvm #!/usr/bin/env python #coding=gbk # ================================= ...
- 学习笔记TF016:CNN实现、数据集、TFRecord、加载图像、模型、训练、调试
AlexNet(Alex Krizhevsky,ILSVRC2012冠军)适合做图像分类.层自左向右.自上向下读取,关联层分为一组,高度.宽度减小,深度增加.深度增加减少网络计算量. 训练模型数据集 ...
- [TFRecord格式数据]利用TFRecords存储与读取带标签的图片
利用TFRecords存储与读取带标签的图片 原创文章,转载请注明出处~ 觉得有用的话,欢迎一起讨论相互学习~Follow Me TFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是 ...
- 深度学习原理与框架-Tfrecord数据集的读取与训练(代码) 1.tf.train.batch(获取batch图片) 2.tf.image.resize_image_with_crop_or_pad(图片压缩) 3.tf.train.per_image_stand..(图片标准化) 4.tf.train.string_input_producer(字符串入队列) 5.tf.TFRecord(读
1.tf.train.batch(image, batch_size=batch_size, num_threads=1) # 获取一个batch的数据 参数说明:image表示输入图片,batch_ ...
- 深度学习原理与框架-Tfrecord数据集的制作 1.tf.train.Examples(数据转换为二进制) 3.tf.image.encode_jpeg(解码图片加码成jpeg) 4.tf.train.Coordinator(构建多线程通道) 5.threading.Thread(建立单线程) 6.tf.python_io.TFR(TFR读入器)
1. 配套使用: tf.train.Examples将数据转换为二进制,提升IO效率和方便管理 对于int类型 : tf.train.Examples(features=tf.train.Featur ...
- 3. Tensorflow生成TFRecord
1. Tensorflow高效流水线Pipeline 2. Tensorflow的数据处理中的Dataset和Iterator 3. Tensorflow生成TFRecord 4. Tensorflo ...
- TFRecord文件的读写
前言在跑通了官网的mnist和cifar10数据之后,笔者尝试着制作自己的数据集,并保存,读入,显示. TensorFlow可以支持cifar10的数据格式, 也提供了标准的TFRecord 格式,而 ...
- 目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练
将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练. import xml.etree.ElementTree as ET import numpy as ...
- tfrecord
制作自己的TFRecord数据集,读取,显示及代码详解 http://blog.csdn.net/miaomiaoyuan/article/details/56865361
- 3 TFRecord样例程序实战
将图片数据写入Record文件 # 定义函数转化变量类型. def _int64_feature(value): return tf.train.Feature(int64_list=tf.train ...
随机推荐
- webservice 的简单实现
1.什么是webservice: 服务端整出一些资源让客户端访问(获取数据) 一个跨语言.跨平台的规范2.作用:跨平台调用.跨语言调用.远程调用 3.什么时候使用webservice: 1.新旧系统之 ...
- Day2-T3
原题目 Describe:质数问题 code: #pragma GCC optimize(2) #include<bits/stdc++.h> #define KKK 1200 using ...
- git使用代理
在使用git科隆一个repo的时候,因为这个repo的子模块是托管在google上的,还是因为gfw导致子模块科隆不下来 只好使用代理了,那么怎么配置git使用代理呢 代码如下 因为我用的是ss所以这 ...
- quartz详解3:quartz数据库集群-锁机制
http://blog.itpub.NET/11627468/viewspace-1764753/ 一.quartz数据库锁 其中,QRTZ_LOCKS就是Quartz集群实现同步机制的行锁表,其表结 ...
- mysql四种事务隔离级别
mysql事务并发问题 ACID什么的就不啰嗦了.mysql多个事务并发的时候,可能会出现如下问题: 1. 更新丢失 即两个事务同时更新某一条数据,后执行的更新操作会覆盖先执行的更新操作,导致先执行的 ...
- Aspectj切入点语法定义
例如定义切入点表达式 execution (* com.sample.service.impl..*.*(..)) execution()是最常用的切点函数,其语法如下所示: 整个表达式可以分为五个 ...
- 转载电子发烧友网---STM32的IO口灌入电流和输出驱动电流
刚开始学习一款单片机的时候一般都是从操作IO口开始的,所以我也一样,先是弄个流水灯. 刚开始我对STM32的认识不够,以为是跟51单片机类似,可以直接操作端口,可是LED灯却没反应,于是乎,仔细查看资 ...
- Kettle无法下载以及点击无反应的问题
最开始用于解决MySQL转移数据到ORACLE的问题,尝试了几种方法. 1.直接从Mysql导出csv文件.这种方式最直接简单,但是问题是数据大的话,容易出现数据对不齐的情况,导入这个时候就会出现错误 ...
- [NOI2019]弹跳(KD-Tree)
被jump送退役了,很生气. 不过切了这题也进不了队,行吧. 退役后写了一下,看到二维平面应该就是KD树,然后可以在KD树上做最短路,然后建立堆和KDTree.然后每次更新则是直接把最短路上的节点删掉 ...
- centos通过yum安装php
1.添加php的yum软件仓库 sudo rpm -Uvh https://mirror.webtatic.com/yum/el6/latest.rpm 2.安装php相关软件,执行过程中全部选择ye ...