具体代码见https://github.com/zhiyishou/py-faster-rcnn



这是我对cup, glasses训练的识别

faster-rcnn在fast-rcnn的基础上加了rpn来将整个训练都置于GPU内,以用来提高效率,这里我们将使用ImageNet的数据集来在faster-rcnn上来训练自己的分类器。从ImageNet上可下载到很多类别的Image与bounding box annotation来进行训练(每一个类别下的annotation都少于等于image的个数,所以我们从annotation来建立索引)。

lib/dataset/factory.py中提供了coco与voc的数据集获取方法,而我们要做的就是在这里加上我们自己的ImageNet获取方法,我们先来建立ImageNet数据获取主文件。coco与pascal_voc的获取都是继承于父类imdb,所以我们可根据pascal_voc的获取方法来做模板修改完成我们的ImageNet类。

创建ImageNet类

由于在faster-rcnn里使用rpn来代替了selective_search,所以我们可以在使用时直接略过有关selective_search的方法,根据pascal_voc类做模板,我们需要留下的方法有:

__init__ //初始化
image_path_at //根据数据集列表的index来取图片绝对地址
image_path_from_index //配合上面
_load_image_set_index //获取数据集列表
_gt_roidb //获取ground-truth数据
rpn_roidb //获取region proposal数据
_load_rpn_roidb //根据gt_roidb生成rpn_roidb数据并合成
_load_psacal_annotation //加载annotation文件并对bounding box进行数据整理

__init__:

def __init__(self, image_set):
imdb.__init__(self, 'imagenet')
self._image_set = image_set
self._data_path = os.path.join(cfg.DATA_DIR, "imagenet")
#类别与对应的wnid,可以修改成自己要训练的类别
self._class_wnids = {
'cup': 'n03147509',
'glasses': 'n04272054'
} #类别,修改类别时同时要修改这里
self._classes = ('__background__', self._class_wnids['cup'], self._class_wnids['glasses'])
self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
#bounding box annotation 文件的目录
self._xml_path = os.path.join(self._data_path, "Annotations")
self._image_ext = '.JPEG'
#我们使用xml文件名来做数据集的索引
# the xml file name and each one corresponding to image file name
self._image_index = self._load_xml_filenames()
self._salt = str(uuid.uuid4())
self._comp_id = 'comp4' self.config = {'cleanup' : True,
'use_salt' : True,
'use_diff' : False,
'matlab_eval' : False,
'rpn_file' : None,
'min_size' : 2} assert os.path.exists(self._data_path), \
'Path does not exist: {}'.format(self._data_path)

image_path_at

def image_path_at(self, i):
#使用index来从xml_filenames取到filename,生成绝对路径
return self.image_path_from_image_filename(self._image_index[i])

image_path_from_image_filename(类似pascal_voc中的image_path_from_index)

def image_path_from_image_filename(self, image_filename):
image_path = os.path.join(self._data_path, 'Images',
image_filename + self._image_ext)
assert os.path.exists(image_path), \
'Path does not exist: {}'.format(image_path)
return image_path

_load_xml_filenames(类似pascal_voc中的_load_image_set_index)

def _load_xml_filenames(self):
#从Annotations文件夹中拿取到bounding box annotation文件名
#用来做数据集的索引
xml_folder_path = os.path.join(self._data_path, "Annotations")
assert os.path.exists(xml_folder_path), \
'Path does not exist: {}'.format(xml_folder_path) for dirpath, dirnames, filenames in os.walk(xml_folder_path):
xml_filenames = [xml_filename.split(".")[0] for xml_filename in filenames] return xml_filenames

gt_roidb

def gt_roidb(self):
#Ground-Truth 数据缓存
cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
if os.path.exists(cache_file):
with open(cache_file, 'rb') as fid:
roidb = cPickle.load(fid)
print '{} gt roidb loaded from {}'.format(self.name, cache_file)
return roidb #从xml中获取Ground-Truth数据
gt_roidb = [self._load_imagenet_annotation(xml_filename)
for xml_filename in self._image_index]
with open(cache_file, 'wb') as fid:
cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
print 'wrote gt roidb to {}'.format(cache_file) return gt_roidb

rpn_roidb

def rpn_roidb(self):
#根据gt_roidb生成rpn_roidb,并进行合并
gt_roidb = self.gt_roidb()
rpn_roidb = self._load_rpn_roidb(gt_roidb)
roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb) return roidb

_load_rpn_roidb

def _load_rpn_roidb(self, gt_roidb):
filename = self.config['rpn_file']
print 'loading {}'.format(filename)
assert os.path.exists(filename), \
'rpn data not found at: {}'.format(filename)
with open(filename, 'rb') as f:
box_list = cPickle.load(f)
return self.create_roidb_from_box_list(box_list, gt_roidb)

_load_imagenet_annotation(类似于pascal_voc中的_load_pascal_annotation)

def _load_imagenet_annotation(self, xml_filename):
#从annotation的xml文件中拿取bounding box数据
filepath = os.path.join(self._data_path, 'Annotations', xml_filename + '.xml')
#这里使用了ap,是我写的一个annotation parser,在后面贴出代码
#它会返回这个xml文件的wnid, 图像文件名,以及里面包含的注解物体
wnid, image_name, objects = ap.parse(filepath)
num_objs = len(objects) boxes = np.zeros((num_objs, 4), dtype=np.uint16)
gt_classes = np.zeros((num_objs), dtype=np.int32)
overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
seg_areas = np.zeros((num_objs), dtype=np.float32) # Load object bounding boxes into a data frame.
for ix, obj in enumerate(objects):
box = obj["box"]
x1 = box['xmin']
y1 = box['ymin']
x2 = box['xmax']
y2 = box['ymax']
# 如果这个bounding box并不是我们想要学习的类别,那则跳过
# go next if the wnid not exist in declared classes
try:
cls = self._class_to_ind[obj["wnid"]]
except KeyError:
print "wnid %s isn't show in given"%obj["wnid"]
continue
boxes[ix, :] = [x1, y1, x2, y2]
gt_classes[ix] = cls
overlaps[ix, cls] = 1.0
seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1) overlaps = scipy.sparse.csr_matrix(overlaps) return {'boxes' : boxes,
'gt_classes': gt_classes,
'gt_overlaps' : overlaps,
'flipped' : False,
'seg_areas' : seg_areas}

annotation_parser.py文件

import os
import xml.dom.minidom def getText(node):
return node.firstChild.nodeValue def getWnid(node):
return getText(node.getElementsByTagName("name")[0]) def getImageName(node):
return getText(node.getElementsByTagName("filename")[0]) def getObjects(node):
objects = []
for obj in node.getElementsByTagName("object"):
objects.append({
"wnid": getText(obj.getElementsByTagName("name")[0]),
"box":{
"xmin": int(getText(obj.getElementsByTagName("xmin")[0])),
"ymin": int(getText(obj.getElementsByTagName("ymin")[0])),
"xmax": int(getText(obj.getElementsByTagName("xmax")[0])),
"ymax": int(getText(obj.getElementsByTagName("ymax")[0])),
}
})
return objects def parse(filepath):
dom = xml.dom.minidom.parse(filepath)
root = dom.documentElement
image_name = getImageName(root)
wnid = getWnid(root)
objects = getObjects(root) return wnid, image_name, objects

则对数据结构的要求是:

|---data
|---imagenet
|---Annotations
|---n03147509
|---n03147509_*.xml
|---...
|---n04272054
|---n04272054_*.xml
|---...
|---Images
|---n03147508_*.JPEG
|---...
|---n04272054_*.JPEG
|---...

同时我在github上也提供了draw方法,可以用来将bounding box画于Image文件上,用来甄别该annotation的正确性

训练

这样,我们的ImageNet类则是生成好了,下面我们则可以训练我们的数据,但是在开始之前,还有一件事情,那就是修改prototxt中的与类别数目有关的值,我将models/pascal_voc拷贝到了models/imagenet进行修改,比如我想要训练ZF,如果使用的是train_faster_rcnn_alt_opt.py,则需要修改models/imagenet/ZF/faster_rcnn_alt_opt/下的所有pt文件里的内容,用如下的法则去替换:

//num为类别的个数
input-data->num_classes = num
class_score->num_output = num
bbox_pred->num_output = num*4

我这里使用train_faster_rcnn_alt_opt.py进行的训练,这样的话则需要把添加的models/imagenet作为可选项

//pt_type 则是添加的选择项,默认使用psacal_voc的models
./tools/train_faster_rcnn_alt_opt.py --gpu 0 \
--net_name ZF \
--weights data/imagenet_models/ZF.v2.caffemodel[optional] \
--imdb imagenet \
--cfg experiments/cfgs/faster_rcnn_alt_opt.yml \
--pt_type imagenet

识别

这里我们则需要使用刚训练出来的模型进行识别

#就像demo.py一样,但是使用训练的models,我创建了tools/classify.py来单独识别
prototxt = os.path.join(cfg.ROOT_DIR, 'models/imagenet', NETS[args.demo_net][0], 'faster_rcnn_alt_opt', 'faster_rcnn_test.pt')
caffemodel = os.path.join(cfg.ROOT_DIR, 'output/faster_rcnn_alt_opt/imagenet/'+ NETS[args.demo_net][0] +'_faster_rcnn_final.caffemodel')

同样,在识别前我们要对识别方法里的Classes进行修改,修改成你自己训练的类别后

执行

./tools/classify.py --net zf

则可对data/demo下的图片文件使用训练的zf网络进行识别

Have fun

使用ImageNet在faster-rcnn上训练自己的分类网络的更多相关文章

  1. Faster RCNN算法训练代码解析(1)

    这周看完faster-rcnn后,应该对其源码进行一个解析,以便后面的使用. 那首先直接先主函数出发py-faster-rcnn/tools/train_faster_rcnn_alt_opt.py ...

  2. Faster RCNN算法训练代码解析(3)

    四个层的forward函数分析: RoIDataLayer:读数据,随机打乱等 AnchorTargetLayer:输出所有anchors(这里分析这个) ProposalLayer:用产生的anch ...

  3. Faster RCNN算法训练代码解析(2)

    接着上篇的博客,我们获取imdb和roidb的数据后,就可以搭建网络进行训练了. 我们回到trian_rpn()函数里面,此时运行完了roidb, imdb = get_roidb(imdb_name ...

  4. 目标检测(四)Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks

    作者:Shaoqing Ren, Kaiming He, Ross Girshick, and Jian Sun SPPnet.Fast R-CNN等目标检测算法已经大幅降低了目标检测网络的运行时间. ...

  5. Faster R-CNN利用新的网络结构来训练

    前言 最近利用Faster R-CNN训练数据,使用ZF模型,效果无法有效提高.就想尝试对ZF的网络结构进行改造,记录下具体操作. 一.更改网络,训练初始化模型 这里为了方便,我们假设更换的网络名为L ...

  6. object detection[faster rcnn]

    这部分,写一写faster rcnn 0. faster rcnn 经过了rcnn,spp,fast rcnn,又到了faster rcnn,作者在对前面的模型回顾中发现,fast rcnn提出的ro ...

  7. 基于候选区域的深度学习目标检测算法R-CNN,Fast R-CNN,Faster R-CNN

    参考文献 [1]Rich feature hierarchies for accurate object detection and semantic segmentation [2]Fast R-C ...

  8. 【神经网络与深度学习】【计算机视觉】Faster R-CNN

    Faster R-CNN Fast-RCNN基本实现端对端(除了proposal阶段外),下一步自然就是要把proposal阶段也用CNN实现(放到GPU上).这就出现了Faster-RCNN,一个完 ...

  9. Paper Reading:Faster RCNN

    Faster R-CNN 论文:Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks 发表时间: ...

随机推荐

  1. Js闭包函数

    一.变量的作用域要理解闭包,首先必须理解Javascript特殊的变量作用域.变量的作用域无非就是两种:全局变量和局部变量.Javascript语言的特殊之处,就在于函数内部可以直接读取全局变量. ( ...

  2. bzoj2764 基因补全

    Description 在生物课中我们学过,碱基组成了DNA(脱氧核糖核酸),他们分别可以用大写字母A,C,T,G表示,其中A总与T配对,C总与G配对.两个碱基序列能相互匹配,当且仅当它们等长,并且任 ...

  3. sql数据库带补全终端命令

    mysql pip install mycli pgsql pip install pgcli 都是python脚本,记录备忘.

  4. Softerra LDAP Browser 使用及配置 有图有真相

    Softerra LDAP Browser 4.5 我使用Softerra LDAP Browser的目的,是为了查找公司的人员信息.网上关于Softerra LDAP Browser配置太少了,所以 ...

  5. 【转】SQL SERVER CLR存储过程实现

    最近做一个项目,需要做一个SQL SERVER 2005的CLR的存储过程,研究了一下CLR的实现.为方便以后再使用,在这里总结一下我的实现流程,也供对CLR感兴趣但又不知道如何实现的朋友们做一下参考 ...

  6. (PowerShell) Managing Windows Registry

    http://powershell.com/cs/blogs/ebookv2/archive/2012/03/23/chapter-16-managing-windows-registry.aspx

  7. HDU Count the string+Next数组测试函数

    链接:http://www.cnblogs.com/jackge/archive/2013/04/20/3032942.html 题意:给定一字符串,求它所有的前缀出现的次数的和.这题很纠结,一开始不 ...

  8. [复变函数]第15堂课 4.3 解析函数的 Taylor 展式

    1.  Taylor 定理: 设 $f(z)$ 在 $K:|z-a|<R$ 内解析, 则 $$\bee\label{15:taylor} f(z)=\sum_{n=0}^\infty c_n(z ...

  9. listview当选中某一个item时设置背景色其他的不变

    listview当选中某一个item时设置背景色其他的不变: 可以使用listview.setOnFoucsChangeListener(listener) ; /** * listview获得焦点和 ...

  10. vs2015编译boost 64位

    ---恢复内容开始--- step 1: 打开Developer Command Prompt for VS2015命令行窗口 step 2: 执行bootstrap.bat,产生bjam.exe s ...