【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)
主要内容:
1.基于CNN的mnist手写数字识别(详细代码注释)
2.该实现中的函数总结
平台:
1.windows 10 64位
2.Anaconda3-4.2.0-Windows-x86_64.exe (当时TF还不支持python3.6,又懒得在高版本的anaconda下配置多个Python环境,于是装了一个3-4.2.0(默认装python3.5),建议装anaconda3的最新版本,TF1.2.0版本已经支持python3.6!)
3.TensorFlow1.1.0
CNN的介绍可以看:
https://en.wikipedia.org/wiki/Convolutional_neural_network
http://cs231n.github.io/convolutional-networks/
这里用的CNN结构是: 输入层-C1-P1-C2-P2-FC1-Dropout-FC2-softmax(输出层)
代码:
# -*- coding: utf-8 -*-
"""
Created on Mon Jun 12 16:36:43 2017
@author: ASUS
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/', one_hot = True) # mnist是一个tensorflow内部的变量
sess = tf.InteractiveSession() # 创建 一个会话
# 权值初始化函数,用截断的正态分布,两倍标准差之外的被截断
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev = 0.1)
return tf.Variable(initial)
# 偏置初始化函数,偏置初始为0.1
def bias_variable(shape):
initial = tf.constant(0.1, shape = shape)
return tf.Variable(initial)
# 定义卷积方式,步长是1111,padding的SAME是使得特征图与输入图大小一致
def conv2d(x,W):
return tf.nn.conv2d(x, W, strides = [1, 1, 1, 1], padding ='SAME')
# 定义池化方式,采用最大池化
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1],padding='SAME')
# 定义占位符
x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
# 1D向量(1,784)转2D(28,28)
x_image = tf.reshape(x, [-1,28,28,1]) # -1 表示样本数量不固定
#---------------第1/4步:定义算法公式-------------------
# 定义 卷积层 conv1
W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)
# 定义 卷积层 conv2
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)
#定义 全连接层 fc1
W_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) # 将tensor拉成向量
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
# 定义Dropout层
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
# 定义 Softmax层
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
#---------------第2/4步:定义loss和优化器-------------------
# 定义loss 和 参数优化器
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv), reduction_indices = [1])) # -sigma y_ * log(y)
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
# 准确率验证
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#---------------第3/4步:训练步骤-------------------
# 训练
tf.global_variables_initializer().run()
for i in range(2000):
batch = mnist.train.next_batch(100)
if i%100 ==0:
train_accuracy = accuracy.eval(feed_dict= {x: batch[0], y_: batch[1], keep_prob:1.0})
print('step %d, training accuracy %g' %(i, train_accuracy))
train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob:0.5})
#---------------第4/4步:测试集上评估模型-------------------
# 在验证阶段可能出先一个问题就是GPU内存不够的问题,这里是整个test输入,进行计算
# GPU内存不够大的话,就会出错(我的GTX 960m,2G)顶不住啊! 所以要分batch的进行
# 这里是输入整个test集的
# print('test accuracy %g ' % accuracy.eval(feed_dict = {
# x: mnist.test.images,
# y_:mnist.test.labels, keep_prob:1.0}))
# 这里是分batch验证的
accuracy_sum = tf.reduce_sum(tf.cast(correct_prediction, tf.float32))
good = 0
total = 0
for i in range(2):
testSet = mnist.test.next_batch(100)
if i ==1 : print(testSet[0].shape[0])
good += accuracy_sum.eval(feed_dict = { x: testSet[0], y_: testSet[1], keep_prob: 1.0})
total += testSet[0].shape[0] # testSet[0].shape[0] 是本batch有的样本数量
print("test accuracy %g"%(good/total))
这里面出了个小问题就是在测试阶段,书上直接把整个test集放进去了,而我的GPU内存不够大,导致出错。所以这里采用了分batch的方法进行测试,大家可以试一下整个test集放进去测试会出现什么情况。
**
函数总结(续上篇)
**:
1. sess = tf.InteractiveSession() 将sess注册为默认的session
2. tf.placeholder() , Placeholder是输入数据的地方,也称为占位符,通俗的理解就是给输入数据(此例中的图片x)和真实标签(y_)提供一个入口,或者是存放地。(个人理解,可能不太正确,后期对TF有深入认识的话再回来改~~)
3. tf.Variable() Variable是用来存储模型参数,与存储数据的tensor不同,tensor一旦使用掉就消失
4. tf.matmul() 矩阵相乘函数
5. tf.reduce_mean 和tf.reduce_sum 是缩减维度的计算均值,以及缩减维度的求和
6. tf.argmax() 是寻找tensor中值最大的元素的序号 ,此例中用来判断类别
7. tf.cast() 用于数据类型转换
————————————–我是分割线(一)———————————–
tf.random_uniform 生成均匀分布的随机数
tf.train.AdamOptimizer() 创建优化器,优化方法为Adam(adaptive moment estimation,Adam优化方法根据损失函数对每个参数的梯度的一阶矩估计和二阶矩估计动态调整针对于每个参数的学习速率)
tf.placeholder “占位符”,只要是对网络的输入,都需要用这个函数这个进行“初始化”
tf.random_normal 生成正态分布
tf.add 和 tf.matmul 数据的相加 、相乘
tf.reduce_sum 缩减维度的求和
tf.pow 求幂函数
tf.subtract 数据的相减
tf.global_variables_initializer 定义全局参数初始化
tf.Session 创建会话.
tf.Variable 创建变量,是用来存储模型参数的变量。是有别于模型的输入数据的
tf.train.AdamOptimizer (learning_rate = 0.001) 采用Adam进行优化,学习率为 0.001
————————————–我是分割线(二)———————————–
1. hidden1_drop = tf.nn.dropout(hidden1, keep_prob) 给 hindden1层增加Droput,返回新的层hidden1_drop,keep_prob是 Droput的比例
2. mnist.train.next_batch() 来详细讲讲 这个函数。一句话概括就是,打乱样本顺序,然后按顺序读取batch_size 个样本 进行返回。
具体看代码及其注释,首先要找到函数定义,在tensorflow\contrib\learn\python\learn\datasets 下的mnist.py
————————————–我是分割线(三)———————————–
1. tf.nn.conv2d(x, W, strides = [1, 1, 1, 1], padding =’SAME’)对于这个函数主要理解 strides和padding,首先明确,x是输入,W是卷积核,并且它们的维数都是4(发现strides里有4个元素没,没错!就是一一对应的)
先说一下卷积核W也是一个四维张量,各维度表示的信息是:[filter_height, filter_width, in_channels, out_channels]
输入x,x是一个四维张量 ,各维度表示的信息是:[batch, in_height, in_width, in_channels]
strides里的每个元素就是对应输入x的四个维度的步长,因为第2,3维是图像的长和宽,所以平时用的strides就在这里设置,而第1,4维一般不用到,所以是1
padding只有两种取值方式,一个是 padding=[‘VALID’] 一个是padding=[‘SAME’]
valid:采用丢弃的方式,只要移动一步时,最右边有超出,则这一步不移动,并且剩余的进行丢弃。如下图,图片长13,卷积核长6,步长是5,当移动一步之后,已经卷积核6-11,再移动一步,已经没有足够的像素点了,所以就不能移动,因此 12,13被丢弃。
same:顾名思义,就是保持输入的大小不变,方法是在图像边缘处填充全0的像素
【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)的更多相关文章
- TensorFlow卷积神经网络实现手写数字识别以及可视化
边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...
- 用Keras搭建神经网络 简单模版(三)—— CNN 卷积神经网络(手写数字图片识别)
# -*- coding: utf-8 -*- import numpy as np np.random.seed(1337) #for reproducibility再现性 from keras.d ...
- 基于卷积神经网络的手写数字识别分类(Tensorflow)
import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...
- TensorFlow(十):卷积神经网络实现手写数字识别以及可视化
上代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = inpu ...
- 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)
莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...
- TensorFlow 之 手写数字识别MNIST
官方文档: MNIST For ML Beginners - https://www.tensorflow.org/get_started/mnist/beginners Deep MNIST for ...
- 利用c++编写bp神经网络实现手写数字识别详解
利用c++编写bp神经网络实现手写数字识别 写在前面 从大一入学开始,本菜菜就一直想学习一下神经网络算法,但由于时间和资源所限,一直未展开比较透彻的学习.大二下人工智能课的修习,给了我一个学习的契机. ...
- BP神经网络的手写数字识别
BP神经网络的手写数字识别 ANN 人工神经网络算法在实践中往往给人难以琢磨的印象,有句老话叫“出来混总是要还的”,大概是由于具有很强的非线性模拟和处理能力,因此作为代价上帝让它“黑盒”化了.作为一种 ...
- keras框架的MLP手写数字识别MNIST,梳理?
keras框架的MLP手写数字识别MNIST 代码: # coding: utf-8 # In[1]: import numpy as np import pandas as pd from kera ...
- 第二节,TensorFlow 使用前馈神经网络实现手写数字识别
一 感知器 感知器学习笔记:https://blog.csdn.net/liyuanbhu/article/details/51622695 感知器(Perceptron)是二分类的线性分类模型,其输 ...
随机推荐
- [SDOI2011] 消防 (树的直径,尺取法)
题目链接 Solution 同 \(NOIP2007\) 树网的核 . 令 \(dist_u\) 为以 \(u\) 为根节点的子树中与 \(u\) 的最大距离. \(~~~~dis_u\) 为 \(u ...
- 表单编码 appliation/x-www-form-urlencoded 与 multipart/form-data 的区别
当表单使用POST方法时,表单数据提交到服务器端之前有两种编码类型可供选择.默认编码类型为 application/x-www-form-urlencoded,此时所有非字母数字类型的字符都需要转换为 ...
- nginx 变量 + lua
nginx变量使用方法详解(8) nil.null与ngx.null 发现一个nginx LUA开发Web App的框架 nginx是个好东西, nginx的openrtsy发行版本更是个好东西. 今 ...
- javaweb学习总结(九)—— 通过Servlet生成验证码图片(转)
(每天都会更新至少一篇以上,有兴趣的可以关注)转载自孤傲苍狼 一.BufferedImage类介绍 生成验证码图片主要用到了一个BufferedImage类,如下:
- C++拷贝(复制)构造函数详解
原文:http://blog.csdn.net/lwbeyond/article/details/6202256/[侵删] 一. 什么是拷贝构造函数 首先对于普通类型的对象来说,它们之间的复制是很简单 ...
- Lua中闭包详解 来自RingOfTheC[ring.of.the.c@gmail.com]
这些东西是平时遇到的, 觉得有一定的价值, 所以记录下来, 以后遇到类似的问题可以查阅, 同时分享出来也能方便需要的人, 转载请注明来自RingOfTheC[ring.of.the.c@gmail.c ...
- LeetCode OJ--Insert Interval **
https://oj.leetcode.com/problems/insert-interval/ 给出有序的区间来,再插入进去一个,也就是区间合并. 刚开始确立了几个思路,看要插入的区间的start ...
- 美图秀秀web开发文档
Xiuxiu 组件 import React, { Component } from 'react'; class XiuXiu extends Component { componentDidMou ...
- PythonWeb开发教程(二),搭建第一个django项目
这篇写怎么创建django项目,以及把django项目运行起来. 1.创建django项目 a.使用命令创建,安装完django之后就有django-admin命令了,执行命令创建即可,命令如下: ...
- linux sed 替换(整行替换,部分替换)、删除delete、新增add、选取
sed命令行格式为: sed [-nefri] ‘command’ 输入文本 常用选项: -n∶使用安静(silent)模式.在一般 sed 的用法中,所有来自 STDI ...