利用TensorFlow识别手写的数字---基于两层卷积网络
1 为什么使用卷积神经网络
Softmax回归是一个比较简单的模型,预测的准确率在91%左右,而使用卷积神经网络将预测的准确率提高到99%。
2 卷积网络的流程

3 代码展示
# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#读入数据
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
#x为训练图像的占位符,y_为训练图像标签的占位符
x = tf.placeholder(tf.float32,[None,784])
y_ = tf.placeholder(tf.float32,[None,10])
#将单张图片从784维向量重新还原为28*28的矩阵图片
x_image = tf.reshape(x,[-1,28,28,1]) #-1 表示任意的数,由实际输入的图像个数决定
# 定义卷积过程中用到的函数
def weight_variable(shape):
initial = tf.truncated_normal(shape,stddev=0.1) #产生正太分布
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1,shape=shape)
return tf.Variable(initial)
def conv2d(x,w):
return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding="SAME")
def max_pool_2x2(x):
return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")
# 第一层卷积
w_conv1 = weight_variable([5,5,1,32])
b_conv1 = bias_variable([32])
h_conv1 = tf.nn.relu(conv2d(x_image,w_conv1)+b_conv1)
h_pool1 = max_pool_2x2(h_conv1)
# 第二层卷积
w_conv2 = weight_variable([5,5,32,64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1,w_conv2)+b_conv2)
h_pool2 = max_pool_2x2(h_conv2)
# 第一层全连接层,输出1024维的向量
w_fc1 = weight_variable([7*7*64,1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1)+b_fc1)
#使用Dropout ,keep_prob 是一个占位符,训练是0.5,测试时为1
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1,keep_prob)
# 第二层全连接层,输出1024维的向量
w_fc2 = weight_variable([1024,10])
b_fc2 = bias_variable([10])
y_conv = tf.matmul(h_fc1_drop,w_fc2)+b_fc2
# 不采用先softmax再计算交叉熵的办法
#采用tf.nn.softmax_cross_entropy_with_logits直接计算
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_,logits=y_conv))
#定义train_step
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
#定义准确率
correct_prediction = tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
# 训练
# 创建Session,对变量初始化
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
#训练2000步
for i in range(2000):
batch = mnist.train.next_batch(50)
# 每一百步报告一次在验证集上的准确率
if i % 100 == 0 :
train_accuracy = accuracy.eval(feed_dict={x:batch[0],y_:batch[1],keep_prob:1})
print("step %d,training accuracy %g" % (i,train_accuracy))
train_step.run(feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5})
# 训练结束后报告在测试集上的准确率
print("test_accuracy %g" % accuracy.eval(feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0}))
4 补充
步长stride是一个一维的向量,长度为4。形式是[a,x,y,z],分别代表[batch滑动步长,水平滑动步长,垂直滑动步长,通道滑动步长]。在tensorflow中,stride的一般形式是[1,x,y,1]
- 第一个1表示:在batch维度上的滑动步长为1,即不跳过任何一个样本
- x表示:卷积核的水平滑动步长
- y表示:卷积核的垂直滑动步长
- 最后一个1表示:在通道维度上的滑动步长为1,即不跳过任何一个颜色通道
利用TensorFlow识别手写的数字---基于两层卷积网络的更多相关文章
- 利用TensorFlow识别手写的数字---基于Softmax回归
1 MNIST数据集 MNIST数据集主要由一些手写数字的图片和相应的标签组成,图片一共有10类,分别对应从0-9,共10个阿拉伯数字.原始的MNIST数据库一共包含下面4个文件,见下表. 训练图像一 ...
- 一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)
笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&a ...
- 【转】机器学习教程 十四-利用tensorflow做手写数字识别
模式识别领域应用机器学习的场景非常多,手写识别就是其中一种,最简单的数字识别是一个多类分类问题,我们借这个多类分类问题来介绍一下google最新开源的tensorflow框架,后面深度学习的内容都会基 ...
- 07 训练Tensorflow识别手写数字
打开Python Shell,输入以下代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input ...
- 利用Tensorflow实现手写字符识别
模式识别领域应用机器学习的场景非常多,手写识别就是其中一种,最简单的数字识别是一个多类分类问题,我们借这个多类分类问题来介绍一下google最新开源的tensorflow框架,后面深度学习的内容都会基 ...
- TensorFlow下利用MNIST训练模型并识别自己手写的数字
最近一直在学习李宏毅老师的机器学习视频教程,学到和神经网络那一块知识的时候,我觉得单纯的学习理论知识过于枯燥,就想着自己动手实现一些简单的Demo,毕竟实践是检验真理的唯一标准!!!但是网上很多的与t ...
- OpenCV+TensorFlow图片手写数字识别(附源码)
初次接触TensorFlow,而手写数字训练识别是其最基本的入门教程,网上关于训练的教程很多,但是模型的测试大多都是官方提供的一些素材,能不能自己随便写一串数字让机器识别出来呢?纸上得来终觉浅,带着这 ...
- 3 TensorFlow入门之识别手写数字
------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...
- 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字
TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...
随机推荐
- day24 面向对象设计part1
#!/usr/bin/env python # -*- coding:utf-8 -*- # ----------------------------------------------------- ...
- Django2.0+小程序技术打造微信小程序助手✍✍✍
Django2.0+小程序技术打造微信小程序助手 整个课程都看完了,这个课程的分享可以往下看,下面有链接,之前做java开发也做了一些年头,也分享下自己看这个视频的感受,单论单个知识点课程本身没问题 ...
- VMware Workstation 10 简体中文安装教程
分享到 一键分享 QQ空间 新浪微博 百度云收藏 人人网 腾讯微博 百度相册 开心网 腾讯朋友 百度贴吧 豆瓣网 搜狐微博 百度新首页 QQ好友 和讯微博 更多... 百度分享 分享到 一键分享 QQ ...
- springcloud系列14 bus的使用
首先springcloud_bus原理: (1)完整流程:发送端(endpoint)构造事件event,将其publish到context上下文中(spring cloud bus有一个父上下文,bo ...
- 通过ID获取元素 注:获取的元素是一个对象,如想对元素进行操作,我们要通过它的属性或方法。
通过ID获取元素 学过HTML/CSS样式,都知道,网页由标签将信息组织起来,而标签的id属性值是唯一的,就像是每人有一个身份证号一样,只要通过身份证号就可以找到相对应的人.那么在网页中,我们通过id ...
- RPM包安装MySQL 5.7.18
系统: CentOS 7 RPM包: mysql-community-client-5.7.18-1.el7.x86_64.rpm mysql-community-server-5.7.18-1.el ...
- springboot中activeMQ消息队列的引入与使用(发送短信)
1.引入pom依赖 <!--activemq--><dependency> <groupId>org.springframework.boot</groupI ...
- HTML - head标签相关
<html> <!-- head标签中主要配置浏览器的配置信息 --> <head> <!-- 网页标题标签, 用来指定网页的标题 --> <ti ...
- 廖雪峰Java13网络编程-2Email编程-1发送email
1.邮件发送 1.1传统邮件发送: 传统的邮件是通过邮局投递,从一个邮局到另一个邮局,最终到达用户的邮箱. 1.2电子邮件发送: 与传统邮件类似,它是从用户电脑的邮件软件(如outlook)发送到邮件 ...
- 移动端自定义输入框的vue组件 ----input
<style scoped lang="less"> .keyboard { font-family: -apple-system, BlinkMacSystemFon ...