tfrecord生成

import os
import xmltodict
import tensorflow as tf
import numpy as np dir_path = 'F:\数据存储\VOCdevkit\VOC2012\Annotations'
dirs = os.listdir(dir_path)
imgs_dir = "F:\数据存储\VOCdevkit\VOC2012\JPEGImages"
out_path = 'F:\数据存储\VOCdevkit\\voc2012.tfrecord' classes = [
"background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
"chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
"pottedplant", "sheep", "sofa", "train", "tvmonitor"
]
sess = tf.Session() def get_and_resize_img(img_file):
'''
将图片设置为224*224的尺寸大小
返回图片,返回变化倍数,shape
'''
img = tf.read_file(imgs_dir + '/' + img_file)
img = tf.image.decode_jpeg(img)
shape_old = sess.run(img).shape
resized_img = tf.image.resize_images(img, [224, 224], method=0)
resized_img = sess.run(resized_img)
resized_img = np.asarray(resized_img, dtype='uint8')
resized_img_str = resized_img.tostring()
shape_new = resized_img.shape
# print(shape_new)
# print(shape_old)
# print('shape_old的长是width是维度1,height是维度0')
w_scale = shape_new[0] / shape_old[1]
h_scale = shape_new[1] / shape_old[0] return resized_img_str, w_scale, h_scale, shape_new writer = tf.python_io.TFRecordWriter(out_path) i = 0
for file in dirs:
i = i + 1
# if i > 1000:
# break
with open(dir_path + '/' + file) as xml_txt:
doc = xmltodict.parse(xml_txt.read())
img_file_name = file.split('.')[0]
resized_img_str, w_scale, h_scale, shape = get_and_resize_img(img_file_name + '.jpg')
img_obtain_classes = []
y_mins = []
x_mins = []
y_maxes = []
x_maxes = []
if type(doc['annotation']["object"]).__name__ == 'OrderedDict':
if doc['annotation']["object"]['name'] in classes:
img_obtain_classes.append(classes.index(doc['annotation']["object"]['name']))
y_mins.append(float(h_scale * int(doc['annotation']["object"]['bndbox']['ymin'])))
x_mins.append(float(w_scale * int(doc['annotation']["object"]['bndbox']['xmin'])))
y_maxes.append(float(h_scale * int(doc['annotation']["object"]['bndbox']['ymax'])))
x_maxes.append(float(w_scale * int(doc['annotation']["object"]['bndbox']['xmax'])))
else:
for one_object in doc['annotation']["object"]:
# ['annotation']["object"][0]["name"]
if one_object['name'] in classes:
img_obtain_classes.append(classes.index(one_object['name']))
y_mins.append(float(h_scale * int(one_object['bndbox']['ymin'])))
x_mins.append(float(w_scale * int(one_object['bndbox']['xmin'])))
y_maxes.append(float(h_scale * int(one_object['bndbox']['ymax'])))
x_maxes.append(float(w_scale * int(one_object['bndbox']['xmax'])))
# example = tf.train.Example(features=tf.train.Features(feature={
# 'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),
# 'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
# 'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str]))
# }
# ))
img_file_name = bytes(img_file_name, encoding='utf8') example = tf.train.Example(features=tf.train.Features(feature={
'filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_file_name])),
'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
'classes': tf.train.Feature(int64_list=tf.train.Int64List(value=img_obtain_classes)),
'y_mins': tf.train.Feature(float_list=tf.train.FloatList(value=y_mins)), # 各个 object 的 ymin
'x_mins': tf.train.Feature(float_list=tf.train.FloatList(value=x_mins)),
'y_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=y_maxes)),
'x_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=x_maxes)),
'encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str]))
}))
writer.write(example.SerializeToString())
writer.close()
sess.close()
print('ok')

tfrecord读取

import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
# import sys
#
# sys.path.append("..") classes = [
"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
"chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
"pottedplant", "sheep", "sofa", "train", "tvmonitor"
] # 'filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_file_name])),
# 'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
# 'classes': tf.train.Feature(int64_list=tf.train.Int64List(value=np.array(img_obtain_classes))),
# 'y_mins': tf.train.Feature(float_list=tf.train.FloatList(value=y_mins)), # 各个 object 的 ymin
# 'x_mins': tf.train.Feature(float_list=tf.train.FloatList(value=x_mins)),
# 'y_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=y_maxes)),
# 'x_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=x_maxes)),
# 'encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str])) def _parse_record(example_proto):
features = {
'filename': tf.FixedLenFeature([], tf.string),
'shape': tf.FixedLenFeature([3], tf.int64),
'classes': tf.VarLenFeature(tf.int64),
'y_mins': tf.VarLenFeature(tf.float32),
'x_mins': tf.VarLenFeature(tf.float32),
'y_maxes': tf.VarLenFeature(tf.float32),
'x_maxes': tf.VarLenFeature(tf.float32),
'encoded': tf.FixedLenFeature((), tf.string)
}
parsed_features = tf.parse_single_example(example_proto, features=features)
return parsed_features def read_test(input_file):
# 用 dataset 读取 tfrecord 文件
dataset = tf.data.TFRecordDataset(input_file)
dataset = dataset.map(_parse_record)
iterator = dataset.make_initializable_iterator()
max_value = tf.placeholder(tf.int64, shape=[])
with tf.Session() as sess:
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(2):
features = sess.run(iterator.get_next())
name = features['filename']
name = name.decode()
shape = features['shape']
classes = features['classes']
y_mins = features['y_mins']
x_mins = features['x_mins']
y_maxes = features['y_maxes']
x_maxes = features['x_maxes']
# name = name.decode()
img_data = features['encoded'] print(len(img_data))
print('=======')
print("shape", shape)
print("name", name)
print("classes", classes.values)
print("y_mins", y_mins.values)
print("x_mins", x_mins.values)
print("y_maxes", y_maxes.values)
print("x_maxes", x_maxes.values)
img_data = np.fromstring(img_data, dtype=np.uint8)
image_data = np.reshape(img_data, shape)
print("img_data", image_data)
# 从 bytes 数组中加载图片原始数据,并重新 reshape.它的结果是 ndarray 数组
# img_data = np.fromstring(img_data, dtype=np.uint8)
# image_data = np.reshape(img_data, shape)
#
# plt.figure()
# # 显示图片
plt.imshow(image_data)
plt.show() read_test('F:\数据存储\VOCdevkit\\voc2012.tfrecord')

尺寸不固定矩阵的存储和读取

import json
import jieba
import tensorflow as tf with open('../data_save/words_info.txt', 'r', encoding='utf-8') as file:
dic = json.loads(file.read())
all_words_word2id = dic["all_words_word2id"] stop_words = []
with open('./stop_words.txt', encoding='utf-8') as f:
line = f.readline()
while line:
stop_words.append(line[:-1])
line = f.readline()
stop_words = set(stop_words)
print('停用词读取完毕,共{n}个单词'.format(n=len(stop_words))) dir_path = 'F:\\数据存储\新闻语料\\news2016zh_train.json'
dir_path_test = 'F:\\数据存储\新闻语料\\news2016zh_valid.json'
out_path = 'F:\\数据存储\新闻语料\\news2016zh_train_new.tfrecord' def getCutSequnce(line):
# 使用jieba 进行中文分词
raw_words = list(jieba.cut(line, cut_all=False))
# 存储一句话的分词结果
raw_word_list = []
# 去除停用词
for word in raw_words:
if word not in stop_words and word not in ['www', 'com', 'http']:
raw_word_list.append(word) return raw_word_list writer = tf.python_io.TFRecordWriter(out_path)
i = 0 with open(dir_path, encoding='utf-8') as txt:
one_dic = txt.readline()
while one_dic:
i = i + 1
if i > 10000:
break
if (i % 1000) == 0:
print(i)
one_dic_json = json.loads(one_dic) title = one_dic_json['title']
content = one_dic_json['content']
if len(content) > 3000:
one_dic = txt.readline()
continue
one_dic = txt.readline() if len(title) == 0 or len(content) == 0:
continue
title_list = getCutSequnce(title)
content_list = getCutSequnce(content)
title_list_index = []
for one in title_list:
try:
title_list_index.append(all_words_word2id[one])
except:
pass content_list_index = []
for one_word in content_list:
try:
content_list_index.append(all_words_word2id[one_word])
except:
pass example = tf.train.Example(features=tf.train.Features(feature={
'title': tf.train.Feature(int64_list=tf.train.Int64List(value=title_list_index)),
'content': tf.train.Feature(int64_list=tf.train.Int64List(value=content_list_index))
}))
writer.write(example.SerializeToString()) import tensorflow as tf
import numpy as np def _parse_record(example_proto):
features = {
'title': tf.VarLenFeature(tf.int64),
'content': tf.VarLenFeature(dtype=tf.int64)
}
parsed_features = tf.parse_single_example(example_proto, features=features)
return parsed_features def read_test(input_file):
# 用 dataset 读取 tfrecord 文件
dataset = tf.data.TFRecordDataset(input_file)
dataset = dataset.map(_parse_record)
iterator = dataset.make_initializable_iterator()
with tf.Session() as sess:
sess.run(iterator.initializer)
for i in range(5):
features = sess.run(iterator.get_next())
name = features['title']
content = features['content'] print("xx", content)
print("xx", np.array(content).shape)
# 从 bytes 数组中加载图片原始数据,并重新 reshape.它的结果是 ndarray 数组 read_test('F:\\数据存储\新闻语料\\news2016zh_train_new.tfrecord')

统计数据条数

import tensorflow as tf

def total_sample(file_name):
sample_nums = 0
for record in tf.python_io.tf_record_iterator(file_name):
sample_nums += 1
return sample_nums result = total_sample('F:\\数据存储\新闻语料\\news2016zh_train_new.tfrecord')
print(result)

TFRecord 使用的更多相关文章

  1. Tensorflow 处理libsvm格式数据生成TFRecord (parse libsvm data to TFRecord)

    #写libsvm格式 数据 write libsvm     #!/usr/bin/env python #coding=gbk # ================================= ...

  2. 学习笔记TF016:CNN实现、数据集、TFRecord、加载图像、模型、训练、调试

    AlexNet(Alex Krizhevsky,ILSVRC2012冠军)适合做图像分类.层自左向右.自上向下读取,关联层分为一组,高度.宽度减小,深度增加.深度增加减少网络计算量. 训练模型数据集 ...

  3. [TFRecord格式数据]利用TFRecords存储与读取带标签的图片

    利用TFRecords存储与读取带标签的图片 原创文章,转载请注明出处~ 觉得有用的话,欢迎一起讨论相互学习~Follow Me TFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是 ...

  4. 深度学习原理与框架-Tfrecord数据集的读取与训练(代码) 1.tf.train.batch(获取batch图片) 2.tf.image.resize_image_with_crop_or_pad(图片压缩) 3.tf.train.per_image_stand..(图片标准化) 4.tf.train.string_input_producer(字符串入队列) 5.tf.TFRecord(读

    1.tf.train.batch(image, batch_size=batch_size, num_threads=1) # 获取一个batch的数据 参数说明:image表示输入图片,batch_ ...

  5. 深度学习原理与框架-Tfrecord数据集的制作 1.tf.train.Examples(数据转换为二进制) 3.tf.image.encode_jpeg(解码图片加码成jpeg) 4.tf.train.Coordinator(构建多线程通道) 5.threading.Thread(建立单线程) 6.tf.python_io.TFR(TFR读入器)

    1. 配套使用: tf.train.Examples将数据转换为二进制,提升IO效率和方便管理 对于int类型 : tf.train.Examples(features=tf.train.Featur ...

  6. 3. Tensorflow生成TFRecord

    1. Tensorflow高效流水线Pipeline 2. Tensorflow的数据处理中的Dataset和Iterator 3. Tensorflow生成TFRecord 4. Tensorflo ...

  7. TFRecord文件的读写

    前言在跑通了官网的mnist和cifar10数据之后,笔者尝试着制作自己的数据集,并保存,读入,显示. TensorFlow可以支持cifar10的数据格式, 也提供了标准的TFRecord 格式,而 ...

  8. 目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练

    将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练. import xml.etree.ElementTree as ET import numpy as ...

  9. tfrecord

    制作自己的TFRecord数据集,读取,显示及代码详解 http://blog.csdn.net/miaomiaoyuan/article/details/56865361

  10. 3 TFRecord样例程序实战

    将图片数据写入Record文件 # 定义函数转化变量类型. def _int64_feature(value): return tf.train.Feature(int64_list=tf.train ...

随机推荐

  1. wpf Log4net的配置和使用

    现在项目涉及的是cs客户端,在项目中使用log4net记录本地日志和异常信息,这里项目做完了,想着自己做一个demo,测试记录一下log4Net的配置使用. 第一步.新建一个wpf应用程序,项目右键 ...

  2. 【洛谷 P2408】 不同子串个数(后缀自动机)

    题目链接 裸体就是身体. 建出\(SAM\),\(DAG\)上跑\(DP\),\(f[u]=1+\sum_{(u,v)\in DAG}f[v]\) 答案为\(f[1]-1\)(因为根节点没有字符) # ...

  3. vue.js相关教程

    Vue.js——60分钟快速入门 http://www.cnblogs.com/keepfool/p/5619070.html

  4. NameError: name “ ” is not defined

    NameError: name “ ” is not defined 问题一:name ‘name’ is not defined "name"两端是双下划线"_&quo ...

  5. HTML学习摘要2

    DAY 2 HTML 标签可以拥有属性.属性提供了有关 HTML 元素的更多的信息. 属性总是以名称/值对的形式出现,比如:name="value". 属性总是在 HTML 元素的 ...

  6. day35-python之协程

    1.协程 # import time # import queue # # def consumer(name): # # print("--->ready to eat baozi. ...

  7. SIM7600CE http post

    SIM7600CE是一款SMT封装的模块,支持 LTE-TDD/LTE-FDD/HSPA+/TD-SCDMA/EVDO和GSM/GPRS/EDGE等频段,支持LTE CAT4(下行速度为150Mbps ...

  8. redis被攻击,怎么预防

    今天,自己的redis服务器被黑客攻击了,数据全部被删除 从图中可以看到,在db0中多了一个crackit,他就是罪魁祸首,他的值就是ssh无密码连接时需要的authorized_keys. 我们被攻 ...

  9. linux 显示不可见字符

    在Linux中,cat -A file可以把文件中的所有可见的和不可见的字符都显示出来,在Vim中,如何将不可见字符也显示出来呢?当然,如果只是想在Vim中查看的话,可以这样:%!cat -A在Vim ...

  10. linux下安装grpc插件 (c++和go语言)

    在debian/ubuntu系统下,需要做如下准备操作: $ [sudo] apt-get install build-essential autoconf libtool pkg-config 如果 ...