基于 keras-js 快速实现浏览器内的 CNN 手写数字识别
https://zhuanlan.zhihu.com/p/33313340
在这篇文章中,我会快速地介绍如何使用 keras 训练一个简单的识别 MNIST(一个手写数字数据集)的 CNN(卷积神经网络),并且把训练好的网络应用到 web 浏览器内。
DEMO 地址:https://starkwang.github.io/keras-js-demo/dist/
零、准备工作
首先需要给你的电脑安装 keras,具体安装的步骤请参考 keras 官方文档
一、快速入门
首先十分推荐阅读 tensorflow 官方文档中的 MNIST For ML Beginners,这里是极客学院的中文翻译
MNIST 是一个很流行的入门级机器学习/计算机视觉数据集,它包含 0 - 9 的各种手写数字图片:

每张图片的尺寸均为 28 * 28,用一个 28 * 28 的二维数组来表示,换句话说,每张图片都是由 784 个像素点组成,每个像素点的值在 0 - 255 之间。
比如下面就是一个 "3" 的数据:
(知乎web移动端代码强制换行,简直有毒)
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 038 043 105 255 253 253 253 253 253 174 006 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 043 139 224 226 252 253 252 252 252 252 252 252 158 014 000 000 000 000 000
000 000 000 000 000 000 000 000 000 178 252 252 252 252 253 252 252 252 252 252 252 252 059 000 000 000 000 000
000 000 000 000 000 000 000 000 000 109 252 252 230 132 133 132 132 189 252 252 252 252 059 000 000 000 000 000
000 000 000 000 000 000 000 000 000 004 029 029 024 000 000 000 000 014 226 252 252 172 007 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 085 243 252 252 144 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 088 189 252 252 252 014 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 091 212 247 252 252 252 204 009 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 032 125 193 193 193 253 252 252 252 238 102 028 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 045 222 252 252 252 252 253 252 252 252 177 000 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 045 223 253 253 253 253 255 253 253 253 253 074 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 031 123 052 044 044 044 044 143 252 252 074 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 015 252 252 074 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 086 252 252 074 000 000 000 000 000 000 000 000
000 000 000 000 000 000 005 075 009 000 000 000 000 000 000 098 242 252 252 074 000 000 000 000 000 000 000 000
000 000 000 000 000 061 183 252 029 000 000 000 000 018 092 239 252 252 243 065 000 000 000 000 000 000 000 000
000 000 000 000 000 208 252 252 147 134 134 134 134 203 253 252 252 188 083 000 000 000 000 000 000 000 000 000
000 000 000 000 000 208 252 252 252 252 252 252 252 252 253 230 153 008 000 000 000 000 000 000 000 000 000 000
000 000 000 000 000 049 157 252 252 252 252 252 217 207 146 045 000 000 000 000 000 000 000 000 000 000 000 000
000 000 000 000 000 000 007 103 235 252 172 103 024 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000
使用 keras,可以很方便地导入 MNIST 数据集:
from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
总体来说,我们的想要得到的网络模型,是有一个固定的输入输出的:
- 输入为一个 28 * 28 的二维整数数组
- 输出是一个长度为 10 的数组,依次表示 0-9 的可能性(例如如果有一张图片 80% 概率为 1, 20% 概率为 7的话,那么这个数组就是
[0, 0.8, 0, 0, 0, 0, 0, 0.2, 0, 0])
二、使用 keras 训练网络
我们想要训练的模型,由以下几层网络组成:
- 32 个 3x3 卷积核的卷积层
- 64 个 3x3 卷积核的卷积层
- 采样因子为 (2, 2) 的池化层
- Dropout 层
- Flatten 层
- ReLu 全连接层
- Dropout 层
- Softmax 全连接层
用 keras 训练一个识别 MNIST 的 CNN 网络非常方便,下面是一个官方给出的例子(源码在此):
from __future__ import print_function
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
batch_size = 128
num_classes = 10
epochs = 12
# input image dimensions
img_rows, img_cols = 28, 28
# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
if K.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
activation='relu',
input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adadelta(),
metrics=['accuracy'])
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
verbose=1,
validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
# Save model
model.save('myMnistCNN.h5')
如果已经安装好了 keras,直接运行即可:
python mnist_cnn.py
三、转换输出模型
获得训练好的 .h5 文件之后,模型还不能直接使用,因为我们需要对它进行转编码,keras-js 提供了一个 python 脚本来自动执行:
python ./python/encoder.py -q myMnistCNN.h5
这个脚本会把 .h5 文件转编码为 keras-js 可读的格式,里面包含了训练好的神经网络的所有模型和参数。
四、使用 keras-js 导入模型
首先需要引入 keras-js,可以通过 script 标签直接引入:
<script src="https://unpkg.com/keras-js"></script>
也可以通过 npm 安装后使用 webpack 构建引入,参考这里
接下来就可以直接创建一个 Model,keras-js 会自动加载对应的 bin 文件:
const model = new KerasJS.Model({
filepath: '/path/to/mnist_cnn.bin',
gpu: true,
transferLayerOutputs: true
})
初始化完毕之后,就可以用于 MNIST 识别了,输入是一个长度为 784 的数组(包含 28*28 各个像素点的灰度值),输出是一个长度为 10 的数组(0-9的概率):
(可以使用上文中给的那个 "3" 的数据范例)
model
.ready()
.then(() => {
// data 是一个长度为 784 的数组,每一项都介于 0 - 255 之间
// 这里我们需要把数组转换为 Float32 类型
const inputData = new Float32Array(data)
// 识别
return model.predict(inputData)
})
.then(outputData => {
// 输出为 0-9 的概率,例如:
// { output: [0, 0, 0, 0.8, 0, 0, 0.2, 0, 0, 0] }
})
.catch(err => {
// ...
})
五、Canvas 实现一个手写板
最后一步就是实现一个手写板,具体的代码就不放上来了,主要就是通过 mousedown、mousemove、mouseup 事件来绘制图形。
绘制完毕之后,调用 ctx.getImageData,就可以得到 canvas 内的像素数据,每个像素对应四个数值,依次是每个点的 rgba 值,处理之后就可以得到长度为 784 的灰度数组了。然后使用上文提到的 model.predict 即可。

基于 keras-js 快速实现浏览器内的 CNN 手写数字识别的更多相关文章
- Keras cnn 手写数字识别示例
#基于mnist数据集的手写数字识别 #构造了cnn网络拟合识别函数,前两层为卷积层,第三层为池化层,第四层为Flatten层,最后两层为全连接层 #基于Keras 2.1.1 Tensorflow ...
- keras框架的CNN手写数字识别MNIST
参考:林大贵.TensorFlow+Keras深度学习人工智能实践应用[M].北京:清华大学出版社,2018. 首先在命令行中写入 activate tensorflow和jupyter notebo ...
- 手写数字识别——利用keras高层API快速搭建并优化网络模型
在<手写数字识别——手动搭建全连接层>一文中,我们通过机器学习的基本公式构建出了一个网络模型,其实现过程毫无疑问是过于复杂了——不得不考虑诸如数据类型匹配.梯度计算.准确度的统计等问题,但 ...
- keras框架的MLP手写数字识别MNIST,梳理?
keras框架的MLP手写数字识别MNIST 代码: # coding: utf-8 # In[1]: import numpy as np import pandas as pd from kera ...
- [Python]基于CNN的MNIST手写数字识别
目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...
- Keras mlp 手写数字识别示例
#基于mnist数据集的手写数字识别 #构造了三层全连接层组成的多层感知机,最后一层为输出层 #基于Keras 2.1.1 Tensorflow 1.4.0 代码: import keras from ...
- 手写数字识别——基于LeNet-5卷积网络模型
在<手写数字识别——利用Keras高层API快速搭建并优化网络模型>一文中,我们搭建了全连接层网络,准确率达到0.98,但是这种网络的参数量达到了近24万个.本文将搭建LeNet-5网络, ...
- mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)
前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...
- 手写数字识别 ----在已经训练好的数据上根据28*28的图片获取识别概率(基于Tensorflow,Python)
通过: 手写数字识别 ----卷积神经网络模型官方案例详解(基于Tensorflow,Python) 手写数字识别 ----Softmax回归模型官方案例详解(基于Tensorflow,Pytho ...
- 手写数字识别 ----卷积神经网络模型官方案例注释(基于Tensorflow,Python)
# 手写数字识别 ----卷积神经网络模型 import os import tensorflow as tf #部分注释来源于 # http://www.cnblogs.com/rgvb178/p/ ...
随机推荐
- 把数据库表的信息添加到list集合里面
把数据库表里面的信息添加到集合里面并且打印出来: 数据库表的内容: java代码逻辑处理: 1 public static void main(String[] args) { 2 3 Connec ...
- AWS学习笔记之Lambda执行权限
最近在网上看到一道关于AWS Lambda的题,十分有意思: A developer has an application that uses an AWS Lambda function to up ...
- 未能加载文件或程序集“System.Runtime.WindowsRuntime, Version=4.0.14.0, Culture=neutral, PublicKeyToken=b77a5c561934e089”或它的某一个依赖项。不应出于执行的目的加载引用程序集。只能在仅限反射的加载程序上下文中加载引用程序集。 (异常来自 HRESULT:0x80131058)
VS项目编译时报错: 未能加载文件或程序集"System.Runtime.WindowsRuntime, Version=4.0.14.0, Culture=neutral, PublicK ...
- [车载以太网] SOME/IP 参数和数据结构的序列化
概述:SOME/IP 参数和数据结构的序列化 大小端/字节序 每个参数(parameter)的字节顺序由接口定义进行规定. 所有的 SOME/IP Header 字段,应该以网络字节序(大端)编码. ...
- Centos下多种PHP拓展安装方法
http://my.oschina.net/u/2400083/blog/518195
- MongoDB入门实战教程(3)
上一篇我们了解了MongoDB的复制集概念和复制集的搭建,本篇我们来了解一下如何实现数据恢复 和 提升安全性的一些实践. 1 Mongo Tools实现数据恢复 MongoDB 4.4之后,备份与恢复 ...
- C# 数字(阿拉伯数字)金额转汉字金额 人民币操作类 :转换人民币大小金额。
/// <summary> /// 转换为人民币大写金额形式 /// </summary> /// <param name="Money">金额 ...
- java compareTo 与 equals 区别
简介 要实现compareTo函数需要实现接口Comparable这个接口 然后这个接口中只有compareTo函数实现一下就可以用Collections.sort等方法. equals 如果不重写, ...
- 虚继承 private virtual class
简介 看到一个代码觉得奇怪,顺便看了一下相关的资料. 简而言之,虚继承是对于C++之中的多重继承相关的,消除多重集成共同的父类的变量的奇异性. 参考资料 https://www.cnblogs.com ...
- ETL数据集成丨为什么没有做好ETL的BI工具最终都会失败?
随着数字化转型,企业越来越重视数据的价值和利用.商业智能(Business Intelligence,BI)作为一种数据分析和决策支持的重要工具,被广泛应用于各行各业.然而,对于BI项目的成功实施,E ...