我看很多人都遇到了这个问题,有很多解决了的。我就把这篇博文再完善一下,让大家对mmdetection使用得心应手。


mmdetection训练自己的数据集时报错 ️ :

# AssertionError: The `num_classes` (3) in Shared2FCBBoxHead of MMDataParallel does not matches the length of `CLASSES` 80) in CocoDataset

你可能已经修改了以下两个文件,但是还是报错:

mmdetection-master\mmdet\core\evaluation\class_names.py

mmdetection-master\mmdet\datasets\coco.py

意思就是你指定的类别(3种)与CocoDataset的类别(80种)不匹配。

如果是报错翻过来的话,也就是你指定的类别(80种)与CocoDataset的类别(3种)不匹配。一定是配置文件里设置错了,去你的配置文件搜索num_classes,然后修改好。


废话不多说,直接上方法。有以下几种方法【经过我多次使用后,推荐第四种,方便的很】:

1️⃣ 是修改最少的,假设你有2个类,你就把上边两处地方,前2个类替换成你的类别。方法比较简单,但是可能存在隐患。【不推荐】

2️⃣ 第二种方法就是修改完 class_names.py 和 voc.py 之后一定要重新编译代码(运行python setup.py install),再进行训练。

我试了,有时候可以,有时候不行,可以尝试一下。

参考:

新版 MMDetection V2.3.0训练测试笔记 - it610.com

mmdetectionV2.x版本 训练自己的VOC数据集_桃子酱momo的博客-CSDN博客

3️⃣ 第三种方法,我之前使用的方法,其实跟重新编译一样,重新编译的原因就是因为环境里的源文件没有修改,所以你才会报错。mmdetection-master目录下只是一些python文件,真正运行程序时,运行的还是环境里的源文件,因为我们直接去环境里修改源文件。

假设我的conda环境名为conda_env_name,因此去下面的目录下,分别修改两个文件:

\anaconda3\envs\conda_env_name\lib\python3.7\site-packages\mmdet\core\evaluation\class_names.py

\anaconda3\envs\conda_env_name\lib\python3.7\site-packages\mmdet\datasets\coco.py

在conda环境里把这两个文件里的类别修改了,就可以了,这一招一定可以。

4️⃣ 第四种方法,更简单,更方便,我现在使用的方法。直接在mmdetection配置文件中指定好所有要指定的东西,因为在mmdetection中配置文件的参数值优先级是最高的,所以不用管环境里有没有修改,配置文件里修改了,就可以了。我写了个脚本,把脚本放到mmdetection根目录,根据自己要用的模型,把脚本中的变量都改成自己的。

我以cascade_mask_rcnn_r101为例:

# 在mmdetection的根目录下运行,如果报错:没有那个参数,就把create_mm_config中那个参数赋值给注释掉。生成配置文件后,直接修改配置文件就可以了。
import os
from mmcv import Config ################################# 下边是要修改的内容 #################################### root_path = os.getcwd()
model_name = 'cascade_mask_rcnn_r101' # 改成自己要使用的模型名字
work_dir = os.path.join(root_path, "work_dirs", model_name) # 训练过程中,保存日志权重文件的路径,。
baseline_cfg_path = os.path.join('configs', 'cascade_rcnn', 'cascade_mask_rcnn_r101_fpn_mstrain_3x_coco.py')
# 改成自己要使用的模型的配置文件路径
save_cfg_path = os.path.join(work_dir, 'config.py') # 生成的配置文件保存的路径 train_data_images = os.path.join(root_path, 'data', 'train', 'images') # 改成自己训练集图片的目录。
val_data_images = os.path.join(root_path, 'data', 'train', 'images') # 改成自己验证集图片的目录。
test_data_images = os.path.join(root_path, 'data', 'val', 'images') # 改成自己测试集图片的目 train_ann_file = os.path.join(root_path, 'data', 'train', 'annotations', 'new_train.json') # 修改为自己的数据集的训练集json
val_ann_file = os.path.join(root_path, 'data', 'train', 'annotations', 'new_val.json') # 修改为自己的数据集的验证集json
test_ann_file = os.path.join(root_path, 'data', 'val', 'annotations', 'new_test.json') # 修改为自己的数据集的验证集json录。 # 去找个网址里找你对应的模型的网址: https://github.com/open-mmlab/mmdetection/blob/master/README_zh-CN.md
load_from = os.path.join(work_dir, 'checkpoint.pth') # 修改成自己的checkpoint.pth路径 # File config
num_classes = 50 # 改成自己的类别数。
classes = ('1', '2', '3', '4', '5', '6', '7', '8', '9', '10',
'11', '12', '13', '14', '15', '16', '17', '18', '19',
'20', '21', '22', '23', '24', '25', '26', '27', '28',
'29', '30', '31', '32', '33', '34', '35', '36', '37',
'38', '39', '40', '41', '42', '43', '44', '45', '46',
'47', '48', '49', '50') # 改成自己的类别 ############### 下边一些参数包含不全,可以在生成的配置文件中再对其他参数进行修改 ##################### # Train config # 根据自己的需求对下面进行配置
gpu_ids = range(0, 1) # 改成自己要用的gpu
gpu_num = 1
total_epochs = 20 # 改成自己想训练的总epoch数
batch_size = 2 ** 1 # 根据自己的显存,改成合适数值,建议是2的倍数。
num_worker = 1 # 比batch_size小,就行
log_interval = 300 # 日志打印的间隔
checkpoint_interval = 7 # 权重文件保存的间隔
lr = 0.02 * batch_size * gpu_num / 16 # 学习率
ratios = [0.5, 1.0, 2.0]
strides = [4, 8, 16, 32, 64] cfg = Config.fromfile(baseline_cfg_path) if not os.path.exists(work_dir):
os.makedirs(work_dir) cfg.work_dir = work_dir
print("Save config dir:", work_dir) # swin和mmdetection的训练集配置不在一个地方,那个不报错用哪个:
cfg.classes = classes
# mmdetection用这个:
cfg.data.train.img_prefix = train_data_images
cfg.data.train.classes = classes
cfg.data.train.ann_file = train_ann_file
# swin用这个,注释上边那个
# cfg.data.train.dataset.img_prefix = train_data_images
# cfg.data.train.dataset.classes = classes
# cfg.data.train.dataset.ann_file = train_ann_file cfg.data.val.img_prefix = val_data_images
cfg.data.val.classes = classes
cfg.data.val.ann_file = val_ann_file cfg.data.test.img_prefix = test_data_images
cfg.data.test.classes = classes
cfg.data.test.ann_file = test_ann_file cfg.data.samples_per_gpu = batch_size
cfg.data.workers_per_gpu = num_worker
cfg.log_config.interval = log_interval # 有些配置文件num_classes可能不在这个地方,生成之后去配置文件里搜索一下,看看都修改了没
for head in cfg.model.roi_head.bbox_head:
head.num_classes = num_classes
if "mask_head" in cfg.model.roi_head:
cfg.model.roi_head.mask_head.num_classes = num_classes cfg.load_from = load_from
cfg.runner.max_epochs = total_epochs
cfg.total_epochs = total_epochs
cfg.optimizer.lr = lr
cfg.checkpoint_config.interval = checkpoint_interval
cfg.model.rpn_head.anchor_generator.ratios = ratios
cfg.model.rpn_head.anchor_generator.strides = strides cfg.dump(save_cfg_path)
print(save_cfg_path)
print("—" * 50)
print(f'CONFIG:\n{cfg.pretty_text}')
print("—" * 50)

生成配置文件后,路径在 ./work_dirs/cascade_mask_rcnn_r101/config.py,在mmdetection根目录下,又可以愉快的进行训练了。

训练命令(在mmdetection根目录):4GPU训练

./tools/dist_train.sh work_dirs/cascade_mask_rcnn_r101/config.py 4

最终也功夫不负有心人,解决掉了这个bug,写此博客,以帮助大家少走弯路。

# AssertionError: The `num_classes` (3) in Shared2FCBBoxHead of MMDataParallel does not matches the length of `CLASSES` 80) in CocoDataset的更多相关文章

  1. Hadoop 2.6.0编译on mac

    花了一个晚上的时间弄了下hadoop的编译环境,碰到些错误,这里保存下. 需要编译Hadoop,不但需要安装Maven,还需要安装protobuf 安装Maven 下载:apache-maven-3. ...

  2. 学习笔记TF020:序列标注、手写小写字母OCR数据集、双向RNN

    序列标注(sequence labelling),输入序列每一帧预测一个类别.OCR(Optical Character Recognition 光学字符识别). MIT口语系统研究组Rob Kass ...

  3. Pytorch版本yolov3源码阅读

    目录 Pytorch版本yolov3源码阅读 1. 阅读test.py 1.1 参数解读 1.2 data文件解析 1.3 cfg文件解析 1.4 根据cfg文件创建模块 1.5 YOLOLayer ...

  4. Eclipse 导入Hadoop 2.6.0 源码

    1. 首先前往 官网(Hadoop 2.6 下载地址)上下载Hadoop的源码文件,并解压 2. 事先请确定已经安装好jdk以及maven(Maven安装教程 这是其他人写的一篇博文,保存profil ...

  5. 目标检测之车辆行人(tensorflow版yolov3)

    背景: 在自动驾驶中,基于摄像头的视觉感知,如同人的眼睛一样重要.而目前主流方案基本都采用深度学习方案(tensorflow等),而非传统图像处理(opencv等). 接下来我们就以YOLOV3为基本 ...

  6. 卷积神经网络学习笔记——Siamese networks(孪生神经网络)

    完整代码及其数据,请移步小编的GitHub地址 传送门:请点击我 如果点击有误:https://github.com/LeBron-Jian/DeepLearningNote 在整理这些知识点之前,我 ...

  7. 开着idea,死机了,关机重启。重启之后,重新打开idea报错java.lang.AssertionError:upexpected content storage modification

    开着idea,死机了,关机重启.重启之后,重新打开idea报错java.lang.AssertionError:upexpected content storage modification. goo ...

  8. [改善Java代码]在switch的default代码块中增加AssertionError错误

    switch的后跟枚举类型,case后列出所有的枚举项,这是一个使用枚举的主流写法,那留着default语句似乎没有任何作用了,程序已经列举出了所有的可能选项,肯定不会执行到default语句,. 错 ...

  9. 【Flask】报错解决方法:AssertionError: View function mapping is overwriting an existing endpoint function: main.user

    运行Flask时出现了一个错误, AssertionError: View function mapping is overwriting an existing endpoint function: ...

  10. 解决用try except 捕获assert函数产生的AssertionError异常时,导致断言失败的用例在测试报告中通过的问题

    在使用Python3做自动化测试过程中可能会遇到,assert函数不加try  except,就可以正常在报告里体现用例不通过,加上变成通过. 这是因为在使用try except 时,捕获了asser ...

随机推荐

  1. 2、mysql存储引擎

    存储引擎 1 存储引擎概述 和大多数的数据库不同, MySQL中有一个存储引擎的概念, 针对不同的存储需求可以选择最优的存储引擎. 存储引擎就是存储数据,建立索引,更新查询数据等等技术的实现方式 .存 ...

  2. 为什么现在连Date类都不建议使用了?

    一.有什么问题吗java.util.Date? java.util.Date(Date从现在开始)是一个糟糕的类型,这解释了为什么它的大部分内容在 Java 1.1 中被弃用(但不幸的是仍在使用). ...

  3. 时间同步 ntp服务器

    目录 一. 定义 二. 项目要求 三. 部署服务端 四. 部署客户端 一. 定义 #01 简介:ntp全名 network time protocol .NTP服务器可以为其他主机提供时间校对服务 # ...

  4. Zabbix6.0使用教程 (三)—zabbix6.0的安装要求

    接上篇,我们继续为大家详细介绍zabbix6.0的使用教程之zabbix6.0的安装部署.接下来我们将从zabbix部署要求到四种不同的安装方式逐一详细的为大家介绍.本篇讲的是部署zabbix6.0的 ...

  5. rst文件查看(Sphinx)

    reStructuredText ( RST . ReST 或 reST )是一种用于文本数据的文件格式,主要用于 Python 编程语言社区的技术文档. 在下载了别人的Python源文件里面有rst ...

  6. 什么是k8s中的sidecar模式

    在Kubernetes中,Sidecar模式是一种将辅助容器与主应用程序容器一起部署在同一个Pod中的设计模式.这种模式的目的是将辅助功能与主应用程序解耦,并提供独立发布.能力重用以及共享资源和网络的 ...

  7. 手动将jar包安装到本地仓库并使用tomcat运行

    参考,欢迎点击原文:https://www.cnblogs.com/panchanggui/p/10696458.html 公司有个老项目是spring的,要我们自己本地安装,我发现我maven一直报 ...

  8. 11_使用SDL播放WAV

    使用命令播放WAV 对于WAV文件来说,可以直接使用ffplay命令播放,而且不用像PCM那样增加额外的参数.因为WAV的文件头中已经包含了相关的音频参数信息. ffplay in.wav 接下来演示 ...

  9. NJUPT第二次积分赛小结与视觉部分开源

    NJUPT第二次积分赛小结与视觉部分开源 跟队友连肝一周多积分赛,写了一堆屎山,总算是今天完赛了.结果也还行,80分到手.其实题目是全做完了的,但验收时我nt了没操作好导致丢了不少分,而且整个控制流程 ...

  10. 记录--产品:请给我实现一个在web端截屏的功能!

    这里给大家分享我在网上总结出来的一些知识,希望对大家有所帮助 一.故事的开始 最近产品又开始整活了,本来是毫无压力的一周,可以小摸一下鱼的,但是突然有一天跟我说要做一个在网页端截屏的功能. 作为一个工 ...