import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 定义函数转化变量类型。
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) # 将数据转化为tf.train.Example格式。
def _make_example(pixels, label, image):
image_raw = image.tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'pixels': _int64_feature(pixels),
'label': _int64_feature(np.argmax(label)),
'image_raw': _bytes_feature(image_raw)
}))
return example # 读取mnist训练数据。
mnist = input_data.read_data_sets("E:\\MNIST_data\\",dtype=tf.uint8, one_hot=True)
images = mnist.train.images
labels = mnist.train.labels
pixels = images.shape[1]
num_examples = mnist.train.num_examples # 输出包含训练数据的TFRecord文件。
with tf.python_io.TFRecordWriter("E:\\MNIST_data\\output.tfrecords") as writer:
for index in range(num_examples):
example = _make_example(pixels, labels[index], images[index])
writer.write(example.SerializeToString())
print("TFRecord训练文件已保存。") # 读取mnist测试数据。
images_test = mnist.test.images
labels_test = mnist.test.labels
pixels_test = images_test.shape[1]
num_examples_test = mnist.test.num_examples # 输出包含测试数据的TFRecord文件。
with tf.python_io.TFRecordWriter("E:\\MNIST_data\\output_test.tfrecords") as writer:
for index in range(num_examples_test):
example = _make_example(pixels_test, labels_test[index], images_test[index])
writer.write(example.SerializeToString())
print("TFRecord测试文件已保存。")

# 读取文件。
reader = tf.TFRecordReader()
filename_queue = tf.train.string_input_producer(["E:\\MNIST_data\\output.tfrecords"])
_,serialized_example = reader.read(filename_queue) # 解析读取的样例。
features = tf.parse_single_example(
serialized_example,
features={
'image_raw':tf.FixedLenFeature([],tf.string),
'pixels':tf.FixedLenFeature([],tf.int64),
'label':tf.FixedLenFeature([],tf.int64)
}) images = tf.decode_raw(features['image_raw'],tf.uint8)
labels = tf.cast(features['label'],tf.int32)
pixels = tf.cast(features['pixels'],tf.int32) sess = tf.Session() # 启动多线程处理输入数据。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord) for i in range(10):
image, label, pixel = sess.run([images, labels, pixels])

吴裕雄 python 神经网络——TensorFlow TFRecord样例程序的更多相关文章

  1. 吴裕雄 python 神经网络TensorFlow实现LeNet模型处理手写数字识别MNIST数据集

    import tensorflow as tf tf.reset_default_graph() # 配置神经网络的参数 INPUT_NODE = 784 OUTPUT_NODE = 10 IMAGE ...

  2. 吴裕雄 python 神经网络——TensorFlow 输入数据处理框架

    import tensorflow as tf files = tf.train.match_filenames_once("E:\\MNIST_data\\output.tfrecords ...

  3. 吴裕雄 python 神经网络——TensorFlow 输入文件队列

    import tensorflow as tf def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64 ...

  4. 吴裕雄 python 神经网络——TensorFlow 图像预处理完整样例

    import numpy as np import tensorflow as tf import matplotlib.pyplot as plt def distort_color(image, ...

  5. 吴裕雄 python 神经网络——TensorFlow 完整神经网络样例程序

    import tensorflow as tf from numpy.random import RandomState batch_size = 8 w1= tf.Variable(tf.rando ...

  6. 吴裕雄 python 神经网络——TensorFlow variables_to_restore函数的使用样例

    import tensorflow as tf v = tf.Variable(0, dtype=tf.float32, name="v") ema = tf.train.Expo ...

  7. 吴裕雄 python 神经网络——TensorFlow训练神经网络:卷积层、池化层样例

    import numpy as np import tensorflow as tf M = np.array([ [[1],[-1],[0]], [[-1],[2],[1]], [[0],[2],[ ...

  8. 吴裕雄 python 神经网络——TensorFlow 数据集高层操作

    import tempfile import tensorflow as tf train_files = tf.train.match_filenames_once("E:\\output ...

  9. 吴裕雄 python 神经网络——TensorFlow 数据集基本使用方法

    import tempfile import tensorflow as tf input_data = [1, 2, 3, 5, 8] dataset = tf.data.Dataset.from_ ...

随机推荐

  1. mybatis 查询list,内容为null,但list的size 为1

    List<Integer> cityList = resourcePartnerService.selectCityList(userId); 需要在SQL里where语句加上 字段不为n ...

  2. Hadoop中的java基本类型的序列化封装类

    Hadoop将很多Writable类归入org.apache.hadoop.io包中,在这些类中,比较重要的有Java基本类.Text.Writable集合.ObjectWritable等,重点介绍J ...

  3. POJ 3177 Redundant Paths (tarjan边双连通分量)

    题目连接:http://poj.org/problem?id=3177 题目大意是给定一些牧场,牧场和牧场之间可能存在道路相连,要求从一个牧场到另一个牧场要有至少两条以上不同的路径,且路径的每条pat ...

  4. 【网易官方】极客战记(codecombat)攻略-地牢-囚犯

    关卡连接: https://codecombat.163.com/play/level/the-prisoner 解放囚犯,你会得到盟友. 简介 敬请期待! 默认代码 # 释放囚犯,击败守卫并夺取宝石 ...

  5. Linux05——用户操作

    用户操作 1.新增用户(useradd 新用户名): 2.设置密码(passwd 用户名): 3.用户是否存在(id  用户名): 4.切换用户(su - 切换用户名) **—— **       s ...

  6. 搭建 Kubernetes 高可用集群

    使用 3 台阿里云服务器(k8s-master0, k8s-master1, k8s-master2)作为 master 节点搭建高可用集群,负载均衡用的是阿里云 SLB ,需要注意的是由于阿里云负载 ...

  7. win7 安装asp.net v4.0

    错误信息影响: HTTP 错误 403.14 - Forbidden Web 服务器被配置为不列出此目录的内容.返回的错误表明IIS缺少针对无后缀的MVC请求的映射,ASP.NET处理程序无法接收到请 ...

  8. 09day 命令提示符优化及yum优化

    export PS1='\[\e[32;1m\][\u@\h \W]\$ \[\e[0m\]' 设置颜色 内容 结束 export PS1='\[\e[30;1m\][\u@\h \W]\$ \[\e ...

  9. Go性能调优

    文章引用自   Go性能调优 在计算机性能调试领域里,profiling 是指对应用程序的画像,画像就是应用程序使用 CPU 和内存的情况. Go语言是一个对性能特别看重的语言,因此语言中自带了 pr ...

  10. 每天进步一点点------入门视频采集与处理(BT656简介)

    凡是做模拟信号采集的,很少不涉及BT.656标准的,因为常见的模拟视频信号采集芯片都支持输出BT.656的数字信号,那么,BT.656到底是何种格式呢?      本文将主要介绍 标准的 8bit B ...