MNIST数据集介绍

MNIST数据集中包含了各种各样的手写数字图片,数据集的官网是:http://yann.lecun.com/exdb/mnist/index.html,我们可以从这里下载数据集。使用如下的代码对数据集进行加载:

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

运行上述代码会自动下载数据集并将文件解压在MNIST_data文件夹下面。代码中的one_hot=True,表示将样本的标签转化为one_hot编码。

MNIST数据集中的图片是28*28的,每张图被转化为一个行向量,长度是28*28=784,每一个值代表一个像素点。数据集中共有60000张手写数据图片,其中55000张训练数据,5000张测试数据。

在MNIST中,mnist.train.images是一个形状为[55000, 784]的张量,其中的第一个维度是用来索引图片,第二个维度图片中的像素。MNIST数据集包含有三部分,训练数据集,验证数据集,测试数据集(mnist.validation)。

标签是介于0-9之间的数字,用于描述图片中的数字,转化为one-hot向量即表示的数字对应的下标为1,其余的值为0。标签的训练数据是[55000,10]的数字矩阵。

下面定义了一个简单的网络对数据集进行训练,代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt mnist = input_data.read_data_sets('MNIST_data', one_hot=True) tf.reset_default_graph() x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10]) w = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10])) pred = tf.matmul(x, w) + b
pred = tf.nn.softmax(pred) cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1)) learning_rate = 0.01 optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) training_epochs = 25
batch_size = 100 display_step = 1 save_path = 'model/' saver = tf.train.Saver() with tf.Session() as sess:
sess.run(tf.global_variables_initializer()) for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(mnist.train.num_examples/batch_size)
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
_, c = sess.run([optimizer, cost], feed_dict={x:batch_xs, y:batch_ys})
avg_cost += c / total_batch if (epoch + 1) % display_step == 0:
print('epoch= ', epoch+1, ' cost= ', avg_cost)
print('finished') correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print('accuracy: ', accuracy.eval({x:mnist.test.images, y:mnist.test.labels})) save = saver.save(sess, save_path=save_path+'mnist.cpkt') print(" starting 2nd session ...... ") with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, save_path=save_path+'mnist.cpkt') correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print('accuracy: ', accuracy.eval({x: mnist.test.images, y: mnist.test.labels})) output = tf.argmax(pred, 1)
batch_xs, batch_ys = mnist.test.next_batch(2)
outputval= sess.run([output], feed_dict={x:batch_xs, y:batch_ys})
print(outputval) im = batch_xs[0]
im = im.reshape(-1, 28) plt.imshow(im, cmap='gray')
plt.show() im = batch_xs[1]
im = im.reshape(-1, 28)
plt.imshow(im, cmap='gray')
plt.show()

TensorFlow——MNIST手写数据集的更多相关文章

  1. TensorFlow实战第五课(MNIST手写数据集识别)

    Tensorflow实现softmax regression识别手写数字 MNIST手写数字识别可以形象的描述为机器学习领域中的hello world. MNIST是一个非常简单的机器视觉数据集.它由 ...

  2. 利用sklearn对MNIST手写数据集开始一个简单的二分类判别器项目(在这个过程中学习关于模型性能的评价指标,如accuracy,precision,recall,混淆矩阵)

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  3. matlab练习程序(神经网络识别mnist手写数据集)

    记得上次练习了神经网络分类,不过当时应该有些地方写的还是不对. 这次用神经网络识别mnist手写数据集,主要参考了深度学习工具包的一些代码. mnist数据集训练数据一共有28*28*60000个像素 ...

  4. TensorFlow系列专题(六):实战项目Mnist手写数据集识别

    欢迎大家关注我们的网站和系列教程:http://panchuang.net/ ,学习更多的机器学习.深度学习的知识! 目录: 导读 MNIST数据集 数据处理 单层隐藏层神经网络的实现 多层隐藏层神经 ...

  5. TensorFlow——MNIST手写数字识别

    MNIST手写数字识别 MNIST数据集介绍和下载:http://yann.lecun.com/exdb/mnist/   一.数据集介绍: MNIST是一个入门级的计算机视觉数据集 下载下来的数据集 ...

  6. 用Kersa搭建神经网络【MNIST手写数据集】

    MNIST手写数据集的识别算得上是深度学习的”hello world“了,所以想要入门必须得掌握.新手入门可以考虑使用Keras框架达到快速实现的目的. 完整代码如下: # 1. 导入库和模块 fro ...

  7. MNIST手写数据集在运行中出现问题解决方案

    今天在运行手写数据集的过程中,出现一个问题,代码没有问题,但是运行的时候一直报错,错误如下: urllib.error.URLError: <urlopen error [SSL: CERTIF ...

  8. 吴裕雄 python 神经网络——TensorFlow 使用卷积神经网络训练和预测MNIST手写数据集

    import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_dat ...

  9. 吴裕雄 python 神经网络——TensorFlow实现回归模型训练预测MNIST手写数据集

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_dat ...

随机推荐

  1. Group_concat介绍与例子

    进公司做的第一个项目就是做一个订单追踪查询,里里外外连接了十一个表,作为公司菜鸡的我麻了爪. 其中有一个需求就是对于多行的数据在一行显示,原谅我才疏学浅 无奈下找到了项目组长  在那学来了这个利器 ( ...

  2. Python--day36--操作系统的作用;多道技术;

  3. 改善Azure App Service托管应用程序性能的几个技巧

    本文介绍了几个技巧,这些技巧可以改善Azure App Service托管应用程序的性能.其中一些技巧是你现在就可以进行的配置变更, 而其他技巧则可能需要对应用程序进行一些重新设计和重构. 开发者都希 ...

  4. tomcat下的work目录和temp目录

    1. tomcat下的work目录 1    用tomcat作web服务器的时候,部署的程序在webApps下,这些程序都是编译后的程序(发布到tomcat的项目里含的类,会被编译成.class后才发 ...

  5. H3C STP监控与维护

  6. thinkjs解决跨域

    this.header("Access-Control-Allow-Origin", "*"); 将上面的代码在请求发送之前执行即可 如果不知道放在哪里 可以参 ...

  7. javascript基础之数组一

    <script type="text/javascript"> //求数组中最大的数 var arr=[123,456,789,657,432,564]; var ar ...

  8. 如何用python“优雅的”调用有道翻译?

    前言 其实在以前就盯上有道翻译了的,但是由于时间问题一直没有研究(我的骚操作还在后面,记得关注),本文主要讲解如何用python调用有道翻译,讲解这个爬虫与有道翻译的js“斗争”的过程! 当然,本文仅 ...

  9. 复盘:错误理解zuul路径匹配,无法使用zuul

    场景: 项目中用到zuul时,配置url总是有问题,无法路由到对应微服务. 配置如下: zuul: routes: m2-member: path: /member/* serviceId: m2-m ...

  10. python监控模块

    pip install psutil 获取内存信息: >>> import psutil >>> mem = psutil.virtual_memory() #获取 ...