Tensorflow学习教程------普通神经网络对mnist数据集分类
首先是不含隐层的神经网络, 输入层是784个神经元 输出层是10个神经元
代码如下

#coding:utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
#每个批次的大小
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size #定义两个placeholder
x = tf.placeholder(tf.float32, [None,784]) #输入图像
y = tf.placeholder(tf.float32, [None,10]) #输入标签 #创建一个简单的神经网络 784个像素点对应784个数 因此输入层是784个神经元 输出层是10个神经元 不含隐层
#最后准确率在92%左右
W = tf.Variable(tf.zeros([784,10])) #生成784行 10列的全0矩阵
b = tf.Variable(tf.zeros([1,10]))
prediction = tf.nn.softmax(tf.matmul(x,W)+b) #二次代价函数
loss = tf.reduce_mean(tf.square(y-prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化变量
init = tf.global_variables_initializer() #结果存放在布尔型列表中
#argmax能给出某个tensor对象在某一维上的其数据最大值所在的索引值
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(prediction,1))
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
with tf.Session() as sess:
sess.run(init)
for epoch in range(21): #21个epoch 把所有的图片训练21次
for batch in range(n_batch): #
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
acc = sess.run(accuracy,feed_dict={x:mnist.test.images, y:mnist.test.labels})
print ("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))

结果如下

Iter 0,Testing Accuracy 0.8304
Iter 1,Testing Accuracy 0.8704
Iter 2,Testing Accuracy 0.8821
Iter 3,Testing Accuracy 0.8876
Iter 4,Testing Accuracy 0.8932
Iter 5,Testing Accuracy 0.8968
Iter 6,Testing Accuracy 0.8995
Iter 7,Testing Accuracy 0.9019
Iter 8,Testing Accuracy 0.9033
Iter 9,Testing Accuracy 0.9048
Iter 10,Testing Accuracy 0.9065
Iter 11,Testing Accuracy 0.9074
Iter 12,Testing Accuracy 0.9084
Iter 13,Testing Accuracy 0.909
Iter 14,Testing Accuracy 0.9094
Iter 15,Testing Accuracy 0.9112
Iter 16,Testing Accuracy 0.9117
Iter 17,Testing Accuracy 0.9128
Iter 18,Testing Accuracy 0.9127
Iter 19,Testing Accuracy 0.9132
Iter 20,Testing Accuracy 0.9144

接下来是含一个隐层的神经网络,输入层是784个神经元,两个隐层都是100个神经元,输出层是10个神经元,迭代500次,最后准确率在88%左右,汗。。。。准确率反而降低了,慢慢调参吧

#coding:utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
#每个批次的大小
batch_size = 50
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size #定义两个placeholder
x = tf.placeholder(tf.float32, [None,784]) #输入图像
y = tf.placeholder(tf.float32, [None,10]) #输入标签 #定义神经网络中间层
Weights_L1 = tf.Variable(tf.random_normal([784,100]))
biase_L1 = tf.Variable(tf.zeros([1,100]))
Wx_plus_b_L1 = tf.matmul(x, Weights_L1)+biase_L1
L1 = tf.nn.tanh(Wx_plus_b_L1) #使用正切函数作为激活函数 Weights_L2 = tf.Variable(tf.random_normal([100,100]))
biase_L2 = tf.Variable(tf.zeros([1,100]))
Wx_plus_b_L2 = tf.matmul(L1, Weights_L2)+biase_L2
L2 = tf.nn.tanh(Wx_plus_b_L2) #使用正切函数作为激活函数 #定义神经网络输出层
Weights_L3 = tf.Variable(tf.random_normal([100,10]))
biase_L3 = tf.Variable(tf.zeros([1,10]))
Wx_plus_b_L3 = tf.matmul(L2,Weights_L3) + biase_L3
prediction = tf.nn.tanh(Wx_plus_b_L3) #二次代价函数
loss = tf.reduce_mean(tf.square(y-prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化变量
init = tf.global_variables_initializer() #结果存放在布尔型列表中
#argmax能给出某个tensor对象在某一维上的其数据最大值所在的索引值
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(prediction,1))
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
with tf.Session() as sess:
sess.run(init)
for epoch in range(500):
for batch in range(n_batch): batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
acc = sess.run(accuracy,feed_dict={x:mnist.test.images, y:mnist.test.labels})
print ("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))


Iter 487,Testing Accuracy 0.8847
Iter 488,Testing Accuracy 0.8853
Iter 489,Testing Accuracy 0.878
Iter 490,Testing Accuracy 0.8861
Iter 491,Testing Accuracy 0.8863
Iter 492,Testing Accuracy 0.8784
Iter 493,Testing Accuracy 0.8855
Iter 494,Testing Accuracy 0.8787
Iter 495,Testing Accuracy 0.881
Iter 496,Testing Accuracy 0.8837
Iter 497,Testing Accuracy 0.8817
Iter 498,Testing Accuracy 0.8837
Iter 499,Testing Accuracy 0.8866
Tensorflow学习教程------普通神经网络对mnist数据集分类的更多相关文章
- TensorFlow——LSTM长短期记忆神经网络处理Mnist数据集
1.RNN(Recurrent Neural Network)循环神经网络模型 详见RNN循环神经网络:https://www.cnblogs.com/pinard/p/6509630.html 2. ...
- 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化
一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...
- 深度学习(一)之MNIST数据集分类
任务目标 对MNIST手写数字数据集进行训练和评估,最终使得模型能够在测试集上达到\(98\%\)的正确率.(最终本文达到了\(99.36\%\)) 使用的库的版本: python:3.8.12 py ...
- TensorFlow初探之简单神经网络训练mnist数据集(TensorFlow2.0代码)
from __future__ import print_function from tensorflow.examples.tutorials.mnist import input_data #加载 ...
- Tensorflow学习教程------实现lenet并且进行二分类
#coding:utf-8 import tensorflow as tf import os def read_and_decode(filename): #根据文件名生成一个队列 filename ...
- TensorFlow——CNN卷积神经网络处理Mnist数据集
CNN卷积神经网络处理Mnist数据集 CNN模型结构: 输入层:Mnist数据集(28*28) 第一层卷积:感受视野5*5,步长为1,卷积核:32个 第一层池化:池化视野2*2,步长为2 第二层卷积 ...
- Tensorflow学习教程------过拟合
Tensorflow学习教程------过拟合 回归:过拟合情况 / 分类过拟合 防止过拟合的方法有三种: 1 增加数据集 2 添加正则项 3 Dropout,意思就是训练的时候隐层神经元每次随机 ...
- deep_learning_LSTM长短期记忆神经网络处理Mnist数据集
1.RNN(Recurrent Neural Network)循环神经网络模型 详见RNN循环神经网络:https://www.cnblogs.com/pinard/p/6509630.html 2. ...
- Tensorflow学习教程------读取数据、建立网络、训练模型,小巧而完整的代码示例
紧接上篇Tensorflow学习教程------tfrecords数据格式生成与读取,本篇将数据读取.建立网络以及模型训练整理成一个小样例,完整代码如下. #coding:utf-8 import t ...
随机推荐
- pytorch max和clamp
torch.max() torch.max(a):数组a的最大值 torch.max(a, dim=1):多维数组沿维度1方向上的最大值,若a为二维数组,则为每行的最大值(此时是对每行的每列值比较取最 ...
- oozie的常见错误
1.变量或路径的英文字母写错,常常是大小写搞混,或者是字母顺序颠倒. 2.本地 oozie_works 工作目录下的文件,如job.properties,workflow.xml等,修改后,忘记上传到 ...
- 微信小程序循环中点击一个元素,其他的元素不发生变化,类似点击一个循环中的语音,其他的不发生点击事件
类似语音,因为都在一个数据内,所以点击第一个,所有的语音都变化,解决方法就是 把整个数据都获取下来,然后更改其中一个需要更改的值,然后再把整个数据都setdata回去,如果需要动画的话,wxml里面放 ...
- ubuntu18.04下载yarn
下载curl sudo apt-get update && sudo apt-get install curl 配置库 curl -sS https://dl.yarnpkg.com/ ...
- MongoDB_02简介
MongoDB简介 MongoDB是一个开源,高性能,无模式的文档型数据库. 它支持的数据结构非常松散,是一种类似于JSON的格式叫BSON,所以他既可以存储比较复杂的数据类型,又相当的灵活. Mon ...
- Linux学习《第四章脚本》20200222
- Golang函数-函数的基本概念
Golang函数-函数的基本概念 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 一.函数的概述 1>.函数定义语法格式 Go语言函数定义格式如下: func 函数名( 函数参 ...
- 洛谷 P2426 删数
题目传送门 解题思路: 区间DP,f[i][j]表示区间i~j可获得的最大值,因为本题的所有区间是可以直接一次性把自己全删掉的,所以所有区间初始化为被一次性删除的值,然后枚举断点,跑区间DP. AC代 ...
- C/C++贪心算法解决TSP问题
贪心算法解决旅行商问题 TSP问题(Traveling Salesman Problem,旅行商问题),由威廉哈密顿爵士和英国数学家克克曼T.P.Kirkman于19世纪初提出.问题描述如下: 有若干 ...
- 端口通不通 telnet wget ssh
如何测试端口通不通(四种方法) 投稿:mrr 一般情况下使用"telnet ip port"判断端口通不通.接下来通过本文给大家分享四种方法测试端口通不通,感兴趣的朋友一起学习吧 ...