动手实现CNN卷积神经网络
数据集采用的是手写数据集(http://yann.lecun.com/exdb/mnist/):
本文构建的CNN网络图如下:


像素点:28*28 = 784,55000张手写数字图片。


# -*- coding: UTF-8 -*- import numpy as np
import tensorflow as tf # 下载并载入 MNIST 手写数字库(55000 * 28 * 28)55000 张训练图像
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('mnist_data', one_hot=True)#将数据保存在mnist_data下 # one_hot 独热码的编码(encoding)形式
# 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 的十位数字
# 0 : 1000000000
# 1 : 0100000000
# 2 : 0010000000
# 3 : 0001000000
# 4 : 0000100000
# 5 : 0000010000
# 6 : 0000001000
# 7 : 0000000100
# 8 : 0000000010
# 9 : 0000000001 # None 表示张量(Tensor)的第一个维度可以是任何长度
# 除以 255 是为了做 归一化(Normalization),把灰度值从 [0, 255] 变成 [0, 1] 区间
# 归一话可以让之后的优化器(optimizer)更快更好地找到误差最小值
input_x = tf.placeholder(tf.float32, [None, 28 * 28]) / 255. # 输入 output_y = tf.placeholder(tf.int32, [None, 10]) # 输出:10个数字的标签 # -1 表示自动推导维度大小。让计算机根据其他维度的值
# 和总的元素大小来推导出 -1 的地方的维度应该是多少
input_x_images = tf.reshape(input_x, [-1, 28, 28, 1]) # 改变形状之后的输入 # 从 Test(测试)数据集里选取 3000 个手写数字的图片和对应标签
test_x = mnist.test.images[:3000] # 图片
test_y = mnist.test.labels[:3000] # 标签 # 构建我们的卷积神经网络:
# 第 1 层卷积
conv1 = tf.layers.conv2d( #conv2d指的是2维卷积
inputs=input_x_images, # 形状 [28, 28, 1]
filters=32, # 32 个过滤器,输出的深度(depth)是32
kernel_size=[5, 5], # 过滤器在二维的大小是 (5 * 5)
strides=1, # 步长是 1
padding='same', # same 表示输出的大小不变,因此需要在外围补零 2 圈
activation=tf.nn.relu # 激活函数是 Relu
) # 经过第一层卷积后输出的形状为 [28, 28, 32] # 第 1 层池化(亚采样)
pool1 = tf.layers.max_pooling2d(
inputs=conv1, # 形状 [28, 28, 32]
pool_size=[2, 2], # 过滤器在二维的大小是(2 * 2)
strides=2 # 步长是 2
) # 经过第 1 层池化后输出的形状 [14, 14, 32] # 第 2 层卷积
conv2 = tf.layers.conv2d(
inputs=pool1, # 形状 [14, 14, 32]
filters=64, # 64 个过滤器,输出的深度(depth)是64
kernel_size=[5, 5], # 过滤器在二维的大小是 (5 * 5)
strides=1, # 步长是 1
padding='same', # same 表示输出的大小不变,因此需要在外围补零 2 圈
activation=tf.nn.relu # 激活函数是 Relu
) # 经过第二层卷积后输出的形状为 [14, 14, 64] # 第 2 层池化(亚采样)
pool2 = tf.layers.max_pooling2d(
inputs=conv2, # 形状 [14, 14, 64]
pool_size=[2, 2], # 过滤器在二维的大小是(2 * 2)
strides=2 # 步长是 2
) # 形状 [7, 7, 64] # 平坦化(flat)。降维
flat = tf.reshape(pool2, [-1, 7 * 7 * 64]) # 形状 [7 * 7 * 64, ] # 1024 个神经元的全连接层
dense = tf.layers.dense(inputs=flat, units=1024, activation=tf.nn.relu) # Dropout : 丢弃 50%(rate=0.5)
dropout = tf.layers.dropout(inputs=dense, rate=0.5) # 10 个神经元的全连接层,这里不用激活函数来做非线性化了
logits = tf.layers.dense(inputs=dropout, units=10) # 输出。形状 [1, 1, 10] # 计算误差(先用 Softmax 计算百分比概率,
# 再用 Cross entropy(交叉熵)来计算百分比概率和对应的独热码之间的误差)
loss = tf.losses.softmax_cross_entropy(onehot_labels=output_y, logits=logits)
#onehot_labels指的是实际的标签值,logits指的是卷积神经网络的预测输出 # Adam 优化器来最小化误差,学习率 0.001
train_op = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss) # 精度。计算 预测值 和 实际标签 的匹配程度
# 返回 (accuracy, update_op), 会创建两个局部变量
accuracy = tf.metrics.accuracy(
labels=tf.argmax(output_y, axis=1),#第一个参数labels为真实标签 注:tf.argmax返回的是最大值的下标
predictions=tf.argmax(logits, axis=1),)[1]#第二个参数predictions为预测标签 # 创建会话
sess = tf.Session()
# 初始化变量:全局和局部
init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init) # 训练 5000 步。这个步数可以调节
for i in range(5000):
batch = mnist.train.next_batch(50) # 从 Train(训练)数据集里取 “下一个” 50 个样本
train_loss, train_op_ = sess.run([loss, train_op], {input_x: batch[0], output_y: batch[1]})
if i % 100 == 0:
test_accuracy = sess.run(accuracy, {input_x: test_x, output_y: test_y})
print("第 {} 步的 训练损失={:.4f}, 测试精度={:.2f}".format(i, train_loss, test_accuracy)) # 测试:打印 20 个预测值 和 真实值
test_output = sess.run(logits, {input_x: test_x[:20]})
inferred_y = np.argmax(test_output, 1)
print(inferred_y, '推测的数字') # 推测的数字
print(np.argmax(test_y[:20], 1), '真实的数字') # 真实的数字 # 关闭会话
sess.close()
结果:
动手实现CNN卷积神经网络的更多相关文章
- CNN(卷积神经网络)、RNN(循环神经网络)、DNN(深度神经网络)的内部网络结构有什么区别?
https://www.zhihu.com/question/34681168 CNN(卷积神经网络).RNN(循环神经网络).DNN(深度神经网络)的内部网络结构有什么区别?修改 CNN(卷积神经网 ...
- Deep Learning模型之:CNN卷积神经网络(一)深度解析CNN
http://m.blog.csdn.net/blog/wu010555688/24487301 本文整理了网上几位大牛的博客,详细地讲解了CNN的基础结构与核心思想,欢迎交流. [1]Deep le ...
- [转]Theano下用CNN(卷积神经网络)做车牌中文字符OCR
Theano下用CNN(卷积神经网络)做车牌中文字符OCR 原文地址:http://m.blog.csdn.net/article/details?id=50989742 之前时间一直在看 Micha ...
- Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现(转)
Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现 zouxy09@qq.com http://blog.csdn.net/zouxy09 自己平时看了一些论文, ...
- CNN(卷积神经网络)、RNN(循环神经网络)、DNN,LSTM
http://cs231n.github.io/neural-networks-1 https://arxiv.org/pdf/1603.07285.pdf https://adeshpande3.g ...
- day-16 CNN卷积神经网络算法之Max pooling池化操作学习
利用CNN卷积神经网络进行训练时,进行完卷积运算,还需要接着进行Max pooling池化操作,目的是在尽量不丢失图像特征前期下,对图像进行downsampling. 首先看下max pooling的 ...
- cnn(卷积神经网络)比较系统的讲解
本文整理了网上几位大牛的博客,详细地讲解了CNN的基础结构与核心思想,欢迎交流. [1]Deep learning简介 [2]Deep Learning训练过程 [3]Deep Learning模型之 ...
- Keras(四)CNN 卷积神经网络 RNN 循环神经网络 原理及实例
CNN 卷积神经网络 卷积 池化 https://www.cnblogs.com/peng8098/p/nlp_16.html 中有介绍 以数据集MNIST构建一个卷积神经网路 from keras. ...
- TensorFlow——CNN卷积神经网络处理Mnist数据集
CNN卷积神经网络处理Mnist数据集 CNN模型结构: 输入层:Mnist数据集(28*28) 第一层卷积:感受视野5*5,步长为1,卷积核:32个 第一层池化:池化视野2*2,步长为2 第二层卷积 ...
随机推荐
- Dubbo -- 关于 api接口调用不使用强依赖
首先,我们都知道 Dubbo 调用api 需要提供暴露 接口, 消费端才通过 ZK 可以调用 通常我们都会使用 提供 api jar包 的方式 使用 这样既方便又快捷 简单 只需要在spr ...
- web手工项目03-登录功能测试用例及缺陷编写-流程图画法-前后台下单及发货流程图-流程图设计测试用例方法-功能测试涉及到的四种数据库场景
回顾 注册功能测试(步骤,需求分析(输入分析,处理分析,输出分析),数据构造(有效等价类,无效等价类,有效数据,无效数据),编写用例,执行用例,缺陷报告) 轮播图功能测试(步骤,需求分析拆分测试点,测 ...
- laravel构建联合查询
参考:http://laravelacademy.org/post/126.html DB门面可以指定不同的数据库连接(通过connection方法) /** * @param $login_uid ...
- 【Leetcode_easy】748. Shortest Completing Word
problem 748. Shortest Completing Word 题意: solution1: class Solution { public: string shortestComplet ...
- iOS面试考察点
)自我介绍.项目经历.专业知识.自由提问 (2)准备简历.投发简历.笔试(电话面试.).面试.复试.终面试.试用.转正.发展.跳槽(加薪升职) 1闲聊 a)自我介绍:自我认识能力 b)评价上一家公司: ...
- 《Django企业开发实战 高效Python Web框架指南》胡阳
链接:https://pan.baidu.com/s/1NmN_IT5RvevCMt9bZCW1-g提取码:2ki9
- 锚点/JQ:点击导航跳到网页中的指定位置
今天做了一个简单的功能,页面往下滚动到一定位置,顶部出现一个浮动的导航栏,点击导航栏标签,下面页面跳转到相应的区域.回到顶部,导航栏隐藏. 因为顶部有一个浮动的导航栏,所以跳转到下面页面的时候,总是盖 ...
- Egret入门学习日记 --- 第十三篇(书中 5.2~5.3节 内容)
第十三篇(书中 5.2~5.3节 内容) 写日记已经十天多了,我发现越到后面,我书写的方式越来越程序化. 感觉渐渐失去了人类所谓的感情似的. 不过,没想到的是,书中的内容,很少出现了错误,我一路过来到 ...
- VMware虚拟机中CentOS7的硬盘空间扩容
查看centos7系统挂载点信息 扩展VMWare-centos7硬盘空间 对新增加的硬盘进行分区.格式化 添加新LVM到已有的LVM组,实现扩容 1.查看centos7系统挂载点信息 df -h查看 ...
- PHP判断是不是爬虫的方法
PHP判断是不是爬虫的方法这个一般用于防止爬虫 和 seo优化(因为爬虫都是按照第一次打开显示的页面 有些ajax 等需要点击才能显示的就爬不到啦)<pre><?php// 判断是否 ...