1 MNIST数据集和CNN网络配置

关于MNIST数据集的说明及配置见使用TensorFlow实现MNIST数据集分类

CNN网络参数配置如下:

  • 原始数据:输入为[28,28],输出为[1,10]
  • 卷积核1:[5,5],32个特征 -->282832
  • 池化核1:[2,2],最大池化 -->141432
  • 卷积核2:[5,5],2个特征 -->141464
  • 池化核2:[2,2],最大池化 -->7764
  • 全连接1:[7764,1024]
  • 全连接2:[1024,10]

2 实验

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist=input_data.read_data_sets("MNIST_data",one_hot=True) #每批次的大小
batch_size = 100
#总批次数
batch_num = mnist.train.num_examples//batch_size #初始化权值函数
def weight_variable(shape):
initial=tf.truncated_normal(shape,stddev=0.1)
return tf.Variable(initial) #初始化偏置值函数
def bias_vairable(shape):
initial=tf.constant(0.1,shape=shape)
return tf.Variable(initial) #卷积层函数
def conv2d(x,w):
return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding='SAME') #池化层函数
def max_pool(x):
return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME') #定义三个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
keep_prob = tf.placeholder(tf.float32) x_image = tf.reshape(x,[-1,28,28,1]) #5*5的卷积核,1个平面->32个平面(每个平面抽取32个特征)
w_conv1 = weight_variable([5,5,1,32])
b_conv1 = bias_vairable([32])
#第一次卷积之后变为 28*28*32
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)
#第一次池化之后变为 14*14*32
h_pool1 = max_pool(h_conv1) #5*5的卷积核,32个平面->64个平面(每个平面抽取2个特征)
w_conv2 = weight_variable([5,5,32,64])
b_conv2 = bias_vairable([64])
#第二次卷积之后变为 14*14*64
h_conv2 = tf.nn.relu(conv2d(h_pool1,w_conv2) + b_conv2)
#第二次池化之后变为 7*7*64
h_pool2 = max_pool(h_conv2)
#7*7*64的图像变成1维向量
h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64]) #第一个全连接层
w_fc1 = weight_variable([7*7*64,1024])
b_fc1 = bias_vairable([1024])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) #第二个全连接层
w_fc2 = weight_variable([1024,10])
b_fc2 = bias_vairable([10])
h_fc2 = tf.matmul(h_fc1_drop,w_fc2) + b_fc2
#prediction = tf.nn.sigmoid(h_fc2) #交叉熵损失函数
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y,1), logits=h_fc2))
#loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=h_fc2))
train = tf.train.AdamOptimizer(0.001).minimize(loss) correct_prediction = (tf.equal(tf.argmax(h_fc2,1), tf.argmax(y,1)))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #初始化变量
init=tf.global_variables_initializer() with tf.Session() as sess:
sess.run(init)
test_feed={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0}
for epoch in range(6):
for batch in range(batch_num):
x_,y_=mnist.train.next_batch(batch_size)
sess.run(train,feed_dict={x:x_,y:y_,keep_prob:0.7})
acc=sess.run(accuracy,feed_dict=test_feed)
print("epoch:",epoch,"accuracy:",acc)

​ 声明:本文转自使用CNN实现MNIST数据集分类

使用CNN实现MNIST数据集分类的更多相关文章

  1. 6.keras-基于CNN网络的Mnist数据集分类

    keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...

  2. 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化

    一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...

  3. 3.keras-简单实现Mnist数据集分类

    keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...

  4. 卷积神经网络CNN识别MNIST数据集

    这次我们将建立一个卷积神经网络,它可以把MNIST手写字符的识别准确率提升到99%,读者可能需要一些卷积神经网络的基础知识才能更好的理解本节的内容. 程序的开头是导入TensorFlow: impor ...

  5. 深度学习(一)之MNIST数据集分类

    任务目标 对MNIST手写数字数据集进行训练和评估,最终使得模型能够在测试集上达到\(98\%\)的正确率.(最终本文达到了\(99.36\%\)) 使用的库的版本: python:3.8.12 py ...

  6. 第十三节,使用带有全局平均池化层的CNN对CIFAR10数据集分类

    这里使用的数据集仍然是CIFAR-10,由于之前写过一篇使用AlexNet对CIFAR数据集进行分类的文章,已经详细介绍了这个数据集,当时我们是直接把这些图片的数据文件下载下来,然后使用pickle进 ...

  7. 神经网络MNIST数据集分类tensorboard

    今天分享同样数据集的CNN处理方式,同时加上tensorboard,可以看到清晰的结构图,迭代1000次acc收敛到0.992 先放代码,注释比较详细,变量名字看单词就能知道啥意思 import te ...

  8. python,tensorflow,CNN实现mnist数据集的训练与验证正确率

    1.工程目录 2.导入data和input_data.py 链接:https://pan.baidu.com/s/1EBNyNurBXWeJVyhNeVnmnA 提取码:4nnl 3.CNN.py i ...

  9. MNIST数据集分类简单版本

      import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 mnist = ...

  10. 6.MNIST数据集分类简单版本

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 载入数据集 mnist = i ...

随机推荐

  1. P5733 【深基6.例1】自动修正

    1.题目介绍 2. 题解 2.1 字符串大小写转换 思路 str[i] -= 'a' -'A'; 注意这里转换方式,即减去偏移量(ASCII码表中,'a'在'A'前面,如果记不得偏移量,就直接写'a' ...

  2. Redis之入门概括与指令

    Redis特点(AP模型,优先保证可用,不会管数据丢失): 快的原因: 基于内存操作,操作不需要跟磁盘交互 k-v结构,类似与hashMap,所以查询速度非常快,接近O(1). 底层数据结构是有如:跳 ...

  3. [转帖]深入JVM - Code Cache内存池

    深入JVM - Code Cache内存池 1. 本文内容 本文简要介绍JVM的 Code Cache(本地代码缓存池). 2. Code Cache 简要介绍 简单来说,JVM会将字节码编译为本地机 ...

  4. Oceanbase开源版 数据库恢复MySQL数据库的过程

    # Oceanbase开源版 数据库恢复MySQL数据库的过程 背景 想进行一下Oceanbase数据库的兼容性验证. 想着用app create 数据库的方式周期比较长. 所以我想着换一套 备份恢复 ...

  5. [转帖]SQL SERVER DBCC命令详解

    https://developer.aliyun.com/article/867768   简介: SQL数据库开发 DBCC DROPCLEANBUFFERS:从缓冲池中删除所有缓存,清除缓冲区 在 ...

  6. [转帖]elasticsearch-create-enrollment-tokenedit

    https://www.elastic.co/guide/en/elasticsearch/reference/current/create-enrollment-token.html The ela ...

  7. [转帖]通过配置优化KingbaseES服务器性能

    目录 1. 概述 2. 数据库应用类型 3. 服务器参数 3.1. max_connections 3.2. shared_buffers 3.3. effective_cache_size 3.4. ...

  8. [转帖]mysql百万级性能瓶颈-数据库选型

    项目中使用了mysql数据库,但数据量增长太快,不久到了百万级,很快又到表到了千万级,尝试了各种优化方式,最终效果仍难达到秒级响应,那么引发了我关于数据库选型到一些思考. 1.mysql的单表性能瓶颈 ...

  9. 玩一玩 golang 汇编(二)

    作者:张富春(ahfuzhang),转载时请注明作者和引用链接,谢谢! cnblogs博客 zhihu Github 公众号:一本正经的瞎扯 上次玩 golang 汇编是使用了一个 python 的 ...

  10. python.exe和pythonw.exe的区别(区分.py、.pyw、.pyc、.pyo文件)

    python和pythonw 在Windows系统搭建好Python的环境后,进入Python的安装目录,大家会发现目录中有python.exe和pythonw.exe两个程序.如下图所示: 它们到底 ...