# train_net.py
#!/usr/bin/env python # --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# -------------------------------------------------------- """Train a Fast R-CNN network on a region of interest database.""" import _init_paths
from fast_rcnn.train import get_training_roidb, train_net
from fast_rcnn.config import cfg, cfg_from_file, cfg_from_list, get_output_dir
from datasets.factory import get_imdb
import datasets.imdb
import caffe
import argparse
import pprint
import numpy as np
import sys def parse_args():
# 运行时命令行参数
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
parser.add_argument('--gpu', dest='gpu_id',
help='GPU device id to use [0]',
default=0, type=int)
parser.add_argument('--solver', dest='solver',
help='solver prototxt',
default=None, type=str)
parser.add_argument('--iters', dest='max_iters',
help='number of iterations to train',
default=40000, type=int)
parser.add_argument('--weights', dest='pretrained_model',
help='initialize with pretrained model weights',
default=None, type=str)
parser.add_argument('--cfg', dest='cfg_file',
help='optional config file',
default=None, type=str)
parser.add_argument('--imdb', dest='imdb_name',
help='dataset to train on',
default='voc_2007_trainval', type=str)
parser.add_argument('--rand', dest='randomize',
help='randomize (do not use a fixed seed)',
action='store_true')
parser.add_argument('--set', dest='set_cfgs',
help='set config keys', default=None,
nargs=argparse.REMAINDER) if len(sys.argv) == 1:
parser.print_help()
sys.exit(1) args = parser.parse_args()
return args def combined_roidb(imdb_names):
# 融合roidb,roidb来自于数据集(实验可能用到多个),所以需要combine多个数据集的roidb
#
def get_roidb(imdb_name):
imdb = get_imdb(imdb_name)
print 'Loaded dataset `{:s}` for training'.format(imdb.name)
# 设置proposal方法,这里是selective search(config.py)
imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)
     # 得到用于训练的roidb,定义在train.py,进行了水平翻转,以及为原始roidb添加了一些说明性的属性
roidb = get_training_roidb(imdb)
return roidb roidbs = [get_roidb(s) for s in imdb_names.split('+')]
roidb = roidbs[0]
# 这里进行combine roidb
if len(roidbs) > 1:
for r in roidbs[1:]:
roidb.extend(r)
imdb = datasets.imdb.imdb(imdb_names)
else:
# get_imdb方法定义在dataset/factory.py,通过名字得到imdb
imdb = get_imdb(imdb_names)
return imdb, roidb if __name__ == '__main__':
args = parse_args() print('Called with args:')
print(args)
#定义在config.py,见py-faster-rcnn代码阅读2-config.py if args.cfg_file is not None:
cfg_from_file(args.cfg_file)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs) cfg.GPU_ID = args.gpu_id print('Using config:')
pprint.pprint(cfg) if not args.randomize:
# fix the random seeds (numpy and caffe) for reproducibility
np.random.seed(cfg.RNG_SEED)
caffe.set_random_seed(cfg.RNG_SEED) # set up caffe
caffe.set_mode_gpu()
caffe.set_device(args.gpu_id) imdb, roidb = combined_roidb(args.imdb_name)
print '{:d} roidb entries'.format(len(roidb)) output_dir = get_output_dir(imdb)
print 'Output will be saved to `{:s}`'.format(output_dir)
#定义在train.py
train_net(args.solver, roidb, output_dir,
pretrained_model=args.pretrained_model,
max_iters=args.max_iters)
# train.py

"""Train a Fast R-CNN network."""

import caffe
from fast_rcnn.config import cfg
import roi_data_layer.roidb as rdl_roidb
from utils.timer import Timer
import numpy as np
import os from caffe.proto import caffe_pb2
import google.protobuf as pb2 class SolverWrapper(object):
# 对caffe solver的简单封装,封装允许我们控制snapshot过程,用于去除归一化学习到的bb回归权重
"""A simple wrapper around Caffe's solver.
This wrapper gives us control over the snapshotting process, which we
use to unnormalize the learned bounding-box regression weights.
""" def __init__(self, solver_prototxt, roidb, output_dir,
pretrained_model=None):
"""Initialize the SolverWrapper."""
self.output_dir = output_dir if (cfg.TRAIN.HAS_RPN and cfg.TRAIN.BBOX_REG and
cfg.TRAIN.BBOX_NORMALIZE_TARGETS):
# RPN can only use precomputed normalization because there are no
# fixed statistics to compute a priori assert cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED
# 求bb回归器的target,返回box的均值和标准差
if cfg.TRAIN.BBOX_REG:
print 'Computing bounding-box regression targets...'
self.bbox_means, self.bbox_stds = \
rdl_roidb.add_bbox_regression_targets(roidb)
print 'done'
# solver,copy预训练模型的权重,set roidb
self.solver = caffe.SGDSolver(solver_prototxt)
if pretrained_model is not None:
print ('Loading pretrained model '
'weights from {:s}').format(pretrained_model)
self.solver.net.copy_from(pretrained_model) self.solver_param = caffe_pb2.SolverParameter()
with open(solver_prototxt, 'rt') as f:
pb2.text_format.Merge(f.read(), self.solver_param) self.solver.net.layers[0].set_roidb(roidb) def snapshot(self):
"""Take a snapshot of the network after unnormalizing the learned
bounding-box regression weights. This enables easy use at test-time.
"""
# 给snapshot的bb预测参数做去归一化
net = self.solver.net scale_bbox_params = (cfg.TRAIN.BBOX_REG and
cfg.TRAIN.BBOX_NORMALIZE_TARGETS and
net.params.has_key('bbox_pred')) if scale_bbox_params:
# save original values
orig_0 = net.params['bbox_pred'][0].data.copy()
orig_1 = net.params['bbox_pred'][1].data.copy() # scale and shift with bbox reg unnormalization; then save snapshot
# 去除归一化,乘标准差,加上均值
net.params['bbox_pred'][0].data[...] = \
(net.params['bbox_pred'][0].data *
self.bbox_stds[:, np.newaxis])
net.params['bbox_pred'][1].data[...] = \
(net.params['bbox_pred'][1].data *
self.bbox_stds + self.bbox_means)
# 给snapshot命名
infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX
if cfg.TRAIN.SNAPSHOT_INFIX != '' else '')
filename = (self.solver_param.snapshot_prefix + infix +
'_iter_{:d}'.format(self.solver.iter) + '.caffemodel')
filename = os.path.join(self.output_dir, filename)
# 存snapshot
net.save(str(filename))
print 'Wrote snapshot to: {:s}'.format(filename) # 只是存入的snapshot的bb参数做了去归一化用于测试,但是训练部分仍需要保持归一化的状态
if scale_bbox_params:
# restore net to original state
net.params['bbox_pred'][0].data[...] = orig_0
net.params['bbox_pred'][1].data[...] = orig_1
return filename def train_model(self, max_iters):
"""Network training loop."""
last_snapshot_iter = -1
timer = Timer()
model_paths = []
while self.solver.iter < max_iters:
# Make one SGD update
timer.tic()
self.solver.step(1)
timer.toc()
if self.solver.iter % (10 * self.solver_param.display) == 0:
print 'speed: {:.3f}s / iter'.format(timer.average_time)
# 达到预设次数保存snapshot
if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
last_snapshot_iter = self.solver.iter
model_paths.append(self.snapshot())
# 整体迭代完成后也要存snapshot
if last_snapshot_iter != self.solver.iter:
model_paths.append(self.snapshot())
return model_paths def get_training_roidb(imdb):
"""Returns a roidb (Region of Interest database) for use in training."""
if cfg.TRAIN.USE_FLIPPED:
print 'Appending horizontally-flipped training examples...'
# 水平翻转,得到更多的roidb
imdb.append_flipped_images()
print 'done' print 'Preparing training data...'
# 为原始数据集的roidb添加一些说明性的属性,max-overlap,max-classes...
rdl_roidb.prepare_roidb(imdb)
print 'done' return imdb.roidb def filter_roidb(roidb):
"""Remove roidb entries that have no usable RoIs."""
# 删掉没用的RoIs, 有效的图片必须各有前景和背景ROI
def is_valid(entry):
# Valid images have:
# (1) At least one foreground RoI OR
# (2) At least one background RoI
overlaps = entry['max_overlaps']
# find boxes with sufficient overlap
fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0]
# Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) &
(overlaps >= cfg.TRAIN.BG_THRESH_LO))[0]
# image is only valid if such boxes exist
valid = len(fg_inds) > 0 or len(bg_inds) > 0
return valid num = len(roidb)
filtered_roidb = [entry for entry in roidb if is_valid(entry)]
num_after = len(filtered_roidb)
print 'Filtered {} roidb entries: {} -> {}'.format(num - num_after,
num, num_after)
return filtered_roidb def train_net(solver_prototxt, roidb, output_dir,
pretrained_model=None, max_iters=40000):
"""Train a Fast R-CNN network."""
# 训练Fast R-CNN
roidb = filter_roidb(roidb)
sw = SolverWrapper(solver_prototxt, roidb, output_dir,
pretrained_model=pretrained_model) print 'Solving...'
# 训练...
model_paths = sw.train_model(max_iters)
print 'done solving'
return model_paths

py-faster-rcnn代码阅读1-train_net.py & train.py的更多相关文章

  1. py faster rcnn+ 1080Ti+cudnn5.0

    看了py-faster-rcnn上的issue,原来大家都遇到各种问题. 我要好好琢磨一下,看看到底怎么样才能更好地把GPU卡发挥出来.最近真是和GPU卡较上劲了. 上午解决了g++的问题不是. 然后 ...

  2. Faster RCNN代码理解(Python)

    转自http://www.infocool.net/kb/Python/201611/209696.html#原文地址 第一步,准备 从train_faster_rcnn_alt_opt.py入: 初 ...

  3. Faster rcnn代码理解(1)

    这段时间看了不少论文,回头看看,感觉还是有必要将Faster rcnn的源码理解一下,毕竟后来很多方法都和它有相近之处,同时理解该框架也有助于以后自己修改和编写自己的框架.好的开始吧- 这里我们跟着F ...

  4. Faster RCNN代码解析

    1.faster_rcnn_end2end训练 1.1训练入口及配置 def train(): cfg.GPU_ID = 0 cfg_file = "../experiments/cfgs/ ...

  5. Faster rcnn代码理解(2)

    接着上篇的博客,咱们继续看一下Faster RCNN的代码- 上次大致讲完了Faster rcnn在训练时是如何获取imdb和roidb文件的,主要都在train_rpn()的get_roidb()函 ...

  6. Faster R-CNN代码例子

    主要参考文章:1,从编程实现角度学习Faster R-CNN(附极简实现) 经常是做到一半发现收敛情况不理想,然后又回去看看这篇文章的细节. 另外两篇: 2,Faster R-CNN学习总结      ...

  7. Faster rcnn代码理解(4)

    上一篇我们说完了AnchorTargetLayer层,然后我将Faster rcnn中的其他层看了,这里把ROIPoolingLayer层说一下: 我先说一下它的实现原理:RPN生成的roi区域大小是 ...

  8. Faster R-CNN论文阅读摘要

    论文链接: https://arxiv.org/pdf/1506.01497.pdf 代码下载: https://github.com/ShaoqingRen/faster_rcnn (MATLAB) ...

  9. Faster rcnn代码理解(3)

    紧接着之前的博客,我们继续来看faster rcnn中的AnchorTargetLayer层: 该层定义在lib>rpn>中,见该层定义: 首先说一下这一层的目的是输出在特征图上所有点的a ...

  10. tensorflow faster rcnn 代码分析一 demo.py

    os.environ["CUDA_VISIBLE_DEVICES"]=2 # 设置使用的GPU tfconfig=tf.ConfigProto(allow_soft_placeme ...

随机推荐

  1. 《Linux内核分析》第一周学习小结 计算机是如何工作的?

    <Linux内核分析>第一周.计算机是如何工作的? 20135204 郝智宇  一.存储程序计算机工作模型 1.      冯诺依曼体系结构: 数字计算机的数制采用二进制:计算机应该按照程 ...

  2. ”数学口袋精灵“第二个Sprint计划---第二天

    “数学口袋精灵”第二个Sprint计划----第二天进度 任务分配: 冯美欣:欢迎界面的音效 吴舒婷:游戏界面的动作条,选择答案后的音效 林欢雯:完善算法代码的设计 进度:   冯美欣:找到了几个音乐 ...

  3. 从零开始学Kotlin-控制语句(4)

    从零开始学Kotlin基础篇系列文章 条件控制-if var a=10 var b=20 if(a>b) print(a) if(a>b){ print(a) }else{ print(b ...

  4. Mysql 间隙锁原理,以及Repeatable Read隔离级别下可以防止幻读原理(百度)

    Mysql知识实在太丰富了,前几天百度的面试官问我MySql在Repeatable Read下面是否会有幻读出现,我说按照事务的特性当然会有, 但是面试官却说 Mysql 在Repeatable Re ...

  5. CUDA ---- GPU架构(Fermi、Kepler)

    GPU架构 SM(Streaming Multiprocessors)是GPU架构中非常重要的部分,GPU硬件的并行性就是由SM决定的. 以Fermi架构为例,其包含以下主要组成部分: CUDA co ...

  6. BZOJ2653 middle(二分答案+主席树)

    与中位数有关的题二分答案是很常用的trick.二分答案之后,将所有大于它的看成1小于它的看成-1,那么只需要判断是否存在满足要求的一段和不小于0. 由于每个位置是1还是-1并不固定,似乎不是很好算.考 ...

  7. 学习Spring Boot:(十四)spring-shiro的密码加密

    前言 前面配置了怎么使用 shiro ,这次研究下怎么使用spring shiro的密码加密,并且需要在新增.更新用户的时候,实现生成盐,加密后的密码进行入库操作. 正文 配置凭证匹配器 @Bean ...

  8. CF1110C Meaningless Operations(构造题)

    这可能是我打那么多次CF比赛时,做出来的最难的一道题了……而且这题也是个绝世好题…… 题目链接:CF原网  洛谷 题目大意:$q$ 组询问,每次给定 $a$ 询问 $\gcd(a\&b,a\o ...

  9. (转)关于Class.getResource和ClassLoader.getResource的路径问题

    Java中取资源时,经常用到Class.getResource和ClassLoader.getResource,这里来看看他们在取资源文件时候的路径问题. 1 Class.getResource(St ...

  10. Python基础学习(二)

    前一段时间学习了Python数据类型,语句和函数,目前书写python的新特性,继续练手!!!! 一.切片 之前我们从python的list 或者 tuple中取得元素都是这样写,显然不够灵活 lis ...