TFRecord 使用
tfrecord生成
import os
import xmltodict
import tensorflow as tf
import numpy as np
dir_path = 'F:\数据存储\VOCdevkit\VOC2012\Annotations'
dirs = os.listdir(dir_path)
imgs_dir = "F:\数据存储\VOCdevkit\VOC2012\JPEGImages"
out_path = 'F:\数据存储\VOCdevkit\\voc2012.tfrecord'
classes = [
"background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
"chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
"pottedplant", "sheep", "sofa", "train", "tvmonitor"
]
sess = tf.Session()
def get_and_resize_img(img_file):
'''
将图片设置为224*224的尺寸大小
返回图片,返回变化倍数,shape
'''
img = tf.read_file(imgs_dir + '/' + img_file)
img = tf.image.decode_jpeg(img)
shape_old = sess.run(img).shape
resized_img = tf.image.resize_images(img, [224, 224], method=0)
resized_img = sess.run(resized_img)
resized_img = np.asarray(resized_img, dtype='uint8')
resized_img_str = resized_img.tostring()
shape_new = resized_img.shape
# print(shape_new)
# print(shape_old)
# print('shape_old的长是width是维度1,height是维度0')
w_scale = shape_new[0] / shape_old[1]
h_scale = shape_new[1] / shape_old[0]
return resized_img_str, w_scale, h_scale, shape_new
writer = tf.python_io.TFRecordWriter(out_path)
i = 0
for file in dirs:
i = i + 1
# if i > 1000:
# break
with open(dir_path + '/' + file) as xml_txt:
doc = xmltodict.parse(xml_txt.read())
img_file_name = file.split('.')[0]
resized_img_str, w_scale, h_scale, shape = get_and_resize_img(img_file_name + '.jpg')
img_obtain_classes = []
y_mins = []
x_mins = []
y_maxes = []
x_maxes = []
if type(doc['annotation']["object"]).__name__ == 'OrderedDict':
if doc['annotation']["object"]['name'] in classes:
img_obtain_classes.append(classes.index(doc['annotation']["object"]['name']))
y_mins.append(float(h_scale * int(doc['annotation']["object"]['bndbox']['ymin'])))
x_mins.append(float(w_scale * int(doc['annotation']["object"]['bndbox']['xmin'])))
y_maxes.append(float(h_scale * int(doc['annotation']["object"]['bndbox']['ymax'])))
x_maxes.append(float(w_scale * int(doc['annotation']["object"]['bndbox']['xmax'])))
else:
for one_object in doc['annotation']["object"]:
# ['annotation']["object"][0]["name"]
if one_object['name'] in classes:
img_obtain_classes.append(classes.index(one_object['name']))
y_mins.append(float(h_scale * int(one_object['bndbox']['ymin'])))
x_mins.append(float(w_scale * int(one_object['bndbox']['xmin'])))
y_maxes.append(float(h_scale * int(one_object['bndbox']['ymax'])))
x_maxes.append(float(w_scale * int(one_object['bndbox']['xmax'])))
# example = tf.train.Example(features=tf.train.Features(feature={
# 'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),
# 'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
# 'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str]))
# }
# ))
img_file_name = bytes(img_file_name, encoding='utf8')
example = tf.train.Example(features=tf.train.Features(feature={
'filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_file_name])),
'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
'classes': tf.train.Feature(int64_list=tf.train.Int64List(value=img_obtain_classes)),
'y_mins': tf.train.Feature(float_list=tf.train.FloatList(value=y_mins)), # 各个 object 的 ymin
'x_mins': tf.train.Feature(float_list=tf.train.FloatList(value=x_mins)),
'y_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=y_maxes)),
'x_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=x_maxes)),
'encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str]))
}))
writer.write(example.SerializeToString())
writer.close()
sess.close()
print('ok')
tfrecord读取
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
# import sys
#
# sys.path.append("..")
classes = [
"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
"chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
"pottedplant", "sheep", "sofa", "train", "tvmonitor"
]
# 'filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_file_name])),
# 'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
# 'classes': tf.train.Feature(int64_list=tf.train.Int64List(value=np.array(img_obtain_classes))),
# 'y_mins': tf.train.Feature(float_list=tf.train.FloatList(value=y_mins)), # 各个 object 的 ymin
# 'x_mins': tf.train.Feature(float_list=tf.train.FloatList(value=x_mins)),
# 'y_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=y_maxes)),
# 'x_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=x_maxes)),
# 'encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str]))
def _parse_record(example_proto):
features = {
'filename': tf.FixedLenFeature([], tf.string),
'shape': tf.FixedLenFeature([3], tf.int64),
'classes': tf.VarLenFeature(tf.int64),
'y_mins': tf.VarLenFeature(tf.float32),
'x_mins': tf.VarLenFeature(tf.float32),
'y_maxes': tf.VarLenFeature(tf.float32),
'x_maxes': tf.VarLenFeature(tf.float32),
'encoded': tf.FixedLenFeature((), tf.string)
}
parsed_features = tf.parse_single_example(example_proto, features=features)
return parsed_features
def read_test(input_file):
# 用 dataset 读取 tfrecord 文件
dataset = tf.data.TFRecordDataset(input_file)
dataset = dataset.map(_parse_record)
iterator = dataset.make_initializable_iterator()
max_value = tf.placeholder(tf.int64, shape=[])
with tf.Session() as sess:
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(2):
features = sess.run(iterator.get_next())
name = features['filename']
name = name.decode()
shape = features['shape']
classes = features['classes']
y_mins = features['y_mins']
x_mins = features['x_mins']
y_maxes = features['y_maxes']
x_maxes = features['x_maxes']
# name = name.decode()
img_data = features['encoded']
print(len(img_data))
print('=======')
print("shape", shape)
print("name", name)
print("classes", classes.values)
print("y_mins", y_mins.values)
print("x_mins", x_mins.values)
print("y_maxes", y_maxes.values)
print("x_maxes", x_maxes.values)
img_data = np.fromstring(img_data, dtype=np.uint8)
image_data = np.reshape(img_data, shape)
print("img_data", image_data)
# 从 bytes 数组中加载图片原始数据,并重新 reshape.它的结果是 ndarray 数组
# img_data = np.fromstring(img_data, dtype=np.uint8)
# image_data = np.reshape(img_data, shape)
#
# plt.figure()
# # 显示图片
plt.imshow(image_data)
plt.show()
read_test('F:\数据存储\VOCdevkit\\voc2012.tfrecord')
尺寸不固定矩阵的存储和读取
import json
import jieba
import tensorflow as tf
with open('../data_save/words_info.txt', 'r', encoding='utf-8') as file:
dic = json.loads(file.read())
all_words_word2id = dic["all_words_word2id"]
stop_words = []
with open('./stop_words.txt', encoding='utf-8') as f:
line = f.readline()
while line:
stop_words.append(line[:-1])
line = f.readline()
stop_words = set(stop_words)
print('停用词读取完毕,共{n}个单词'.format(n=len(stop_words)))
dir_path = 'F:\\数据存储\新闻语料\\news2016zh_train.json'
dir_path_test = 'F:\\数据存储\新闻语料\\news2016zh_valid.json'
out_path = 'F:\\数据存储\新闻语料\\news2016zh_train_new.tfrecord'
def getCutSequnce(line):
# 使用jieba 进行中文分词
raw_words = list(jieba.cut(line, cut_all=False))
# 存储一句话的分词结果
raw_word_list = []
# 去除停用词
for word in raw_words:
if word not in stop_words and word not in ['www', 'com', 'http']:
raw_word_list.append(word)
return raw_word_list
writer = tf.python_io.TFRecordWriter(out_path)
i = 0
with open(dir_path, encoding='utf-8') as txt:
one_dic = txt.readline()
while one_dic:
i = i + 1
if i > 10000:
break
if (i % 1000) == 0:
print(i)
one_dic_json = json.loads(one_dic)
title = one_dic_json['title']
content = one_dic_json['content']
if len(content) > 3000:
one_dic = txt.readline()
continue
one_dic = txt.readline()
if len(title) == 0 or len(content) == 0:
continue
title_list = getCutSequnce(title)
content_list = getCutSequnce(content)
title_list_index = []
for one in title_list:
try:
title_list_index.append(all_words_word2id[one])
except:
pass
content_list_index = []
for one_word in content_list:
try:
content_list_index.append(all_words_word2id[one_word])
except:
pass
example = tf.train.Example(features=tf.train.Features(feature={
'title': tf.train.Feature(int64_list=tf.train.Int64List(value=title_list_index)),
'content': tf.train.Feature(int64_list=tf.train.Int64List(value=content_list_index))
}))
writer.write(example.SerializeToString())
import tensorflow as tf
import numpy as np
def _parse_record(example_proto):
features = {
'title': tf.VarLenFeature(tf.int64),
'content': tf.VarLenFeature(dtype=tf.int64)
}
parsed_features = tf.parse_single_example(example_proto, features=features)
return parsed_features
def read_test(input_file):
# 用 dataset 读取 tfrecord 文件
dataset = tf.data.TFRecordDataset(input_file)
dataset = dataset.map(_parse_record)
iterator = dataset.make_initializable_iterator()
with tf.Session() as sess:
sess.run(iterator.initializer)
for i in range(5):
features = sess.run(iterator.get_next())
name = features['title']
content = features['content']
print("xx", content)
print("xx", np.array(content).shape)
# 从 bytes 数组中加载图片原始数据,并重新 reshape.它的结果是 ndarray 数组
read_test('F:\\数据存储\新闻语料\\news2016zh_train_new.tfrecord')
统计数据条数
import tensorflow as tf
def total_sample(file_name):
sample_nums = 0
for record in tf.python_io.tf_record_iterator(file_name):
sample_nums += 1
return sample_nums
result = total_sample('F:\\数据存储\新闻语料\\news2016zh_train_new.tfrecord')
print(result)
TFRecord 使用的更多相关文章
- Tensorflow 处理libsvm格式数据生成TFRecord (parse libsvm data to TFRecord)
#写libsvm格式 数据 write libsvm #!/usr/bin/env python #coding=gbk # ================================= ...
- 学习笔记TF016:CNN实现、数据集、TFRecord、加载图像、模型、训练、调试
AlexNet(Alex Krizhevsky,ILSVRC2012冠军)适合做图像分类.层自左向右.自上向下读取,关联层分为一组,高度.宽度减小,深度增加.深度增加减少网络计算量. 训练模型数据集 ...
- [TFRecord格式数据]利用TFRecords存储与读取带标签的图片
利用TFRecords存储与读取带标签的图片 原创文章,转载请注明出处~ 觉得有用的话,欢迎一起讨论相互学习~Follow Me TFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是 ...
- 深度学习原理与框架-Tfrecord数据集的读取与训练(代码) 1.tf.train.batch(获取batch图片) 2.tf.image.resize_image_with_crop_or_pad(图片压缩) 3.tf.train.per_image_stand..(图片标准化) 4.tf.train.string_input_producer(字符串入队列) 5.tf.TFRecord(读
1.tf.train.batch(image, batch_size=batch_size, num_threads=1) # 获取一个batch的数据 参数说明:image表示输入图片,batch_ ...
- 深度学习原理与框架-Tfrecord数据集的制作 1.tf.train.Examples(数据转换为二进制) 3.tf.image.encode_jpeg(解码图片加码成jpeg) 4.tf.train.Coordinator(构建多线程通道) 5.threading.Thread(建立单线程) 6.tf.python_io.TFR(TFR读入器)
1. 配套使用: tf.train.Examples将数据转换为二进制,提升IO效率和方便管理 对于int类型 : tf.train.Examples(features=tf.train.Featur ...
- 3. Tensorflow生成TFRecord
1. Tensorflow高效流水线Pipeline 2. Tensorflow的数据处理中的Dataset和Iterator 3. Tensorflow生成TFRecord 4. Tensorflo ...
- TFRecord文件的读写
前言在跑通了官网的mnist和cifar10数据之后,笔者尝试着制作自己的数据集,并保存,读入,显示. TensorFlow可以支持cifar10的数据格式, 也提供了标准的TFRecord 格式,而 ...
- 目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练
将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练. import xml.etree.ElementTree as ET import numpy as ...
- tfrecord
制作自己的TFRecord数据集,读取,显示及代码详解 http://blog.csdn.net/miaomiaoyuan/article/details/56865361
- 3 TFRecord样例程序实战
将图片数据写入Record文件 # 定义函数转化变量类型. def _int64_feature(value): return tf.train.Feature(int64_list=tf.train ...
随机推荐
- WebSocket 转
即时通信常用手段 1.第三方平台 谷歌.腾讯 环信等多如牛毛,其中谷歌即时通信是免费的,但免费就是免费的并不好用.其他的一些第三方一般收费的,使用要则限流(1s/限制x条消息)要么则限制用户数. 但稳 ...
- openssl jia adress
???????????????????????????????????????????openssl证IP 首先创建openssl.cnf, 内容如下. 其中organizationalUnitNam ...
- 什么是MBR
MBR的定义 MBR(Main Boot Record)主引导记录是位于磁盘最前边的一段引导代码,由磁盘操作系统(DOS)在对磁盘初始化时产生,负责磁盘操作系统(DOS)对磁盘进行读写时磁盘分区合法性 ...
- 【kubernetes】通过rancher2部署k8s
1. K8S相关介绍 十分钟带你理解Kubernetes核心概念 2. 部署rancher # 更新操作系统软件包 yum update -y # 删除历史容器及数据 docker rm -f $(d ...
- 【转载】C#中ToArray方法将List集合转换为对应的数组
在C#的List集合操作中,可以使用List集合自带的ToArray方法来将List集合转换为对应的Array数组元素.ToArray方法的签名为T[] ToArray(),存在于命名空间System ...
- Node笔记(新手入门必看)
. 初识Node.js 1.1 Node.js是什么 Node.js® is a JavaScript runtime built on Chrome's V8 JavaScript engine. ...
- group by 和 order by 的区别 + 理解过程
order by 和 group by 的区别order by 和 group by 的区别:1,order by 从英文里理解就是行的排序方式,默认的为升序. order by 后面必须列出排序的字 ...
- android RecyclerView的Grid布局案例
1.先创建activity_grid.xml 和 activity_grid_item.xml <?xml version="1.0" encoding="utf- ...
- Air for ANE:一星期的调试笔记
来源:http://blog.csdn.net/hero82748274/article/details/8656674 第一次尝试ANE的东西,让我感觉到很折腾人.adobe 出的这个方案虽然可以解 ...
- iptables详细介绍
iptables简介 netfilter/iptables(简称为iptables)组成Linux平台下的包过滤防火墙,与大多数的Linux软件一样,这个包过滤防火墙是免费的,它可以代替昂贵的商业防火 ...