1 MNIST数据集和CNN网络配置

关于MNIST数据集的说明及配置见使用TensorFlow实现MNIST数据集分类

CNN网络参数配置如下:

  • 原始数据:输入为[28,28],输出为[1,10]
  • 卷积核1:[5,5],32个特征 -->282832
  • 池化核1:[2,2],最大池化 -->141432
  • 卷积核2:[5,5],2个特征 -->141464
  • 池化核2:[2,2],最大池化 -->7764
  • 全连接1:[7764,1024]
  • 全连接2:[1024,10]

2 实验

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist=input_data.read_data_sets("MNIST_data",one_hot=True) #每批次的大小
batch_size = 100
#总批次数
batch_num = mnist.train.num_examples//batch_size #初始化权值函数
def weight_variable(shape):
initial=tf.truncated_normal(shape,stddev=0.1)
return tf.Variable(initial) #初始化偏置值函数
def bias_vairable(shape):
initial=tf.constant(0.1,shape=shape)
return tf.Variable(initial) #卷积层函数
def conv2d(x,w):
return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding='SAME') #池化层函数
def max_pool(x):
return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME') #定义三个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
keep_prob = tf.placeholder(tf.float32) x_image = tf.reshape(x,[-1,28,28,1]) #5*5的卷积核,1个平面->32个平面(每个平面抽取32个特征)
w_conv1 = weight_variable([5,5,1,32])
b_conv1 = bias_vairable([32])
#第一次卷积之后变为 28*28*32
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)
#第一次池化之后变为 14*14*32
h_pool1 = max_pool(h_conv1) #5*5的卷积核,32个平面->64个平面(每个平面抽取2个特征)
w_conv2 = weight_variable([5,5,32,64])
b_conv2 = bias_vairable([64])
#第二次卷积之后变为 14*14*64
h_conv2 = tf.nn.relu(conv2d(h_pool1,w_conv2) + b_conv2)
#第二次池化之后变为 7*7*64
h_pool2 = max_pool(h_conv2)
#7*7*64的图像变成1维向量
h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64]) #第一个全连接层
w_fc1 = weight_variable([7*7*64,1024])
b_fc1 = bias_vairable([1024])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) #第二个全连接层
w_fc2 = weight_variable([1024,10])
b_fc2 = bias_vairable([10])
h_fc2 = tf.matmul(h_fc1_drop,w_fc2) + b_fc2
#prediction = tf.nn.sigmoid(h_fc2) #交叉熵损失函数
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y,1), logits=h_fc2))
#loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=h_fc2))
train = tf.train.AdamOptimizer(0.001).minimize(loss) correct_prediction = (tf.equal(tf.argmax(h_fc2,1), tf.argmax(y,1)))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #初始化变量
init=tf.global_variables_initializer() with tf.Session() as sess:
sess.run(init)
test_feed={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0}
for epoch in range(6):
for batch in range(batch_num):
x_,y_=mnist.train.next_batch(batch_size)
sess.run(train,feed_dict={x:x_,y:y_,keep_prob:0.7})
acc=sess.run(accuracy,feed_dict=test_feed)
print("epoch:",epoch,"accuracy:",acc)

​ 声明:本文转自使用CNN实现MNIST数据集分类

使用CNN实现MNIST数据集分类的更多相关文章

  1. 6.keras-基于CNN网络的Mnist数据集分类

    keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...

  2. 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化

    一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...

  3. 3.keras-简单实现Mnist数据集分类

    keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...

  4. 卷积神经网络CNN识别MNIST数据集

    这次我们将建立一个卷积神经网络,它可以把MNIST手写字符的识别准确率提升到99%,读者可能需要一些卷积神经网络的基础知识才能更好的理解本节的内容. 程序的开头是导入TensorFlow: impor ...

  5. 深度学习(一)之MNIST数据集分类

    任务目标 对MNIST手写数字数据集进行训练和评估,最终使得模型能够在测试集上达到\(98\%\)的正确率.(最终本文达到了\(99.36\%\)) 使用的库的版本: python:3.8.12 py ...

  6. 第十三节,使用带有全局平均池化层的CNN对CIFAR10数据集分类

    这里使用的数据集仍然是CIFAR-10,由于之前写过一篇使用AlexNet对CIFAR数据集进行分类的文章,已经详细介绍了这个数据集,当时我们是直接把这些图片的数据文件下载下来,然后使用pickle进 ...

  7. 神经网络MNIST数据集分类tensorboard

    今天分享同样数据集的CNN处理方式,同时加上tensorboard,可以看到清晰的结构图,迭代1000次acc收敛到0.992 先放代码,注释比较详细,变量名字看单词就能知道啥意思 import te ...

  8. python,tensorflow,CNN实现mnist数据集的训练与验证正确率

    1.工程目录 2.导入data和input_data.py 链接:https://pan.baidu.com/s/1EBNyNurBXWeJVyhNeVnmnA 提取码:4nnl 3.CNN.py i ...

  9. MNIST数据集分类简单版本

      import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 mnist = ...

  10. 6.MNIST数据集分类简单版本

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 载入数据集 mnist = i ...

随机推荐

  1. -- spi flash 擦除接口调用HAL库不同函数的区别

    [描述] 在使用STM32F429操作W25Q128时,为验证flash工作正常,做简单的读写数据校验,在擦除接口中使用 HAL_SPI_Transmit 方法一直工作异常,使用 HAL_SPI_Tr ...

  2. python环境 anaconda安装

    官网: https://www.anaconda.com/distribution/#macos 国内镜像: https://mirrors.tuna.tsinghua.edu.cn/anaconda ...

  3. .net core 3.0 获取 IServiceProvider 实例

    .net core 3.0后,获取IServiceProvider需要绕点弯路 首先,新建一个类: public class MyServiceProviderFactory : IServicePr ...

  4. GoMusic-歌单迁移工具网站 一键迁移网易云/QQ音乐歌单至 Apple/Youtube/Spotify Music

    GoMusic是什么: GoMusic是一个在线歌单迁移工具网站,帮助用户一键迁移网易云/QQ音乐歌单至 Apple/Youtube/Spotify Music,直接输入歌单链接,复制查询结果,打开 ...

  5. [转帖]java -D参数设置系统属性无效问题及解决

    https://www.jb51.net/article/271236.htm   这篇文章主要介绍了java -D参数设置系统属性无效问题及解决方案,具有很好的参考价值,希望对大家有所帮助.如有错误 ...

  6. [转帖]Java和Scala的前世今生

    第一部分:Java 计算机语言介绍 第一代语言:机器语言.指令以二进制代码形式存在 第二代语言:汇编语言.使用助记符表示一条机器指令 第三代语言:高级语言 C.Pascal.Fortran面向过程的语 ...

  7. Nginx拆分配置文件的办法

    Nginx拆分配置文件的办法 摘要 最近公司使用Nginx进行微服务的路由处理 但是发现随着业务发展, 配置文件越来越复杂. 修改起来也很容易出现错误. 基于此. 想通过拆分配置文件的方式来提高修改效 ...

  8. SQLSERVER 通过分离附加的方式迁移文件存储的位置.

    有时候 SQLSERVER的数据库所在磁盘满了,或者是性能变的很差, 需要转换一下磁盘的位置. 这个时候最简单的办法是通过分离附加的方式进行处理. 0. 准备工作 0.1 备份-备份-备份 没有备份 ...

  9. postman中monitor的使用

    monitor就是一个摸鱼的功能,我们把写好的接口部署到postman的web服务器中, 绑定自己的邮箱,运行结果会发送到自己的邮箱中,不用实时监控,是个非常方便 的功能(不安全) 1.crete a ...

  10. React数据通信父传子和子传父的使用

    组件中的props 在react中,props的特点是: 1.可以给组件传递任意类型的数据 2.props是只读的对象,只能够读取属性的值,无法修改对象 如过我们强行修改数据,会报错,告诉我们该属性是 ...