转载:https://blog.csdn.net/u011311291/article/details/81121519

https://blog.csdn.net/qq_34564612/article/details/79138876

2018年07月19日 19:43:58 姚贤贤 阅读数:1370
 
版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u011311291/article/details/81121519

faster RCNN(keras版本)代码讲解博客索引:
1.faster RCNN(keras版本)代码讲解(1)-概述
2.faster RCNN(keras版本)代码讲解(2)-数据准备
3.faster RCNN(keras版本)代码讲解(3)-训练流程详情
4.faster RCNN(keras版本)代码讲解(4)-共享卷积层详情
5.faster RCNN(keras版本)代码讲解(5)-RPN层详情
6.faster RCNN(keras版本)代码讲解(6)-ROI Pooling层详情

一.整体流程概述
1.输入参数,其实输入1个就行了(D:\tempFile\VOCdevkit),另外一个resnet权重只是为了加快训练,如图:

2.从VOC2007数据集中读取数据,变成想要的数据格式
3.定义生成数据的迭代器
4.定义3个网络,一个是resnet共享卷积层,一个rpn层,一个分类器层
5.进入迭代,每次只训练一张图片
6.是否要进行图片增强
7.根据特征图和定义框的比例,IOU等计算出y_train值,作为网络的label
8.训练rpn层,输出物体,和物体框的坐标
9.然后再进行分类器层层的训练

二.代码(关键部位已经给出注释)

from __future__ import division
import random
import pprint
import sys
import time
import numpy as np
from optparse import OptionParser
import pickle from keras import backend as K
from keras.optimizers import Adam, SGD, RMSprop
from keras.layers import Input
from keras.models import Model
from keras_frcnn import config, data_generators
from keras_frcnn import losses as losses
import keras_frcnn.roi_helpers as roi_helpers
from keras.utils import generic_utils sys.setrecursionlimit(40000) parser = OptionParser() parser.add_option("-p", "--path", dest="train_path", help="Path to training data.")
parser.add_option("-o", "--parser", dest="parser", help="Parser to use. One of simple or pascal_voc",
default="pascal_voc")
parser.add_option("-n", "--num_rois", type="int", dest="num_rois", help="Number of RoIs to process at once.", default=32)
parser.add_option("--network", dest="network", help="Base network to use. Supports vgg or resnet50.", default='resnet50')
parser.add_option("--hf", dest="horizontal_flips", help="Augment with horizontal flips in training. (Default=false).", action="store_true", default=False)
parser.add_option("--vf", dest="vertical_flips", help="Augment with vertical flips in training. (Default=false).", action="store_true", default=False)
parser.add_option("--rot", "--rot_90", dest="rot_90", help="Augment with 90 degree rotations in training. (Default=false).",
action="store_true", default=False)
parser.add_option("--num_epochs", type="int", dest="num_epochs", help="Number of epochs.", default=2000)
parser.add_option("--config_filename", dest="config_filename", help=
"Location to store all the metadata related to the training (to be used when testing).",
default="config.pickle")
parser.add_option("--output_weight_path", dest="output_weight_path", help="Output path for weights.", default='./model_frcnn.hdf5')
parser.add_option("--input_weight_path", dest="input_weight_path", help="Input path for weights. If not specified, will try to load default weights provided by keras.") (options, args) = parser.parse_args() if not options.train_path: # if filename is not given
parser.error('Error: path to training data must be specified. Pass --path to command line') if options.parser == 'pascal_voc':
from keras_frcnn.pascal_voc_parser import get_data
elif options.parser == 'simple':
from keras_frcnn.simple_parser import get_data
else:
raise ValueError("Command line option parser must be one of 'pascal_voc' or 'simple'") # pass the settings from the command line, and persist them in the config object
C = config.Config() C.use_horizontal_flips = bool(options.horizontal_flips)
C.use_vertical_flips = bool(options.vertical_flips)
C.rot_90 = bool(options.rot_90) C.model_path = options.output_weight_path
C.num_rois = int(options.num_rois) #有基于VGG和resnet两种网络模型
if options.network == 'vgg':
C.network = 'vgg'
from keras_frcnn import vgg as nn
elif options.network == 'resnet50':
from keras_frcnn import resnet as nn
C.network = 'resnet50'
else:
print('Not a valid model')
raise ValueError # check if weight path was passed via command line
if options.input_weight_path:
C.base_net_weights = options.input_weight_path
else:
# set the path to weights based on backend and model
C.base_net_weights = nn.get_weight_path() all_imgs, classes_count, class_mapping = get_data(options.train_path)
print(len(all_imgs)) #所有图片的信息,图片名称,位置等
print(len(classes_count)) #dict,类别:数量,例如'chair': 1432
print(len(class_mapping)) #dict,各个类别对应的标签向量,0-19,例如chair:0,car:1 #再加入'背景'这个类别
if 'bg' not in classes_count:
classes_count['bg'] = 0
class_mapping['bg'] = len(class_mapping)
C.class_mapping = class_mapping # 将class_mapping中的key和value对调
inv_map = {v: k for k, v in class_mapping.items()} print('Training images per class:')
pprint.pprint(classes_count)
print('Num classes (including bg) = {}'.format(len(classes_count))) config_output_filename = options.config_filename with open(config_output_filename, 'wb') as config_f:
pickle.dump(C,config_f)
print('Config has been written to {}, and can be loaded when testing to ensure correct results'.format(config_output_filename)) # shuffle数据
random.shuffle(all_imgs) num_imgs = len(all_imgs)
# 将all_imgs分为训练集和测试集
train_imgs = [s for s in all_imgs if s['imageset'] == 'trainval']
val_imgs = [s for s in all_imgs if s['imageset'] == 'test'] print('Num train samples {}'.format(len(train_imgs)))
print('Num val samples {}'.format(len(val_imgs)))
# 生成anchor
data_gen_train = data_generators.get_anchor_gt(train_imgs, classes_count, C, nn.get_img_output_length, K.image_dim_ordering(), mode='train')
# data_gen_train = data_generators.get_anchor_gt(train_imgs, classes_count, C, nn.get_img_output_length, K.image_dim_ordering(), mode='train')
data_gen_val = data_generators.get_anchor_gt(val_imgs, classes_count, C, nn.get_img_output_length,K.image_dim_ordering(), mode='val') #查看后端是th还是tf,纠正输入方式
if K.image_dim_ordering() == 'th':
input_shape_img = (3, None, None)
else:
input_shape_img = (None, None, 3) img_input = Input(shape=input_shape_img)
roi_input = Input(shape=(None, 4)) # define the base network (resnet here, can be VGG, Inception, etc)
#定义nn的输入层,还有faster rcnn共享卷积层
shared_layers = nn.nn_base(img_input, trainable=True)
print("shared_layers",shared_layers.shape) # define the RPN, built on the base layers
#获取anchor的个数,3重基准大小快,3种比例框,3*3=9
num_anchors = len(C.anchor_box_scales) * len(C.anchor_box_ratios)
#定义rpn层,return [x_class, x_regr, base_layers]
rpn = nn.rpn(shared_layers, num_anchors) classifier = nn.classifier(shared_layers, roi_input, C.num_rois, nb_classes=len(classes_count), trainable=True) #定义rpn模型的输入和输出一个框2分类(最后使用的sigmod而不是softmax)和框的回归
model_rpn = Model(img_input, rpn[:2])
#定义classifier的输入和输出
model_classifier = Model([img_input, roi_input], classifier) # this is a model that holds both the RPN and the classifier, used to load/save weights for the models
model_all = Model([img_input, roi_input], rpn[:2] + classifier) try:
print('loading weights from {}'.format(C.base_net_weights))
model_rpn.load_weights(C.base_net_weights, by_name=True)
model_classifier.load_weights(C.base_net_weights, by_name=True)
except:
print('Could not load pretrained model weights. Weights can be found in the keras application folder \
https://github.com/fchollet/keras/tree/master/keras/applications') optimizer = Adam(lr=1e-5)
optimizer_classifier = Adam(lr=1e-5)
model_rpn.compile(optimizer=optimizer, loss=[losses.rpn_loss_cls(num_anchors), losses.rpn_loss_regr(num_anchors)])
model_classifier.compile(optimizer=optimizer_classifier, loss=[losses.class_loss_cls, losses.class_loss_regr(len(classes_count)-1)], metrics={'dense_class_{}'.format(len(classes_count)): 'accuracy'})
model_all.compile(optimizer='sgd', loss='mae') epoch_length = 1000
num_epochs = int(options.num_epochs)
iter_num = 0 losses = np.zeros((epoch_length, 5))
rpn_accuracy_rpn_monitor = []
rpn_accuracy_for_epoch = []
start_time = time.time() best_loss = np.Inf class_mapping_inv = {v: k for k, v in class_mapping.items()}
print('Starting training') vis = True for epoch_num in range(num_epochs): progbar = generic_utils.Progbar(epoch_length)
print('Epoch {}/{}'.format(epoch_num + 1, num_epochs)) while True:
try: if len(rpn_accuracy_rpn_monitor) == epoch_length and C.verbose:
mean_overlapping_bboxes = float(sum(rpn_accuracy_rpn_monitor))/len(rpn_accuracy_rpn_monitor)
rpn_accuracy_rpn_monitor = []
print('Average number of overlapping bounding boxes from RPN = {} for {} previous iterations'.format(mean_overlapping_bboxes, epoch_length))
if mean_overlapping_bboxes == 0:
print('RPN is not producing bounding boxes that overlap the ground truth boxes. Check RPN settings or keep training.')
print("生成data_gen_train")
#X为经过最小边600的比例变换的原始图像,Y为[所有框位置的和类别(正例还是反例),所有框的前36层为位置和后36层(框和gt的比值)],img_data增强图像后的图像信息
#那么RPN的reg输出也是比值
X, Y, img_data = next(data_gen_train)
print(X.shape,Y[0].shape,Y[1].shape) loss_rpn = model_rpn.train_on_batch(X, Y)
print("loss_rpn",len(loss_rpn))
print("loss_rpn0",loss_rpn[0])
print("loss_rpn1",loss_rpn[1])
print("loss_rpn2",loss_rpn[2]) P_rpn = model_rpn.predict_on_batch(X)
# print("P_rpn_cls",P_rpn[0].reshape((P_rpn[0].shape[1],P_rpn[0].shape[2],P_rpn[0].shape[3]))[:,:,0])
print("P_rpn_cls",P_rpn[0].shape)
print("P_rpn_reg",P_rpn[1].shape) #获得最终选中的框
R = roi_helpers.rpn_to_roi(P_rpn[0], P_rpn[1], C, K.image_dim_ordering(), use_regr=True, overlap_thresh=0.7, max_boxes=300) # note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format
#再对回归出来的框进行一次iou的计算,再一次过滤,只保留bg框和物体框
#X2 from (x1,y1,x2,y2) to (x,y,w,h)
#Y1为每个框对应类别标签,one-host编码
#Y2为每个框和gt的比值,(x,x,160),前80表示框是否正确,后80为20个类别可能的框
X2, Y1, Y2, IouS = roi_helpers.calc_iou(R, img_data, C, class_mapping)
print("X2",X2.shape)
# print("X2_0",X2[0,0,:])
# print("X2_1",X2[0,1,:])
print("Y1",Y1.shape)
print("Y2",Y2.shape) if X2 is None:
rpn_accuracy_rpn_monitor.append(0)
rpn_accuracy_for_epoch.append(0)
continue
#选出正例还是反例的index,背景的为反例,物体为正例
neg_samples = np.where(Y1[0, :, -1] == 1)
pos_samples = np.where(Y1[0, :, -1] == 0)
print("neg_samples",len(neg_samples[0]))
print("pos_samples",len(pos_samples[0])) if len(neg_samples) > 0:
neg_samples = neg_samples[0]
else:
neg_samples = [] if len(pos_samples) > 0:
pos_samples = pos_samples[0]
else:
pos_samples = [] rpn_accuracy_rpn_monitor.append(len(pos_samples))
rpn_accuracy_for_epoch.append((len(pos_samples)))
#num_rois=32,正例要求小于num_rois//2,其它全部由反例填充
if C.num_rois > 1:
if len(pos_samples) < C.num_rois//2:
selected_pos_samples = pos_samples.tolist()
print("selected_pos_samples",len(selected_pos_samples))
else:
selected_pos_samples = np.random.choice(pos_samples, C.num_rois//2, replace=False).tolist()
print("selected_pos_samples",len(selected_pos_samples))
try:
selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=False).tolist()
print("selected_neg_samples",len(selected_neg_samples))
except:
selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=True).tolist()
print("selected_neg_samples",len(selected_neg_samples))
sel_samples = selected_pos_samples + selected_neg_samples
else:
# in the extreme case where num_rois = 1, we pick a random pos or neg sample
selected_pos_samples = pos_samples.tolist()
selected_neg_samples = neg_samples.tolist()
if np.random.randint(0, 2):
sel_samples = random.choice(neg_samples)
else:
sel_samples = random.choice(pos_samples) print("sel_samples",len(sel_samples))
print("sel_samples",sel_samples)
loss_class = model_classifier.train_on_batch([X, X2[:, sel_samples, :]], [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])
# P_classifier = model_classifier.predict([X, X2[:, sel_samples, :]])
# #[out_class, out_regr]
# print("P_classifier_out_class",P_classifier[0].shape)
# print("P_classifier_out_regr",P_classifier[1].shape)
# import cv2
# cv2.waitKey(0)
losses[iter_num, 0] = loss_rpn[1]
losses[iter_num, 1] = loss_rpn[2] losses[iter_num, 2] = loss_class[1]
losses[iter_num, 3] = loss_class[2]
losses[iter_num, 4] = loss_class[3] iter_num += 1 progbar.update(iter_num, [('rpn_cls', np.mean(losses[:iter_num, 0])), ('rpn_regr', np.mean(losses[:iter_num, 1])),
('detector_cls', np.mean(losses[:iter_num, 2])), ('detector_regr', np.mean(losses[:iter_num, 3]))]) if iter_num == epoch_length:
loss_rpn_cls = np.mean(losses[:, 0])
loss_rpn_regr = np.mean(losses[:, 1])
loss_class_cls = np.mean(losses[:, 2])
loss_class_regr = np.mean(losses[:, 3])
class_acc = np.mean(losses[:, 4]) mean_overlapping_bboxes = float(sum(rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
rpn_accuracy_for_epoch = [] if C.verbose:
print('Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'.format(mean_overlapping_bboxes))
print('Classifier accuracy for bounding boxes from RPN: {}'.format(class_acc))
print('Loss RPN classifier: {}'.format(loss_rpn_cls))
print('Loss RPN regression: {}'.format(loss_rpn_regr))
print('Loss Detector classifier: {}'.format(loss_class_cls))
print('Loss Detector regression: {}'.format(loss_class_regr))
print('Elapsed time: {}'.format(time.time() - start_time)) curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
iter_num = 0
start_time = time.time() if curr_loss < best_loss:
if C.verbose:
print('Total loss decreased from {} to {}, saving weights'.format(best_loss,curr_loss))
best_loss = curr_loss
model_all.save_weights(C.model_path) break except Exception as e:
print('Exception: {}'.format(e))
continue print('Training complete, exiting.')

faster RCNN(keras版本)代码讲解(3)-训练流程详情的更多相关文章

  1. 新人如何运行Faster RCNN的tensorflow代码

    0.目的 刚刚学习faster rcnn目标检测算法,在尝试跑通github上面Xinlei Chen的tensorflow版本的faster rcnn代码时候遇到很多问题(我真是太菜),代码地址如下 ...

  2. (原)faster rcnn的tensorflow代码的理解

    转载请注明出处: https://www.cnblogs.com/darkknightzh/p/10043864.html 参考网址: 论文:https://arxiv.org/abs/1506.01 ...

  3. Faster RCNN算法demo代码解析

    一. Faster-RCNN代码解释 先看看代码结构: Data: This directory holds (after you download them): Caffe models pre-t ...

  4. Faster R-CNN利用新的网络结构来训练

    前言 最近利用Faster R-CNN训练数据,使用ZF模型,效果无法有效提高.就想尝试对ZF的网络结构进行改造,记录下具体操作. 一.更改网络,训练初始化模型 这里为了方便,我们假设更换的网络名为L ...

  5. Windows10 Faster R-CNN(GPU版) 配置训练自己的模型

    参考链接 1. 找到合适自己的版本,下载安装Anaconda 点击跳转下载安装 Anaconda,双击下载好的 .exe 文件安装,只勾选第一个把 conda 添加到 PATH 路径.

  6. faster rcnn相关内容

    转自: https://zhuanlan.zhihu.com/p/31426458 faster rcnn的基本结构 Faster RCNN其实可以分为4个主要内容: Conv layers.作为一种 ...

  7. faster rcnn 详解

    转自:https://zhuanlan.zhihu.com/p/31426458 经过R-CNN和Fast RCNN的积淀,Ross B. Girshick在2016年提出了新的Faster RCNN ...

  8. 实战 | 源码入门之Faster RCNN

    前言 学习深度学习和计算机视觉,特别是目标检测方向的学习者,一定听说过Faster Rcnn:在目标检测领域,Faster Rcnn表现出了极强的生命力,被大量的学习者学习,研究和工程应用.网上有很多 ...

  9. Faster RCNN学习笔记

    感谢知乎大神的分享 https://zhuanlan.zhihu.com/p/31426458 Ross B. Girshick在2016年提出了新的Faster RCNN,在结构上,Faster R ...

随机推荐

  1. python 数据处理 对csv文件进行数据处理

    数据如下图: 用python对数据进行处理: #读取csv文件内容并进行数据处理 import os import csv import datetime import re from itertoo ...

  2. WARN No appenders could be found for logger 。。。。

    对于类似与标题的警告信息,一般来说是环境在没有加载log4j的配置文件之前就读取了log4j的包. 解决方法就是先加载log4j的配置文件,然后再加载log4j的包. 另一个解决方案就是移除log4j ...

  3. 014.Delphi插件之QPlugins,MDI窗口

    不知道为什么,这个DEMO编译出来报错,运行不了,在QDAC群里问了一下也没人响应. 效果如下 主程序代码如下 unit Frm_Main; interface uses Winapi.Windows ...

  4. 4.RabbitMQ 4种交换模式

    请看 demo https://github.com/kevin-li-06/eshop.git

  5. sql语句中 and 与or 的优先级

  6. oracle,uuid为主键,插入时直接更新id

    uuid为主键,插入时自动更新 -- Create table create table TECHNOLOGYCOMPANY ( ID VARCHAR2(32) default SYS_GUID() ...

  7. Golang的基础数据类型-字符串型

    Golang的基础数据类型-字符串型 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 一.字符型概述 上一篇博客学习了使用单引号括起来的单个字符就是字符类型,在Golang中有两种表 ...

  8. golang 使用编译选项-H=windowsgui后,仍然输出log到console

    大概原理: 略略略... if debug { modkernel32 := syscall.NewLazyDLL("kernel32.dll") procAllocConsole ...

  9. 126-PHP类__get()魔术方法

    <?php class ren{ //定义人类 //定义成员属性 private $name='Tom'; private $age=15; //定义__get()魔术方法 public fun ...

  10. 5. 支撑高并发,高可用,海量数据备份恢复的Redis重要性

    商品详情页的架构实现 缓存架构 第一块儿,要掌握的很好的,就是redis架构 高并发,高可用,海量数据,备份,随时可以恢复,缓存架构如果要支撑这些要点,首先呢,redis就得支撑 redis架构,每秒 ...