# train_net.py
#!/usr/bin/env python # --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# -------------------------------------------------------- """Train a Fast R-CNN network on a region of interest database.""" import _init_paths
from fast_rcnn.train import get_training_roidb, train_net
from fast_rcnn.config import cfg, cfg_from_file, cfg_from_list, get_output_dir
from datasets.factory import get_imdb
import datasets.imdb
import caffe
import argparse
import pprint
import numpy as np
import sys def parse_args():
# 运行时命令行参数
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
parser.add_argument('--gpu', dest='gpu_id',
help='GPU device id to use [0]',
default=0, type=int)
parser.add_argument('--solver', dest='solver',
help='solver prototxt',
default=None, type=str)
parser.add_argument('--iters', dest='max_iters',
help='number of iterations to train',
default=40000, type=int)
parser.add_argument('--weights', dest='pretrained_model',
help='initialize with pretrained model weights',
default=None, type=str)
parser.add_argument('--cfg', dest='cfg_file',
help='optional config file',
default=None, type=str)
parser.add_argument('--imdb', dest='imdb_name',
help='dataset to train on',
default='voc_2007_trainval', type=str)
parser.add_argument('--rand', dest='randomize',
help='randomize (do not use a fixed seed)',
action='store_true')
parser.add_argument('--set', dest='set_cfgs',
help='set config keys', default=None,
nargs=argparse.REMAINDER) if len(sys.argv) == 1:
parser.print_help()
sys.exit(1) args = parser.parse_args()
return args def combined_roidb(imdb_names):
# 融合roidb,roidb来自于数据集(实验可能用到多个),所以需要combine多个数据集的roidb
#
def get_roidb(imdb_name):
imdb = get_imdb(imdb_name)
print 'Loaded dataset `{:s}` for training'.format(imdb.name)
# 设置proposal方法,这里是selective search(config.py)
imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)
     # 得到用于训练的roidb,定义在train.py,进行了水平翻转,以及为原始roidb添加了一些说明性的属性
roidb = get_training_roidb(imdb)
return roidb roidbs = [get_roidb(s) for s in imdb_names.split('+')]
roidb = roidbs[0]
# 这里进行combine roidb
if len(roidbs) > 1:
for r in roidbs[1:]:
roidb.extend(r)
imdb = datasets.imdb.imdb(imdb_names)
else:
# get_imdb方法定义在dataset/factory.py,通过名字得到imdb
imdb = get_imdb(imdb_names)
return imdb, roidb if __name__ == '__main__':
args = parse_args() print('Called with args:')
print(args)
#定义在config.py,见py-faster-rcnn代码阅读2-config.py if args.cfg_file is not None:
cfg_from_file(args.cfg_file)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs) cfg.GPU_ID = args.gpu_id print('Using config:')
pprint.pprint(cfg) if not args.randomize:
# fix the random seeds (numpy and caffe) for reproducibility
np.random.seed(cfg.RNG_SEED)
caffe.set_random_seed(cfg.RNG_SEED) # set up caffe
caffe.set_mode_gpu()
caffe.set_device(args.gpu_id) imdb, roidb = combined_roidb(args.imdb_name)
print '{:d} roidb entries'.format(len(roidb)) output_dir = get_output_dir(imdb)
print 'Output will be saved to `{:s}`'.format(output_dir)
#定义在train.py
train_net(args.solver, roidb, output_dir,
pretrained_model=args.pretrained_model,
max_iters=args.max_iters)
# train.py

"""Train a Fast R-CNN network."""

import caffe
from fast_rcnn.config import cfg
import roi_data_layer.roidb as rdl_roidb
from utils.timer import Timer
import numpy as np
import os from caffe.proto import caffe_pb2
import google.protobuf as pb2 class SolverWrapper(object):
# 对caffe solver的简单封装,封装允许我们控制snapshot过程,用于去除归一化学习到的bb回归权重
"""A simple wrapper around Caffe's solver.
This wrapper gives us control over the snapshotting process, which we
use to unnormalize the learned bounding-box regression weights.
""" def __init__(self, solver_prototxt, roidb, output_dir,
pretrained_model=None):
"""Initialize the SolverWrapper."""
self.output_dir = output_dir if (cfg.TRAIN.HAS_RPN and cfg.TRAIN.BBOX_REG and
cfg.TRAIN.BBOX_NORMALIZE_TARGETS):
# RPN can only use precomputed normalization because there are no
# fixed statistics to compute a priori assert cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED
# 求bb回归器的target,返回box的均值和标准差
if cfg.TRAIN.BBOX_REG:
print 'Computing bounding-box regression targets...'
self.bbox_means, self.bbox_stds = \
rdl_roidb.add_bbox_regression_targets(roidb)
print 'done'
# solver,copy预训练模型的权重,set roidb
self.solver = caffe.SGDSolver(solver_prototxt)
if pretrained_model is not None:
print ('Loading pretrained model '
'weights from {:s}').format(pretrained_model)
self.solver.net.copy_from(pretrained_model) self.solver_param = caffe_pb2.SolverParameter()
with open(solver_prototxt, 'rt') as f:
pb2.text_format.Merge(f.read(), self.solver_param) self.solver.net.layers[0].set_roidb(roidb) def snapshot(self):
"""Take a snapshot of the network after unnormalizing the learned
bounding-box regression weights. This enables easy use at test-time.
"""
# 给snapshot的bb预测参数做去归一化
net = self.solver.net scale_bbox_params = (cfg.TRAIN.BBOX_REG and
cfg.TRAIN.BBOX_NORMALIZE_TARGETS and
net.params.has_key('bbox_pred')) if scale_bbox_params:
# save original values
orig_0 = net.params['bbox_pred'][0].data.copy()
orig_1 = net.params['bbox_pred'][1].data.copy() # scale and shift with bbox reg unnormalization; then save snapshot
# 去除归一化,乘标准差,加上均值
net.params['bbox_pred'][0].data[...] = \
(net.params['bbox_pred'][0].data *
self.bbox_stds[:, np.newaxis])
net.params['bbox_pred'][1].data[...] = \
(net.params['bbox_pred'][1].data *
self.bbox_stds + self.bbox_means)
# 给snapshot命名
infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX
if cfg.TRAIN.SNAPSHOT_INFIX != '' else '')
filename = (self.solver_param.snapshot_prefix + infix +
'_iter_{:d}'.format(self.solver.iter) + '.caffemodel')
filename = os.path.join(self.output_dir, filename)
# 存snapshot
net.save(str(filename))
print 'Wrote snapshot to: {:s}'.format(filename) # 只是存入的snapshot的bb参数做了去归一化用于测试,但是训练部分仍需要保持归一化的状态
if scale_bbox_params:
# restore net to original state
net.params['bbox_pred'][0].data[...] = orig_0
net.params['bbox_pred'][1].data[...] = orig_1
return filename def train_model(self, max_iters):
"""Network training loop."""
last_snapshot_iter = -1
timer = Timer()
model_paths = []
while self.solver.iter < max_iters:
# Make one SGD update
timer.tic()
self.solver.step(1)
timer.toc()
if self.solver.iter % (10 * self.solver_param.display) == 0:
print 'speed: {:.3f}s / iter'.format(timer.average_time)
# 达到预设次数保存snapshot
if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
last_snapshot_iter = self.solver.iter
model_paths.append(self.snapshot())
# 整体迭代完成后也要存snapshot
if last_snapshot_iter != self.solver.iter:
model_paths.append(self.snapshot())
return model_paths def get_training_roidb(imdb):
"""Returns a roidb (Region of Interest database) for use in training."""
if cfg.TRAIN.USE_FLIPPED:
print 'Appending horizontally-flipped training examples...'
# 水平翻转,得到更多的roidb
imdb.append_flipped_images()
print 'done' print 'Preparing training data...'
# 为原始数据集的roidb添加一些说明性的属性,max-overlap,max-classes...
rdl_roidb.prepare_roidb(imdb)
print 'done' return imdb.roidb def filter_roidb(roidb):
"""Remove roidb entries that have no usable RoIs."""
# 删掉没用的RoIs, 有效的图片必须各有前景和背景ROI
def is_valid(entry):
# Valid images have:
# (1) At least one foreground RoI OR
# (2) At least one background RoI
overlaps = entry['max_overlaps']
# find boxes with sufficient overlap
fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0]
# Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) &
(overlaps >= cfg.TRAIN.BG_THRESH_LO))[0]
# image is only valid if such boxes exist
valid = len(fg_inds) > 0 or len(bg_inds) > 0
return valid num = len(roidb)
filtered_roidb = [entry for entry in roidb if is_valid(entry)]
num_after = len(filtered_roidb)
print 'Filtered {} roidb entries: {} -> {}'.format(num - num_after,
num, num_after)
return filtered_roidb def train_net(solver_prototxt, roidb, output_dir,
pretrained_model=None, max_iters=40000):
"""Train a Fast R-CNN network."""
# 训练Fast R-CNN
roidb = filter_roidb(roidb)
sw = SolverWrapper(solver_prototxt, roidb, output_dir,
pretrained_model=pretrained_model) print 'Solving...'
# 训练...
model_paths = sw.train_model(max_iters)
print 'done solving'
return model_paths

py-faster-rcnn代码阅读1-train_net.py & train.py的更多相关文章

  1. py faster rcnn+ 1080Ti+cudnn5.0

    看了py-faster-rcnn上的issue,原来大家都遇到各种问题. 我要好好琢磨一下,看看到底怎么样才能更好地把GPU卡发挥出来.最近真是和GPU卡较上劲了. 上午解决了g++的问题不是. 然后 ...

  2. Faster RCNN代码理解(Python)

    转自http://www.infocool.net/kb/Python/201611/209696.html#原文地址 第一步,准备 从train_faster_rcnn_alt_opt.py入: 初 ...

  3. Faster rcnn代码理解(1)

    这段时间看了不少论文,回头看看,感觉还是有必要将Faster rcnn的源码理解一下,毕竟后来很多方法都和它有相近之处,同时理解该框架也有助于以后自己修改和编写自己的框架.好的开始吧- 这里我们跟着F ...

  4. Faster RCNN代码解析

    1.faster_rcnn_end2end训练 1.1训练入口及配置 def train(): cfg.GPU_ID = 0 cfg_file = "../experiments/cfgs/ ...

  5. Faster rcnn代码理解(2)

    接着上篇的博客,咱们继续看一下Faster RCNN的代码- 上次大致讲完了Faster rcnn在训练时是如何获取imdb和roidb文件的,主要都在train_rpn()的get_roidb()函 ...

  6. Faster R-CNN代码例子

    主要参考文章:1,从编程实现角度学习Faster R-CNN(附极简实现) 经常是做到一半发现收敛情况不理想,然后又回去看看这篇文章的细节. 另外两篇: 2,Faster R-CNN学习总结      ...

  7. Faster rcnn代码理解(4)

    上一篇我们说完了AnchorTargetLayer层,然后我将Faster rcnn中的其他层看了,这里把ROIPoolingLayer层说一下: 我先说一下它的实现原理:RPN生成的roi区域大小是 ...

  8. Faster R-CNN论文阅读摘要

    论文链接: https://arxiv.org/pdf/1506.01497.pdf 代码下载: https://github.com/ShaoqingRen/faster_rcnn (MATLAB) ...

  9. Faster rcnn代码理解(3)

    紧接着之前的博客,我们继续来看faster rcnn中的AnchorTargetLayer层: 该层定义在lib>rpn>中,见该层定义: 首先说一下这一层的目的是输出在特征图上所有点的a ...

  10. tensorflow faster rcnn 代码分析一 demo.py

    os.environ["CUDA_VISIBLE_DEVICES"]=2 # 设置使用的GPU tfconfig=tf.ConfigProto(allow_soft_placeme ...

随机推荐

  1. java实现图像的直方图均衡以及灰度线性变化,灰度拉伸

    写了四个方法,分别实现图片的灰度化,直方图均衡,灰度线性变化,灰度拉伸,其中好多地方特别是灰度拉伸这一块觉得自己实现的有问题,请大大们多多指教. import java.awt.Image; impo ...

  2. JS开发之CommonJs和AMD/CMD规范

    CommonJS是主要为了JS在后端的表现制定的,他是不适合前端的,AMD(异步模块定义)出现了,它就主要为前端JS的表现制定规范. 在兼容CommonJS的系统中,你可以使用JavaScript开发 ...

  3. Maven 学习笔记——Maven和Eclipse(2)

    前面已经配置好Maven的环境和本地仓库已经准备好了,下面我们通过Eclipse创建Maven项目. 1.安装Maven集成于Eclipse IDE (Eclipse的版本中如果已经集成了Maven插 ...

  4. 转Git学习碰到的问题

    [原]Git学习碰到的问题 1.通过 windows 下的 gitbash 连接 github 时 $ ssh-keygen -t rsa -C "abcd@efgh.com" / ...

  5. CMS垃圾回收过程

    1.总体介绍: CMS(Concurrent Mark-Sweep)是以牺牲吞吐量为代价来获得最短回收停顿时间的垃圾回收器.对于要求服务器响应速度的应用上,这种垃圾回收器非常适合.在启动JVM参数加上 ...

  6. Ubuntu18.04 安装后的简单实用设置[未完成]

    1. 安装完成. 2. 更新 sudo apt-get update 3. 修改vi 放置键盘错位的问题 编辑文件/etc/vim/vimrc.tiny 将“compatible”改成“nocompa ...

  7. Orchard是如何运行的

    建立一个CMS网站(内容管理系统)是不同于建立一个普通的web站点:它更像是建立一个应用程序容器. 设计这样一个系统时,必须建立一流的可扩展性功能.这必需是一个非常开放式的构架,但是一个开放性的系统可 ...

  8. 使用AutoMapper实现Dto和Model的自由转换(上)

    在实际的软件开发项目中,我们的“业务逻辑”常常需要我们对同样的数据进行各种变换.例如,一个Web应用通过前端收集用户的输入成为Dto,然后将Dto转换成领域模型并持久化到数据库中.另一方面,当用户请求 ...

  9. Python【知识点】傻傻的函数内变量

    问题的由来 有个学生问我一个问题关于函数内部变量的我们来一起看下代码: Code1 x = 50 def func(): print(x) global x print("x修改前的值:&q ...

  10. 自学huawei之路-AC6005-8AP添加授权码

    返回自学Huawei之路 自学huawei之路-AC6005-8AP添加授权码