MNIST数据集介绍

MNIST数据集中包含了各种各样的手写数字图片,数据集的官网是:http://yann.lecun.com/exdb/mnist/index.html,我们可以从这里下载数据集。使用如下的代码对数据集进行加载:

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

运行上述代码会自动下载数据集并将文件解压在MNIST_data文件夹下面。代码中的one_hot=True,表示将样本的标签转化为one_hot编码。

MNIST数据集中的图片是28*28的,每张图被转化为一个行向量,长度是28*28=784,每一个值代表一个像素点。数据集中共有60000张手写数据图片,其中55000张训练数据,5000张测试数据。

在MNIST中,mnist.train.images是一个形状为[55000, 784]的张量,其中的第一个维度是用来索引图片,第二个维度图片中的像素。MNIST数据集包含有三部分,训练数据集,验证数据集,测试数据集(mnist.validation)。

标签是介于0-9之间的数字,用于描述图片中的数字,转化为one-hot向量即表示的数字对应的下标为1,其余的值为0。标签的训练数据是[55000,10]的数字矩阵。

下面定义了一个简单的网络对数据集进行训练,代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt mnist = input_data.read_data_sets('MNIST_data', one_hot=True) tf.reset_default_graph() x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10]) w = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10])) pred = tf.matmul(x, w) + b
pred = tf.nn.softmax(pred) cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1)) learning_rate = 0.01 optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) training_epochs = 25
batch_size = 100 display_step = 1 save_path = 'model/' saver = tf.train.Saver() with tf.Session() as sess:
sess.run(tf.global_variables_initializer()) for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(mnist.train.num_examples/batch_size)
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
_, c = sess.run([optimizer, cost], feed_dict={x:batch_xs, y:batch_ys})
avg_cost += c / total_batch if (epoch + 1) % display_step == 0:
print('epoch= ', epoch+1, ' cost= ', avg_cost)
print('finished') correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print('accuracy: ', accuracy.eval({x:mnist.test.images, y:mnist.test.labels})) save = saver.save(sess, save_path=save_path+'mnist.cpkt') print(" starting 2nd session ...... ") with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, save_path=save_path+'mnist.cpkt') correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print('accuracy: ', accuracy.eval({x: mnist.test.images, y: mnist.test.labels})) output = tf.argmax(pred, 1)
batch_xs, batch_ys = mnist.test.next_batch(2)
outputval= sess.run([output], feed_dict={x:batch_xs, y:batch_ys})
print(outputval) im = batch_xs[0]
im = im.reshape(-1, 28) plt.imshow(im, cmap='gray')
plt.show() im = batch_xs[1]
im = im.reshape(-1, 28)
plt.imshow(im, cmap='gray')
plt.show()

TensorFlow——MNIST手写数据集的更多相关文章

  1. TensorFlow实战第五课(MNIST手写数据集识别)

    Tensorflow实现softmax regression识别手写数字 MNIST手写数字识别可以形象的描述为机器学习领域中的hello world. MNIST是一个非常简单的机器视觉数据集.它由 ...

  2. 利用sklearn对MNIST手写数据集开始一个简单的二分类判别器项目(在这个过程中学习关于模型性能的评价指标,如accuracy,precision,recall,混淆矩阵)

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  3. matlab练习程序(神经网络识别mnist手写数据集)

    记得上次练习了神经网络分类,不过当时应该有些地方写的还是不对. 这次用神经网络识别mnist手写数据集,主要参考了深度学习工具包的一些代码. mnist数据集训练数据一共有28*28*60000个像素 ...

  4. TensorFlow系列专题(六):实战项目Mnist手写数据集识别

    欢迎大家关注我们的网站和系列教程:http://panchuang.net/ ,学习更多的机器学习.深度学习的知识! 目录: 导读 MNIST数据集 数据处理 单层隐藏层神经网络的实现 多层隐藏层神经 ...

  5. TensorFlow——MNIST手写数字识别

    MNIST手写数字识别 MNIST数据集介绍和下载:http://yann.lecun.com/exdb/mnist/   一.数据集介绍: MNIST是一个入门级的计算机视觉数据集 下载下来的数据集 ...

  6. 用Kersa搭建神经网络【MNIST手写数据集】

    MNIST手写数据集的识别算得上是深度学习的”hello world“了,所以想要入门必须得掌握.新手入门可以考虑使用Keras框架达到快速实现的目的. 完整代码如下: # 1. 导入库和模块 fro ...

  7. MNIST手写数据集在运行中出现问题解决方案

    今天在运行手写数据集的过程中,出现一个问题,代码没有问题,但是运行的时候一直报错,错误如下: urllib.error.URLError: <urlopen error [SSL: CERTIF ...

  8. 吴裕雄 python 神经网络——TensorFlow 使用卷积神经网络训练和预测MNIST手写数据集

    import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_dat ...

  9. 吴裕雄 python 神经网络——TensorFlow实现回归模型训练预测MNIST手写数据集

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_dat ...

随机推荐

  1. HDU 1850 Nim-Sum思想总结、

    算法介绍: Nim游戏是指两个对手在m个堆中轮流随意从某一个堆中拿出n个元素,假定两个对手都是足够聪明,直至最后一次取的人将所有元素取出,此人取得胜利.与之相反的是Misere游戏,相同的游戏规则,但 ...

  2. Springboot 2.x下多数据源配置

    本文同样适用于2.x版本下Mybatis的多数据源配置 项目中经常会遇到一个项目需要访问多个数据源的情况,多数情况下可以参考这个教程进行配置. 不过该教程适合springboot1.x版本,由于2.x ...

  3. Koa2 遇到Method Not Allowed 获取不到返回值

    https://q.cnblogs.com/q/114462/          都来找我  Haisen‘s blogs 求求各位大神了,2点多了没解决睡不着啊,我按照网上用的koa2-cors,g ...

  4. 高可用之keepalived的配置文件详解

    ! Configuration File for keepalived global_defs { notification_email { acassen@firewall.loc failover ...

  5. 查看laravel版本

    方法1: 使用php artisan --version ,只要能看懂这个命令的人一定已经具有初步的Laravel知识.再介绍一种不需要命令,直接去文件中去查看的方法. 方法2: 在项目文件中找ven ...

  6. js基础——对象和数组

    1.Object类型 1)使用new运算符    var box = new Object();===>等同于 var box = Object();(省略new关键字)    box.name ...

  7. H3C命令调试debugging--用户视图

    <H3C>terminal debugging     //使用debugging必须使用的命令--打开调试信 息的屏幕输出开关 <H3C>display debugging  ...

  8. 文本框(代替input)输入长度限制、提示

    <div class="inform_content_text"> <textarea name="name" placeholder=&qu ...

  9. CSS兼容性问题的解决方式(更新中···)

    1.清除浮动的兼容性(低版本的浏览器不兼容问题) .clearfix:after{ content:""; clear:both; display:block; visibilit ...

  10. centos 利用mailx发送邮件

    这里就已163或者126邮箱为例!阿里云的25号端口好像发送不了,用465端口可以发送成功! 安装:yum install -y mailx 然后就是修改配置文件 set ssl-verify=ign ...