# train_net.py
#!/usr/bin/env python # --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# -------------------------------------------------------- """Train a Fast R-CNN network on a region of interest database.""" import _init_paths
from fast_rcnn.train import get_training_roidb, train_net
from fast_rcnn.config import cfg, cfg_from_file, cfg_from_list, get_output_dir
from datasets.factory import get_imdb
import datasets.imdb
import caffe
import argparse
import pprint
import numpy as np
import sys def parse_args():
# 运行时命令行参数
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
parser.add_argument('--gpu', dest='gpu_id',
help='GPU device id to use [0]',
default=0, type=int)
parser.add_argument('--solver', dest='solver',
help='solver prototxt',
default=None, type=str)
parser.add_argument('--iters', dest='max_iters',
help='number of iterations to train',
default=40000, type=int)
parser.add_argument('--weights', dest='pretrained_model',
help='initialize with pretrained model weights',
default=None, type=str)
parser.add_argument('--cfg', dest='cfg_file',
help='optional config file',
default=None, type=str)
parser.add_argument('--imdb', dest='imdb_name',
help='dataset to train on',
default='voc_2007_trainval', type=str)
parser.add_argument('--rand', dest='randomize',
help='randomize (do not use a fixed seed)',
action='store_true')
parser.add_argument('--set', dest='set_cfgs',
help='set config keys', default=None,
nargs=argparse.REMAINDER) if len(sys.argv) == 1:
parser.print_help()
sys.exit(1) args = parser.parse_args()
return args def combined_roidb(imdb_names):
# 融合roidb,roidb来自于数据集(实验可能用到多个),所以需要combine多个数据集的roidb
#
def get_roidb(imdb_name):
imdb = get_imdb(imdb_name)
print 'Loaded dataset `{:s}` for training'.format(imdb.name)
# 设置proposal方法,这里是selective search(config.py)
imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)
     # 得到用于训练的roidb,定义在train.py,进行了水平翻转,以及为原始roidb添加了一些说明性的属性
roidb = get_training_roidb(imdb)
return roidb roidbs = [get_roidb(s) for s in imdb_names.split('+')]
roidb = roidbs[0]
# 这里进行combine roidb
if len(roidbs) > 1:
for r in roidbs[1:]:
roidb.extend(r)
imdb = datasets.imdb.imdb(imdb_names)
else:
# get_imdb方法定义在dataset/factory.py,通过名字得到imdb
imdb = get_imdb(imdb_names)
return imdb, roidb if __name__ == '__main__':
args = parse_args() print('Called with args:')
print(args)
#定义在config.py,见py-faster-rcnn代码阅读2-config.py if args.cfg_file is not None:
cfg_from_file(args.cfg_file)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs) cfg.GPU_ID = args.gpu_id print('Using config:')
pprint.pprint(cfg) if not args.randomize:
# fix the random seeds (numpy and caffe) for reproducibility
np.random.seed(cfg.RNG_SEED)
caffe.set_random_seed(cfg.RNG_SEED) # set up caffe
caffe.set_mode_gpu()
caffe.set_device(args.gpu_id) imdb, roidb = combined_roidb(args.imdb_name)
print '{:d} roidb entries'.format(len(roidb)) output_dir = get_output_dir(imdb)
print 'Output will be saved to `{:s}`'.format(output_dir)
#定义在train.py
train_net(args.solver, roidb, output_dir,
pretrained_model=args.pretrained_model,
max_iters=args.max_iters)
# train.py

"""Train a Fast R-CNN network."""

import caffe
from fast_rcnn.config import cfg
import roi_data_layer.roidb as rdl_roidb
from utils.timer import Timer
import numpy as np
import os from caffe.proto import caffe_pb2
import google.protobuf as pb2 class SolverWrapper(object):
# 对caffe solver的简单封装,封装允许我们控制snapshot过程,用于去除归一化学习到的bb回归权重
"""A simple wrapper around Caffe's solver.
This wrapper gives us control over the snapshotting process, which we
use to unnormalize the learned bounding-box regression weights.
""" def __init__(self, solver_prototxt, roidb, output_dir,
pretrained_model=None):
"""Initialize the SolverWrapper."""
self.output_dir = output_dir if (cfg.TRAIN.HAS_RPN and cfg.TRAIN.BBOX_REG and
cfg.TRAIN.BBOX_NORMALIZE_TARGETS):
# RPN can only use precomputed normalization because there are no
# fixed statistics to compute a priori assert cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED
# 求bb回归器的target,返回box的均值和标准差
if cfg.TRAIN.BBOX_REG:
print 'Computing bounding-box regression targets...'
self.bbox_means, self.bbox_stds = \
rdl_roidb.add_bbox_regression_targets(roidb)
print 'done'
# solver,copy预训练模型的权重,set roidb
self.solver = caffe.SGDSolver(solver_prototxt)
if pretrained_model is not None:
print ('Loading pretrained model '
'weights from {:s}').format(pretrained_model)
self.solver.net.copy_from(pretrained_model) self.solver_param = caffe_pb2.SolverParameter()
with open(solver_prototxt, 'rt') as f:
pb2.text_format.Merge(f.read(), self.solver_param) self.solver.net.layers[0].set_roidb(roidb) def snapshot(self):
"""Take a snapshot of the network after unnormalizing the learned
bounding-box regression weights. This enables easy use at test-time.
"""
# 给snapshot的bb预测参数做去归一化
net = self.solver.net scale_bbox_params = (cfg.TRAIN.BBOX_REG and
cfg.TRAIN.BBOX_NORMALIZE_TARGETS and
net.params.has_key('bbox_pred')) if scale_bbox_params:
# save original values
orig_0 = net.params['bbox_pred'][0].data.copy()
orig_1 = net.params['bbox_pred'][1].data.copy() # scale and shift with bbox reg unnormalization; then save snapshot
# 去除归一化,乘标准差,加上均值
net.params['bbox_pred'][0].data[...] = \
(net.params['bbox_pred'][0].data *
self.bbox_stds[:, np.newaxis])
net.params['bbox_pred'][1].data[...] = \
(net.params['bbox_pred'][1].data *
self.bbox_stds + self.bbox_means)
# 给snapshot命名
infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX
if cfg.TRAIN.SNAPSHOT_INFIX != '' else '')
filename = (self.solver_param.snapshot_prefix + infix +
'_iter_{:d}'.format(self.solver.iter) + '.caffemodel')
filename = os.path.join(self.output_dir, filename)
# 存snapshot
net.save(str(filename))
print 'Wrote snapshot to: {:s}'.format(filename) # 只是存入的snapshot的bb参数做了去归一化用于测试,但是训练部分仍需要保持归一化的状态
if scale_bbox_params:
# restore net to original state
net.params['bbox_pred'][0].data[...] = orig_0
net.params['bbox_pred'][1].data[...] = orig_1
return filename def train_model(self, max_iters):
"""Network training loop."""
last_snapshot_iter = -1
timer = Timer()
model_paths = []
while self.solver.iter < max_iters:
# Make one SGD update
timer.tic()
self.solver.step(1)
timer.toc()
if self.solver.iter % (10 * self.solver_param.display) == 0:
print 'speed: {:.3f}s / iter'.format(timer.average_time)
# 达到预设次数保存snapshot
if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
last_snapshot_iter = self.solver.iter
model_paths.append(self.snapshot())
# 整体迭代完成后也要存snapshot
if last_snapshot_iter != self.solver.iter:
model_paths.append(self.snapshot())
return model_paths def get_training_roidb(imdb):
"""Returns a roidb (Region of Interest database) for use in training."""
if cfg.TRAIN.USE_FLIPPED:
print 'Appending horizontally-flipped training examples...'
# 水平翻转,得到更多的roidb
imdb.append_flipped_images()
print 'done' print 'Preparing training data...'
# 为原始数据集的roidb添加一些说明性的属性,max-overlap,max-classes...
rdl_roidb.prepare_roidb(imdb)
print 'done' return imdb.roidb def filter_roidb(roidb):
"""Remove roidb entries that have no usable RoIs."""
# 删掉没用的RoIs, 有效的图片必须各有前景和背景ROI
def is_valid(entry):
# Valid images have:
# (1) At least one foreground RoI OR
# (2) At least one background RoI
overlaps = entry['max_overlaps']
# find boxes with sufficient overlap
fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0]
# Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) &
(overlaps >= cfg.TRAIN.BG_THRESH_LO))[0]
# image is only valid if such boxes exist
valid = len(fg_inds) > 0 or len(bg_inds) > 0
return valid num = len(roidb)
filtered_roidb = [entry for entry in roidb if is_valid(entry)]
num_after = len(filtered_roidb)
print 'Filtered {} roidb entries: {} -> {}'.format(num - num_after,
num, num_after)
return filtered_roidb def train_net(solver_prototxt, roidb, output_dir,
pretrained_model=None, max_iters=40000):
"""Train a Fast R-CNN network."""
# 训练Fast R-CNN
roidb = filter_roidb(roidb)
sw = SolverWrapper(solver_prototxt, roidb, output_dir,
pretrained_model=pretrained_model) print 'Solving...'
# 训练...
model_paths = sw.train_model(max_iters)
print 'done solving'
return model_paths

py-faster-rcnn代码阅读1-train_net.py & train.py的更多相关文章

  1. py faster rcnn+ 1080Ti+cudnn5.0

    看了py-faster-rcnn上的issue,原来大家都遇到各种问题. 我要好好琢磨一下,看看到底怎么样才能更好地把GPU卡发挥出来.最近真是和GPU卡较上劲了. 上午解决了g++的问题不是. 然后 ...

  2. Faster RCNN代码理解(Python)

    转自http://www.infocool.net/kb/Python/201611/209696.html#原文地址 第一步,准备 从train_faster_rcnn_alt_opt.py入: 初 ...

  3. Faster rcnn代码理解(1)

    这段时间看了不少论文,回头看看,感觉还是有必要将Faster rcnn的源码理解一下,毕竟后来很多方法都和它有相近之处,同时理解该框架也有助于以后自己修改和编写自己的框架.好的开始吧- 这里我们跟着F ...

  4. Faster RCNN代码解析

    1.faster_rcnn_end2end训练 1.1训练入口及配置 def train(): cfg.GPU_ID = 0 cfg_file = "../experiments/cfgs/ ...

  5. Faster rcnn代码理解(2)

    接着上篇的博客,咱们继续看一下Faster RCNN的代码- 上次大致讲完了Faster rcnn在训练时是如何获取imdb和roidb文件的,主要都在train_rpn()的get_roidb()函 ...

  6. Faster R-CNN代码例子

    主要参考文章:1,从编程实现角度学习Faster R-CNN(附极简实现) 经常是做到一半发现收敛情况不理想,然后又回去看看这篇文章的细节. 另外两篇: 2,Faster R-CNN学习总结      ...

  7. Faster rcnn代码理解(4)

    上一篇我们说完了AnchorTargetLayer层,然后我将Faster rcnn中的其他层看了,这里把ROIPoolingLayer层说一下: 我先说一下它的实现原理:RPN生成的roi区域大小是 ...

  8. Faster R-CNN论文阅读摘要

    论文链接: https://arxiv.org/pdf/1506.01497.pdf 代码下载: https://github.com/ShaoqingRen/faster_rcnn (MATLAB) ...

  9. Faster rcnn代码理解(3)

    紧接着之前的博客,我们继续来看faster rcnn中的AnchorTargetLayer层: 该层定义在lib>rpn>中,见该层定义: 首先说一下这一层的目的是输出在特征图上所有点的a ...

  10. tensorflow faster rcnn 代码分析一 demo.py

    os.environ["CUDA_VISIBLE_DEVICES"]=2 # 设置使用的GPU tfconfig=tf.ConfigProto(allow_soft_placeme ...

随机推荐

  1. 12.16daily_scrum

    这个阶段,我们组需要攻克的技术难题一个是测试及美化界面,另一个是在M1阶段的基础上进一步细化和完善悬浮窗的功能,具体的工作内容如下: 具体工作: 小组成员 今日任务 明日任务 工作时间 李睿琦 图片笔 ...

  2. 20170929php

    这是之前学习PHP类使用的代码 <?phpclass animal{ var $name="1"; var $sex="2"; public static ...

  3. Java Socket 多线程聊天室

    本来这次作业我是想搞个图形界面的,然而现实情况是我把题意理解错了,于是乎失去了最初的兴致,还是把程序变成了功能正确但是“UI”不友好的console了,但是不管怎么样,前期的图形界面的开发还是很有收获 ...

  4. 老李的blog使用日记(3)

    匆匆忙忙.碌碌无为,这是下一个作业,VS,多么神圣高大上,即使这样,有多少人喜欢你就有多少人烦你,依然逃不了被推销的命运,这抑或是它喜欢接受的,但是作为被迫接受者,能做的的也只有接受,而已. 既来之则 ...

  5. css之3D变换

    3D变换的x,y,z轴是分别效果是: x轴旋转的话,就是头和脚进行转动 y轴旋转的话,就是左右手进行转动 z轴旋转的话,就是整个身体平铺在旋转. 上面是针对旋转的意思去,但是对于其他的类似一样,就是这 ...

  6. node.js处理url常用方法

    处理非阻塞I/O /* *回调函数的方法 异步 */ /* function f(cb){ fs.readFile('./4',(err,data)=>{ cb(data.toString()) ...

  7. Node.js使用UDP通讯

    Node.js 的 dgram 模块可以方便的创建udp服务,,以下是使用 dgram模块创建的udp服务和客户端的一个简单例子. 一.创建 UDP Server var dgram = requir ...

  8. 洛谷P13445 [USACO5.4]奶牛的电信Telecowmunication(网络流)

    题目描述 农夫约翰的奶牛们喜欢通过电邮保持联系,于是她们建立了一个奶牛电脑网络,以便互相交流.这些机器用如下的方式发送电邮:如果存在一个由c台电脑组成的序列a1,a2,...,a(c),且a1与a2相 ...

  9. Java之Map的使用场景

    总结之 Map接口 的使用场景(day04) Map: Map中的集合,元素是成对存在的(理解为夫妻).每个元素由键与值两部分组成,通过键可以找对所对应的值 Map中的集合不能包含重复的键,值可以重复 ...

  10. spring boot 系列之六:深入理解spring boot的自动配置

    我们知道,spring boot自动配置功能可以根据不同情况来决定spring配置应该用哪个,不应该用哪个,举个例子: Spring的JdbcTemplate是不是在Classpath里面?如果是,并 ...