TFRecord 使用
tfrecord生成
import os
import xmltodict
import tensorflow as tf
import numpy as np
dir_path = 'F:\数据存储\VOCdevkit\VOC2012\Annotations'
dirs = os.listdir(dir_path)
imgs_dir = "F:\数据存储\VOCdevkit\VOC2012\JPEGImages"
out_path = 'F:\数据存储\VOCdevkit\\voc2012.tfrecord'
classes = [
"background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
"chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
"pottedplant", "sheep", "sofa", "train", "tvmonitor"
]
sess = tf.Session()
def get_and_resize_img(img_file):
'''
将图片设置为224*224的尺寸大小
返回图片,返回变化倍数,shape
'''
img = tf.read_file(imgs_dir + '/' + img_file)
img = tf.image.decode_jpeg(img)
shape_old = sess.run(img).shape
resized_img = tf.image.resize_images(img, [224, 224], method=0)
resized_img = sess.run(resized_img)
resized_img = np.asarray(resized_img, dtype='uint8')
resized_img_str = resized_img.tostring()
shape_new = resized_img.shape
# print(shape_new)
# print(shape_old)
# print('shape_old的长是width是维度1,height是维度0')
w_scale = shape_new[0] / shape_old[1]
h_scale = shape_new[1] / shape_old[0]
return resized_img_str, w_scale, h_scale, shape_new
writer = tf.python_io.TFRecordWriter(out_path)
i = 0
for file in dirs:
i = i + 1
# if i > 1000:
# break
with open(dir_path + '/' + file) as xml_txt:
doc = xmltodict.parse(xml_txt.read())
img_file_name = file.split('.')[0]
resized_img_str, w_scale, h_scale, shape = get_and_resize_img(img_file_name + '.jpg')
img_obtain_classes = []
y_mins = []
x_mins = []
y_maxes = []
x_maxes = []
if type(doc['annotation']["object"]).__name__ == 'OrderedDict':
if doc['annotation']["object"]['name'] in classes:
img_obtain_classes.append(classes.index(doc['annotation']["object"]['name']))
y_mins.append(float(h_scale * int(doc['annotation']["object"]['bndbox']['ymin'])))
x_mins.append(float(w_scale * int(doc['annotation']["object"]['bndbox']['xmin'])))
y_maxes.append(float(h_scale * int(doc['annotation']["object"]['bndbox']['ymax'])))
x_maxes.append(float(w_scale * int(doc['annotation']["object"]['bndbox']['xmax'])))
else:
for one_object in doc['annotation']["object"]:
# ['annotation']["object"][0]["name"]
if one_object['name'] in classes:
img_obtain_classes.append(classes.index(one_object['name']))
y_mins.append(float(h_scale * int(one_object['bndbox']['ymin'])))
x_mins.append(float(w_scale * int(one_object['bndbox']['xmin'])))
y_maxes.append(float(h_scale * int(one_object['bndbox']['ymax'])))
x_maxes.append(float(w_scale * int(one_object['bndbox']['xmax'])))
# example = tf.train.Example(features=tf.train.Features(feature={
# 'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),
# 'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
# 'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str]))
# }
# ))
img_file_name = bytes(img_file_name, encoding='utf8')
example = tf.train.Example(features=tf.train.Features(feature={
'filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_file_name])),
'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
'classes': tf.train.Feature(int64_list=tf.train.Int64List(value=img_obtain_classes)),
'y_mins': tf.train.Feature(float_list=tf.train.FloatList(value=y_mins)), # 各个 object 的 ymin
'x_mins': tf.train.Feature(float_list=tf.train.FloatList(value=x_mins)),
'y_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=y_maxes)),
'x_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=x_maxes)),
'encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str]))
}))
writer.write(example.SerializeToString())
writer.close()
sess.close()
print('ok')
tfrecord读取
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
# import sys
#
# sys.path.append("..")
classes = [
"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
"chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
"pottedplant", "sheep", "sofa", "train", "tvmonitor"
]
# 'filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_file_name])),
# 'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
# 'classes': tf.train.Feature(int64_list=tf.train.Int64List(value=np.array(img_obtain_classes))),
# 'y_mins': tf.train.Feature(float_list=tf.train.FloatList(value=y_mins)), # 各个 object 的 ymin
# 'x_mins': tf.train.Feature(float_list=tf.train.FloatList(value=x_mins)),
# 'y_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=y_maxes)),
# 'x_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=x_maxes)),
# 'encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str]))
def _parse_record(example_proto):
features = {
'filename': tf.FixedLenFeature([], tf.string),
'shape': tf.FixedLenFeature([3], tf.int64),
'classes': tf.VarLenFeature(tf.int64),
'y_mins': tf.VarLenFeature(tf.float32),
'x_mins': tf.VarLenFeature(tf.float32),
'y_maxes': tf.VarLenFeature(tf.float32),
'x_maxes': tf.VarLenFeature(tf.float32),
'encoded': tf.FixedLenFeature((), tf.string)
}
parsed_features = tf.parse_single_example(example_proto, features=features)
return parsed_features
def read_test(input_file):
# 用 dataset 读取 tfrecord 文件
dataset = tf.data.TFRecordDataset(input_file)
dataset = dataset.map(_parse_record)
iterator = dataset.make_initializable_iterator()
max_value = tf.placeholder(tf.int64, shape=[])
with tf.Session() as sess:
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(2):
features = sess.run(iterator.get_next())
name = features['filename']
name = name.decode()
shape = features['shape']
classes = features['classes']
y_mins = features['y_mins']
x_mins = features['x_mins']
y_maxes = features['y_maxes']
x_maxes = features['x_maxes']
# name = name.decode()
img_data = features['encoded']
print(len(img_data))
print('=======')
print("shape", shape)
print("name", name)
print("classes", classes.values)
print("y_mins", y_mins.values)
print("x_mins", x_mins.values)
print("y_maxes", y_maxes.values)
print("x_maxes", x_maxes.values)
img_data = np.fromstring(img_data, dtype=np.uint8)
image_data = np.reshape(img_data, shape)
print("img_data", image_data)
# 从 bytes 数组中加载图片原始数据,并重新 reshape.它的结果是 ndarray 数组
# img_data = np.fromstring(img_data, dtype=np.uint8)
# image_data = np.reshape(img_data, shape)
#
# plt.figure()
# # 显示图片
plt.imshow(image_data)
plt.show()
read_test('F:\数据存储\VOCdevkit\\voc2012.tfrecord')
尺寸不固定矩阵的存储和读取
import json
import jieba
import tensorflow as tf
with open('../data_save/words_info.txt', 'r', encoding='utf-8') as file:
dic = json.loads(file.read())
all_words_word2id = dic["all_words_word2id"]
stop_words = []
with open('./stop_words.txt', encoding='utf-8') as f:
line = f.readline()
while line:
stop_words.append(line[:-1])
line = f.readline()
stop_words = set(stop_words)
print('停用词读取完毕,共{n}个单词'.format(n=len(stop_words)))
dir_path = 'F:\\数据存储\新闻语料\\news2016zh_train.json'
dir_path_test = 'F:\\数据存储\新闻语料\\news2016zh_valid.json'
out_path = 'F:\\数据存储\新闻语料\\news2016zh_train_new.tfrecord'
def getCutSequnce(line):
# 使用jieba 进行中文分词
raw_words = list(jieba.cut(line, cut_all=False))
# 存储一句话的分词结果
raw_word_list = []
# 去除停用词
for word in raw_words:
if word not in stop_words and word not in ['www', 'com', 'http']:
raw_word_list.append(word)
return raw_word_list
writer = tf.python_io.TFRecordWriter(out_path)
i = 0
with open(dir_path, encoding='utf-8') as txt:
one_dic = txt.readline()
while one_dic:
i = i + 1
if i > 10000:
break
if (i % 1000) == 0:
print(i)
one_dic_json = json.loads(one_dic)
title = one_dic_json['title']
content = one_dic_json['content']
if len(content) > 3000:
one_dic = txt.readline()
continue
one_dic = txt.readline()
if len(title) == 0 or len(content) == 0:
continue
title_list = getCutSequnce(title)
content_list = getCutSequnce(content)
title_list_index = []
for one in title_list:
try:
title_list_index.append(all_words_word2id[one])
except:
pass
content_list_index = []
for one_word in content_list:
try:
content_list_index.append(all_words_word2id[one_word])
except:
pass
example = tf.train.Example(features=tf.train.Features(feature={
'title': tf.train.Feature(int64_list=tf.train.Int64List(value=title_list_index)),
'content': tf.train.Feature(int64_list=tf.train.Int64List(value=content_list_index))
}))
writer.write(example.SerializeToString())
import tensorflow as tf
import numpy as np
def _parse_record(example_proto):
features = {
'title': tf.VarLenFeature(tf.int64),
'content': tf.VarLenFeature(dtype=tf.int64)
}
parsed_features = tf.parse_single_example(example_proto, features=features)
return parsed_features
def read_test(input_file):
# 用 dataset 读取 tfrecord 文件
dataset = tf.data.TFRecordDataset(input_file)
dataset = dataset.map(_parse_record)
iterator = dataset.make_initializable_iterator()
with tf.Session() as sess:
sess.run(iterator.initializer)
for i in range(5):
features = sess.run(iterator.get_next())
name = features['title']
content = features['content']
print("xx", content)
print("xx", np.array(content).shape)
# 从 bytes 数组中加载图片原始数据,并重新 reshape.它的结果是 ndarray 数组
read_test('F:\\数据存储\新闻语料\\news2016zh_train_new.tfrecord')
统计数据条数
import tensorflow as tf
def total_sample(file_name):
sample_nums = 0
for record in tf.python_io.tf_record_iterator(file_name):
sample_nums += 1
return sample_nums
result = total_sample('F:\\数据存储\新闻语料\\news2016zh_train_new.tfrecord')
print(result)
TFRecord 使用的更多相关文章
- Tensorflow 处理libsvm格式数据生成TFRecord (parse libsvm data to TFRecord)
#写libsvm格式 数据 write libsvm #!/usr/bin/env python #coding=gbk # ================================= ...
- 学习笔记TF016:CNN实现、数据集、TFRecord、加载图像、模型、训练、调试
AlexNet(Alex Krizhevsky,ILSVRC2012冠军)适合做图像分类.层自左向右.自上向下读取,关联层分为一组,高度.宽度减小,深度增加.深度增加减少网络计算量. 训练模型数据集 ...
- [TFRecord格式数据]利用TFRecords存储与读取带标签的图片
利用TFRecords存储与读取带标签的图片 原创文章,转载请注明出处~ 觉得有用的话,欢迎一起讨论相互学习~Follow Me TFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是 ...
- 深度学习原理与框架-Tfrecord数据集的读取与训练(代码) 1.tf.train.batch(获取batch图片) 2.tf.image.resize_image_with_crop_or_pad(图片压缩) 3.tf.train.per_image_stand..(图片标准化) 4.tf.train.string_input_producer(字符串入队列) 5.tf.TFRecord(读
1.tf.train.batch(image, batch_size=batch_size, num_threads=1) # 获取一个batch的数据 参数说明:image表示输入图片,batch_ ...
- 深度学习原理与框架-Tfrecord数据集的制作 1.tf.train.Examples(数据转换为二进制) 3.tf.image.encode_jpeg(解码图片加码成jpeg) 4.tf.train.Coordinator(构建多线程通道) 5.threading.Thread(建立单线程) 6.tf.python_io.TFR(TFR读入器)
1. 配套使用: tf.train.Examples将数据转换为二进制,提升IO效率和方便管理 对于int类型 : tf.train.Examples(features=tf.train.Featur ...
- 3. Tensorflow生成TFRecord
1. Tensorflow高效流水线Pipeline 2. Tensorflow的数据处理中的Dataset和Iterator 3. Tensorflow生成TFRecord 4. Tensorflo ...
- TFRecord文件的读写
前言在跑通了官网的mnist和cifar10数据之后,笔者尝试着制作自己的数据集,并保存,读入,显示. TensorFlow可以支持cifar10的数据格式, 也提供了标准的TFRecord 格式,而 ...
- 目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练
将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练. import xml.etree.ElementTree as ET import numpy as ...
- tfrecord
制作自己的TFRecord数据集,读取,显示及代码详解 http://blog.csdn.net/miaomiaoyuan/article/details/56865361
- 3 TFRecord样例程序实战
将图片数据写入Record文件 # 定义函数转化变量类型. def _int64_feature(value): return tf.train.Feature(int64_list=tf.train ...
随机推荐
- 「LibreOJ NOI Round #2」不等关系
「LibreOJ NOI Round #2」不等关系 解题思路 令 \(F(k)\) 为恰好有 \(k\) 个大于号不满足的答案,\(G(k)\) 表示钦点了 \(k\) 个大于号不满足,剩下随便填的 ...
- python3.7 64位中安装pygame1.9.3
1.我是用pip命令来安装的,首先,打开cmd,输入pip,检查电脑中有没有安装这个插件(一般python2.7以上自带pip工具) 2.更新pip工具的命令:python -m pip instal ...
- 【阿里云开发】- 安装MySQL数据库
我用的机器配置是 阿里云轻量服务器,系统:CentOS7.3,内存:2G,系统盘40G,1核. 在CentOS中默认安装有MariaDB,这个是MySQL的分支,但为了需要,还是要在系统中安装MySQ ...
- canvas教程(三) 绘制曲线
经过 canvas 教程(二) 绘制直线 我们知道了 canvas 的直线是怎么绘制的 而本次是给大家带来曲线相关的绘制 绘制圆形 在 canvas 中我们可以使用 arc 方法画一个圆 contex ...
- UCOSIII事件标志组
两种同步机制 "或"同步 "与"同步 使能 #define OS_CFG_FLAG_EN 1u /* Enable (1) or Disable (0) cod ...
- Winform开发1
VS的Winform开发中,TextBox可能拖过来的时候不能改变其高度,这就要在其属性Multiline为True.
- 聊Java中的任务调度的实现方法及比较
前言 任务调度是指基于给定时间点,给定时间间隔或者给定执行次数自动执行任务.本文由浅入深介绍四种任务调度的 Java 实现: Timer ScheduledExecutor 开源工具包 Quartz ...
- 【大数据技术能力提升_2】numpy学习
numpy学习 标签(空格分隔): numpy python 数据类型 5种类型:布尔值(bool),整数(int),无符号整数(uint).浮点(float).复数(complex) 支持的原始类型 ...
- Struts框架笔记04_拦截器_标签库
目录 1. Struts2的拦截器 1.1 拦截器概述 1.2 拦截器的实现原理 1.3 Struts的执行流程 1.4 拦截器入门 1.4.1 环境搭建 1.4.2 编写拦截器 1.4.3 配置拦截 ...
- Linux下环境变量设置 (转)
Linux下环境变量设置 1.在Windows 系统下,很多软件安装都需要配置环境变量,比如 安装 jdk ,如果不配置环境变量,在非软件安装的目录下运行javac 命令,将会报告找不到文件,类似的错 ...