我假设已经成功编译caffe,如果没有,请参考http://caffe.berkeleyvision.org/installation.html

在本教程中,我假设你的caffe安装目录是CAFFE_ROOT

一.数据准备

首先,你需要从MNIST网站下载mnist数据,并转换数据格式。可以通过执行以下命令来实现

cd $CAFFE_ROOT
./data/mnist/get_mnist.sh
./examples/mnist/create_mnist.sh

如果显示没有安装wget或者gunzip,那么你需要分别安装。运行以上脚本之后,
examples/mnist文件夹下应该有以下两个文件夹:mnist_lmdb 和 mnist_test_lmdb

至此,数据准备完毕。

二. LeNet: the MNIST 分类模型(classification Model)

在我运行训练程序之前,让我解释下发生了什么。我们使用LeNet网络,LeNet因为在数字分类任务

表现非常不错而受到关注。我们使用了和LeNet原始实现有轻微不同的版本。我们用ReLU激活函数

代替了神经元的sigmoid激活函数。

LeNet的设计包含了CNNs的特性,这些特性仍然被用在类似ImageNet的这样的大型模型中。实际上,

LeNet包含卷积层,卷积层后面跟随着池化层,然后另外一层卷积层跟着这一层池化层,然后跟着两个

全连接层,和传统的多成感知相识。我们在以下文件定义了这些网络层:

$CAFFE_ROOT/examples/mnist/lenet_train_test.prototxt.

三. 定义MNIST网络

这节讲述了 lenet_train_test.prototxt 的模型定义,用于手写数字分类(识别).

我们假设你熟悉Goople Protobuf,并认为你读过caffe使用的protobuf定义,你可以

在以下路径中找到:$CAFFE_ROOT/src/caffe/proto/caffe.proto.

具体而言,我们会写一个caffe::NetParameter(或者用python, caffe.proto.caffe_pb2.NetParameter) protobuf.

我们将会通过给一个网络名字开始:

name:"LetNet"

数据层

当前,我们将会从刚刚创建的lmdb中读取MNIST数据,从lmdb中读取数据在data layer中定义

layer{

name: "mnist"

type: "Data"

transform_param{

scale:0.00390625

}

data_param{

source: "mnist_train_lmdb"

backend: LMDB

batch_size:64

}

top: "data"

top: "label"

}

具体而言,这网络层的名字为mnist, 类型为data,这个网络层从给定的lmdb source读取数据。我们使的

batch_size为64,我们缩放输入的图像像素,这样可以让想素质的范围落在[0,1]之间,为什么是0。00390625呢?

因为1/256=0.00390625。最后,该网络层产生两个blobs,一个是data blob, 另外一个是label blob

卷积层

让我们开始定义第一层卷积层吧。

layer{

name: "conv1"

type: "Convolution"

param { lr_mult:1 }

param { lr_mult:2 }

convolution_param {

num_output: 20

kernel_size: 5

stride: 1

weight_filler {

type: "xavier"

}

bias_filler {

type: "constant"

}

}

bottom: "data"

top: "conv1"

}

这层(第一层卷积层)接收data blob(数据层产生的数据)然后生成conv1 layer.conv1 产生20个通道的输出,

卷积核大小为5x5,步长为1。

filler允许我们随机初始化权重和偏置的值,对于weight filler, 我们使用xavier算法,该算法基于输入输出神经元的数量

自动决定初始化的尺度。对于偏置,我们简单的将其初始化为constant, 默认为0。

lr_mults是对于层的可学习参数的学习率的调整。在这个例子中,在运行期间我们将会把

权重学习率设置成solver给的学习率相同,偏置学习率是solver给的学习率的两倍,因为

这样有利于收敛速率。

池化层

实际上池化层更好定义。

layer {

name: "pool1"

type: "Pooling"

pooling_param {

kernel_size: 2

stride: 2

pool: MAX

}

bottom: "conv1"

top: "pool1"

}

以上定义的意思是说我们会通过2x2的过滤器,和步长为2的方式执行最大池化

(所以相邻的池化区域不会产生重叠)

同样,你可以写第二层卷积层和池化层。详细内容查看:

$CAFFE_ROOT/examples/mnist/lenet_train_test.prototxt

全连接层

写全连接层同样简单

layer {

name: "ip1"

type: "InnerProduct"

param { lt_mult: 1}

param { lr_mult: 2}

inner_product_param {

num_output: 500

weigh_filler {

type: "xavier"

}

bias_filler {

type: "constant"

}

bottom: "pool2"

top: "ip1"

}

这个定义为全连接层(在caffe框架中,我们称InnerProduct layer)有500个输出。

ReLU层

ReLU层也一样简单

Layer {

name: "relu1"

type: "ReLU"

bottom: "ip1"

top: "ip1"

}

因为ReLU是元素层面的运算,我们可以do in-place运算来保存记忆。这是通过给bottom和top blobs

相同的名字来实现。当然,不要在其它层类型给重复的blob名字

ReLU层之后,我们会写另外一层的全连接层

layer {

name: "ip2"

type: "InnerProduct:

param {lr_mult: 1 }

param { lr_mult: 2 }

inner_product_param {

num_output: 10

weight_filler {

type: "xavier"

}

bias_filler {

type: "constant"

}

bottom: "ip1"

top: "ip2"

}

损失层

最后,我们写损失层

layer {

name: "loss"

type: "SoftmaxWidthLoss"

bottom: "ip2:

bottom: "label"

}

softmax_loss 层实现了softmax和多项后勤损失(multinomial logistic loss).

softmax_loss takes two blobs, 第一个是预测,第二个是给数据层提供标签。

它不产生任何输出。它所做的就是开始反向传播的时候计算损失函数的值,并report,

并依据ip2层初始化梯度。

额外的提示:写神经网络层的规则。

神经网络层的定义如下:

layer {

// ...layer definition

include: { phase: TRAIN }

}

这是一个规则,基于当前的网络状态,控制层包含在网络里面。关于层规则和模型原理的更多规则,

你可以参考:$CAFFE_ROOT/src/caffe/proto/caffe.proto

在上面的例子,这层只包括TRAIN phase。如果我们把TRAIN换成TEST, 那么这层会只包括test phase。

默认情况下,没有层规则,一层总是被包含在网络里面。因此,lenet_train_test.prototxt有两层数据层定义

(with different batch_size), 一层是用来训练,另一层是测试期间使用。同样,在TEST phase 包含精度层

(Accuracy layer), 用来每100次迭代汇报一次精度,在lenet_solver.prototxt定义。

四. 定义MNIST Solver

仔细检查每prototxt每一行的解释:$CAFFE_ROOT/examples/mnist/lenet_solver.prototxt:

# The train/test net protocol buffer definition
net: "examples/mnist/lenet_train_test.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100
# Carry out testing every 500 training iterations.
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.01
momentum: 0.9
weight_decay: 0.0005
# The learning rate policy
lr_policy: "inv"
gamma: 0.0001
power: 0.75
# Display every 100 iterations
display: 100
# The maximum number of iterations
max_iter: 10000
# snapshot intermediate results
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet"
# solver mode: CPU or GPU
solver_mode: GPU 五. 训练和测试模型
你写了网络定义protobuf和solver protobuf 文件之后,训练和测试是非常简单的。
简单的执行train_lenet.sh, 或者执行以下命令:
cd $CAFFE_ROOT
./examples/mnist/train_lenet.sh train_lenet.sh是一个简单的脚本,但是这里有一个简单的解释:主要训练工具是caffe. 当你运行代码的时候,你会看到很多如下的信息:
I1203 net.cpp:66] Creating Layer conv1
I1203 net.cpp:76] conv1 <- data
I1203 net.cpp:101] conv1 -> conv1
I1203 net.cpp:116] Top shape: 20 24 24
I1203 net.cpp:127] conv1 needs backward computation. 这些信息告诉你每一层的细节,它的连接和它的输出模型。这些信息也许有助于你调试。
初始化之后,将会开始训练:
I1203 net.cpp:142] Network initialization done.
I1203 solver.cpp:36] Solver scaffolding done.
I1203 solver.cpp:44] Solving LeNet
基于solver的设置,每100次迭代我们将会打印训练损失函数;每500次迭代将会测试一次网络。
你会看到如下信息:
I1203 solver.cpp:204] Iteration 100, lr = 0.00992565
I1203 solver.cpp:66] Iteration 100, loss = 0.26044
...
I1203 solver.cpp:84] Testing net
I1203 solver.cpp:111] Test score #0: 0.9785
I1203 solver.cpp:111] Test score #1: 0.0606671
对于每一次训练迭代,lr是迭代的训练速率,loss是训练函数。对于测试期间的输出,
score 0 是精度, score 1是测试的损失。 几分钟之后就完成了。
I1203 solver.cpp:84] Testing net
I1203 solver.cpp:111] Test score #0: 0.9897
I1203 solver.cpp:111] Test score #1: 0.0324599
I1203 solver.cpp:126] Snapshotting to lenet_iter_10000
I1203 solver.cpp:133] Snapshotting solver state to lenet_iter_10000.solverstate
I1203 solver.cpp:78] Optimization Done.
最后的模型,会以二进制protobuf文件储存。储存在lenet_iter_1000 如果你用实际情况的数据训练,你可以部署你训练的模型在你的应用中。 原文网址:http://caffe.berkeleyvision.org/gathered/examples/mnist.html
 
 


 
 

基于LeNet的手写汉字识别(caffe)的更多相关文章

  1. <脱机手写汉字识别若干关键技术研究>

    脱机手写汉字识别若干关键技术研究 对于大字符集识别问题,一般采用模板匹配的算法,主要是因为该算法比较简单,识别速度快.但直接的模板匹配算法往往无法满足实际应用中对识别精度的需求.为此任俊玲编著的< ...

  2. 基于LeNet网络的中文验证码识别

    基于LeNet网络的中文验证码识别 由于公司需要进行了中文验证码的图片识别开发,最近一段时间刚忙完上线,好不容易闲下来就继上篇<基于Windows10 x64+visual Studio2013 ...

  3. 【Caffe 测试】Training LeNet on MNIST with Caffe

    Training LeNet on MNIST with Caffe We will assume that you have Caffe successfully compiled. If not, ...

  4. 基于Python使用SVM识别简单的字符验证码的完整代码开源分享

    关键字:Python,SVM,字符验证码,机器学习,验证码识别 1   概述 基于Python使用SVM识别简单的验证字符串的完整代码开源分享. 因为目前有了更厉害的新技术来解决这类问题了,但是本文作 ...

  5. 基于FPGA的肤色识别算法实现

    大家好,给大家介绍一下,这是基于FPGA的肤色识别算法实现. 我们今天这篇文章有两个内容一是实现基于FPGA的彩色图片转灰度实现,然后在这个基础上实现基于FPGA的肤色检测算法实现. 将彩色图像转化为 ...

  6. 基于MATLAB的人脸识别算法的研究

    基于MATLAB的人脸识别算法的研究 作者:lee神 现如今机器视觉越来越盛行,从智能交通系统的车辆识别,车牌识别到交通标牌的识别:从智能手机的人脸识别的性别识别:如今无人驾驶汽车更是应用了大量的机器 ...

  7. 基于FPGA的数字识别的实现

    欢迎大家关注我的微信公众号:FPGA开源工作室     基于FPGA的数字识别的实现二 作者:lee神 1 背景知识 1.1基于FPGA的数字识别的方法 通常,针对印刷体数字识别使用的算法有:基于模版 ...

  8. 【文智背后的奥秘】系列篇——基于CRF的人名识别

    版权声明:本文由文智原创文章,转载请注明出处: 文章原文链接:https://www.qcloud.com/community/article/133 来源:腾云阁 https://www.qclou ...

  9. 基于 OpenCV 的人脸识别

    基于 OpenCV 的人脸识别 一点背景知识 OpenCV 是一个开源的计算机视觉和机器学习库.它包含成千上万优化过的算法,为各种计算机视觉应用提供了一个通用工具包.根据这个项目的关于页面,OpenC ...

随机推荐

  1. 动图+源码,演示Java中常用数据结构执行过程及原理

    最近在整理数据结构方面的知识, 系统化看了下Java中常用数据结构, 突发奇想用动画来绘制数据流转过程. 主要基于jdk8, 可能会有些特性与jdk7之前不相同, 例如LinkedList Linke ...

  2. 为什么操作DOM会影响WEB应用的性能?

    面试官经常会问你:"平时工作中,你怎么优化自己应用的性能?" 你回答如下:"我平时遵循以下几条原则来优化我的项目.以提高性能,主要有:" a. 减少DOM操作的 ...

  3. JavaScript数组方法大全(第一篇)

    数组方法大全(第一篇) 注意:第一次写博客有点小紧张,如有错误欢迎指出,如有雷同纯属巧合,本次总结参考书籍JavaScript权威指南,有兴趣的小伙伴可以去翻阅一下哦 join()方法 该方法是将数组 ...

  4. 调度系统Airflow1.10.4调研与介绍和docker安装

    Airflow1.10.4介绍与安装 现在是9102年,8月中旬.airflow当前版本是1.10.4. 随着公司调度任务增大,原有的,基于crontab和mysql的任务调度方案已经不太合适了,需要 ...

  5. Unity 自定义Inspector面板时的数据持久化问题

    自定义Inspector面板的步骤: Unity内创建自定义的Inspector需要在Asset的任意文件夹下创建一个名字是Editor的文件夹,随后这个文件夹内的cs文件就会被放在vstu生成的Ed ...

  6. abp(net core)+easyui+efcore实现仓储管理系统——使用 WEBAPI实现CURD (十五)

    core)+easyui+efcore实现仓储管理系统目录 abp(net core)+easyui+efcore实现仓储管理系统——ABP总体介绍(一) abp(net core)+easyui+e ...

  7. 企查查app新增企业数据抓取

    企查查每日新增企业数据抓取尚未完成的工作: 需要自行抓包获取设备id,appid,sign等等 sign和时间戳保持一致即可 把所有的数据库.redis配置 无法自动登录,账号需要独立 redis数据 ...

  8. k好数(动态规划)

    问题描述 如果一个自然数N的K进制表示中任意的相邻的两位都不是相邻的数字,那么我们就说这个数是K好数.求L位K进制数中K好数的数目.例如K = 4,L = 2的时候,所有K好数为11.13.20.22 ...

  9. Postman系列五:Postman中电商网站cookie、token检验与参数传递实战

    一:Postman中电商网站cookie实战 Postman接口请求使用cookie两种方式: 1.直接在header(头域)中添加cookie,适用于已知请求cookie头域的情况 2.使用Post ...

  10. css3的@media

    都知道bootstrap响应式布局很酷,但是是怎么实现的呢?其官网首页有提到这一切的功劳都是来自于CSS 媒体查询(Media Query). 使用 @media 查询,你可以针对不同的媒体类型定义不 ...