使用TensorFlow实现MNIST数据集分类
1 MNIST数据集
MNIST数据集由70000张28x28像素的黑白图片组成,每一张图片都写有0~9中的一个数字,每个像素点的灰度值在0 ~ 255(0是黑色,255是白色)之间。

MINST数据集是由Yann LeCun教授提供的手写数字数据库文件,其官方下载地址THE MNIST DATABASE of handwritten digits

下载好MNIST数据集后,将其放在Spyder工作目录下(若使用Jupyter编程,则放在Jupyter工作目录下),如图:

G:\Anaconda\Spyder为笔者Spyder工作目录,MNIST_data为新建文件夹,读者也可以自行命名。
2 实验
为方便设计神经网络输入层,将每张28x28像素图片的像素值按行排成一行,故输入层设计28x28=784个神经元,隐藏层设计600个神经元,输出层设计10个神经元。使用read_data_sets()函数载入数据集,并返回一个类,这个类将MNIST数据集划分为train、validation、test 3个数据集,对应图片数分别为55000、5000、10000。本文采用交叉熵损失函数,并且为防止过拟合问题产生,引入正则化方法。
mnist.py
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#载入数据集
mnist=input_data.read_data_sets("MNIST_data",one_hot=True)
#每批次的大小
batch_size=100
#总批次数
batch_num=mnist.train.num_examples//batch_size
#训练轮数
training_step = tf.Variable(0,trainable=False)
#定义两个placeholder
x=tf.placeholder(tf.float32, [None,784])
y=tf.placeholder(tf.float32, [None,10])
#神经网络layer_1
w1=tf.Variable(tf.random_normal([784,600]))
b1=tf.Variable(tf.constant(0.1,shape=[600]))
z1=tf.matmul(x,w1)+b1
a1=tf.nn.tanh(z1)
#神经网络layer_2
w2=tf.Variable(tf.random_normal([600,10]))
b2=tf.Variable(tf.constant(0.1,shape=[10]))
z2=tf.matmul(a1,w2)+b2
#交叉熵代价函数
cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y,1),logits=z2)
#cross_entropy=tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=z2)
#L2正则化函数
regularizer=tf.contrib.layers.l2_regularizer(0.0001)
#总损失
loss=tf.reduce_mean(cross_entropy)+regularizer(w1)+regularizer(w2)
#学习率(指数衰减法)
laerning_rate = tf.train.exponential_decay(0.8,training_step,batch_num,0.999)
#梯度下降法优化器
train=tf.train.GradientDescentOptimizer(laerning_rate).minimize(loss,global_step=training_step)
#预测精度
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(z2,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#初始化变量
init=tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    test_feed={x:mnist.test.images,y:mnist.test.labels}
    for epoch in range(51):
        for batch in range(batch_num):
            x_,y_=mnist.train.next_batch(batch_size)
            sess.run(train,feed_dict={x:x_,y:y_})
        acc=sess.run(accuracy,feed_dict=test_feed)
        if epoch%10==0:
            print("epoch:",epoch,"accuracy:",acc)

迭代50次后,精度达到97.68%。
 声明:本文转自使用TensorFlow实现MNIST数据集分类
使用TensorFlow实现MNIST数据集分类的更多相关文章
- 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化
		
一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...
 - 3.keras-简单实现Mnist数据集分类
		
keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...
 - 6.keras-基于CNN网络的Mnist数据集分类
		
keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...
 - 一个简单的TensorFlow可视化MNIST数据集识别程序
		
下面是TensorFlow可视化MNIST数据集识别程序,可视化内容是,TensorFlow计算图,表(loss, 直方图, 标准差(stddev)) # -*- coding: utf-8 -*- ...
 - 深度学习原理与框架-Tensorflow基本操作-mnist数据集的逻辑回归  1.tf.matmul(点乘操作) 2.tf.equal(对应位置是否相等) 3.tf.cast(将布尔类型转换为数值类型) 4.tf.argmax(返回最大值的索引) 5.tf.nn.softmax(计算softmax概率值) 6.tf.train.GradientDescentOptimizer(损失值梯度下降器)
		
1. tf.matmul(X, w) # 进行点乘操作 参数说明:X,w都表示输入的数据, 2.tf.equal(x, y) # 比较两个数据对应位置的数是否相等,返回值为True,或者False 参 ...
 - TensorFlow 训练MNIST数据集(2)—— 多层神经网络
		
在我的上一篇随笔中,采用了单层神经网络来对MNIST进行训练,在测试集中只有约90%的正确率.这次换一种神经网络(多层神经网络)来进行训练和测试. 1.获取MNIST数据 MNIST数据集只要一行代码 ...
 - 《Hands-On Machine Learning with Scikit-Learn&TensorFlow》mnist数据集错误及解决方案
		
最近在看这本书看到Chapter 3.Classification,是关于mnist数据集的分类,里面有个代码是 from sklearn.datasets import fetch_mldata m ...
 - TensorFlow训练MNIST数据集(1) —— softmax 单层神经网络
		
1.MNIST数据集简介 首先通过下面两行代码获取到TensorFlow内置的MNIST数据集: from tensorflow.examples.tutorials.mnist import inp ...
 - 基于 tensorflow 的 mnist 数据集预测
		
1. tensorflow 基本使用方法 2. mnist 数据集简介与预处理 3. 聚类算法模型 4. 使用卷积神经网络进行特征生成 5. 训练网络模型生成结果 how to install ten ...
 - TensorFlow 下 mnist 数据集的操作及可视化
		
from tensorflow.examples.tutorials.mnist import input_data 首先需要连网下载数据集: mnsit = input_data.read_data ...
 
随机推荐
- 【TouchGFX】使用v4.18.1版本创建预制电路板工程的正确方法
			
选择要使用的电路板 实现自己的程序 Designer运行仿真没问题并生成代码 我习惯使用IAR工具,发现直接编译有错误 上述错误是因为Designer默认生成的工具链是CubeIDE,所以需要使用Cu ...
 - [转帖]【性能】大页内存 (HugePages)在通用程序优化中的应用
			
目录 1. 背景 2. 基于指纹的音乐检索简介 3. 原理 4. 小页的困境 5. 大页内存的配置和使用 6. 大页内存的优化效果 7. 大页内存的使用场景 8. 总结 LD_PRELOAD用法 原文 ...
 - DellEMC 服务器安装ESXi的简单步骤
			
DellEMC 服务器安装ESXi的简单步骤 背景 ESXi的镜像其实分为多种. 官方会发布一个版本的ISO. 然后会不定期进行升级, 解决安全,性能以及功能bug等. 7.0 为例的话 就有ESXi ...
 - iftop的学习与使用
			
iftop的学习与使用 背景 前段时间一直进行netperf 等网络性能验证工具的学习与使用. 监控很多时候采用了 node-exporter + prometheus + grafana来进行观察 ...
 - 【转帖】mysql一个索引块有多少指针_深刻理解MySQL系列之索引
			
索引 查找一条数据的过程 先看下InnoDB的逻辑存储结构:node 表空间:能够看作是InnoDB存储引擎逻辑结构的最高层,全部的数据都存放在表空间中.默认有个共享表空间ibdata1.若是启用in ...
 - Oracle数据库无法启动的简单处理
			
1. 最近一台测试机器上面的Oracle数据库启动不起来了. 提示信息是UNDOTBS2的表空间找不到. 2. 然后可以使用 startup mount 简单开起来 但是发现还是无法使用. 3.本来想 ...
 - CentOS7升级Glibc到超过2.17版本无法启动的解决办法
			
CentOS7升级Glibc到超过2.17版本无法启动的解决办法 背景 今天有同事告知服务器宕机无法启动. 提示信息为: [sda] Assuming drive cache: write throu ...
 - React中css的module
			
处理css全局作用 现在有这样一个场景: A页面和B页面都有一个相同的类名 我们在A页面中有引入css. B页面没有css 在我们切换A和B页面的时候. A页面的css也作用在了B页面. 我们只希望A ...
 - 【代码分享】使用 avx512 + 查表法,优化凯撒加密
			
作者:张富春(ahfuzhang),转载时请注明作者和引用链接,谢谢! cnblogs博客 zhihu Github 公众号:一本正经的瞎扯 关于凯撒加密,具体请看:https://en.wikipe ...
 - linux服务器cup100%问题排查
			
一.出现问题在发现公司门禁服务无法开门的第一时间,去线上服务器上查看了一下进程的运行情况,具体运行如下: 第一次在查看的时候发现并没有我需要的服务entranceguard进程(图片是后续截图的) 二 ...