使用CNN实现MNIST数据集分类
1 MNIST数据集和CNN网络配置
关于MNIST数据集的说明及配置见使用TensorFlow实现MNIST数据集分类
CNN网络参数配置如下:
- 原始数据:输入为[28,28],输出为[1,10]
- 卷积核1:[5,5],32个特征 -->282832
- 池化核1:[2,2],最大池化 -->141432
- 卷积核2:[5,5],2个特征 -->141464
- 池化核2:[2,2],最大池化 -->7764
- 全连接1:[7764,1024]
- 全连接2:[1024,10]
2 实验
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#载入数据集
mnist=input_data.read_data_sets("MNIST_data",one_hot=True)
#每批次的大小
batch_size = 100
#总批次数
batch_num = mnist.train.num_examples//batch_size
#初始化权值函数
def weight_variable(shape):
initial=tf.truncated_normal(shape,stddev=0.1)
return tf.Variable(initial)
#初始化偏置值函数
def bias_vairable(shape):
initial=tf.constant(0.1,shape=shape)
return tf.Variable(initial)
#卷积层函数
def conv2d(x,w):
return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding='SAME')
#池化层函数
def max_pool(x):
return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
#定义三个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
keep_prob = tf.placeholder(tf.float32)
x_image = tf.reshape(x,[-1,28,28,1])
#5*5的卷积核,1个平面->32个平面(每个平面抽取32个特征)
w_conv1 = weight_variable([5,5,1,32])
b_conv1 = bias_vairable([32])
#第一次卷积之后变为 28*28*32
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)
#第一次池化之后变为 14*14*32
h_pool1 = max_pool(h_conv1)
#5*5的卷积核,32个平面->64个平面(每个平面抽取2个特征)
w_conv2 = weight_variable([5,5,32,64])
b_conv2 = bias_vairable([64])
#第二次卷积之后变为 14*14*64
h_conv2 = tf.nn.relu(conv2d(h_pool1,w_conv2) + b_conv2)
#第二次池化之后变为 7*7*64
h_pool2 = max_pool(h_conv2)
#7*7*64的图像变成1维向量
h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64])
#第一个全连接层
w_fc1 = weight_variable([7*7*64,1024])
b_fc1 = bias_vairable([1024])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
#第二个全连接层
w_fc2 = weight_variable([1024,10])
b_fc2 = bias_vairable([10])
h_fc2 = tf.matmul(h_fc1_drop,w_fc2) + b_fc2
#prediction = tf.nn.sigmoid(h_fc2)
#交叉熵损失函数
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y,1), logits=h_fc2))
#loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=h_fc2))
train = tf.train.AdamOptimizer(0.001).minimize(loss)
correct_prediction = (tf.equal(tf.argmax(h_fc2,1), tf.argmax(y,1)))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
#初始化变量
init=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
test_feed={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0}
for epoch in range(6):
for batch in range(batch_num):
x_,y_=mnist.train.next_batch(batch_size)
sess.run(train,feed_dict={x:x_,y:y_,keep_prob:0.7})
acc=sess.run(accuracy,feed_dict=test_feed)
print("epoch:",epoch,"accuracy:",acc)

声明:本文转自使用CNN实现MNIST数据集分类
使用CNN实现MNIST数据集分类的更多相关文章
- 6.keras-基于CNN网络的Mnist数据集分类
keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...
- 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化
一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...
- 3.keras-简单实现Mnist数据集分类
keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...
- 卷积神经网络CNN识别MNIST数据集
这次我们将建立一个卷积神经网络,它可以把MNIST手写字符的识别准确率提升到99%,读者可能需要一些卷积神经网络的基础知识才能更好的理解本节的内容. 程序的开头是导入TensorFlow: impor ...
- 深度学习(一)之MNIST数据集分类
任务目标 对MNIST手写数字数据集进行训练和评估,最终使得模型能够在测试集上达到\(98\%\)的正确率.(最终本文达到了\(99.36\%\)) 使用的库的版本: python:3.8.12 py ...
- 第十三节,使用带有全局平均池化层的CNN对CIFAR10数据集分类
这里使用的数据集仍然是CIFAR-10,由于之前写过一篇使用AlexNet对CIFAR数据集进行分类的文章,已经详细介绍了这个数据集,当时我们是直接把这些图片的数据文件下载下来,然后使用pickle进 ...
- 神经网络MNIST数据集分类tensorboard
今天分享同样数据集的CNN处理方式,同时加上tensorboard,可以看到清晰的结构图,迭代1000次acc收敛到0.992 先放代码,注释比较详细,变量名字看单词就能知道啥意思 import te ...
- python,tensorflow,CNN实现mnist数据集的训练与验证正确率
1.工程目录 2.导入data和input_data.py 链接:https://pan.baidu.com/s/1EBNyNurBXWeJVyhNeVnmnA 提取码:4nnl 3.CNN.py i ...
- MNIST数据集分类简单版本
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 mnist = ...
- 6.MNIST数据集分类简单版本
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 载入数据集 mnist = i ...
随机推荐
- HttpClient获取不到最新的系统代理
默认情况下,HttpClient是默认采用系统代理,但是,如果你在程序运行过程中,手动修改系统代理,对于HttpClient是无效的,它依然会用老的代理去访问. 解决方法 使用下面的代码,你可以自己实 ...
- 百度网盘(百度云)SVIP超级会员共享账号每日更新(2023.12.20)
一.百度网盘SVIP超级会员共享账号 可能很多人不懂这个共享账号是什么意思,小编在这里给大家做一下解答. 我们多知道百度网盘很大的用处就是类似U盘,不同的人把文件上传到百度网盘,别人可以直接下载,避免 ...
- Docker导出镜像的总结
Docker导出镜像的总结 安装Docker mkdir -p /etc/docker cat >/etc/docker/daemon.josn <<EOF { "bip& ...
- [转帖]三篇文章了解 TiDB 技术内幕 - 谈调度
返回全部 申砾产品技术解读2017-06-06 为什么要进行调度 先回忆一下 三篇文章了解 TiDB 技术内幕 - 说存储提到的一些信息,TiKV 集群是 TiDB 数据库的分布式 KV 存储引擎,数 ...
- [转帖]Jmeter之界面语言设置
https://developer.aliyun.com/article/1173114#:~:text=%E6%B0%B8%E4%B9%85%E6%80%A7%E8%AE%BE%E7%BD%AE%E ...
- [转帖]整理常用的 vim 命令
vim 是一款功能强大的文本编辑器,它是Linux下常用的编辑器之一,对于熟练掌握了 vim 的人来说,用它编辑文件,方便又快捷,能极大的提高工作效率 vim 功能强大,对应的命令也非常的多,对于初学 ...
- [转帖]/etc/profile和/etc/environment的区别
时间 2019-11-07 标签 profile environment 区别 繁體版 原文 https://my.oschina.net/u/2885925/blog/2989579 /etc ...
- [转帖]Oracle 通过 Exadata 云基础设施 X9M 提供卓越的数据库性能和规模
https://www.modb.pro/db/397202 32个节点的RAC 服务器 每个服务器 两个 64核心的AMD CPU 四个线程干管理 252个线程进行数据库处理 252*32=8064 ...
- Skia 编译及踩坑实践
本文要点 •了解并入门 Skia.OpenGL 和 Vulkan •了解 Skia 在后端渲染上的坑点 前言 Skia 是什么 Skia 是一个开源 2D 图形库,提供可跨各种硬件和软件平台工作的通用 ...
- 详解Promise.race()可以解决多个异步请求那个请求先返回
Promise.race([]);接受一个参数,由promise组成的一个数组: 它的返回结果是promise对象: 它的结果和状态由什么去决定呢? 由第一个改变Promise状态的对象去决定:若是返 ...