1、MNIST数据集简介

  首先通过下面两行代码获取到TensorFlow内置的MNIST数据集:

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('./data/mnist', one_hot=True)

  MNIST数据集共有55000(mnist.train.num_examples)张用于训练的数据,对应的有55000个标签;共有10000(mnist.test.num_examples)张用于测试的图片的数据,同样的有10000个标签与之对应。为了方便访问,这些图片或标签的数据都是被格式化了的。

  MNIST数据集的训练数据集(mnist.train.images)是一个 55000 * 784 的矩阵,矩阵的每一行代表一张图片(28 * 28 * 1)的数据,图片的数据范围是 [0, 1],代表像素点灰度归一化后的值。

  训练集的标签(mnist.train.labels)是一个55000 * 10 的矩阵,每一行的10个数字分别代表对应的图片属于数字0到9的概率,范围是0或1。一个标签行只有一个是1,表示该图片的正确数字是对应的下标值, 其余是0。

  测试集与训练集的类似,只是数据量不同。

  以下代码显示部分MNIST训练图片的形状及标签:

import numpy as np
import matplotlib.pyplot as plot
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('./data/mnist', one_hot=True)
trainImages = mnist.train.images
trainLabels = mnist.train.labels plot.figure(1, figsize=(4, 3))
for i in range(6):
curImage = np.reshape(trainImages[i, :], (28, 28))
curLabel = np.argmax(trainLabels[i, :])
ax = plot.subplot(int(str(23) + str(i+1)))
plot.imshow(curImage, cmap=plot.get_cmap('gray'))
plot.axis('off')
ax.set_title(curLabel) plot.suptitle('MNIST')
plot.show()

  上述代码输出的MNIST图片及其标签:

 2、通过单层神经网络进行训练

 def train(trainCycle=50000, debug=False):
inputSize = 784
outputSize = 10
batchSize = 64
inputs = tf.placeholder(tf.float32, shape=[None, inputSize]) # x * w = [64, 784] * [784, 10]
weights = tf.Variable(tf.random_normal([784, 10], 0, 0.1))
bias = tf.Variable(tf.random_normal([outputSize], 0, 0.1))
outputs = tf.add(tf.matmul(inputs, weights), bias)
outputs = tf.nn.softmax(outputs) labels = tf.placeholder(tf.float32, shape=[None, outputSize]) loss = tf.reduce_mean(tf.square(outputs - labels))
optimizer = tf.train.GradientDescentOptimizer(0.1)
trainer = optimizer.minimize(loss) sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(trainCycle):
batch = mnist.train.next_batch(batchSize)
sess.run([trainer, loss], feed_dict={inputs: batch[0], labels: batch[1]}) if debug and i % 1000 == 0:
corrected = tf.equal(tf.argmax(labels, 1), tf.argmax(outputs, 1))
accuracy = tf.reduce_mean(tf.cast(corrected, tf.float32))
accuracyValue = sess.run(accuracy, feed_dict={inputs: batch[0], labels: batch[1]})
print(i, ' train set accuracy:', accuracyValue) # 测试
corrected = tf.equal(tf.argmax(labels, 1), tf.argmax(outputs, 1))
accuracy = tf.reduce_mean(tf.cast(corrected, tf.float32))
accuracyValue = sess.run(accuracy, feed_dict={inputs: mnist.test.images, labels: mnist.test.labels})
print("accuracy on test set:", accuracyValue) sess.close()

3、训练结果

  上述模型的最终输出为:

由打印日志可以看出,前期收敛速度很快,后期开始波动。最后该模型在训练集上的正确率大概为90%,测试集上也差不多。精度还是比较低的,说明单层的神经网络在处理图片数据上存在着很大的缺陷,并不是一个很好的选择。

本文地址:https://www.cnblogs.com/laishenghao/p/9576806.html

TensorFlow训练MNIST数据集(1) —— softmax 单层神经网络的更多相关文章

  1. TensorFlow 训练MNIST数据集(2)—— 多层神经网络

    在我的上一篇随笔中,采用了单层神经网络来对MNIST进行训练,在测试集中只有约90%的正确率.这次换一种神经网络(多层神经网络)来进行训练和测试. 1.获取MNIST数据 MNIST数据集只要一行代码 ...

  2. TensorFlow训练MNIST数据集(3) —— 卷积神经网络

    前面两篇随笔实现的单层神经网络 和多层神经网络, 在MNIST测试集上的正确率分别约为90%和96%.在换用多层神经网络后,正确率已有很大的提升.这次将采用卷积神经网络继续进行测试. 1.模型基本结构 ...

  3. 使用tensorflow实现mnist手写识别(单层神经网络实现)

    import tensorflow as tf import tensorflow.examples.tutorials.mnist.input_data as input_data import n ...

  4. 2、TensorFlow训练MNIST

    装载自:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html TensorFlow训练MNIST 这个教程的目标读者是对机器学习和T ...

  5. 一个简单的TensorFlow可视化MNIST数据集识别程序

    下面是TensorFlow可视化MNIST数据集识别程序,可视化内容是,TensorFlow计算图,表(loss, 直方图, 标准差(stddev)) # -*- coding: utf-8 -*- ...

  6. 使用caffe训练mnist数据集 - caffe教程实战(一)

    个人认为学习一个陌生的框架,最好从例子开始,所以我们也从一个例子开始. 学习本教程之前,你需要首先对卷积神经网络算法原理有些了解,而且安装好了caffe 卷积神经网络原理参考:http://cs231 ...

  7. 实践详细篇-Windows下使用VS2015编译的Caffe训练mnist数据集

    上一篇记录的是学习caffe前的环境准备以及如何创建好自己需要的caffe版本.这一篇记录的是如何使用编译好的caffe做训练mnist数据集,步骤编号延用上一篇 <实践详细篇-Windows下 ...

  8. TensorFlow训练MNIST报错ResourceExhaustedError

    title: TensorFlow训练MNIST报错ResourceExhaustedError date: 2018-04-01 12:35:44 categories: deep learning ...

  9. 深度学习原理与框架-Tensorflow基本操作-mnist数据集的逻辑回归 1.tf.matmul(点乘操作) 2.tf.equal(对应位置是否相等) 3.tf.cast(将布尔类型转换为数值类型) 4.tf.argmax(返回最大值的索引) 5.tf.nn.softmax(计算softmax概率值) 6.tf.train.GradientDescentOptimizer(损失值梯度下降器)

    1. tf.matmul(X, w) # 进行点乘操作 参数说明:X,w都表示输入的数据, 2.tf.equal(x, y) # 比较两个数据对应位置的数是否相等,返回值为True,或者False 参 ...

随机推荐

  1. 移除jboss响应中的中间件信息

    JBoss 4.2 Suppressing the X-Powered-By header in JBoss 4.2.x can be done by modifying the web.xml fi ...

  2. 第五次作业 hql查询

    hql查询是基于对象的查询,不是基于表的查询. 1.hql的简单查询 @Test public void queryUsers() { //简单查询 SessionFactory sf = null; ...

  3. 中国将有可能在全球化的背景下收获新的人口红利:3星|《<财经>2019:预测与战略》

    <财经>2019 :预测与战略 <财经>杂志的年刊.内容是针对2019年的预测分析.我认为<财经>的调查报告比较有深度,分析则不是我爱看的类型. 总体评价3星,有参 ...

  4. 抓取js动态生成的数据分析案例

    需求:爬取https://www.xuexi.cn/f997e76a890b0e5a053c57b19f468436/018d244441062d8916dd472a4c6a0a0b.html页面中的 ...

  5. 命令行翻译 推荐一个linux系统中可用的终端小程序

    程序的github地址:https://github.com/fanbrightup/fanyi 使用起来非常简单,同时支持中英文互译甚至是整句. 步骤一:首先你需要安装node,参见我的node安装 ...

  6. 【洛谷】【单调栈】P4333 [COI2007] Patrik

    --接上一篇题解,[洛谷][单调栈]P1823音乐会的等待 关于题目大意在上一篇题解里已经说清楚了,这里不再多阐述 想看题目->戳这里 [算法分析:] 在对元素a进行判断时,如果它与栈顶元素相等 ...

  7. ap、map值计算

    ap:所有图片某一个类 map:所有图片所有类的ap的平均 以一个score为阈值,大于score的所有框是假定正确输出的所有预测,将这些框和gt匹配(iou大于某一个阈值认为匹配成功),得到当前sc ...

  8. MP实战系列(十一)之封装方法详解(续一)

    之前写的封装方法详解,比较简要. 今天我主要讲增加和删除及其修改.查的话得单独再详讲. 增删改查,无论是Java或者C#等等,凡是对数据库操作的都离不开这四个. 一.增加方法讲解 MyBatis Pl ...

  9. shiro实战系列(九)之Web

    一.Configuration(配置) 将 Shiro 集成到任何 Web 应用程序的最简单的方法是在 web.xml 中配置 ContextListener 和 Filter,理解如何读取 Shir ...

  10. redis集群搭建及连接(阿里云)

    阿里云上面装redis集群基本被虐死,主要问题就是私有IP和公有IP. 下面分享成功搭建的步骤: 两台测试服务器,分别为:127.0.0.1,127.0.0.2.每分服务器有3个节点. 1.127.0 ...