caffe团队用imagenet图片进行训练,迭代30多万次,训练出来一个model。这个model将图片分为1000类,应该是目前为止最好的图片分类model了。

假设我现在有一些自己的图片想进行分类,但样本量太小,可能只有几百张,而一般深度学习都要求样本量在1万以上,因此训练出来的model精度太低,根本用不上,那怎么办呢?

那就用caffe团队提供给我们的model吧。

因为训练好的model里面存放的就是一些参数,因此我们实际上就是把别人预先训练好的参数,拿来作为我们的初始化参数,而不需要再去随机初始化了。图片的整个训练过程,说白了就是将初始化参数不断更新到最优的参数的一个过程,既然这个过程别人已经帮我们做了,而且比我们做得更好,那为什么不用他们的成果呢?

使用别人训练好的参数,必须有一个前提,那就是必须和别人用同一个network,因为参数是根据network而来的。当然,最后一层,我们是可以修改的,因为我们的数据可能并没有1000类,而只有几类。我们把最后一层的输出类别改一下,然后把层的名称改一下就可以了。最后用别人的参数、修改后的network和我们自己的数据,再进行训练,使得参数适应我们的数据,这样一个过程,通常称之为微调(fine tuning).

既然前两篇文章我们已经讲过使用digits来进行训练和可视化,这样一个神器怎么能不使用呢?因此本文以此工具为例,讲解整个微调训练过程。

一、下载model参数

可以直接在浏览器里输入地址下载,也可以运行脚本文件下载。下载地址为:http://dl.caffe.berkeleyvision.org/bvlc_reference_caffenet.caffemodel

文件名称为:bvlc_reference_caffenet.caffemodel,文件大小为230M左右,为了代码的统一,将这个caffemodel文件下载到caffe根目录下的 models/bvlc_reference_caffenet/ 文件夹下面。也可以运行脚本文件进行下载:

# sudo ./scripts/download_model_binary.py models/bvlc_reference_caffenet

 二、准备数据

如果有自己的数据最好,如果没有,可以下载我的练习数据:http://pan.baidu.com/s/1MotUe

这些数据共有500张图片,分为大巴车、恐龙、大象、鲜花和马五个类,每个类100张。编号分别以3,4,5,6,7开头,各为一类。我从其中每类选出20张作为测试,其余80张作为训练。因此最终训练图片400张(放在train文件夹内,每个类一个子文件夹),测试图片100张(放在test文件夹内,每个类一个子文件夹)。

将图片下载下来后解压,放在一个文件夹内。比如我在当前用户根目录下创建了一个data文件夹,专门用来存放数据,因此我的训练图片路径为:/home/xxx/data/re/train

打开浏览器,运行digits,如果没有这个工具的,推荐安装,真的是学习caffe的神器。安装及使用可参见我的前两篇文章:Caffe学习系列(21):caffe图形化操作工具digits的安装与运行

新建一个classification dataset,设置如下图:

下面图片格式选为jpg, 为dataset取一个名字,就开始转换吧。结果如图:

三、设置model

回到digits根目录,新建一个classification model, 选中你的dataset, 开始设置最重要的network.

caffenet的网络配置文件,放在 caffe/models/bvlc_reference_caffenet/ 这个文件夹里面,名字叫train_val.prototxt。打开这个文件,将里面的内容复制到上图的Custom Network文本框里,然后进行修改,主要修改这几个地方:

1、修改train阶段的data层为:

layer {
name: "data"
type: "Data"
top: "data"
top: "label"
include {
phase: TRAIN
}
transform_param {
mirror: true
crop_size: 227
}
}

即把均值文件(mean_file)、数据源文件(source)、批次大小(batch_size)和数据源格式(backend)这四项都删除了。因为这四项系统会根据dataset和页面左边“solver options"的设置自动生成。

2、修改test阶段的data层:

layer {
name: "data"
type: "Data"
top: "data"
top: "label"
include {
phase: TEST
}
transform_param {
mirror: false
crop_size: 227
}
}

和上面一样,也是删除那些项。

3、修改最后一个全连接层(fc8):

layer {
name: "fc8-re" #原来为"fc8"
type: "InnerProduct"
bottom: "fc7"
top: "fc8"
param {
lr_mult: 1.0
decay_mult: 1.0
}
param {
lr_mult: 2.0
decay_mult: 0.0
}
inner_product_param {
num_output: 5 #原来为"1000"
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0.0
}
}
}

看注释的地方,就只有两个地方修改,其它不变。

设置好后,就可以开始微调了(fine tuning).

训练结果就是一个新的model,可以用来单张图片和多张图片测试。具体测试方法前一篇文章已讲过,在此就不重复了。

在此,将别人训练好的model用到我们自己的图片分类上,整个微调过程就是这样了。如果你不用digits,而直接用命令操作,那就更简单,只需要修改一个train_val.prototxt的配置文件就可以了,其它都是一样的操作。

2016.12.6更新

这篇文章是将近一年前写的,digits版本已经升级,所以有些地方设置有点变化,导致很多网友出现错误。最多的错误提示如下:

ERROR: Layer 'accuracy' references bottom 'label' at the TEST stage however this blob is not included at that stage. Please consider using an include directive to limit the scope of this layer.

我当时用的版本是digits 3.0, 现在大家用的是digits 4.0, 因此会出现这个错误。修改如下:

最后四层的设置:

layer {
name: "re-fc8"
type: "InnerProduct"
bottom: "fc7"
top: "fc8"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
inner_product_param { weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0
}
}
}
layer {
name: "accuracy"
type: "Accuracy"
bottom: "fc8"
bottom: "label"
top: "accuracy"
include {
stage:"val"
}
}
layer {
name: "loss"
type: "SoftmaxWithLoss"
bottom: "fc8"
bottom: "label"
top: "loss"
exclude{
stage:"deploy"
}
}
layer {
name: "prob"
type: "Softmax"
bottom: "fc8"
top: "prob"
include{
stage:"deploy"
}
}

  

原来网络结构中的全连接层fc8, 需要改一下名字,如我的改成"re-fc8". 因为我们做的是微调。微调的意思就是先在别的数据集上进行训练,把训练好的权值,作为我们现在数据集的权值初始化,就不再需要随机初始化了。现在的数据和训练时的数据不一致,因此有些层数的设置就会有点区别。比如这个例子中,用来训练模型的数据集是imagenet,分为1000类,而我们的数据集就只有5类,因此在fc8这层上的num_output就会有区别,因此在这一层上就不能用人家的权值了,就需要把这层的名字改得和原来的网络结构不一样。

在digits 4.0版本中,最后的全连接层不再需要num_output这个参数了,因此大家需要把这行删除掉。digits会自动根据你的类别数把这个参数补充上。

也许原来的配置文件中没有Softmax层,现在需要加上这一层,因为digits会根据这里的设置自动生成train_test.prototxt和deploy.prototxt两个文件。其它需要修改的地方,就是最后三层的include和exclude了。

最后还有一个问题就是显存的问题。实话讲我的这个训练数据选得不太好,很吃显存,有些GPU不好的同学,运行起来很吃力。因此大家非要用这个数据的话,建议把batch_size调低些。我用的是nvidia k20, 4G显存,batch_size设置为16—32之间,运行得不错,1分钟左右运行完。

Caffe学习系列(23):如何将别人训练好的model用到自己的数据上的更多相关文章

  1. 【神经网络与深度学习】如何将别人训练好的model用到自己的数据上

    caffe团队用imagenet图片进行训练,迭代30多万次,训练出来一个model.这个model将图片分为1000类,应该是目前为止最好的图片分类model了. 假设我现在有一些自己的图片想进行分 ...

  2. Caffe学习系列(四)之--训练自己的模型

    前言: 本文章记录了我将自己的数据集处理并训练的流程,帮助一些刚入门的学习者,也记录自己的成长,万事起于忽微,量变引起质变. 正文: 一.流程 1)准备数据集  2)数据转换为lmdb格式  3)计算 ...

  3. caffe学习系列(2):训练和测试自己的图片

    参考:http://www.cnblogs.com/denny402/p/5083300.html 上述主要介绍的是从自己的原始图片转为lmdb数据,再到训练.测试的整个流程(另外可参考薛开宇的笔记) ...

  4. Caffe 学习系列

    学习列表: Google protocol buffer在windows下的编译 caffe windows 学习第一步:编译和安装(vs2012+win 64) caffe windows学习:第一 ...

  5. Caffe学习系列(12):训练和测试自己的图片

    学习caffe的目的,不是简单的做几个练习,最终还是要用到自己的实际项目或科研中.因此,本文介绍一下,从自己的原始图片到lmdb数据,再到训练和测试模型的整个流程. 一.准备数据 有条件的同学,可以去 ...

  6. 转 Caffe学习系列(12):训练和测试自己的图片

    学习caffe的目的,不是简单的做几个练习,最终还是要用到自己的实际项目或科研中.因此,本文介绍一下,从自己的原始图片到lmdb数据,再到训练和测试模型的整个流程. 一.准备数据 有条件的同学,可以去 ...

  7. Caffe学习系列(12):训练和测试自己的图片--linux平台

    Caffe学习系列(12):训练和测试自己的图片   学习caffe的目的,不是简单的做几个练习,最终还是要用到自己的实际项目或科研中.因此,本文介绍一下,从自己的原始图片到lmdb数据,再到训练和测 ...

  8. Caffe学习系列(22):caffe图形化操作工具digits运行实例

    上接:Caffe学习系列(21):caffe图形化操作工具digits的安装与运行 经过前面的操作,我们就把数据准备好了. 一.训练一个model 右击右边Models模块的” Images" ...

  9. Caffe学习系列(21):caffe图形化操作工具digits的安装与运行

    经过前面一系列的学习,我们基本上学会了如何在linux下运行caffe程序,也学会了如何用python接口进行数据及参数的可视化. 如果还没有学会的,请自行细细阅读: caffe学习系列:http:/ ...

随机推荐

  1. BFS、DFS与选课问题(拓扑排序)

    1选课问题 Leetcode上有这样一道题:有代号0,1,2……n-1的n门课程.其中选择某些课程需要另一些课程作为前提条件.用一组pair来表示这些条件:[1,0],[1,2],表示如果要选修课程1 ...

  2. WPF学习之路(十四)样式和模板

    样式 实例: <Window.Resources> <Style x:Key="BtnStyle"> <Setter Property=" ...

  3. Tableau——BI software

    Tableau 8权威指南 (权威的数据可视化实战手册,中国传媒大学教授沈浩.北京大学研究员袁晓如 联袂推荐) 触手可及的大数据分析工具——Tableau案例集 写给专业数据分析师的丛书,无门槛的大数 ...

  4. 【转发】NPAPI开发详解,Windows版

    NPAPI开发详解,Windows版 9 jiaofeng601, +479 9人支持,来自Meteor.猪爪.hanyuxinting更多 .是非黑白 .Yuan Xulei.hyolin.Andy ...

  5. x01.os.9: 进程切换

    进入内核后,当然不能无所事事.先创建三个进程,分别打印 A,B,C.虽然只是简单的打印,但却是一切扩展的基础,不可等闲视之. 进程切换,涉及一系列的寄存器需要保护,于是,就有了 ProcessStac ...

  6. System V IPC(1)-消息队列

    一.概述                                                    System V三种IPC:消息队列,信号量,共享内存.这三种IPC最先出现在AT&am ...

  7. 荒芜的周六-PHP之面向对象(三)

    hi 又是开森的周六了.积攒的两周的衣服,终于是差不多洗完了.大下午的才来学点东西~~ 1.PHP面向对象(三) 四.OOP的高级实践 4.3 Static-静态成员 <?phpdate_def ...

  8. WinCE及Windows软件开发相关书籍转让

    从开始做WinCE开发到现在已经十多年了,最初可以学习和参考的资料并不多,那时候还没有Stack Overflow,Google也还可以正常访问.遇到问题时,一般都在Google Groups的mic ...

  9. 【Python数据分析】简单爬虫 爬取知乎神回复

    看知乎的时候发现了一个 “如何正确地吐槽” 收藏夹,里面的一些神回复实在很搞笑,但是一页一页地看又有点麻烦,而且每次都要打开网页,于是想如果全部爬下来到一个文件里面,是不是看起来很爽,并且随时可以看到 ...

  10. ExecutorType 的类型