github 上大神的代码 https://github.com/endernewton/tf-faster-rcnn.git

在自己跑的过程中的问题:

1. 数据集的问题:

作者实现了 voc,coco数据集接口。由于我要跑自己的数据,所以要重写数据接口。为了方便我将自己的数据格式改为voc的数据格式,使用原来voc的数据接口pascal_voc.py。

voc 数据格式中需要文件:

data

-----VOCdevkit2007  (自己可以改)

|

----VOC2007

|

-----Annotations (目标的标注文件.xml)

-----ImageSets

|

-----  trainval.txt  (用于训练的图像名)

----- test.txt     (用于测试的图像名)

-----JPEGImages  (jpg 图像)

具体  .xml 文件编写根据自己已有的数据

写xml 文件主要内容:

from  xml.dom.minidom import Document

doc=Document()
Annotation=doc.createElement('annotation') # 创建annotation 域
doc.appendChild(Annotation) # 写入annotation 域 object=doc.createElement('object')
Annotation.appendChild('object') # 写入name
object_name=doc.createElement('name')
object_name_text=doc.createTextNode('分类类别名')
object_name.appendChild(object_name_text)
object.appendChild(object_name) # 写入difficult,虽然不用,但是如果不加直接使用pascal_voc会出错
object_difficult=doc.createElement('difficult')
object_difficult_text=doc.createTextNode('0')
object_difficult.appendChild(object_difficult_text)
object.appendChild(object_difficult) # 写入box
bndbox=doc.createElement('bndbox')
object.appendChild(bndbox) object_box=doc.createElement('bndbox')
object_box_xmin=doc.createElement('xmin')
object_box_xmin_text=doc.createTextNode(str(image_box[0]))
object_box_xmin.appendChild(object_box_xmin_text)
bndbox.appendChild(object_box_xmin) object_box_ymin=doc.createElement('ymin')
object_box_ymin_text=doc.createTextNode(str(image_box[1]))
object_box_ymin.appendChild(object_box_ymin_text)
bndbox.appendChild(object_box_ymin) object_box_xmax=doc.createElement('xmax')
object_box_xmax_text=doc.createTextNode(str(image_box[2]))
object_box_xmax.appendChild(object_box_xmax_text)
bndbox.appendChild(object_box_xmax) object_box_ymax=doc.createElement('ymax')
object_box_ymax_text=doc.createTextNode(str(image_box[3]))
object_box_ymax.appendChild(object_box_ymax_text)
bndbox.appendChild(object_box_ymax) f=open(filename,"w")
f.write(doc.toprettyxml(indent=" "))
f.close()

  得到:

<annotation>
<object>
<name>abc</name>
<difficult>0</difficult>
<bndbox>
<xmin>107</xmin>
<ymin>155</ymin>
<xmax>193</xmax>
<ymax>214</ymax>
</bndbox>
</object>
</annotation>

改pascal_voc.py 文件,修改自己的classes,以及xml中对应域的名字等。

2. 数据完成之后,就可以用来训练了,此时出现问题:

Assign requires shapes of both tensors to match. lhs shape= [2048,124] rhs shape= [2048,84]

因为我现在变为30类,30+1 (背景),31*4=124 (4为box 的定位),而原来为84类。

怎么改最后的输出类别个数?在caffe中可以直接在prototxt 定义的网络结构中改,在tensorflow中怎么改呢?

  1. 我们执行train_faster_rcnn 传入了(gpuId, dataset, net) 调用tools/trainval_net.py
  2. 在trainval_net.py 中调用net=resnetv1, load 网络模型, 调用models/train_net
  3. 在train_net 中调用train_model 函数,定义计算图,在initialize 函数中对sess 进行初始化
  def initialize(self, sess):
# Initial file lists are empty
np_paths = []
ss_paths = []
# Fresh train directly from ImageNet weights
print('Loading initial model weights from {:s}'.format(self.pretrained_model))
variables = tf.global_variables()
# Initialize all variables first
sess.run(tf.variables_initializer(variables, name='init'))
var_keep_dic = self.get_variables_in_checkpoint_file(self.pretrained_model)
# Get the variables to restore, ignoring the variables to fix
variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic)
# 要加载的变量
restorer = tf.train.Saver(variables_to_restore)
# 进行加载。。出错的地方就是这里
restorer.restore(sess, self.pretrained_model)
print('Loaded.')
# Need to fix the variables before loading, so that the RGB weights are changed to BGR
# For VGG16 it also changes the convolutional weights fc6 and fc7 to
# fully connected weights
self.net.fix_variables(sess, self.pretrained_model)
print('Fixed.')
last_snapshot_iter = 0
rate = cfg.TRAIN.LEARNING_RATE
stepsizes = list(cfg.TRAIN.STEPSIZE) return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

  要改正,就要不加载最后的 预测层和 box 回归层。

对要加载的文件进行选择,然后就可训练自己的数据了

tensorflow faster rann的更多相关文章

  1. tensorflow faster rcnn 代码分析一 demo.py

    os.environ["CUDA_VISIBLE_DEVICES"]=2 # 设置使用的GPU tfconfig=tf.ConfigProto(allow_soft_placeme ...

  2. Tensorflow faster rcnn系列一

    注意:本文主要是学习用,发现了一个在faster rcnn训练流程写的比较详细的博客. 大部分内容来自以下博客连接:https://blog.csdn.net/weixin_37203756/arti ...

  3. python3 + Tensorflow + Faster R-CNN训练自己的数据

    之前实现过faster rcnn, 但是因为各种原因,有需要实现一次,而且发现许多博客都不全面.现在发现了一个比较全面的博客.自己根据这篇博客实现的也比较顺利.在此记录一下(照搬). 原博客:http ...

  4. Faster_Rcnn在windows下运行踩坑总结

    Faster_Rcnn在windows下运行踩坑总结  20190524 今天又是元气满满的一天! 1.代码下载 2.编译 3.下载数据集 4.下载pre-train Model 5.运行train ...

  5. TensorFlow_Faster_RCNN中demo.py的运行(CPU Only)

    GitHub项目地址,https://github.com/endernewton/tf-faster-rcnnTensorflow Faster RCNN for Object Detection. ...

  6. Technology Document Guide of TensorRT

    Technology Document Guide of TensorRT Abstract 本示例支持指南概述了GitHub和产品包中包含的所有受支持的TensorRT 7.2.1示例.Tensor ...

  7. 新人如何运行Faster RCNN的tensorflow代码

    0.目的 刚刚学习faster rcnn目标检测算法,在尝试跑通github上面Xinlei Chen的tensorflow版本的faster rcnn代码时候遇到很多问题(我真是太菜),代码地址如下 ...

  8. Tensorflow版Faster RCNN源码解析(TFFRCNN) (2)推断(测试)过程不使用RPN时代码运行流程

    本blog为github上CharlesShang/TFFRCNN版源码解析系列代码笔记第二篇   推断(测试)过程不使用RPN时代码运行流程 作者:Jiang Wu  原文见:https://hom ...

  9. TensorFlow Object Detection API中的Faster R-CNN /SSD模型参数调整

    关于TensorFlow Object Detection API配置,可以参考之前的文章https://becominghuman.ai/tensorflow-object-detection-ap ...

随机推荐

  1. Python3 与 C# 面向对象之~异常相关

      周末多码文,昨天晚上一篇,今天再来一篇: 在线编程:https://mybinder.org/v2/gh/lotapp/BaseCode/master 在线预览:http://github.les ...

  2. NowCoder--牛客练习赛30 C_小K的疑惑

    题目链接 :牛客练习赛30 C_小K的疑惑 i j k 可以相同 而且 距离%2 只有 0 1两种情况 我们考虑 因为要 d(i j)=d(i k)=d(j k) 所以我们只能找 要么三个点 任意两个 ...

  3. Python中pandas dataframe删除一行或一列:drop函数

    用法:DataFrame.drop(labels=None,axis=0, index=None, columns=None, inplace=False) 参数说明:labels 就是要删除的行列的 ...

  4. Linux:不同文件相同列字符合并文件(awk函数)

    存在file1.txt,其内容如下: H aa 0 0 1 -9 H bb 0 0 2 -9 H cc 0 0 2 -9 存在file2.txt,其内容如下: H aa 0 0 0 -9 asd qw ...

  5. java web整合office web apps

    1.下载安装vmware虚拟机 2.下载windows server 2012或者window server 2012 R2的iso镜像 http://www.xp85.com/html/Window ...

  6. POJ 3249 Test for Job (记忆化搜索)

    Test for Job Time Limit: 5000MS   Memory Limit: 65536K Total Submissions: 11830   Accepted: 2814 Des ...

  7. 新建工程时报错(26, 13) Failed to resolve: com.android.support:appcompat-v7:28.+ ,

    allprojects { repositories { jcenter() maven { url "https://maven.google.com" } } }

  8. octave基本操作

    参考: https://blog.csdn.net/iszhenyu/article/details/78712228:  吴恩达机器学习视频: 在学习机器学习的过程中,免不了要跟MATLAB.Oct ...

  9. 《Java 程序设计》第一周学习总结

    本周因为刚刚接触Linux和码云等等,所以在完成作业的时候遇到很多问题. 首先,在安装Linux没有安装盘片,在盘片安装之后成功建立虚拟机,建立虚拟机后首先要下载jdk,第一次下载时没有选对格式,Li ...

  10. 剑指Offer_编程题_11

    题目描述 输入一个整数,输出该数二进制表示中1的个数.其中负数用补码表示.   class Solution { public: int NumberOf1(int n) { int size = 3 ...