import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("F:\TensorflowProject\MNIST_data",one_hot=True) #每个批次大小
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples //batch_size #初始化权值
def weight_variable(shape):
initial = tf.truncated_normal(shape,stddev=0.1) #初始化一个截断的正态分布
return tf.Variable(initial) #初始化偏值
def bias_variable(shape):
initial = tf.constant(0.1,shape=shape)
return tf.Variable(initial) #卷积层
def conv2d(x,W):
#x input tensor of shape '[batch,in_height,in_width,in_channels]'
#W filter/kernel tensor of shape [filter_height,filter_width,in_channels,out_channels]
#strides[0] = strides[3] = 1, strides[1]代表x方向的步长,strides[2]代表y方向的步长
#padding:A string from :SAME 或者 VALID
return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME') #池化层
def max_pool_2x2(x):
#ksize[1,x,y,1]
return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME') #定义两个placeholder
x = tf.placeholder(tf.float32,[None,784]) #28*28
y = tf.placeholder(tf.float32,[None,10]) #设置x的格式为4D向量 [batch,in_height,in_width,in_chanels]
x_image = tf.reshape(x,[-1,28,28,1]) #初始化第一个卷积层的权值和偏值
W_conv1 = weight_variable([5,5,1,32])
b_conv1 = bias_variable([32]) #把x_image和权值向量进行卷积,再加上偏置值,然后应用于relu激活函数
h_conv1 = tf.nn.relu(conv2d(x_image,W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1) #max-pooling,经过池化计算得到一个结果 #初始化第二个卷积层的权值和偏置值
W_conv2 = weight_variable([5,5,32,64])
b_conv2 = bias_variable([64]) #把h_pool1和权值向量进行卷积,再加上偏置值,然后应用于relu激活函数
h_conv2 = tf.nn.relu(conv2d(h_pool1,W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2) #max-pooling #28*28的图片第一次卷积后还是28*28,第一次池化后为14*14
#第二次卷积后是14*14,第二次池化后为7*7
#上面步骤完成以后得到64张7*7的平面 #初始化第一个全连接层的权值
W_fc1 = weight_variable([7*7*64,1024]) #上一步有 7*7*64个神经元,全连接层有1024个神经元
b_fc1 = bias_variable([1024]) #1024个节点 #把池化层2的输出扁平化为1维
h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64])
#求第一个全连接层的输出
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat,W_fc1) + b_fc1) #keep_prob标识神经元输出概率
keep_prob =tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1,keep_prob) #初始化第二个全连接层
W_fc2 = weight_variable([1024,10])
b_fc2 = bias_variable([10]) #计算输出
prediction = tf.nn.softmax(tf.matmul(h_fc1_drop,W_fc2) + b_fc2) #交叉熵代价函数
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
#使用AdamOptimizer进行优化
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) #用布尔列表存放结果
correct_prediction = tf.equal(tf.argmax(prediction,1),tf.argmax(y,1))
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  for epoch in range(21):
    for batch in range(n_batch):
      batch_xs,batch_ys = mnist.train.next_batch(batch_size)
      sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.7})     test_acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
    print("Iter "+str(epoch)+" ,Testing Accuracy = "+str(test_acc))

##############运行结果

Iter  0  ,Testing Accuracy =  0.9552
Iter 1 ,Testing Accuracy = 0.9743
Iter 2 ,Testing Accuracy = 0.9796
Iter 3 ,Testing Accuracy = 0.9807
Iter 4 ,Testing Accuracy = 0.9849
Iter 5 ,Testing Accuracy = 0.9863
Iter 6 ,Testing Accuracy = 0.9859
Iter 7 ,Testing Accuracy = 0.9885
Iter 8 ,Testing Accuracy = 0.9887
Iter 9 ,Testing Accuracy = 0.9894
Iter 10 ,Testing Accuracy = 0.9907
Iter 11 ,Testing Accuracy = 0.991
Iter 12 ,Testing Accuracy = 0.9903
Iter 13 ,Testing Accuracy = 0.992
Iter 14 ,Testing Accuracy = 0.9904
Iter 15 ,Testing Accuracy = 0.9915
Iter 16 ,Testing Accuracy = 0.9903
Iter 17 ,Testing Accuracy = 0.9912
Iter 18 ,Testing Accuracy = 0.9917
Iter 19 ,Testing Accuracy = 0.9912
Iter 20 ,Testing Accuracy = 0.992

卷积神经网络应用于tensorflow手写数字识别(第三版)的更多相关文章

  1. MINST手写数字识别(三)—— 使用antirectifier替换ReLU激活函数

    这是一个来自官网的示例:https://github.com/keras-team/keras/blob/master/examples/antirectifier.py 与之前的MINST手写数字识 ...

  2. Tensorflow手写数字识别(交叉熵)练习

    # coding: utf-8import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data #pr ...

  3. Tensorflow手写数字识别训练(梯度下降法)

    # coding: utf-8 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data #p ...

  4. tensorflow应用于手写数字识别(第二版)

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data#载入数据集 mnist = inp ...

  5. tensorflow 手写数字识别

    https://www.kaggle.com/kakauandme/tensorflow-deep-nn 本人只是负责将这个kernels的代码整理了一遍,具体还是请看原链接 import numpy ...

  6. Tensorflow手写数字识别---MNIST

    MNIST数据集:包含数字0-9的灰度图, 图片size为28x28.训练样本:55000,测试样本:10000,验证集:5000

  7. tensorflow手写数字识别(有注释)

    import tensorflow as tf import numpy as np # const = tf.constant(2.0, name='const') # b = tf.placeho ...

  8. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  9. 基于TensorFlow的MNIST手写数字识别-初级

    一:MNIST数据集    下载地址 MNIST是一个包含很多手写数字图片的数据集,一共4个二进制压缩文件 分别是test set images,test set labels,training se ...

随机推荐

  1. MySQL存储过程例子

    -- 索引 INDEXCREATE INDEX idx_sname ON student( sname(4)); ALTER TABLE teacher add index idx_tname(tna ...

  2. SciPy 线性代数

    章节 SciPy 介绍 SciPy 安装 SciPy 基础功能 SciPy 特殊函数 SciPy k均值聚类 SciPy 常量 SciPy fftpack(傅里叶变换) SciPy 积分 SciPy ...

  3. aforge视频录像,对界面进行重绘

    由于项目需要,需要录像的时候在界面加多一个圆圈,并且一起录制下来. 只需要在NewFrame增加以下代码 private void videoSourcePlayer1_NewFrame(object ...

  4. pta 拯救007(Floyd)

    7-9 拯救007(25 分) 在老电影“007之生死关头”(Live and Let Die)中有一个情节,007被毒贩抓到一个鳄鱼池中心的小岛上,他用了一种极为大胆的方法逃脱 —— 直接踩着池子里 ...

  5. Windows 10长脸了!

    Windows 10一直毁誉参半,但是在微软的持续升级完善之下,同时随着时间的流逝,已经顺利成为全球第一大桌面操作系统,并开始逐渐甩开Windows 7,全球设备安装量已经达到约8亿部. 根据最新的S ...

  6. [题解] LuoguP2257 YY的GCD

    传送门 给\(n,m\),让你求 \[ \sum\limits_{i=1}^n \sum\limits_{j=1}^m [\gcd(i,j) \in prime] \] 有\(T\)组询问\((T \ ...

  7. Codeforces 556A:Case of the Zeros and Ones

    A. Case of the Zeros and Ones time limit per test 1 second memory limit per test 256 megabytes input ...

  8. XV6操作系统代码阅读心得(五):文件系统

    Unix文件系统 当今的Unix文件系统(Unix File System, UFS)起源于Berkeley Fast File System.和所有的文件系统一样,Unix文件系统是以块(Block ...

  9. 修改序列(Sequence)的初始值(START WITH)

    1 执行:Alter Sequence SeqTest2010_S Increment By 1007; 2 执行:Select SeqTest2010_S.NextVal From Dual; 3 ...

  10. SpringMVC原理及流程解析

    前言 春节期间宅在家里闲来无事,对SpringMVC进行了比较深入的了解,将之前模糊不清的地方基本摸索清楚了,特此撰文总结记录一下. 正文 一.一个请求为什么会调用到SpringMVC框架里? 首先问 ...