手动设计神经网络进行MNIST分类
前言:
用手工设计的两层神经网络,经过200个epoch,最后得到0.9599,约0.96的精度
正文
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist = input_data.read_data_sets("MNIST_data\MNIST_data",one_hot=True) #每个批次的大小
batch_size = 32
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size #定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
y_ = tf.cast(tf.argmax(y,axis=1),tf.int32)
#创建一个简单的神经网络
W_1 = tf.Variable(tf.random_normal([784,120],dtype=tf.float32))
b_1 = tf.Variable(tf.zeros([120]))
h_1 = tf.nn.relu(tf.matmul(x,W_1)+b_1) W_2 = tf.Variable(tf.random_normal([120,10],dtype=tf.float32))
b_2 = tf.Variable(tf.zeros([10]))
prediction = tf.matmul(h_1,W_2)+b_2
prediction_ = tf.nn.softmax(tf.matmul(h_1,W_2)+b_2) # #二次代价函数
#loss = tf.reduce_mean(tf.square(y-prediction))
#交叉熵损失函数
#loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_, logits=prediction_)
loss = tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=prediction)
#loss = tf.reduce_mean(-tf.reduce_sum(y * tf.log(prediction),reduction_indices=[1]))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化变量
init = tf.global_variables_initializer() #结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(prediction_,1)) #返回最大值所在位置,1表示行的维度
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) with tf.Session() as sess:
sess.run(init)
for epoch in range(200):
for batch in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step, feed_dict={x:batch_xs,y:batch_ys}) acc = sess.run(accuracy,feed_dict={x:mnist.test.images, y:mnist.test.labels})
print("Iter"+str(epoch)+',Testing Accuracy'+str(acc))
其中要注意的地方应该有:
loss函数的计算,用了tf.losses.sparse_softmax_cross_entropy这个交叉熵损失函数,其中:
labels_的输入是样本是真实标签,类似于[1,2,3,4,5,1,1,2....]这种,
所以,MNIST的样本标签是one-hot形式的,要先用tf.argmax转换成上述形式;
logits的输入类似于[1.22,4.23,2.45,...]这种,由于该函数会先进行logits-->softmax的计算,所以不用先把logits转换为softmax形式;
手动设计神经网络进行MNIST分类的更多相关文章
- Pytorch搭建卷积神经网络用于MNIST分类
import torch from torch.utils.data import DataLoader from torchvision import datasets from torchvisi ...
- matlab练习程序(神经网络识别mnist手写数据集)
记得上次练习了神经网络分类,不过当时应该有些地方写的还是不对. 这次用神经网络识别mnist手写数据集,主要参考了深度学习工具包的一些代码. mnist数据集训练数据一共有28*28*60000个像素 ...
- deep_learning_LSTM长短期记忆神经网络处理Mnist数据集
1.RNN(Recurrent Neural Network)循环神经网络模型 详见RNN循环神经网络:https://www.cnblogs.com/pinard/p/6509630.html 2. ...
- TensorFlow——LSTM长短期记忆神经网络处理Mnist数据集
1.RNN(Recurrent Neural Network)循环神经网络模型 详见RNN循环神经网络:https://www.cnblogs.com/pinard/p/6509630.html 2. ...
- L26 使用卷积及循环神经网络进行文本分类
文本情感分类 文本分类是自然语言处理的一个常见任务,它把一段不定长的文本序列变换为文本的类别.本节关注它的一个子问题:使用文本情感分类来分析文本作者的情绪.这个问题也叫情感分析,并有着广泛的应用. 同 ...
- 使用pytorch快速搭建神经网络实现二分类任务(包含示例)
使用pytorch快速搭建神经网络实现二分类任务(包含示例) Introduce 上一篇学习笔记介绍了不使用pytorch包装好的神经网络框架实现logistic回归模型,并且根据autograd实现 ...
- CVPR2022 | A ConvNet for the 2020s & 如何设计神经网络总结
前言 本文深入探讨了如何设计神经网络.如何使得训练神经网络具有更加优异的效果,以及思考网络设计的物理意义. 欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结.最新技术跟踪.经典论文解读.CV招聘 ...
- 芝麻HTTP:TensorFlow LSTM MNIST分类
本节来介绍一下使用 RNN 的 LSTM 来做 MNIST 分类的方法,RNN 相比 CNN 来说,速度可能会慢,但可以节省更多的内存空间. 初始化 首先我们可以先初始化一些变量,如学习率.节点单元数 ...
- 用标准3层神经网络实现MNIST识别
一.MINIST数据集下载 1.https://pjreddie.com/projects/mnist-in-csv/ 此网站提供了mnist_train.csv和mnist_test.cs ...
随机推荐
- Permute Digits
You are given two positive integer numbers a and b. Permute (change order) of the digits of a to con ...
- MVC-MVP-MVVM框架模式分析
MVC(Model-View-Controller) MVC 架构模式图(经典版) 注:实际上,Model和View永远不能相互通信,只能通过Controller传递:上图只是MVC模式的经典图. M ...
- LB_GETCURSEL和LB_GETTEXT的使用
case IDC_LIST1: { switch (HIWORD(wParam)) { case LBN_DBLCLK: { HWND hwndList = GetDlgItem(hDlg, IDC_ ...
- DevExpress WPF v19.1新版亮点:Gantt/Map控件新功能
行业领先的.NET界面控件DevExpress 日前正式发布v19.1版本,本站将以连载的形式介绍各版本新增内容.在本系列文章中将为大家介绍DevExpress WPFv19.1中新增的一些控件及部分 ...
- Perf Event :Linux下的系统性能调优工具
Perf Event :Linux下的系统性能调优工具 2011-05-27 10:35 刘 明 IBMDW 字号:T | T Perf Event 是一款随 Linux 内核代码一同发布和维护的性能 ...
- MySQL 创建唯一索引忽略对已经重复数据的检查
MySQL 创建唯一索引忽略对已经重复数据的检查 在创建唯一索引的基础上加上关键字"IGNORE "即可.(注意,经测试,在5.7版本已经不再支持该参数) # 重复数据 mysql ...
- Java面试之基础篇(5)
41.a.hashCode() 有什么用?与 a.equals(b) 有什么关系? hashCode() 方法对应对象整型的 hash 值.它常用于基于 hash 的集合类,如 Hash ...
- kvm:双网卡做bond+桥接
一,KVM基础 kvm是一种技术,云计算是一种模式,虚拟化是利用相应的技术方法在一台物理机器上将其按照不同的需求划分成多个相同或者不同的虚拟操作系统,并且各个虚拟系统可以同时运行,互不干扰,其中任何一 ...
- javaweb上传大文件的问题
总结一下大文件分片上传和断点续传的问题.因为文件过大(比如1G以上),必须要考虑上传过程网络中断的情况.http的网络请求中本身就已经具备了分片上传功能,当传输的文件比较大时,http协议自动会将文件 ...
- BZOJ1491 Red is good
题目链接:Click here Solution: 考虑设\(f(i,j)\)表示当前还有\(i\)张红牌,\(j\)张黑牌时的期望收益 易得状态转移方程:\(f(i,j)=\frac{i}{i+j} ...