第一个CNN代码,暂时对于CNN的BP还不熟悉。但是通过这个代码对于tensorflow的运行机制有了初步的理解

 '''
softmax classifier for mnist created on 2019.9.28
author: vince
'''
import math
import logging
import numpy
import random
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
from sklearn.metrics import accuracy_score def weight_bais_variable(shape):
init = tf.random.truncated_normal(shape = shape, stddev = 0.01);
return tf.Variable(init); def bais_variable(shape):
init = tf.constant(0.1, shape=shape);
return tf.Variable(init); def conv2d(x, w):
return tf.nn.conv2d(x, w, [1, 1, 1, 1], padding = "SAME"); def max_pool_2x2(x):
return tf.nn.max_pool2d(x, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = "SAME"); def cnn(x, rate):
with tf.name_scope('reshape'):
x_image = tf.reshape(x, [-1, 28, 28, 1]); #first layer, conv & pool
with tf.name_scope('conv1'):
w_conv1 = weight_bais_variable([5, 5, 1, 32]);
b_conv1 = bais_variable([32]);
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1); #28 * 28 * 32
with tf.name_scope('pool1'):
h_pool1 = max_pool_2x2(h_conv1); #14 * 14 * 32 #second layer, conv & pool
with tf.name_scope('conv2'):
w_conv2 = weight_bais_variable([5, 5, 32, 64]);
b_conv2 = bais_variable([64]);
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2); #14 * 14 * 64
with tf.name_scope('pool2'):
h_pool2 = max_pool_2x2(h_conv2); #7 * 7 * 64 #first full connect layer, feature graph -> feature vector
with tf.name_scope('fc1'):
w_fc1 = weight_bais_variable([7 * 7 * 64, 1024]);
b_fc1 = bais_variable([1024]);
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]);
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1);
with tf.name_scope("dropout1"):
h_fc1_drop = tf.nn.dropout(h_fc1, rate); #second full connect layer,
with tf.name_scope('fc2'):
w_fc2 = weight_bais_variable([1024, 10]);
b_fc2 = bais_variable([10]);
#h_fc2 = tf.matmul(h_fc1_drop, w_fc2) + b_fc2;
h_fc2 = tf.matmul(h_fc1, w_fc2) + b_fc2;
return h_fc2; def main():
logging.basicConfig(level = logging.INFO,
format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
datefmt = '%a, %d %b %Y %H:%M:%S'); mnist = read_data_sets('../data/MNIST',one_hot=True) # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签 x = tf.placeholder(tf.float32, [None, 784]);
y_real = tf.placeholder(tf.float32, [None, 10]);
rate = tf.placeholder(tf.float32); y_pre = cnn(x, rate); sess = tf.InteractiveSession();
sess.run(tf.global_variables_initializer()); loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = y_pre, labels = y_real));
train_op = tf.train.GradientDescentOptimizer(0.5).minimize(loss); correct_prediction = tf.equal(tf.argmax(y_pre, 1), tf.argmax(y_real, 1));
prediction_op= tf.reduce_mean(tf.cast(correct_prediction, tf.float32));
for _ in range(300):
batch_xs, batch_ys = mnist.train.next_batch(128);
sess.run(train_op, feed_dict = {x : batch_xs, y_real : batch_ys, rate: 0.5});
if _ % 10 == 0:
accuracy = sess.run(prediction_op, feed_dict = {x : mnist.test.images, y_real : mnist.test.labels, rate: 0.0 });
logging.info("%s : %s" % (_, accuracy)); if __name__ == "__main__":
main();

使用tensorflow实现cnn进行mnist识别的更多相关文章

  1. 使用tensorflow的softmax进行mnist识别

    tensorflow真是方便,看来深度学习需要怎么使用框架.如何建模- ''' softmax classifier for mnist created on 2019.9.28 author: vi ...

  2. Tensorflow搭建CNN实现验证码识别

    完整代码:GitHub 我的简书:Awesome_Tang的简书 整个项目代码分为三部分: Generrate_Captcha: 生成验证码图片(训练集,验证集和测试集): 读取图片数据和标签(标签即 ...

  3. 机器学习: Tensor Flow with CNN 做表情识别

    我们利用 TensorFlow 构造 CNN 做表情识别,我们用的是FER-2013 这个数据库, 这个数据库一共有 35887 张人脸图像,这里只是做一个简单到仿真实验,为了计算方便,我们用其中到 ...

  4. Tensorflow实践:CNN实现MNIST手写识别模型

    前言 本文假设大家对CNN.softmax原理已经比较熟悉,着重点在于使用Tensorflow对CNN的简单实践上.所以不会对算法进行详细介绍,主要针对代码中所使用的一些函数定义与用法进行解释,并给出 ...

  5. 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识

    深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...

  6. 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别

    深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...

  7. [Python]基于CNN的MNIST手写数字识别

    目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...

  8. TensorFlow 入门之手写识别CNN 三

    TensorFlow 入门之手写识别CNN 三 MNIST 卷积神经网络 Fly 多层卷积网络 多层卷积网络的基本理论 构建一个多层卷积网络 权值初始化 卷积和池化 第一层卷积 第二层卷积 密集层连接 ...

  9. TensorFlow 入门之手写识别(MNIST) softmax算法

    TensorFlow 入门之手写识别(MNIST) softmax算法 MNIST flyu6 softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...

随机推荐

  1. 最通俗易懂的 Java 11 新特性讲解

    大多数开发者还是沉浸在 Java 8 中,而 Java 14 将要在 2020 年 3 月 17 日发布了,而我还在写着 Java 11 的新特性.Java 11 是 Java 8 之后的第一个 LT ...

  2. 一文了解各大图数据库查询语言(Gremlin vs Cypher vs nGQL)| 操作入门篇

    文章的开头我们先来看下什么是图数据库,根据维基百科的定义:图数据库是使用图结构进行语义查询的数据库,它使用节点.边和属性来表示和存储数据. 虽然和关系型数据库存储的结构不同(关系型数据库为表结构,图数 ...

  3. Java大浮点数精度

    BigDecimal 精度问题 BigDecimal舍入模式 ROUND_DOWN 向零舍入. 即1.55 变为 1.5 , -1.55 变为-1.5 ROUND_UP 向远离0的方向舍入 即 1.5 ...

  4. Could not find a valid gem 'redis' (= 0)

    Could not find a valid gem 'redis' (= 0) 报错详情如下: ERROR: Could not find a valid gem 'redis' (>= 0) ...

  5. 026.掌握Service-外部访问

    一 集群外部访问 由于Pod和Service都是Kubernetes集群范围内的虚拟概念,所以集群外的客户端默认情况,无法通过Pod的IP地址或者Service的虚拟IP地址:虚拟端口号进行访问.通常 ...

  6. Oracle学习笔记--Oracle启动过程归纳整理

    Oracle 启动过程分为nomount状态mount状态open状态 每个状态下Oracle都会进行不同的操作:1.nomount状态 在$ORACLE_HOME/dbs目录下寻找参数文件 参数文件 ...

  7. 记录一个引用文件所有js文件的方法

    在项目api声明的时候,避免每次添加新的js都要对应去处理 首先我在项目api文件下新建一个files的文件夹,然后再api文件夹下的index.js这样写: var api = {}; const  ...

  8. java批量处理

    最近用到Java批量处理,一次性处理多个文件夹下的多个文件,在此记录一下. 我的思路:首先要保证文件夹和文件夹下的文件的命名是有规律的,利用for循环,每次自增变量,再拼接字符串,从而得到各个文件的路 ...

  9. Spring Boot 结合 Redis 序列化配置的一些问题

    前言 最近在学习Spring Boot结合Redis时看了一些网上的教程,发现这些教程要么比较老,要么不知道从哪抄得,运行起来有问题.这里分享一下我最新学到的写法 默认情况下,Spring 为我们提供 ...

  10. 聊一聊React中虚拟DOM

    1. 什么是虚拟 DOM 在 React 中实际上是 render 函数中return 的内容会生成 DOM,return 中的内容由两部分组成,一部分是 JSX ,另一部分就是 state 中的数据 ...