1 为什么使用卷积神经网络

Softmax回归是一个比较简单的模型,预测的准确率在91%左右,而使用卷积神经网络将预测的准确率提高到99%。

2 卷积网络的流程

3 代码展示

# -*- coding: utf-8 -*-

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #读入数据
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
#x为训练图像的占位符,y_为训练图像标签的占位符
x = tf.placeholder(tf.float32,[None,784])
y_ = tf.placeholder(tf.float32,[None,10]) #将单张图片从784维向量重新还原为28*28的矩阵图片
x_image = tf.reshape(x,[-1,28,28,1]) #-1 表示任意的数,由实际输入的图像个数决定 # 定义卷积过程中用到的函数
def weight_variable(shape):
initial = tf.truncated_normal(shape,stddev=0.1) #产生正太分布
return tf.Variable(initial) def bias_variable(shape):
initial = tf.constant(0.1,shape=shape)
return tf.Variable(initial) def conv2d(x,w):
return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding="SAME") def max_pool_2x2(x):
return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME") # 第一层卷积
w_conv1 = weight_variable([5,5,1,32])
b_conv1 = bias_variable([32])
h_conv1 = tf.nn.relu(conv2d(x_image,w_conv1)+b_conv1)
h_pool1 = max_pool_2x2(h_conv1) # 第二层卷积
w_conv2 = weight_variable([5,5,32,64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1,w_conv2)+b_conv2)
h_pool2 = max_pool_2x2(h_conv2) # 第一层全连接层,输出1024维的向量
w_fc1 = weight_variable([7*7*64,1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1)+b_fc1)
#使用Dropout ,keep_prob 是一个占位符,训练是0.5,测试时为1
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1,keep_prob) # 第二层全连接层,输出1024维的向量
w_fc2 = weight_variable([1024,10])
b_fc2 = bias_variable([10])
y_conv = tf.matmul(h_fc1_drop,w_fc2)+b_fc2 # 不采用先softmax再计算交叉熵的办法
#采用tf.nn.softmax_cross_entropy_with_logits直接计算
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_,logits=y_conv))
#定义train_step
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) #定义准确率
correct_prediction = tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) # 训练
# 创建Session,对变量初始化
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer()) #训练2000步
for i in range(2000):
batch = mnist.train.next_batch(50)
# 每一百步报告一次在验证集上的准确率
if i % 100 == 0 :
train_accuracy = accuracy.eval(feed_dict={x:batch[0],y_:batch[1],keep_prob:1})
print("step %d,training accuracy %g" % (i,train_accuracy))
train_step.run(feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5}) # 训练结束后报告在测试集上的准确率
print("test_accuracy %g" % accuracy.eval(feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0}))

4 补充

步长stride是一个一维的向量,长度为4。形式是[a,x,y,z],分别代表[batch滑动步长,水平滑动步长,垂直滑动步长,通道滑动步长]。在tensorflow中,stride的一般形式是[1,x,y,1]

  • 第一个1表示:在batch维度上的滑动步长为1,即不跳过任何一个样本
  • x表示:卷积核的水平滑动步长
  • y表示:卷积核的垂直滑动步长
  • 最后一个1表示:在通道维度上的滑动步长为1,即不跳过任何一个颜色通道

利用TensorFlow识别手写的数字---基于两层卷积网络的更多相关文章

  1. 利用TensorFlow识别手写的数字---基于Softmax回归

    1 MNIST数据集 MNIST数据集主要由一些手写数字的图片和相应的标签组成,图片一共有10类,分别对应从0-9,共10个阿拉伯数字.原始的MNIST数据库一共包含下面4个文件,见下表. 训练图像一 ...

  2. 一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)

    笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&a ...

  3. 【转】机器学习教程 十四-利用tensorflow做手写数字识别

    模式识别领域应用机器学习的场景非常多,手写识别就是其中一种,最简单的数字识别是一个多类分类问题,我们借这个多类分类问题来介绍一下google最新开源的tensorflow框架,后面深度学习的内容都会基 ...

  4. 07 训练Tensorflow识别手写数字

    打开Python Shell,输入以下代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input ...

  5. 利用Tensorflow实现手写字符识别

    模式识别领域应用机器学习的场景非常多,手写识别就是其中一种,最简单的数字识别是一个多类分类问题,我们借这个多类分类问题来介绍一下google最新开源的tensorflow框架,后面深度学习的内容都会基 ...

  6. TensorFlow下利用MNIST训练模型并识别自己手写的数字

    最近一直在学习李宏毅老师的机器学习视频教程,学到和神经网络那一块知识的时候,我觉得单纯的学习理论知识过于枯燥,就想着自己动手实现一些简单的Demo,毕竟实践是检验真理的唯一标准!!!但是网上很多的与t ...

  7. OpenCV+TensorFlow图片手写数字识别(附源码)

    初次接触TensorFlow,而手写数字训练识别是其最基本的入门教程,网上关于训练的教程很多,但是模型的测试大多都是官方提供的一些素材,能不能自己随便写一串数字让机器识别出来呢?纸上得来终觉浅,带着这 ...

  8. 3 TensorFlow入门之识别手写数字

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  9. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

随机推荐

  1. day24 面向对象设计part1

    #!/usr/bin/env python # -*- coding:utf-8 -*- # ----------------------------------------------------- ...

  2. Django2.0+小程序技术打造微信小程序助手✍✍✍

    Django2.0+小程序技术打造微信小程序助手  整个课程都看完了,这个课程的分享可以往下看,下面有链接,之前做java开发也做了一些年头,也分享下自己看这个视频的感受,单论单个知识点课程本身没问题 ...

  3. VMware Workstation 10 简体中文安装教程

    分享到 一键分享 QQ空间 新浪微博 百度云收藏 人人网 腾讯微博 百度相册 开心网 腾讯朋友 百度贴吧 豆瓣网 搜狐微博 百度新首页 QQ好友 和讯微博 更多... 百度分享 分享到 一键分享 QQ ...

  4. springcloud系列14 bus的使用

    首先springcloud_bus原理: (1)完整流程:发送端(endpoint)构造事件event,将其publish到context上下文中(spring cloud bus有一个父上下文,bo ...

  5. 通过ID获取元素 注:获取的元素是一个对象,如想对元素进行操作,我们要通过它的属性或方法。

    通过ID获取元素 学过HTML/CSS样式,都知道,网页由标签将信息组织起来,而标签的id属性值是唯一的,就像是每人有一个身份证号一样,只要通过身份证号就可以找到相对应的人.那么在网页中,我们通过id ...

  6. RPM包安装MySQL 5.7.18

    系统: CentOS 7 RPM包: mysql-community-client-5.7.18-1.el7.x86_64.rpm mysql-community-server-5.7.18-1.el ...

  7. springboot中activeMQ消息队列的引入与使用(发送短信)

    1.引入pom依赖 <!--activemq--><dependency> <groupId>org.springframework.boot</groupI ...

  8. HTML - head标签相关

    <html> <!-- head标签中主要配置浏览器的配置信息 --> <head> <!-- 网页标题标签, 用来指定网页的标题 --> <ti ...

  9. 廖雪峰Java13网络编程-2Email编程-1发送email

    1.邮件发送 1.1传统邮件发送: 传统的邮件是通过邮局投递,从一个邮局到另一个邮局,最终到达用户的邮箱. 1.2电子邮件发送: 与传统邮件类似,它是从用户电脑的邮件软件(如outlook)发送到邮件 ...

  10. 移动端自定义输入框的vue组件 ----input

    <style scoped lang="less"> .keyboard { font-family: -apple-system, BlinkMacSystemFon ...