import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 定义函数转化变量类型。
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) # 将数据转化为tf.train.Example格式。
def _make_example(pixels, label, image):
image_raw = image.tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'pixels': _int64_feature(pixels),
'label': _int64_feature(np.argmax(label)),
'image_raw': _bytes_feature(image_raw)
}))
return example # 读取mnist训练数据。
mnist = input_data.read_data_sets("E:\\MNIST_data\\",dtype=tf.uint8, one_hot=True)
images = mnist.train.images
labels = mnist.train.labels
pixels = images.shape[1]
num_examples = mnist.train.num_examples # 输出包含训练数据的TFRecord文件。
with tf.python_io.TFRecordWriter("E:\\MNIST_data\\output.tfrecords") as writer:
for index in range(num_examples):
example = _make_example(pixels, labels[index], images[index])
writer.write(example.SerializeToString())
print("TFRecord训练文件已保存。") # 读取mnist测试数据。
images_test = mnist.test.images
labels_test = mnist.test.labels
pixels_test = images_test.shape[1]
num_examples_test = mnist.test.num_examples # 输出包含测试数据的TFRecord文件。
with tf.python_io.TFRecordWriter("E:\\MNIST_data\\output_test.tfrecords") as writer:
for index in range(num_examples_test):
example = _make_example(pixels_test, labels_test[index], images_test[index])
writer.write(example.SerializeToString())
print("TFRecord测试文件已保存。")

# 读取文件。
reader = tf.TFRecordReader()
filename_queue = tf.train.string_input_producer(["E:\\MNIST_data\\output.tfrecords"])
_,serialized_example = reader.read(filename_queue) # 解析读取的样例。
features = tf.parse_single_example(
serialized_example,
features={
'image_raw':tf.FixedLenFeature([],tf.string),
'pixels':tf.FixedLenFeature([],tf.int64),
'label':tf.FixedLenFeature([],tf.int64)
}) images = tf.decode_raw(features['image_raw'],tf.uint8)
labels = tf.cast(features['label'],tf.int32)
pixels = tf.cast(features['pixels'],tf.int32) sess = tf.Session() # 启动多线程处理输入数据。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord) for i in range(10):
image, label, pixel = sess.run([images, labels, pixels])

吴裕雄 python 神经网络——TensorFlow TFRecord样例程序的更多相关文章

  1. 吴裕雄 python 神经网络TensorFlow实现LeNet模型处理手写数字识别MNIST数据集

    import tensorflow as tf tf.reset_default_graph() # 配置神经网络的参数 INPUT_NODE = 784 OUTPUT_NODE = 10 IMAGE ...

  2. 吴裕雄 python 神经网络——TensorFlow 输入数据处理框架

    import tensorflow as tf files = tf.train.match_filenames_once("E:\\MNIST_data\\output.tfrecords ...

  3. 吴裕雄 python 神经网络——TensorFlow 输入文件队列

    import tensorflow as tf def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64 ...

  4. 吴裕雄 python 神经网络——TensorFlow 图像预处理完整样例

    import numpy as np import tensorflow as tf import matplotlib.pyplot as plt def distort_color(image, ...

  5. 吴裕雄 python 神经网络——TensorFlow 完整神经网络样例程序

    import tensorflow as tf from numpy.random import RandomState batch_size = 8 w1= tf.Variable(tf.rando ...

  6. 吴裕雄 python 神经网络——TensorFlow variables_to_restore函数的使用样例

    import tensorflow as tf v = tf.Variable(0, dtype=tf.float32, name="v") ema = tf.train.Expo ...

  7. 吴裕雄 python 神经网络——TensorFlow训练神经网络:卷积层、池化层样例

    import numpy as np import tensorflow as tf M = np.array([ [[1],[-1],[0]], [[-1],[2],[1]], [[0],[2],[ ...

  8. 吴裕雄 python 神经网络——TensorFlow 数据集高层操作

    import tempfile import tensorflow as tf train_files = tf.train.match_filenames_once("E:\\output ...

  9. 吴裕雄 python 神经网络——TensorFlow 数据集基本使用方法

    import tempfile import tensorflow as tf input_data = [1, 2, 3, 5, 8] dataset = tf.data.Dataset.from_ ...

随机推荐

  1. NOIP做题练习(day3)

    A - 军队 问题描述 给定一个有 \(n\) 个队伍的人组成的序列,第 \(i\) 个队伍 \(i\) 有 \(s[i]\)个人组成,一个 \(l\) 到 \(r\)的子序列是合法的,当且仅当\(( ...

  2. 题解【SP1716】GSS3 - Can you answer these queries III

    题目描述 You are given a sequence \(A\) of \(N (N <= 50000)\) integers between \(-10000\) and \(10000 ...

  3. 题解【CJOJ1371】[IOI2002]任务安排

    P1371 - [IOI2002]任务安排 Description N个任务排成一个序列在一台机器上等待完成(顺序不得改变),这N个任务被分成若干批,每批包含相邻的若干任务.从时刻0开始,这些任务被分 ...

  4. Django数据迁移时(或者新建模型时)报错:Did you install mysqlclient,解决后又报错:mysqlclient 1.3.13 or newer is required;you have 0.9.3

    报错信息如下: 解决方法一: 给项目根目录下mysite应用下的__init__.py文件加入如下代码: 运行又报错: 报错信息是:  mysqlclient版本太低 点击上图框中的链接进入到pyth ...

  5. 使用git pull同步github代码到服务器

    我直接用git pull的时候遇到这个错误: error: Your local changes to the following files would be overwritten by merg ...

  6. eclipse 添加主題

    在使用Eclipse过程中可能想更换下界面主题,此处介绍的是一款主题插件 Eclipse Color Theme 打开Eclipse,Help --> Eclipse Marketplace 在 ...

  7. HDU 3823 Prime Friend(线性欧拉筛+打表)

    Besides the ordinary Boy Friend and Girl Friend, here we define a more academic kind of friend: Prim ...

  8. 微信小程序UDP通信,注意点 接收 onMessage 收到的message是ArrayBuffer缓冲,不能直接输出,要另转String处理

    1.WXML 页面代码 <!--index.wxml--> <view class="container"> <view class="us ...

  9. TCP和UDP的一些注意事项

    TCP的一些注意事项 1. tcp服务器一般情况下都需要绑定,否则客户端找不到这个服务器,更无法链接到服务器 2. tcp客户端一般不绑定,因为是主动链接服务器,所以只要确定好服务器的ip.port等 ...

  10. strtok() and strtod()

    1.strtok(参数1,参数2)按指定的分割符将字符串分割开来 参数1:表示要被分割的字符串的地址: 参数2:表示指定的分割符的地址: 例如:按空格分割“Hello World” buffer[] ...