TensorFlow——MNIST手写数据集
MNIST数据集介绍
MNIST数据集中包含了各种各样的手写数字图片,数据集的官网是:http://yann.lecun.com/exdb/mnist/index.html,我们可以从这里下载数据集。使用如下的代码对数据集进行加载:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
运行上述代码会自动下载数据集并将文件解压在MNIST_data文件夹下面。代码中的one_hot=True,表示将样本的标签转化为one_hot编码。
MNIST数据集中的图片是28*28的,每张图被转化为一个行向量,长度是28*28=784,每一个值代表一个像素点。数据集中共有60000张手写数据图片,其中55000张训练数据,5000张测试数据。
在MNIST中,mnist.train.images是一个形状为[55000, 784]的张量,其中的第一个维度是用来索引图片,第二个维度图片中的像素。MNIST数据集包含有三部分,训练数据集,验证数据集,测试数据集(mnist.validation)。
标签是介于0-9之间的数字,用于描述图片中的数字,转化为one-hot向量即表示的数字对应的下标为1,其余的值为0。标签的训练数据是[55000,10]的数字矩阵。
下面定义了一个简单的网络对数据集进行训练,代码如下:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt mnist = input_data.read_data_sets('MNIST_data', one_hot=True) tf.reset_default_graph() x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10]) w = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10])) pred = tf.matmul(x, w) + b
pred = tf.nn.softmax(pred) cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1)) learning_rate = 0.01 optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) training_epochs = 25
batch_size = 100 display_step = 1 save_path = 'model/' saver = tf.train.Saver() with tf.Session() as sess:
sess.run(tf.global_variables_initializer()) for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(mnist.train.num_examples/batch_size)
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
_, c = sess.run([optimizer, cost], feed_dict={x:batch_xs, y:batch_ys})
avg_cost += c / total_batch if (epoch + 1) % display_step == 0:
print('epoch= ', epoch+1, ' cost= ', avg_cost)
print('finished') correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print('accuracy: ', accuracy.eval({x:mnist.test.images, y:mnist.test.labels})) save = saver.save(sess, save_path=save_path+'mnist.cpkt') print(" starting 2nd session ...... ") with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, save_path=save_path+'mnist.cpkt') correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print('accuracy: ', accuracy.eval({x: mnist.test.images, y: mnist.test.labels})) output = tf.argmax(pred, 1)
batch_xs, batch_ys = mnist.test.next_batch(2)
outputval= sess.run([output], feed_dict={x:batch_xs, y:batch_ys})
print(outputval) im = batch_xs[0]
im = im.reshape(-1, 28) plt.imshow(im, cmap='gray')
plt.show() im = batch_xs[1]
im = im.reshape(-1, 28)
plt.imshow(im, cmap='gray')
plt.show()
TensorFlow——MNIST手写数据集的更多相关文章
- TensorFlow实战第五课(MNIST手写数据集识别)
Tensorflow实现softmax regression识别手写数字 MNIST手写数字识别可以形象的描述为机器学习领域中的hello world. MNIST是一个非常简单的机器视觉数据集.它由 ...
- 利用sklearn对MNIST手写数据集开始一个简单的二分类判别器项目(在这个过程中学习关于模型性能的评价指标,如accuracy,precision,recall,混淆矩阵)
.caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...
- matlab练习程序(神经网络识别mnist手写数据集)
记得上次练习了神经网络分类,不过当时应该有些地方写的还是不对. 这次用神经网络识别mnist手写数据集,主要参考了深度学习工具包的一些代码. mnist数据集训练数据一共有28*28*60000个像素 ...
- TensorFlow系列专题(六):实战项目Mnist手写数据集识别
欢迎大家关注我们的网站和系列教程:http://panchuang.net/ ,学习更多的机器学习.深度学习的知识! 目录: 导读 MNIST数据集 数据处理 单层隐藏层神经网络的实现 多层隐藏层神经 ...
- TensorFlow——MNIST手写数字识别
MNIST手写数字识别 MNIST数据集介绍和下载:http://yann.lecun.com/exdb/mnist/ 一.数据集介绍: MNIST是一个入门级的计算机视觉数据集 下载下来的数据集 ...
- 用Kersa搭建神经网络【MNIST手写数据集】
MNIST手写数据集的识别算得上是深度学习的”hello world“了,所以想要入门必须得掌握.新手入门可以考虑使用Keras框架达到快速实现的目的. 完整代码如下: # 1. 导入库和模块 fro ...
- MNIST手写数据集在运行中出现问题解决方案
今天在运行手写数据集的过程中,出现一个问题,代码没有问题,但是运行的时候一直报错,错误如下: urllib.error.URLError: <urlopen error [SSL: CERTIF ...
- 吴裕雄 python 神经网络——TensorFlow 使用卷积神经网络训练和预测MNIST手写数据集
import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_dat ...
- 吴裕雄 python 神经网络——TensorFlow实现回归模型训练预测MNIST手写数据集
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_dat ...
随机推荐
- H3C 配置帧中继交换
- Python--day24--多继承
如果本生没有func方法的话就调用距离自己最近的基类的方法 钻石继承: 查找方法的顺序:如下例的找func方法(广度优先) 例1: 例2: 漏斗继承: 小乌龟继承问题:(最顶端的节点F是最后查找的) ...
- 连接远程mysql(Linux环境)
保证三点: 1.打开/etc/my.cnf,找到[mysqld]项,在其后加入一句:skip-name-resolve,保存,重启mysql服务. service mysqld restart 或者 ...
- 【GYM102091】2018-2019 ACM-ICPC, Asia Nakhon Pathom Regional Contest
A-Evolution Game 题目大意:有$n$个不同的野兽,定义第$i$ 个野兽有 $i$ 个眼睛和 $h[i]$ 个角,你可以任意从中选择一个野兽进行进化,每次进化角数量必须增加,而且进化后要 ...
- JUnit 单元测试断言推荐 AssertJ
文章转自:http://sgq0085.iteye.com/blog/2030609 前言 由于JUnit的Assert是公认的烂API,所以不推荐使用,目前推荐使用的是AssertJ. Assert ...
- Linux 内核 回顾: ISA
设计上 ISA 总线非常老了, 并且是非常地低能, 但是它仍然持有一块挺大的控制设备的 市场. 如果速度不重要并且你想支持老式主板, 一个 ISA 实现要优于 PCI. 这个老标准 的另外一个好处是如 ...
- dll中全局变量在外部进行引用
在Windows中实际导出全局变量,您必须使用类似于export / import语法的语法,例如: #ifdef COMPILING_THE_DLL #define MY_DLL_EXPORT ex ...
- HDU - 3671 Boonie and Clyde (图的割点)
As two icons of the Great Depression, Bonnie and Clyde represent the ultimate criminal couple. Stori ...
- 【53.57%】【codeforces 610C】Harmony Analysis
time limit per test3 seconds memory limit per test256 megabytes inputstandard input outputstandard o ...
- codeforces 1136E 线段树
codeforces 1136E: 题意:给你一个长度为n的序列a和长度为n-1的序列k,序列a在任何时候都满足如下性质,a[i+1]>=ai+ki,如果更新后a[i+1]<ai+ki了, ...