TFRecord读写简介+Demo 基于Ubuntu18.04+Tensorflow1.12 无WARNING
简介
- TFRecord是TensorFlow官方推荐使用的数据格式化存储工具。
- 它规范了数据的读写方式。
- 只要生成一次TFRecord,之后的数据读取和加工处理的效率都会得到提高。
将图片转换成TFRecord
本例,将fashion-MNIST数据转换成TFRecord,需要先下载fashion数据集到当前目录下,参考:https://github.com/zalandoresearch/fashion-mnist/tree/master/data/fashion
import numpy as np
import tensorflow as tf
import gzip
import os fashion_mnist_directory = './data/fashion/' def load_mnist(path, kind='train'):
labels_path = os.path.join(path, '%s-labels-idx1-ubyte.gz' % kind)
images_path = os.path.join(path, '%s-images-idx3-ubyte.gz' % kind) with gzip.open(labels_path, 'rb') as lbpath:
labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8) with gzip.open(images_path, 'rb') as imgpath:
images = np.frombuffer(imgpath.read(), dtype=np.uint8, offset=16).reshape(-1, 784) print(labels_path, "shape =", labels.shape)
print(images_path, "shape =", images.shape) return images, labels def make_example(image, label):
return tf.train.Example(features=tf.train.Features(feature={
'image_raw' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[image.tobytes()])),
'label' : tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label) ])) })) def write_tfrecord(images, labels, filename):
writer = tf.python_io.TFRecordWriter(filename)
for image, label, k in zip(images, labels, range(labels.shape[0])):
exam = make_example(image, label)
writer.write(exam.SerializeToString())
if (k%100 == 0):
print("\rwriting", filename, "%6.2f%% complited." %(100.0*(k+1)/labels.shape[0]), end='') print("\rwriting", filename, "%6.2f%% complited." %(100.0))
writer.close() def main():
train_images, train_labels = load_mnist(fashion_mnist_directory, 'train')
test_images, test_labels = load_mnist(fashion_mnist_directory, 't10k') write_tfrecord(train_images, train_labels, 'fashion_mnist_train.tfrecords')
write_tfrecord(test_images, test_labels, 'fashion_mnist_test.tfrecords') if __name__ == '__main__':
main()
读取TFRecord数据来训练
以下代码读取TFRecord数据用于训练,改代码改编自官方例程:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/how_tos/reading_data
原始代码运行时报错,已修复。
注意:在这个例子中,_, loss_value = sess.run([train_op, loss]),只执行一次Batch Input,无论[]中是什么,有多少个操作。
import argparse
import os.path
import sys
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import mnist FLAGS = None TRAIN_FILE = 'fashion_mnist_train.tfrecords'
VALIDATION_FILE = 'fashion_mnist_test.tfrecords' def decode(serialized_example):
features = tf.parse_single_example(serialized_example,
features={'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)})
image = tf.decode_raw(features['image_raw'], tf.uint8)
image.set_shape((mnist.IMAGE_PIXELS))
label = tf.cast(features['label'], tf.int32)
return image, label def augment(image, label):
"""Placeholder for data augmentation."""
# OPTIONAL: Could reshape into a 28x28 image and apply distortions here.
return image, label def normalize(image, label):
"""Convert `image` from [0, 255] -> [-0.5, 0.5] floats."""
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
return image, label def inputs(train, batch_size, num_epochs):
"""Reads input data"""
if not num_epochs:
num_epochs = None
filename = os.path.join(FLAGS.train_dir, TRAIN_FILE if train else VALIDATION_FILE) with tf.name_scope('input'):
dataset = tf.data.TFRecordDataset(filename)
dataset = dataset.map(decode)
dataset = dataset.map(augment)
dataset = dataset.map(normalize)
dataset = dataset.shuffle(1000 + 3 * batch_size)
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
return iterator.get_next() def run_training():
with tf.Graph().as_default():
image_batch, label_batch = inputs(train=True,
batch_size=FLAGS.batch_size,
num_epochs=FLAGS.num_epochs)
logits = mnist.inference(image_batch, FLAGS.hidden1, FLAGS.hidden2)
loss = mnist.loss(logits, label_batch)
train_op = mnist.training(loss, FLAGS.learning_rate) init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer()) with tf.Session() as sess:
sess.run(init_op)
try:
step = 0
while True: # Train until OutOfRangeError
start_time = time.time()
_, loss_value = sess.run([train_op, loss])
duration = time.time() - start_time
if step % 100 == 0:
print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
step += 1
except tf.errors.OutOfRangeError:
print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step)) def main(_):
run_training() if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--learning_rate', type=float, default=0.01, help='Initial learning rate.')
parser.add_argument('--num_epochs', type=int, default=2, help='Number of epochs to run trainer.')
parser.add_argument('--hidden1', type=int, default=128, help='Number of units in hidden layer 1.')
parser.add_argument('--hidden2', type=int, default=32, help='Number of units in hidden layer 2.')
parser.add_argument('--batch_size', type=int, default=100, help='Batch size.')
parser.add_argument('--train_dir', type=str, default='./', help='Directory with the training data.')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
参考了:
- https://blog.csdn.net/gg_18826075157/article/details/78449104
- https://github.com/zalandoresearch/fashion-mnist/blob/master/utils/mnist_reader.py
TFRecord读写简介+Demo 基于Ubuntu18.04+Tensorflow1.12 无WARNING的更多相关文章
- kubeadm部署1.17.3[基于Ubuntu18.04]
基于 Ubuntu18.04 使用 kubeadm 部署Kubernetes 1.17.3 高可用集群 环境 所有节点初始化 # cat <<EOF>> /etc/hosts ...
- 基于Ubuntu18.04一站式部署(python-mysql-redis-nginx)
基于Ubuntu18.04一站式部署 Python3.6.8的安装 1. 安装依赖 ~$ sudo apt install openssl* zlib* 2. 安装python3.6.8(个人建议从官 ...
- ubuntu18.04系统下无外部显示问题解决
记录一下自己作死过程. 由于学习的需要,在windows10下装了ubuntu18.04系统,第一次装这个系统时,也出现了无外部显示,那时候是老师帮忙搞好的,当时没太在意,只是走马关花的看了老师操作了 ...
- Kubernetes 基于 ubuntu18.04 手工部署 (k8s)
由于工作的需要, 手工部署一个 Kubernetes 环境(k8s).(以前都是云上搞定,拿来用) 习惯把这种工作记录下来,自己备查也和别人分享 网上相关文章很多, 我也参考了很多,这里推荐一个 链接 ...
- TensorFlow从入门到理解(一):搭建开发环境【基于Ubuntu18.04】
*注:教程及本文章皆使用Python3+语言,执行.py文件都是用终端(如果使用Python2+和IDE都会和本文描述有点不符) 一.安装,测试,卸载 TensorFlow官网介绍得很全面,很完美了, ...
- 腾讯云服务器ubuntu18.04部署禅道系统
踩了不少坑,记录一下. 基于ubuntu18.04 一开始按照网上的攻略下载安装包 ZenTaoPMS.9.8.3.zbox_64.tar.gz,通过FileZilla传到linux的/opt下面,解 ...
- 【Tool】---ubuntu18.04配置oh-my-zsh工具
作为Linux忠实用户,应该没有人不知道bash shell工具了吧,其实除了bash还有许多其他的工具,zsh就是一款很好得选择,基于zsh shell得基础之上,oh-my-zsh工具更是超级利器 ...
- ubuntu18.04下搭建深度学习环境anaconda2+ cuda9.0+cudnn7.0.5+tensorflow1.7【原创】【学习笔记】
PC:ubuntu18.04.i5.七彩虹GTX1060显卡.固态硬盘.机械硬盘 作者:庄泽彬(欢迎转载,请注明作者) 说明:记录在ubuntu18.04环境下搭建深度学习的环境,之前安装了cuda9 ...
- tensorflow/pytorch/mxnet的pip安装,非源代码编译,基于cuda10/cudnn7.4.1/ubuntu18.04.md
os安装 目前对tensorflow和cuda支持最好的是ubuntu的18.04 ,16.04这种lts,推荐使用18.04版本.非lts的版本一般不推荐. Windows倒是也能用来装深度GPU环 ...
随机推荐
- 流暢的python學習-3
一.文件操作 #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Thu Apr 23 20:59 ...
- python errno库与socket.connect_ex()方法的结合使用
前言:一般socket链接会首选connect方法,该方法会一直尝试链接.那么今天展示下connect_ex()方法,该方法如果链接成功会返回0,失败会返回errno库中的errorcode中的key ...
- ts 学习笔记-基础篇
目录 基础 原始数据类型 布尔值 数字 字符串 空值 Null 和 Undefined 任意值 类型推论 联合类型 接口 数组 函数 类型断言 申明文件 什么是申明文件 三斜线指令 第三方声明文件 内 ...
- 解决iOS上网页滑动不流畅问题
body { overflow:auto; /* 用于 android4+,或其他设备 */ -webkit-overflow-scrolling:touch; /* 用于 ios5+ */ }说明: ...
- 模式识别课程大作业 Shopee 商品图像检索
大作业项目简介 在如今的信息科技时代, 带有拍照功能的移动设备如手机.相机等得到了极大的普及和流行, 各种各样的图片和视频可以随时随地获得, 并借助互联网快速传播, 这种趋势使得网络上的数字图片和视频 ...
- 实战 | Hive 数据倾斜问题定位排查及解决
Hive 数据倾斜怎么发现,怎么定位,怎么解决 多数介绍数据倾斜的文章都是以大篇幅的理论为主,并没有给出具体的数据倾斜案例.当工作中遇到了倾斜问题,这些理论很难直接应用,导致我们面对倾斜时还是不知所措 ...
- 自学linux——2.认识目录及常用指(命)令
认识目录及常用指(命)令 1.备份: 快照(还原精灵):短期备份 频繁备份 可关可开.可能会影响系统的操作. 备份时:虚拟机--快照 还原时:虚拟机--快照--快照管理器--相应位置--转到 克隆 ...
- [洛谷P3376题解]网络流(最大流)的实现算法讲解与代码
[洛谷P3376题解]网络流(最大流)的实现算法讲解与代码 更坏的阅读体验 定义 对于给定的一个网络,有向图中每个的边权表示可以通过的最大流量.假设出发点S水流无限大,求水流到终点T后的最大流量. 起 ...
- MSF使用OpenSSL流量加密
MSF使用OpenSSL流量加密 前言 之前在博客里使用了Openssl对流量进行加密,这次我们来复现暗月师傅红队指南中的一篇文章,尝试用OpenSSL对Metasploit的流量进行加密,以此来躲避 ...
- 跟我一起写 Makefile(六)
书写命令 ---- 每条规则中的命令和操作系统Shell的命令行是一致的.make会一按顺序一条一条的执行命令,每条命令的开头必须以[Tab]键开头,除非,命令是紧跟在依赖规则后面的分号后的.在命令行 ...