1 MNIST数据集

MNIST数据集由70000张28x28像素的黑白图片组成,每一张图片都写有0~9中的一个数字,每个像素点的灰度值在0 ~ 255(0是黑色,255是白色)之间。



MINST数据集是由Yann LeCun教授提供的手写数字数据库文件,其官方下载地址THE MNIST DATABASE of handwritten digits



下载好MNIST数据集后,将其放在Spyder工作目录下(若使用Jupyter编程,则放在Jupyter工作目录下),如图:



G:\Anaconda\Spyder为笔者Spyder工作目录,MNIST_data为新建文件夹,读者也可以自行命名。

2 实验

为方便设计神经网络输入层,将每张28x28像素图片的像素值按行排成一行,故输入层设计28x28=784个神经元,隐藏层设计600个神经元,输出层设计10个神经元。使用read_data_sets()函数载入数据集,并返回一个类,这个类将MNIST数据集划分为train、validation、test 3个数据集,对应图片数分别为55000、5000、10000。本文采用交叉熵损失函数,并且为防止过拟合问题产生,引入正则化方法。

mnist.py

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist=input_data.read_data_sets("MNIST_data",one_hot=True) #每批次的大小
batch_size=100
#总批次数
batch_num=mnist.train.num_examples//batch_size
#训练轮数
training_step = tf.Variable(0,trainable=False) #定义两个placeholder
x=tf.placeholder(tf.float32, [None,784])
y=tf.placeholder(tf.float32, [None,10]) #神经网络layer_1
w1=tf.Variable(tf.random_normal([784,600]))
b1=tf.Variable(tf.constant(0.1,shape=[600]))
z1=tf.matmul(x,w1)+b1
a1=tf.nn.tanh(z1) #神经网络layer_2
w2=tf.Variable(tf.random_normal([600,10]))
b2=tf.Variable(tf.constant(0.1,shape=[10]))
z2=tf.matmul(a1,w2)+b2 #交叉熵代价函数
cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y,1),logits=z2)
#cross_entropy=tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=z2)
#L2正则化函数
regularizer=tf.contrib.layers.l2_regularizer(0.0001)
#总损失
loss=tf.reduce_mean(cross_entropy)+regularizer(w1)+regularizer(w2)
#学习率(指数衰减法)
laerning_rate = tf.train.exponential_decay(0.8,training_step,batch_num,0.999)
#梯度下降法优化器
train=tf.train.GradientDescentOptimizer(laerning_rate).minimize(loss,global_step=training_step) #预测精度
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(z2,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) #初始化变量
init=tf.global_variables_initializer() with tf.Session() as sess:
sess.run(init)
test_feed={x:mnist.test.images,y:mnist.test.labels}
for epoch in range(51):
for batch in range(batch_num):
x_,y_=mnist.train.next_batch(batch_size)
sess.run(train,feed_dict={x:x_,y:y_})
acc=sess.run(accuracy,feed_dict=test_feed)
if epoch%10==0:
print("epoch:",epoch,"accuracy:",acc)



迭代50次后,精度达到97.68%。

​ 声明:本文转自使用TensorFlow实现MNIST数据集分类

使用TensorFlow实现MNIST数据集分类的更多相关文章

  1. 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化

    一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...

  2. 3.keras-简单实现Mnist数据集分类

    keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...

  3. 6.keras-基于CNN网络的Mnist数据集分类

    keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...

  4. 一个简单的TensorFlow可视化MNIST数据集识别程序

    下面是TensorFlow可视化MNIST数据集识别程序,可视化内容是,TensorFlow计算图,表(loss, 直方图, 标准差(stddev)) # -*- coding: utf-8 -*- ...

  5. 深度学习原理与框架-Tensorflow基本操作-mnist数据集的逻辑回归 1.tf.matmul(点乘操作) 2.tf.equal(对应位置是否相等) 3.tf.cast(将布尔类型转换为数值类型) 4.tf.argmax(返回最大值的索引) 5.tf.nn.softmax(计算softmax概率值) 6.tf.train.GradientDescentOptimizer(损失值梯度下降器)

    1. tf.matmul(X, w) # 进行点乘操作 参数说明:X,w都表示输入的数据, 2.tf.equal(x, y) # 比较两个数据对应位置的数是否相等,返回值为True,或者False 参 ...

  6. TensorFlow 训练MNIST数据集(2)—— 多层神经网络

    在我的上一篇随笔中,采用了单层神经网络来对MNIST进行训练,在测试集中只有约90%的正确率.这次换一种神经网络(多层神经网络)来进行训练和测试. 1.获取MNIST数据 MNIST数据集只要一行代码 ...

  7. 《Hands-On Machine Learning with Scikit-Learn&TensorFlow》mnist数据集错误及解决方案

    最近在看这本书看到Chapter 3.Classification,是关于mnist数据集的分类,里面有个代码是 from sklearn.datasets import fetch_mldata m ...

  8. TensorFlow训练MNIST数据集(1) —— softmax 单层神经网络

    1.MNIST数据集简介 首先通过下面两行代码获取到TensorFlow内置的MNIST数据集: from tensorflow.examples.tutorials.mnist import inp ...

  9. 基于 tensorflow 的 mnist 数据集预测

    1. tensorflow 基本使用方法 2. mnist 数据集简介与预处理 3. 聚类算法模型 4. 使用卷积神经网络进行特征生成 5. 训练网络模型生成结果 how to install ten ...

  10. TensorFlow 下 mnist 数据集的操作及可视化

    from tensorflow.examples.tutorials.mnist import input_data 首先需要连网下载数据集: mnsit = input_data.read_data ...

随机推荐

  1. 【TouchGFX 】使用 CubeMX 创建 TouchGFX 工程时 LCD 死活不显示

    生成的代码死活无法让LCD显示,经两个晚上的分析验证是LTDC_CLK引脚速度设置为低速导致,经测试中速.高速.超高速都正常,真是冤,聊以此以示纪念

  2. Oracle process/session/cursor/tx/tm的简单学习

    Oracle process/session/cursor/tx/tm的简单学习 Oracle的部署模式 Oracle安装时有专用模式和共享模式的区别 共享模式(Shared mode): 在共享模式 ...

  3. [转帖]Jmeter脚本录制:Jmeter5.0脚本录制

    第一部分进行jmeter设置 第一步:在JMeter中添加线程组 第二步:在线程组下添加HTTP请求默认值 添加->配置元件->HTTP请求默认值,设置服务器IP和端口号 第三步:在线程组 ...

  4. [转帖]备份VCSA内置Postgresql数据库

    首先命令行远程登录到VCSA服务器,然后执行如下命令停掉VCSA的核心服务vmware-vpxd: vCenterServerAppliance:~ # service vmware-vpxd sto ...

  5. [转帖]INTEL MLC(Memory Latency Checker)介绍

    https://zhuanlan.zhihu.com/p/359823092 在定位机器性能问题的时候,有时会觉得机器莫名其妙地跑的慢,怎么也看不出来问题.CPU频率也正常,程序热点也没问题,可就是慢 ...

  6. [转帖]Redis如何绑定CPU

    文章系转载,便于分类和归纳,源文地址:https://www.yisu.com/zixun/672271.html 绑定 CPU Redis 6.0 开始支持绑定 CPU,可以有效减少线程上下文切换. ...

  7. [转帖]使用GCC编译器实测兆芯KX-U6780A的SPEC CPU2006成绩

      https://baijiahao.baidu.com/s?id=1722775453962904303 兆芯KX-U6780A是一款8核2.7GHz的使用x86/AMD64指令集(架构)的国产C ...

  8. ELK运维文档

    Logstash 目录 Logstash Monitoring API Node Info API Plugins Info API Node Stats API Hot Threads API lo ...

  9. ClickHouse(08)ClickHouse表引擎概况

    目录 合并树家族 日志引擎系列 集成的表引擎 其他特殊的引擎 资料分享 参考文章 目前ClickHouse的表引擎主要有下面四个系列,合并树家族.日志引擎系列.集成的表引擎和其他特殊的引擎. 合并树家 ...

  10. 13.4 DirectX内部劫持绘制

    相对于外部绘图技术的不稳定性,内部绘制则显得更加流程与稳定,在Dx9环境中,函数EndScene是在绘制3D场景后,用于完成将最终的图像渲染到屏幕的一系列操作的函数.它会将缓冲区中的图像清空,设置视口 ...