import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True) def compute_accuracy(v_xs,v_ys):
global prediction
y_pre=sess.run(prediction,feed_dict={xs:v_xs,keep_prob:})
correct_prediction=tf.equal(tf.argmax(y_pre,),tf.argmax(v_ys,))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
result=sess.run(accuracy,feed_dict={xs:v_xs,ys:v_ys})
return result def weight_varirable(shape):
inital=tf.truncated_normal(shape,stddev=0.1)
return tf.Variable(inital) def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial) def conv2d(x,W):
return tf.nn.conv2d(x,W,strides=[,,,],padding='SAME') def max_poo_(x):
return tf.nn.max_pool(x, ksize=[,,,], strides=[,,,], padding='SAME') xs=tf.placeholder(tf.float32,[None,])
ys=tf.placeholder(tf.float32,[None,])
keep_prob=tf.placeholder(tf.float32) x_image=tf.reshape(xs,[-,,,]) W_conv1=weight_varirable([,,,])
b_conv1=bias_variable([])
h_conv1=tf.nn.relu(conv2d(x_image,W_conv1)+b_conv1)
h_pool1=max_poo_(h_conv1) W_conv2=weight_varirable([,,,])
b_conv2=bias_variable([])
h_conv2=tf.nn.relu(conv2d(h_pool1,W_conv2)+b_conv2)
h_pool2=max_poo_(h_conv2) W_fc1=weight_varirable([**,])
b_fc1=bias_variable([]) h_pool2_flat=tf.reshape(h_pool2,[-,**])
h_fcl=tf.nn.relu(tf.matmul(h_pool2_flat,W_fc1)+b_fc1)
h_fc1_drop=tf.nn.dropout(h_fcl,keep_prob) W_fc2=weight_varirable([,])
b_fc2=bias_variable([]) prediction=tf.nn.softmax(tf.matmul(h_fc1_drop,W_fc2)+b_fc2) cross_entropy=tf.reduce_mean(
-tf.reduce_sum(ys*tf.log(prediction),
reduction_indices=[])) train_step=tf.train.AdamOptimizer(1e-).minimize(cross_entropy) sess=tf.Session() sess.run(tf.global_variables_initializer()) for i in range():
batch_xs, batch_ys = mnist.train.next_batch()
sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys,keep_prob:0.5})
if i % == :
print(compute_accuracy(
mnist.test.images, mnist.test.labels))

如果有同学没有MINST数据,请到http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_download.html下载,或者QQ问我

Tensorflow通过CNN实现MINST数据分类的更多相关文章

  1. FaceRank-人脸打分基于 TensorFlow 的 CNN 模型

    FaceRank-人脸打分基于 TensorFlow 的 CNN 模型 隐私 因为隐私问题,训练图片集并不提供,稍微可能会放一些卡通图片. 数据集 130张 128*128 张网络图片,图片名: 1- ...

  2. Tensorflow简单CNN实现

    觉得有用的话,欢迎一起讨论相互学习~Follow Me 少说废话多写代码~ """转换图像数据格式时需要将它们的颜色空间变为灰度空间,将图像尺寸修改为同一尺寸,并将标签依 ...

  3. Tensorflow的CNN教程解析

    之前的博客我们已经对RNN模型有了个粗略的了解.作为一个时序性模型,RNN的强大不需要我在这里重复了.今天,让我们来看看除了RNN外另一个特殊的,同时也是广为人知的强大的神经网络模型,即CNN模型.今 ...

  4. [DL学习笔记]从人工神经网络到卷积神经网络_3_使用tensorflow搭建CNN来分类not_MNIST数据(有一些问题)

    3:用tensorflow搭个神经网络出来 为什么用tensorflow呢,应为谷歌是亲爹啊,虽然有些人说caffe更适合图像啊mxnet效率更高等等,但爸爸就是爸爸,Android都能那么火,一个道 ...

  5. 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)

    上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...

  6. [Tensorflow] Cookbook - CNN

    Convolutional Neural Networks (CNNs) are responsible for the major breakthroughs in image recognitio ...

  7. 6 TensorFlow实现cnn识别手写数字

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  8. TensorFlow实现CNN

    TensorFlow是目前深度学习最流行的框架,很有学习的必要,下面我们就来实际动手,使用TensorFlow搭建一个简单的CNN,来对经典的mnist数据集进行数字识别. 如果对CNN还不是很熟悉的 ...

  9. tensorflow构建CNN模型时的常用接口函数

    (1)tf.nn.max_pool()函数 解释: tf.nn.max_pool(value, ksize, strides, padding, data_format='NHWC', name=No ...

随机推荐

  1. 微信小程序 tabBar模板

    tabBar导航栏 小程序tabBar,我们可以通过app.json进行配置,可以放置于顶部或者底部,用于不同功能页面的切换,挺好的... 但,,,貌似不能让动态修改tabBar(需求:通过switc ...

  2. angularJS ng-repeat="item in XXX track by $index"问题记录

    参考:https://blog.csdn.net/lunhui1994_/article/details/80236315 问题:项目中对数据做了分页效果,理想是:当页数大于6时,隐藏>6的页数 ...

  3. Geoserver的跨域问题

    使用tomcat访问Geoserver服务的时候,只调服务没问题,但是查询要素属性的时候出现如下“XMLHttpRequest”.“not allowed by Access-Control-Allo ...

  4. springboot整合jpa和mybatis实现主从复制

    百度多方参考终于配出我自己的了,以下仅供参考 参考https://www.cnblogs.com/cjsblog/p/9712457.html 代码 首先数据源配置 spring.datasource ...

  5. 关于Windows10企业版的激活方法

    今天打开Excel在使用的时候,突然弹出弹窗,说我的激活即将过期什么的,让我转到设置进行激活. 第一个想到的办法就是更换产品密钥,在网上找了不少产品密钥,密钥有效,但是需要连接企业激活什么的,因为我是 ...

  6. QSerialPort类

    一.简介     QSerialPort类是Qt5封装的串口类,可以与串口进行通信.QSerialPortInfo是一个辅助类,提供串口的一些信息,如可用的串口名称,描述,制造商,序列号,串口16位产 ...

  7. MSSQLSERVER跨服务器连接(远程登录)的示例代码

    MSSQLSERVER跨服务器链接服务器创建方法如下 复制代码 代码如下: --声明变量 Declare @svrname varchar(255), @dbname varchar(255), @s ...

  8. csps模拟67神炎皇,降雷皇,幻魔皇题解

    题面:https://www.cnblogs.com/Juve/articles/11648975.html 神炎皇: 打表找规律?和$\phi$有关? 答案就是$\sum\limits_{i=2}^ ...

  9. C#一般处理程序设置和读取session(session报错“未将对象引用设置到对象的实例”解决)

    登陆模块时,用到了session和cookie.在一般处理程序中处理session,一直报错.最后找到问题原因是需要调用 irequiressessionstate接口. 在ashx文件中,设置ses ...

  10. C++ 函数模板&类模板详解

    在 C++ 中,模板分为函数模板和类模板两种.函数模板是用于生成函数的,类模板则是用于生成类的. 函数模板&模板函数     类模板&模板类  必须区分概念 函数模板是模板,模板函数时 ...