https://zhuanlan.zhihu.com/p/33313340

在这篇文章中,我会快速地介绍如何使用 keras 训练一个简单的识别 MNIST(一个手写数字数据集)的 CNN(卷积神经网络),并且把训练好的网络应用到 web 浏览器内。

DEMO 地址:https://starkwang.github.io/keras-js-demo/dist/

 

零、准备工作

首先需要给你的电脑安装 keras,具体安装的步骤请参考 keras 官方文档


一、快速入门

首先十分推荐阅读 tensorflow 官方文档中的 MNIST For ML Beginners,这里是极客学院的中文翻译

MNIST 是一个很流行的入门级机器学习/计算机视觉数据集,它包含 0 - 9 的各种手写数字图片:

每张图片的尺寸均为 28 * 28,用一个 28 * 28 的二维数组来表示,换句话说,每张图片都是由 784 个像素点组成,每个像素点的值在 0 - 255 之间。

比如下面就是一个 "3" 的数据:

(知乎web移动端代码强制换行,简直有毒)

000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 038 043 105 255 253 253 253 253 253 174 006 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 043 139 224 226 252 253 252 252 252 252 252 252 158 014 000 000 000 000 000
000 000 000 000 000 000 000 000 000 178 252 252 252 252 253 252 252 252 252 252 252 252 059 000 000 000 000 000
000 000 000 000 000 000 000 000 000 109 252 252 230 132 133 132 132 189 252 252 252 252 059 000 000 000 000 000
000 000 000 000 000 000 000 000 000 004 029 029 024 000 000 000 000 014 226 252 252 172 007 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 085 243 252 252 144 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 088 189 252 252 252 014 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 091 212 247 252 252 252 204 009 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 032 125 193 193 193 253 252 252 252 238 102 028 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 045 222 252 252 252 252 253 252 252 252 177 000 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 045 223 253 253 253 253 255 253 253 253 253 074 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 031 123 052 044 044 044 044 143 252 252 074 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 015 252 252 074 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 086 252 252 074 000 000 000 000 000 000 000 000
000 000 000 000 000 000 005 075 009 000 000 000 000 000 000 098 242 252 252 074 000 000 000 000 000 000 000 000
000 000 000 000 000 061 183 252 029 000 000 000 000 018 092 239 252 252 243 065 000 000 000 000 000 000 000 000
000 000 000 000 000 208 252 252 147 134 134 134 134 203 253 252 252 188 083 000 000 000 000 000 000 000 000 000
000 000 000 000 000 208 252 252 252 252 252 252 252 252 253 230 153 008 000 000 000 000 000 000 000 000 000 000
000 000 000 000 000 049 157 252 252 252 252 252 217 207 146 045 000 000 000 000 000 000 000 000 000 000 000 000
000 000 000 000 000 000 007 103 235 252 172 103 024 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000

使用 keras,可以很方便地导入 MNIST 数据集:

from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

总体来说,我们的想要得到的网络模型,是有一个固定的输入输出的:

  • 输入为一个 28 * 28 的二维整数数组
  • 输出是一个长度为 10 的数组,依次表示 0-9 的可能性(例如如果有一张图片 80% 概率为 1, 20% 概率为 7的话,那么这个数组就是 [0, 0.8, 0, 0, 0, 0, 0, 0.2, 0, 0]

二、使用 keras 训练网络

我们想要训练的模型,由以下几层网络组成:

  1. 32 个 3x3 卷积核的卷积层
  2. 64 个 3x3 卷积核的卷积层
  3. 采样因子为 (2, 2) 的池化层
  4. Dropout 层
  5. Flatten 层
  6. ReLu 全连接层
  7. Dropout 层
  8. Softmax 全连接层

用 keras 训练一个识别 MNIST 的 CNN 网络非常方便,下面是一个官方给出的例子(源码在此):

from __future__ import print_function
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K batch_size = 128
num_classes = 10
epochs = 12 # input image dimensions
img_rows, img_cols = 28, 28 # the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data() if K.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1) x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples') # convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes) model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
activation='relu',
input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax')) model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adadelta(),
metrics=['accuracy']) model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
verbose=1,
validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1]) # Save model
model.save('myMnistCNN.h5')

如果已经安装好了 keras,直接运行即可:

python mnist_cnn.py 

三、转换输出模型

获得训练好的 .h5 文件之后,模型还不能直接使用,因为我们需要对它进行转编码,keras-js 提供了一个 python 脚本来自动执行:

python ./python/encoder.py -q myMnistCNN.h5 

这个脚本会把 .h5 文件转编码为 keras-js 可读的格式,里面包含了训练好的神经网络的所有模型和参数。


四、使用 keras-js 导入模型

首先需要引入 keras-js,可以通过 script 标签直接引入:

<script src="https://unpkg.com/keras-js"></script> 

也可以通过 npm 安装后使用 webpack 构建引入,参考这里

接下来就可以直接创建一个 Model,keras-js 会自动加载对应的 bin 文件:

const model = new KerasJS.Model({
filepath: '/path/to/mnist_cnn.bin',
gpu: true,
transferLayerOutputs: true
})

初始化完毕之后,就可以用于 MNIST 识别了,输入是一个长度为 784 的数组(包含 28*28 各个像素点的灰度值),输出是一个长度为 10 的数组(0-9的概率):

(可以使用上文中给的那个 "3" 的数据范例)

model
.ready()
.then(() => {
// data 是一个长度为 784 的数组,每一项都介于 0 - 255 之间
// 这里我们需要把数组转换为 Float32 类型
const inputData = new Float32Array(data)
// 识别
return model.predict(inputData)
})
.then(outputData => {
// 输出为 0-9 的概率,例如:
// { output: [0, 0, 0, 0.8, 0, 0, 0.2, 0, 0, 0] }
})
.catch(err => {
// ...
})

五、Canvas 实现一个手写板

最后一步就是实现一个手写板,具体的代码就不放上来了,主要就是通过 mousedownmousemovemouseup 事件来绘制图形。

绘制完毕之后,调用 ctx.getImageData,就可以得到 canvas 内的像素数据,每个像素对应四个数值,依次是每个点的 rgba 值,处理之后就可以得到长度为 784 的灰度数组了。然后使用上文提到的 model.predict 即可。

基于 keras-js 快速实现浏览器内的 CNN 手写数字识别的更多相关文章

  1. Keras cnn 手写数字识别示例

    #基于mnist数据集的手写数字识别 #构造了cnn网络拟合识别函数,前两层为卷积层,第三层为池化层,第四层为Flatten层,最后两层为全连接层 #基于Keras 2.1.1 Tensorflow ...

  2. keras框架的CNN手写数字识别MNIST

    参考:林大贵.TensorFlow+Keras深度学习人工智能实践应用[M].北京:清华大学出版社,2018. 首先在命令行中写入 activate tensorflow和jupyter notebo ...

  3. 手写数字识别——利用keras高层API快速搭建并优化网络模型

    在<手写数字识别——手动搭建全连接层>一文中,我们通过机器学习的基本公式构建出了一个网络模型,其实现过程毫无疑问是过于复杂了——不得不考虑诸如数据类型匹配.梯度计算.准确度的统计等问题,但 ...

  4. keras框架的MLP手写数字识别MNIST,梳理?

    keras框架的MLP手写数字识别MNIST 代码: # coding: utf-8 # In[1]: import numpy as np import pandas as pd from kera ...

  5. [Python]基于CNN的MNIST手写数字识别

    目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...

  6. Keras mlp 手写数字识别示例

    #基于mnist数据集的手写数字识别 #构造了三层全连接层组成的多层感知机,最后一层为输出层 #基于Keras 2.1.1 Tensorflow 1.4.0 代码: import keras from ...

  7. 手写数字识别——基于LeNet-5卷积网络模型

    在<手写数字识别——利用Keras高层API快速搭建并优化网络模型>一文中,我们搭建了全连接层网络,准确率达到0.98,但是这种网络的参数量达到了近24万个.本文将搭建LeNet-5网络, ...

  8. mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)

    前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...

  9. 手写数字识别 ----在已经训练好的数据上根据28*28的图片获取识别概率(基于Tensorflow,Python)

    通过: 手写数字识别  ----卷积神经网络模型官方案例详解(基于Tensorflow,Python) 手写数字识别  ----Softmax回归模型官方案例详解(基于Tensorflow,Pytho ...

  10. 手写数字识别 ----卷积神经网络模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----卷积神经网络模型 import os import tensorflow as tf #部分注释来源于 # http://www.cnblogs.com/rgvb178/p/ ...

随机推荐

  1. MCP赋能,给Cursor插上“外挂翅膀”:实战操作数据库

    先给大家个例子, 展示如何用mcp如何带飞cursor的.  话不多说,  继续展示 1.建立项目 提示词如下: " 新建个java项目, 叫user-demo,  通过spring boo ...

  2. Golang解析yaml文件

    一.具体思路 将配置yaml文件内容解析为我们定义好的struct,这种比较简单,如果想获取对应的值,直接获取即可. 二.实现步骤 首先根据配置文件的内容定义一个结构体Config,结构体类型和yam ...

  3. 现在的AI工具还能写剧本杀了?

    本文由 ChatMoney团队出品 近年来,剧本杀作为一种新兴社交游戏,收到了越来越多人的喜爱,它不仅需要玩家们发挥自身演技,还需运用逻辑思维推理,分析所获得的线索,找出案件真凶.然而你是否想过,你在 ...

  4. opencv学习:学习如何对图像进行缩放、剪切、移位等处理

    又是每周一次的坑爹OPENCV!加油奥里给! 1.图像缩放--直接调用函数操作 Mat img = imread("E:/lena.jpg"); int img_cols = im ...

  5. 替换GitLab的方案之Gitea

    概述 官网:https://docs.gitea.com/zh-cn/ GitHub地址:https://github.com/go-gitea/gitea Gitea 是一个轻量级的 DevOps ...

  6. DRF之分页类源码分析

    DRF之分页类源码分析 [一]分页类介绍 Django REST framework(DRF)是一个用于构建Web API的强大工具,它提供了分页功能,使你能够控制API响应的数据量. 在DRF中,分 ...

  7. SQL Server数据库巡检

    查询所有表名 select name from sysobjects where xtype='u' select * from sys.tables 查询所有表名及对应架构 select t.[na ...

  8. window10本地搭建DeepSeek R1(一)

    本章介绍在window上部署 DeepSeek R1-8B + Open WebUI :需要安装的有:Ollama,python 3.11,DeepSeek ,Open WebUI. 一:环境:我的w ...

  9. 串口wifi模块、串口无线模块

    串口无线模块ZLSN7046T是上海卓岚推出的wifi转串口模块.它能够将wifi信号转化为串口信号,且支持多种功能,邮票孔封装,体积小巧可以外置天线或者内置天线.7046T支持一个UART TTL电 ...

  10. java--Hibernate对象状态、一级缓存、映射

    对象的状态 Hibernate中对象的状态: 临时/瞬时状态.持久化状态.游离状态.  临时状态 特点: 直接new出来的对象; 不处于session的管理; 数据库中没有对象的记录;  持久化状 ...