TensorFlow——MNIST手写数据集
MNIST数据集介绍
MNIST数据集中包含了各种各样的手写数字图片,数据集的官网是:http://yann.lecun.com/exdb/mnist/index.html,我们可以从这里下载数据集。使用如下的代码对数据集进行加载:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
运行上述代码会自动下载数据集并将文件解压在MNIST_data文件夹下面。代码中的one_hot=True,表示将样本的标签转化为one_hot编码。
MNIST数据集中的图片是28*28的,每张图被转化为一个行向量,长度是28*28=784,每一个值代表一个像素点。数据集中共有60000张手写数据图片,其中55000张训练数据,5000张测试数据。
在MNIST中,mnist.train.images是一个形状为[55000, 784]的张量,其中的第一个维度是用来索引图片,第二个维度图片中的像素。MNIST数据集包含有三部分,训练数据集,验证数据集,测试数据集(mnist.validation)。
标签是介于0-9之间的数字,用于描述图片中的数字,转化为one-hot向量即表示的数字对应的下标为1,其余的值为0。标签的训练数据是[55000,10]的数字矩阵。
下面定义了一个简单的网络对数据集进行训练,代码如下:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt mnist = input_data.read_data_sets('MNIST_data', one_hot=True) tf.reset_default_graph() x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10]) w = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10])) pred = tf.matmul(x, w) + b
pred = tf.nn.softmax(pred) cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1)) learning_rate = 0.01 optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) training_epochs = 25
batch_size = 100 display_step = 1 save_path = 'model/' saver = tf.train.Saver() with tf.Session() as sess:
sess.run(tf.global_variables_initializer()) for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(mnist.train.num_examples/batch_size)
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
_, c = sess.run([optimizer, cost], feed_dict={x:batch_xs, y:batch_ys})
avg_cost += c / total_batch if (epoch + 1) % display_step == 0:
print('epoch= ', epoch+1, ' cost= ', avg_cost)
print('finished') correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print('accuracy: ', accuracy.eval({x:mnist.test.images, y:mnist.test.labels})) save = saver.save(sess, save_path=save_path+'mnist.cpkt') print(" starting 2nd session ...... ") with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, save_path=save_path+'mnist.cpkt') correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print('accuracy: ', accuracy.eval({x: mnist.test.images, y: mnist.test.labels})) output = tf.argmax(pred, 1)
batch_xs, batch_ys = mnist.test.next_batch(2)
outputval= sess.run([output], feed_dict={x:batch_xs, y:batch_ys})
print(outputval) im = batch_xs[0]
im = im.reshape(-1, 28) plt.imshow(im, cmap='gray')
plt.show() im = batch_xs[1]
im = im.reshape(-1, 28)
plt.imshow(im, cmap='gray')
plt.show()
TensorFlow——MNIST手写数据集的更多相关文章
- TensorFlow实战第五课(MNIST手写数据集识别)
Tensorflow实现softmax regression识别手写数字 MNIST手写数字识别可以形象的描述为机器学习领域中的hello world. MNIST是一个非常简单的机器视觉数据集.它由 ...
- 利用sklearn对MNIST手写数据集开始一个简单的二分类判别器项目(在这个过程中学习关于模型性能的评价指标,如accuracy,precision,recall,混淆矩阵)
.caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...
- matlab练习程序(神经网络识别mnist手写数据集)
记得上次练习了神经网络分类,不过当时应该有些地方写的还是不对. 这次用神经网络识别mnist手写数据集,主要参考了深度学习工具包的一些代码. mnist数据集训练数据一共有28*28*60000个像素 ...
- TensorFlow系列专题(六):实战项目Mnist手写数据集识别
欢迎大家关注我们的网站和系列教程:http://panchuang.net/ ,学习更多的机器学习.深度学习的知识! 目录: 导读 MNIST数据集 数据处理 单层隐藏层神经网络的实现 多层隐藏层神经 ...
- TensorFlow——MNIST手写数字识别
MNIST手写数字识别 MNIST数据集介绍和下载:http://yann.lecun.com/exdb/mnist/ 一.数据集介绍: MNIST是一个入门级的计算机视觉数据集 下载下来的数据集 ...
- 用Kersa搭建神经网络【MNIST手写数据集】
MNIST手写数据集的识别算得上是深度学习的”hello world“了,所以想要入门必须得掌握.新手入门可以考虑使用Keras框架达到快速实现的目的. 完整代码如下: # 1. 导入库和模块 fro ...
- MNIST手写数据集在运行中出现问题解决方案
今天在运行手写数据集的过程中,出现一个问题,代码没有问题,但是运行的时候一直报错,错误如下: urllib.error.URLError: <urlopen error [SSL: CERTIF ...
- 吴裕雄 python 神经网络——TensorFlow 使用卷积神经网络训练和预测MNIST手写数据集
import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_dat ...
- 吴裕雄 python 神经网络——TensorFlow实现回归模型训练预测MNIST手写数据集
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_dat ...
随机推荐
- Python--day63--图书管理系统表结构设计
pycharm中运行manage.py的工具: 1,将变更翻译成SQL语句,去数据库中执行(ORM操作过数据库后都要执行这两条命令) 2,ForeignKey会自动在属性后面加_id 3,ORM封装p ...
- Laravel 5.* 执行seeder命令出现错误的解决方法
最近在使用Laravel开发一个项目,测试中需要增加数据库基础数据动作,当第一次执行完`php artisan db:seed` 后,增加新的seeder文件时执行会报错.错误信息如下`[Reflec ...
- React MVC框架 <某某后台商品管理开源项目> 完成项目总结
**百货后台商品信息开源项目 1.利用React app脚手架 2.封装打包 buid 3.更偏向于后台程序员开发思维 4.利用的 react -redux react-router-dom ...
- java 面试题之交通灯管理系统
需求: 交通灯管理系统的项目需求 Ø 异步随机生成按照各个路线行驶的车辆. 例如: 由南向而来去往北向的车辆 ---- 直行车辆 由西向而来去往南向的车辆 ---- 右转车辆 由东向而来去往南向的车辆 ...
- egg-socket在egg中的使用
WebSocket 的产生源于 Web 开发中日益增长的实时通信需求,对比基于 http 的轮询方式,它大大节省了网络带宽,同时也降低了服务器的性能消耗: socket.io 支持 websocket ...
- tikz 常用命令总结
使用斜线填充区域,并绘制边界 \fill[pattern color=red, pattern=north west lines, opacity=0.4] (0,0) -- (0,1) -- (1, ...
- 字符串和While循环
字符串是以单引号或双引号括起来的任意文本 创建字符串 str1 = "shaoge is a good man!" 字符串运算 字符串连接 str6 = "shaoge ...
- [vue/no-parsing-error] Parsing error: x-invalid-end-tag.eslint-plugin-vue
[vue/no-parsing-error] Parsing error: x-invalid-end-tag.eslint-plugin-vue 解决方案:vscode里面选择设置->搜索 ...
- 0003 HTML常用标签(含base、锚点)、路径
学习目标 理解: 相对路径三种形式 应用 排版标签 文本格式化标签 图像标签 链接 相对路径,绝对路径的使用 1. HTML常用标签 首先 HTML和CSS是两种完全不同的语言,我们学的是结构,就只写 ...
- ajax异步发送时遇到的问题
问题原因是:controller中方法名与url中的名字不一样造成的 解决办法:找到错误的方法名,将其与url中的方法名统一: