Android+TensorFlow+CNN+MNIST 手写数字识别实现
Android+TensorFlow+CNN+MNIST 手写数字识别实现
SkySeraph 2018
Email:skyseraph00#163.com
更多精彩请直接访问SkySeraph个人站点:www.skyseraph.com
Overview
本文系“SkySeraph AI 实践到理论系列”第一篇,咱以AI界的HelloWord 经典MNIST数据集为基础,在Android平台,基于TensorFlow,实现CNN的手写数字识别。
Code here~
Practice
Environment
- TensorFlow: 1.2.0
- Python: 3.6
- Python IDE: PyCharm 2017.2
- Android IDE: Android Studio 3.0
Train & Evaluate(Python+TensorFlow)
训练和评估部分主要目的是生成用于测试用的pb文件,其保存了利用TensorFlow python API构建训练后的网络拓扑结构和参数信息,实现方式有很多种,除了cnn外还可以使用rnn,fcnn等。
其中基于cnn的函数也有两套,分别为tf.layers.conv2d和tf.nn.conv2d, tf.layers.conv2d使用tf.nn.conv2d作为后端处理,参数上filters是整数,filter是4维张量。原型如下:
convolutional.py文件
def conv2d(inputs, filters, kernel_size, strides=(1, 1), padding=’valid’, data_format=’channels_last’,
dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=None,
bias_initializer=init_ops.zeros_initializer(), kernel_regularizer=None, bias_regularizer=None,
activity_regularizer=None, kernel_constraint=None, bias_constraint=None, trainable=True, name=None,
reuse=None)
gen_nn_ops.py 文件
def conv2d(input, filter, strides, padding, use_cudnn_on_gpu=True, data_format="NHWC", name=None)
官方Demo实例中使用的是layers module,结构如下:
- Convolutional Layer #1:32个5×5的filter,使用ReLU激活函数
- Pooling Layer #1:2×2的filter做max pooling,步长为2
- Convolutional Layer #2:64个5×5的filter,使用ReLU激活函数
- Pooling Layer #2:2×2的filter做max pooling,步长为2
- Dense Layer #1:1024个神经元,使用ReLU激活函数,dropout率0.4 (为了避免过拟合,在训练的时候,40%的神经元会被随机去掉)
- Dense Layer #2 (Logits Layer):10个神经元,每个神经元对应一个类别(0-9)
核心代码在cnn_model_fn(features, labels, mode)函数中,完成卷积结构的完整定义,核心代码如下.
也可以采用传统的tf.nn.conv2d函数, 核心代码如下。
Test(Android+TensorFlow)
- 核心是使用API接口: TensorFlowInferenceInterface.java
- 配置gradle 或者 自编译TensorFlow源码导入jar和so
compile ‘org.tensorflow:tensorflow-android:1.2.0’ 导入pb文件.pb文件放assets目录,然后读取
String actualFilename = labelFilename.split(“file:///android_asset/“)[1];
Log.i(TAG, “Reading labels from: “ + actualFilename);
BufferedReader br = null;
br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
String line;
while ((line = br.readLine()) != null) {
c.labels.add(line);
}
br.close();TensorFlow接口使用

- 最终效果:


Theory
MNIST
MNIST,最经典的机器学习模型之一,包含0~9的数字,28*28大小的单色灰度手写数字图片数据库,其中共60,000 training examples和10,000 test examples。
文件目录如下,主要包括4个二进制文件,分别为训练和测试图片及Label。
如下为训练图片的二进制结构,在真实数据前(pixel),有部分描述字段(魔数,图片个数,图片行数和列数),真实数据的存储采用大端规则。
(大端规则,就是数据的高字节保存在低内存地址中,低字节保存在高内存地址中)
在具体实验使用,需要提取真实数据,可采用专门用于处理字节的库struct中的unpack_from方法,核心方法如下:
struct.unpack_from(self._fourBytes2, buf, index)
MNIST作为AI的Hello World入门实例数据,TensorFlow封装对其封装好了函数,可直接使用
mnist = input_data.read_data_sets(‘MNIST’, one_hot=True)
CNN(Convolutional Neural Network)
CNN Keys
- CNN,Convolutional Neural Network,中文全称卷积神经网络,即所谓的卷积网(ConvNets)。
- 卷积(Convolution)可谓是现代深度学习中最最重要的概念了,它是一种数学运算,读者可以从下面链接[23]中卷积相关数学机理,包括分别从傅里叶变换和狄拉克δ函数中推到卷积定义,我们可以从字面上宏观粗鲁的理解成将因子翻转相乘卷起来。
- 卷积动画。演示如下图[26],更多动画演示可参考[27]

神经网络。一个由大量神经元(neurons)组成的系统,如下图所示[21]

其中x表示输入向量,w为权重,b为偏值bias,f为激活函数。Activation Function 激活函数: 常用的非线性激活函数有Sigmoid、tanh、ReLU等等,公式如下如所示。
- Sigmoid缺点
- 函数饱和使梯度消失(神经元在值为 0 或 1 的时候接近饱和,这些区域,梯度几乎为 0)
- sigmoid 函数不是关于原点中心对称的(无0中心化)
- tanh: 存在饱和问题,但它的输出是零中心的,因此实际中 tanh 比 sigmoid 更受欢迎。
- ReLU
- 优点1:ReLU 对于 SGD 的收敛有巨大的加速作用
- 优点2:只需要一个阈值就可以得到激活值,而不用去算一大堆复杂的(指数)运算
- 缺点:需要合理设置学习率(learning rate),防止训练时dead,还可以使用Leaky ReLU/PReLU/Maxout等代替

- Sigmoid缺点
- Pooling池化。一般分为平均池化mean pooling和最大池化max pooling,如下图所示[21]为max pooling,除此之外,还有重叠池化(OverlappingPooling)[24],空金字塔池化(Spatial Pyramid Pooling)[25]
- 平均池化:计算图像区域的平均值作为该区域池化后的值。
- 最大池化:选图像区域的最大值作为该区域池化后的值。


CNN Architecture
- 三层神经网络。分别为输入层(Input layer),输出层(Output layer),隐藏层(Hidden layer),如下图所示[21]

- CNN层级结构。 斯坦福cs231n中阐述了一种[INPUT-CONV-RELU-POOL-FC],如下图所示[21],分别为输入层,卷积层,激励层,池化层,全连接层。
- CNN通用架构分为如下三层结构:
- Convolutional layers 卷积层
- Pooling layers 汇聚层
- Dense (fully connected) layers 全连接层

- 动画演示。参考[22]。
Regression + Softmax
机器学习有监督学习(supervised learning)中两大算法分别是分类算法和回归算法,分类算法用于离散型分布预测,回归算法用于连续型分布预测。
回归的目的就是建立一个回归方程用来预测目标值,回归的求解就是求这个回归方程的回归系数。
其中回归(Regression)算法包括Linear Regression,Logistic Regression等, Softmax Regression是其中一种用于解决多分类(multi-class classification)问题的Logistic回归算法的推广,经典实例就是在MNIST手写数字分类上的应用。
Linear Regression
Linear Regression是机器学习中最基础的模型,其目标是用预测结果尽可能地拟合目标label
- 多元线性回归模型定义

- 多元线性回归求解

- Mean Square Error (MSE)
- Gradient Descent(梯度下降法)
- Normal Equation(普通最小二乘法)
- 局部加权线性回归(LocallyWeightedLinearRegression, LWLR ):针对线性回归中模型欠拟合现象,在估计中引入一些偏差以便降低预测的均方误差。
- 岭回归(ridge regression)和缩减方法
- 选择: Normal Equation相比Gradient Descent,计算量大(需计算X的转置与逆矩阵),只适用于特征个数小于100000时使用;当特征数量大于100000时使用梯度法。当X不可逆时可替代方法为岭回归算法。LWLR方法增加了计算量,因为它对每个点做预测时都必须使用整个数据集,而不是计算出回归系数得到回归方程后代入计算即可,一般不选择。
- 调优: 平衡预测偏差和模型方差(高偏差就是欠拟合,高方差就是过拟合)
- 获取更多的训练样本 - 解决高方差
- 尝试使用更少的特征的集合 - 解决高方差
- 尝试获得其他特征 - 解决高偏差
- 尝试添加多项组合特征 - 解决高偏差
- 尝试减小 λ - 解决高偏差
- 尝试增加 λ -解决高方差
Softmax Regression
- Softmax Regression估值函数(hypothesis)

- Softmax Regression代价函数(cost function)

- 理解:

- Softmax Regression & Logistic Regression:
- 多分类 & 二分类。Logistic Regression为K=2时的Softmax Regression
- 针对K类问题,当类别之间互斥时可采用Softmax Regression,当非斥时,可采用K个独立的Logistic Regression
- 总结: Softmax Regression适用于类别数量大于2的分类,本例中用于判断每张图属于每个数字的概率。
References & Recommends
MNIST
- [01]Mnist官网
- [02]Visualizing MNIST: An Exploration of Dimensionality Reduction
- [03]TensorFlow Mnist官方实例
- [04]Sample code for “Tensorflow and deep learning, without a PhD”
Softmax
CNN
- [21]Stanford University’s Convolutional Neural Networks for Visual Recognition course materials 翻译
- [22]July CNN笔记:通俗理解卷积神经网络
- [23]理解卷积Convolution
- [24]Imagenet classification with deep convolutional neural networks
- [25]Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition
- [26]Convolutional Neural Networks-Basics
- [27]A technical report on convolution arithmetic in the context of deep learning
TensorFlow+CNN / TensorFlow+Android
- [31]Google官方Demo
- [32]Google官方Codelab
- [33]deep-learning-cnns-in-tensorflow Github
- [34]tensorflow-classifier-android
- [35]creating-custom-model-for-android-using-tensorflow
- [36]TF-NN Mnist实例
By SkySeraph-2018
本文首发于skyseraph.com:“Android+TensorFlow+CNN+MNIST 手写数字识别实现”
Android+TensorFlow+CNN+MNIST 手写数字识别实现的更多相关文章
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...
- 基于TensorFlow的MNIST手写数字识别-初级
一:MNIST数据集 下载地址 MNIST是一个包含很多手写数字图片的数据集,一共4个二进制压缩文件 分别是test set images,test set labels,training se ...
- Tensorflow之MNIST手写数字识别:分类问题(1)
一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点: 1.将离散特征的取值扩展 ...
- Tensorflow实现MNIST手写数字识别
之前我们讲了神经网络的起源.单层神经网络.多层神经网络的搭建过程.搭建时要注意到的具体问题.以及解决这些问题的具体方法.本文将通过一个经典的案例:MNIST手写数字识别,以代码的形式来为大家梳理一遍神 ...
- Tensorflow之MNIST手写数字识别:分类问题(2)
整体代码: #数据读取 import tensorflow as tf import matplotlib.pyplot as plt import numpy as np from tensorfl ...
- 基于TensorFlow的MNIST手写数字识别-深入
构建多层卷积神经网络时需要多组W和偏移项b,我们封装2个方法来产生W和b 初级MNIST中用0初始化W和b,这里用噪声初始化进行对称打破,防止产生梯度0,同时用一个小的正值来初始化b避免dead ne ...
- keras—神经网络CNN—MNIST手写数字识别
from keras.datasets import mnist from keras.utils import np_utils from plot_image_1 import plot_imag ...
- TensorFlow——MNIST手写数字识别
MNIST手写数字识别 MNIST数据集介绍和下载:http://yann.lecun.com/exdb/mnist/ 一.数据集介绍: MNIST是一个入门级的计算机视觉数据集 下载下来的数据集 ...
- 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)
上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...
随机推荐
- webpack学习之路
当自己在学习webpack的时候,在网上发现中文的很详细的教程很少,于是便想将自己学习webpack的笔记记录整理下来,便有了这篇文章,希望对大家有所帮助,如果有错误,欢迎大家指出. 在我们开始之前 ...
- Redux 介绍
本文主要是对 Redux 官方文档 的梳理以及自身对 Redux 的理解. 单页面应用的痛点 对于复杂的单页面应用,状态(state)管理非常重要.state 可能包括:服务端的响应数据.本地对响应数 ...
- CTF---Web入门第六题 因缺思汀的绕过
因缺思汀的绕过分值:20 来源: pcat 难度:中 参与人数:6479人 Get Flag:2002人 答题人数:2197人 解题通过率:91% 访问解题链接去访问题目,可以进行答题.根据web题一 ...
- [51nod1310]Chandrima and XOR
有这样一个小到大排列的无穷序列S:1, 2, 4, 5, 8......,其中任何一个数转为2进制不包括2个连续的1.给出一个长度为N的正整数数组A,A1, A2......An记录的是下标(下标从1 ...
- C语言缓冲区(缓存)详解
缓冲区又称为缓存,它是内存空间的一部分.也就是说,在内存空间中预留了一定的存储空间,这些存储空间用来缓冲输入或输出的数据,这部分预留的空间就叫做缓冲区.缓冲区根据其对应的是输入设备还是输出设备,分为输 ...
- Sql Server——数据增删改
所谓数据的增删改就是在创建好数据库和表后向表中添加数据.删除表中的数据.更改表中的一些数据. 新增数据: 语法一: insert into 表名 values (数据内容) --这里需要 ...
- Unity 小笔记
1,Time.deltatime放在Update和fixedupdate中得到的值是不一样的.还以为是通过两个值来获取. 2,VR中绘制射线可以使用LineRender. 3,Unity中判断一个东西 ...
- 98、vue.js简单入门
本篇导航: 介绍与安装 vue常用指令 一.介绍与安装 vue是一套构建用户界面的JAVASCRIPT框架.与其它大型框架不同的是,Vue 被设计为可以自底向上逐层应用.Vue 的核心库只关注视图层, ...
- debian 9 双显卡安装NVIDIA显卡驱动
最近用debian,给debian装n卡驱动折腾了好几天了,主要还是网络不好,官方wiki的方法下载经常卡死..摸索了几天感觉已经摸到了头绪,决定写下来供大家参考参考 先提供单显卡NVIDIA驱动的安 ...
- PHP中put和post区别
1. 使用支持和范围的区别: PHP提供了对PUT方法的支持,在Http定义的与服务器的交互方法中,PUT是把消息本体中的消息发送到一个URL,形式上跟POST类似; PHP 提供对诸如 Netsca ...