第一个CNN代码,暂时对于CNN的BP还不熟悉。但是通过这个代码对于tensorflow的运行机制有了初步的理解

 '''
softmax classifier for mnist created on 2019.9.28
author: vince
'''
import math
import logging
import numpy
import random
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
from sklearn.metrics import accuracy_score def weight_bais_variable(shape):
init = tf.random.truncated_normal(shape = shape, stddev = 0.01);
return tf.Variable(init); def bais_variable(shape):
init = tf.constant(0.1, shape=shape);
return tf.Variable(init); def conv2d(x, w):
return tf.nn.conv2d(x, w, [1, 1, 1, 1], padding = "SAME"); def max_pool_2x2(x):
return tf.nn.max_pool2d(x, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = "SAME"); def cnn(x, rate):
with tf.name_scope('reshape'):
x_image = tf.reshape(x, [-1, 28, 28, 1]); #first layer, conv & pool
with tf.name_scope('conv1'):
w_conv1 = weight_bais_variable([5, 5, 1, 32]);
b_conv1 = bais_variable([32]);
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1); #28 * 28 * 32
with tf.name_scope('pool1'):
h_pool1 = max_pool_2x2(h_conv1); #14 * 14 * 32 #second layer, conv & pool
with tf.name_scope('conv2'):
w_conv2 = weight_bais_variable([5, 5, 32, 64]);
b_conv2 = bais_variable([64]);
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2); #14 * 14 * 64
with tf.name_scope('pool2'):
h_pool2 = max_pool_2x2(h_conv2); #7 * 7 * 64 #first full connect layer, feature graph -> feature vector
with tf.name_scope('fc1'):
w_fc1 = weight_bais_variable([7 * 7 * 64, 1024]);
b_fc1 = bais_variable([1024]);
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]);
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1);
with tf.name_scope("dropout1"):
h_fc1_drop = tf.nn.dropout(h_fc1, rate); #second full connect layer,
with tf.name_scope('fc2'):
w_fc2 = weight_bais_variable([1024, 10]);
b_fc2 = bais_variable([10]);
#h_fc2 = tf.matmul(h_fc1_drop, w_fc2) + b_fc2;
h_fc2 = tf.matmul(h_fc1, w_fc2) + b_fc2;
return h_fc2; def main():
logging.basicConfig(level = logging.INFO,
format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
datefmt = '%a, %d %b %Y %H:%M:%S'); mnist = read_data_sets('../data/MNIST',one_hot=True) # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签 x = tf.placeholder(tf.float32, [None, 784]);
y_real = tf.placeholder(tf.float32, [None, 10]);
rate = tf.placeholder(tf.float32); y_pre = cnn(x, rate); sess = tf.InteractiveSession();
sess.run(tf.global_variables_initializer()); loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = y_pre, labels = y_real));
train_op = tf.train.GradientDescentOptimizer(0.5).minimize(loss); correct_prediction = tf.equal(tf.argmax(y_pre, 1), tf.argmax(y_real, 1));
prediction_op= tf.reduce_mean(tf.cast(correct_prediction, tf.float32));
for _ in range(300):
batch_xs, batch_ys = mnist.train.next_batch(128);
sess.run(train_op, feed_dict = {x : batch_xs, y_real : batch_ys, rate: 0.5});
if _ % 10 == 0:
accuracy = sess.run(prediction_op, feed_dict = {x : mnist.test.images, y_real : mnist.test.labels, rate: 0.0 });
logging.info("%s : %s" % (_, accuracy)); if __name__ == "__main__":
main();

使用tensorflow实现cnn进行mnist识别的更多相关文章

  1. 使用tensorflow的softmax进行mnist识别

    tensorflow真是方便,看来深度学习需要怎么使用框架.如何建模- ''' softmax classifier for mnist created on 2019.9.28 author: vi ...

  2. Tensorflow搭建CNN实现验证码识别

    完整代码:GitHub 我的简书:Awesome_Tang的简书 整个项目代码分为三部分: Generrate_Captcha: 生成验证码图片(训练集,验证集和测试集): 读取图片数据和标签(标签即 ...

  3. 机器学习: Tensor Flow with CNN 做表情识别

    我们利用 TensorFlow 构造 CNN 做表情识别,我们用的是FER-2013 这个数据库, 这个数据库一共有 35887 张人脸图像,这里只是做一个简单到仿真实验,为了计算方便,我们用其中到 ...

  4. Tensorflow实践:CNN实现MNIST手写识别模型

    前言 本文假设大家对CNN.softmax原理已经比较熟悉,着重点在于使用Tensorflow对CNN的简单实践上.所以不会对算法进行详细介绍,主要针对代码中所使用的一些函数定义与用法进行解释,并给出 ...

  5. 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识

    深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...

  6. 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别

    深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...

  7. [Python]基于CNN的MNIST手写数字识别

    目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...

  8. TensorFlow 入门之手写识别CNN 三

    TensorFlow 入门之手写识别CNN 三 MNIST 卷积神经网络 Fly 多层卷积网络 多层卷积网络的基本理论 构建一个多层卷积网络 权值初始化 卷积和池化 第一层卷积 第二层卷积 密集层连接 ...

  9. TensorFlow 入门之手写识别(MNIST) softmax算法

    TensorFlow 入门之手写识别(MNIST) softmax算法 MNIST flyu6 softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...

随机推荐

  1. AE脚本:把SubRip/SRT/TXT/VTT字幕导入到AE

    脚本介绍 如果您需要在视频中嵌入字幕以进行网络或磁带传送,那么这个脚本则非常有用.可以将SubRip/SRT/TXT/VTT字幕格式文件通过 pt_ImportSubtitles脚本直接加载到AE软件 ...

  2. mysql 存储过程 执行存储过程修改了表中所有行的信息

    存储过程中的where条件语句,如果传入的参数和表字段名相同,存储过程就会把这个约束条件忽略.小结:存储过程中传递的参数名不要和字段名相同.特别是修改.删除等操作,可能会对整张表产生影响.后果会很严重 ...

  3. java网络编程——socket实现简单的CS会话

    还记得当年学计网课时用python实现一个简单的CS会话功能,这也是学习socket网络编程的第一步,现改用java重新实现,以此记录. 客户端 import java.io.*; import ja ...

  4. 零基础JavaScript编码(二)

    任务目的 在上一任务基础上继续JavaScript的体验 学习JavaScript中的if判断语法,for循环语法 学习JavaScript中的数组对象 学习如何读取.处理数据,并动态创建.修改DOM ...

  5. Spring Boot从入门到精通(七)集成Redis实现Session共享

    单点登录(SSO)是指在多个应用系统中,登录用户只需要登录验证一次就可以访问所有相互信任的应用系统,Redis Session共享是实现单点登录的一种方式.本文是通过Spring Boot框架集成Re ...

  6. PHP5.6.23+Apache2.4.20+Eclipse for PHP 4.5开发环境配置

    一.Apache配置(以httpd-2.4.20-x64-vc14.zip为例)(http://www.apachelounge.com/download/) 1.安装运行库vc11和vc14 2.解 ...

  7. 【Java必修课】判断String是否包含子串的四种方法及性能对比

    1 简介 判断一个字符串是否包含某个特定子串是常见的场景,比如判断一篇文章是否包含敏感词汇.判断日志是否有ERROR信息等.本文将介绍四种方法并进行性能测试. 2 四种方法 2.1 JDK原生方法St ...

  8. 必备技能五、router路由钩子

    在路由跳转的时候,我们需要一些权限判断或者其他操作.这个时候就需要使用路由的钩子函数. 定义:路由钩子主要是给使用者在路由发生变化时进行一些特殊的处理而定义的函数. 总体来讲vue里面提供了三大类钩子 ...

  9. ZXingObjC直接引用第三方工程使用方法

    1.下载ZXingObjc压缩包,解压缩. 2.将文件拷贝到项目工程目录下 3.到工程目录中ZXingObjc文件夹中将ZXing的执行文件拖拽到项目中. 4.点击项目targets ——>Bu ...

  10. TEA5676 + AT24C08 FM收音机 搜台 存台 mmap 实现读写

    硬件说明TEA5767 + AT24c08 要使用耳机收听,不加功放芯片,声音非常小. 这2个芯片都支持 3.3 或 5.0 电源支持连线比较简单,sda scl 接到 2440 对应的 排针上,找出 ...