图片训练:使用卷积神经网络(CNN)识别手写数字
这篇文章中,我们将使用CNN构建一个Tensorflow.js模型来分辨手写的数字。首先,我们通过使之“查看”数以千计的数字图片以及他们对应的标识来训练分辨器。然后我们再通过此模型从未“见到”过的测试数据评估这个分辨器的精确度。
一、运行代码
这篇文章的全部代码可以在仓库TensorFlow.js examples中的tfjs-examples/mnist 下找到,你可以通过下面的方式clone下来然后运行这个demo:
$ git clone https://github.com/tensorflow/tfjs-examples
$ cd tfjs-examples/mnist
$ yarn
$ yarn watch
上面的这个目录完全是独立的,所以完全可以copy下来然后创建你个人的项目。
二、数据相关
这篇文章中,我们将会使用 MNIST 的手写数据,这些我们将要去分辨的手写数据如下所示:
为了预处理这些数据,我们已经写了 data.js, 这个文件包含了Minsdata类,而这个类可以帮助我们从MNIST的数据集中获取到任意的一些列的MNIST。
而MnistData这个类将全部的数据分割成了训练数据和测试数据。我们训练模型的时候,分辨器就会只观察训练数据。而当我们评价模型时,我们就仅仅使用测试数据,而这些测试数据是模型还没有看见到的,这样就可以来观察模型预测全新的数据了。
这个MnistData有两个共有方法:
- nextTrainBatch(batchSize): 从训练数据中返回一批任意的图片以及他们的标识。
- nextTestBatch(batchSize): 从测试数据中返回一批图片以及他们的标识。
注意:当我们训练MNIST分辨器时,应当注意数据获取的任意性是非常重要的,这样模型预测才不会受到我们提供图片顺序的干扰。例如,如果我们每次给这个模型第一次都提供的是数字1,那么在训练期间,这个模型就会简单的预测第一个就是1(因为这样可以减小损失函数)。 而如果我们每次训练时都提供的是2,那么它也会简单切换为预测2并且永远不会预测1(同样的,也是因为这样可以减少损失函数)。如果每次都提供这样典型的、有代表性的数字,那么这个模型将永远也学不会做出一个精确的预测。
三、创建模型
在这一部分,我们将会创建一个卷积图片识别模型。为了这样做,我们使用了Sequential模型(模型中最为简单的一个类型),在这个模型中,张量(tensors)可以连续的从一层传递到下一层中。
首先,我们需要使用tf.sequential先初始化一个sequential模型:
const model = tf.sequential();
既然我们已经创建了一个模型,那么我们就可以添加层了。
四、添加第一层
我们要添加的第一层是一个2维的卷积层。卷积将过滤窗口掠过图片来学习空间上来说不会转变的变量(即图片中不同位置的模式或者物体将会被平等对待)。
我们可以通过tf.layers.conv2d来创建一个2维的卷积层,这个卷积层可以接受一个配置对象来定义层的结构,如下所示:
model.add(tf.layers.conv2d({
inputShape: [, , ],
kernelSize: ,
filters: ,
strides: ,
activation: 'relu',
kernelInitializer: 'VarianceScaling'
}));
让我们拆分对象中的每个参数吧:
- inputShape。这个数据的形状将回流入模型的第一层。在这个示例中,我们的MNIST例子是28 x 28像素的黑白图片,这个关于图片的特定的格式即[row, column, depth],所以我们想要配置一个[28, 28, 1]的形状,其中28行和28列是这个数字在每个维度上的像素数,且其深度为1,这是因为我们的图片只有1个颜色:
- kernelSize。划过卷积层过滤窗口的数量将会被应用到输入数据中去。这里,我们设置了kernalSize的值为5,也就是指定了一个5 x 5的卷积窗口。
- filters。这个kernelSize的过滤窗口的数量将会被应用到输入数据中,我们这里将8个过滤器应用到数据中。
- strides。 即滑动窗口每一步的步长。比如每当过滤器移动过图片时将会由多少像素的变化。这里,我们指定其步长为1,这意味着每一步都是1像素的移动。
- activation。这个activation函数将会在卷积完成之后被应用到数据上。在这个例子中,我们应用了relu函数,这个函数在机器学习中是一个非常常见的激活函数。
- kernelInitializer。这个方法对于训练动态的模型是非常重要的,他被用于任意地初始化模型的weights。我们这里将不会深入细节来讲,但是 VarianceScaling (即这里用的)真的是一个初始化非常好的选择。
五、添加第二层
让我们为这个模型添加第二层:一个最大的池化层(pooling layer),这个层中我们将通过 tf.layers.maxPooling2d 来创建。这一层将会通过在每个滑动窗口中计算最大值来降频取样得到结果。
model.add(tf.layers.maxPooling2d({
poolSize: [, ],
strides: [, ]
}));
- poolSize。这个滑动池窗口的数量将会被应用到输入的数据中。这里我们设置poolSize为[2, 2],所以这就意味着池化层将会对输入数据应用2x2的窗口。
- strides。 这个池化层的步长大小。比如,当每次挪开输入数据时窗口需要移动多少像素。这里我们指定strides为[2, 2],这就意味着过滤器将会以在水平方向和竖直方向上同时移动2个像素的方式来划过图片。
注意:因为poolSize和strides都是2x2,所以池化层空口将会完全不会重叠。这也就意味着池化层将会把激活的大小从上一层减少一半。
六、添加剩下的层
图片训练:使用卷积神经网络(CNN)识别手写数字的更多相关文章
- 如何用卷积神经网络CNN识别手写数字集?
前几天用CNN识别手写数字集,后来看到kaggle上有一个比赛是识别手写数字集的,已经进行了一年多了,目前有1179个有效提交,最高的是100%,我做了一下,用keras做的,一开始用最简单的MLP, ...
- python手写神经网络实现识别手写数字
写在开头:这个实验和matlab手写神经网络实现识别手写数字一样. 实验说明 一直想自己写一个神经网络来实现手写数字的识别,而不是套用别人的框架.恰巧前几天,有幸从同学那拿到5000张已经贴好标签的手 ...
- 使用神经网络来识别手写数字【译】(三)- 用Python代码实现
实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...
- matlab手写神经网络实现识别手写数字
实验说明 一直想自己写一个神经网络来实现手写数字的识别,而不是套用别人的框架.恰巧前几天,有幸从同学那拿到5000张已经贴好标签的手写数字图片,于是我就尝试用matlab写一个网络. 实验数据:500 ...
- 【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)
主要内容: 1.基于CNN的mnist手写数字识别(详细代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...
- 6 TensorFlow实现cnn识别手写数字
------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...
- keras—神经网络CNN—MNIST手写数字识别
from keras.datasets import mnist from keras.utils import np_utils from plot_image_1 import plot_imag ...
- 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)
上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...
- C#中调用Matlab人工神经网络算法实现手写数字识别
手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化 投影 矩阵 目标定位 Matlab 手写数字图像识别简介: 手写 ...
随机推荐
- 深度优先搜索DFS和广度优先搜索BFS
DFS简介 深度优先搜索,一般会设置一个数组visited记录每个顶点的访问状态,初始状态图中所有顶点均未被访问,从某个未被访问过的顶点开始按照某个原则一直往深处访问,访问的过程中随时更新数组visi ...
- MFC 消息框
窗口类能够使用messagebox int ret = MessageBox(_T("内容"), _T("标题"), MB_OKCANCLE| //MB_OB ...
- JAVA经典算法40+
现在是3月份,也是每年开年企业公司招聘的高峰期,同时有许多的朋友也出来找工作.现在的招聘他们有时会给你出一套面试题或者智力测试题,也有的直接让你上机操作,写一段程序.算法的计算不乏出现,基于这个原因我 ...
- vue2.x和vue1.0
1.首先挂载方式上 在vue2.0中,如果使用body或者html作为挂载点,则会报以下警告: Do not mount Vue to <html> or <body> - m ...
- 在IIS7.5下配置PHP环境
1.下载安装ZkeysPHP,路径随意 找到该程序集 D:\ZkeysSoft\Php\php5isapi.dll 2.在站点配置“处理程序映射”,添加php后缀映射由D:\ZkeysSoft\Php ...
- asp.net文件上传下载
泽优大文件上传产品测试 泽优大文件上传控件up6,基于php开发环境测试. 开发环境:HBuilder 服务器:wamp64 数据库:mysql 可视化数据库编辑工具:Navicat Premium ...
- Mysql之数据库操作
数据库操作: 链接数据库: mysql -uroot -p masql -uroot -pmysql 退出数据库: exit/quit/ctrl + d sql语句最后需要分号结尾: 查看时间: ...
- 从点到面,给Button的属性动画
属性动画是API 11加进来的一个新特性,其实在现在来说也没什么新的了.属性动画可以对任意view的属性做动画,实现动画的原理就是在给定的时间内把属性从一个值变为另一个值.因此可以说属性动画什么都可以 ...
- webservice之helloword(web)rs
spring整合webservice 1.pom.xml文件 <dependencies> <!-- cxf 进行rs开发 必须导入 --> <dependency> ...
- linux下禁用SELinux
http://chenzhou123520.iteye.com/blog/1313582 如何开启或关闭SELinux RedHat的 /etc/sysconfig/selinux 在新版本中的Red ...