代码结构概览

核心部分

  • configs:储存各种网络的yaml配置文件
  • datasets:存放数据集的地方
  • detectron2:运行代码的核心组件
  • tools:提供了运行代码的入口以及一切可视化的代码文件。

Tutorial部分

  • demo:显而易见就是demo
  • docs: 同样显而易见。。
  • tests:提供了一些测试代码
  • projects:提供了真实的项目代码示例,之后自己的代码结构可参照这个结构写。

代码逻辑分析

超参数配置

进入tools/train_net.pymain函数,第一行cfg = setup(args)是配置参数。Detectron2中的参数配置使用了yacs这个库,这个库能够很好地重用和拼接超参数文件配置。

我们先看一下detrctron2/config/的文件结构:

  • compat.py: 应该是对之前的Detectron库的兼容吧,可忽略。
  • config.py: 定义了一个CfgNode类,这个类继承自fvcore库(fb写的一个共公共库,提供一些共享的函数,方便各种不同项目使用)中定义的CfgNode,总之就是不断继承。。。继承关系是这样的:

    detrctron2.config.CfgNode->fcvore.common.config.CfgNode->yacs.config.CfgNode->dict

    另外该文件还提供了get_cfg()方法,该方法会返回一个含有默认配置的CfgNode,而这些默认的配置值在下面的default.py中定义了,之所以这样做是因为要配置的默认值太多了,所以为了文档清晰才写到了一个新的文件中去,不过,yacs库的作者也建议这样做。
  • default.py: 如上面所说,该文件定义了各种参数的默认值。

了解配置函数的方法后我们再回到tools/train_net.py,我们一行一行的来理解。

  • tools/train_net.py
from detectron2.config import get_cfg
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch
... def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(cfg, args)
return cfg
  • cfg = get_cfg(): 获取已经配置好默认参数的cfg
  • cfg.merge_from_file(args.config_file):config_file是指定的yaml配置文件,通过merge_from_file这个函数会将yaml文件中指定的超参数对默认值进行覆盖。
  • cfg.merge_from_list(args.opts):merge_from_list作用同上面的类似,只不过是通过命令行的方式覆盖。

    例如
opts = ["SYSTEM.NUM_GPUS", 8, "TRAIN.SCALES", "(1, 2, 3, 4)"]
cfg.merge_from_list(opts)
print("cfg\n",cfg)

那么最后会有

cfg
... (一些默认值超参数)
SYSTEM:
NUM_GPUS: 8
TRAIN:
SCALES: (1,2,3,4)
  • cfg.freeze(): freeze函数的作用是将超参数值冻结,避免被程序不小心修改。
  • default_setup(cfg, args):default_setupdetectron2/engine/default.py中提供的一个默认配置函数,具体是怎么配置的这里不详细说明了。不过需要知道的值这个文件中还提供了很多其他的配置函数,例如还提供了两个类:DefaultPredictorDefaultTrainer

Trainer

既然上面提到了DefaultTrainer,那么我们就从这个类入手了解一下detectron2.engine,其代码结构如下:

  • train_loop.py: 这个函数主要作用是提供了三个重要的类:

    • HookBase: 这是一个Hook的基类,用于指定在训练前后或者每一个step前后需要做什么事情,所以根据特定的需求需要对如下四种方法做不同的定义:before_train,after_train,before_step,after_step。以before_step
    • TrainerBase: 该类中定义的函数可以归纳成三种:
      • register_hooks:这个很好理解,就是将用户定义的一些hooks进行注册,说大白话就是把若干个Hook放在一个list里面去。之后只需要遍历这个list依次执行就可以了。
      • 第二类其实就是上面提到的遍历hook list并执行hook,不过这个遍历有四种,分别是before_train,after_train,before_step,after_step。还有一个就是run_step,这个函数其实就是平常我们在编写训练过程的代码,例如读数据,训练模型,获取损失值,求导数,反向梯度更新等,只不过在这个类里面没有定义。
      • 第三类就是train函数,它有两个参数,分别是开始的迭代数和最大的迭代数。之后就是重复依次执行第二类中的函数指定迭代次数。
    • SimpleTrainer:其实就是继承自TrainerBase,然后定义了run_step等方法。我们后面也可以继承这个类做进一步的自定义。
  • defaults.py: 上面已介绍,提供了两个类:DefaultPredictorDefaultTrainer,这个DefaultTrainer就继承自SimpleTrainer,所以存在如下继承关系:

    detectron2.engine.default.DefaultTrainer->detectron2.engine.train_loop.SimpleTrainer->detectron2.engine.train_loop.TrainerBase

  • hooks.py:定义了很多继承自train_loop.HookBase的Hook。

  • launch.py: 前面提到过,可以理解成代码启动器,可以根据命令决定是否采用分布式训练(或者单机多卡)或者单机单卡训练。

好了,我们继续回到tools/train_net.py的main函数,代码如下所示。

def main(args):
cfg = setup(args) if args.eval_only:
...
trainer = Trainer(cfg)
trainer.resume_or_load(resume=args.resume)
if cfg.TEST.AUG.ENABLED:
trainer.register_hooks(
[hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
)
return trainer.train()

可以看到下面定义了一个Trainer,它继承自detectron2.engine.default.DefaultTrainer,这个父类会自动解析cfg。之后只需要调用trainer.train()就可以开始训练了。

小结

至此我们对detectron2的逻辑有了大致的了解了,那么接下来我们来了解一下detectron2.engine.default.DefaultTrainer是如何解析cfg的,这部分内容请参见Detectron2代码阅读笔记-(二)

微信公众号:AutoML机器学习

MARSGGBO♥原创

如有意合作或学术讨论欢迎私戳联系~
邮箱:marsggbo@foxmail.com






2019-10-15 10:37:50

Detectron2源码阅读笔记-(一)Config&Trainer的更多相关文章

  1. Detectron2源码阅读笔记-(二)Registry&build_*方法

    ​ Trainer解析 我们继续Detectron2代码阅读笔记-(一)中的内容. 上图画出了detectron2文件夹中的三个子文件夹(tools,config,engine)之间的关系.那么剩下的 ...

  2. Detectron2源码阅读笔记-(三)Dataset pipeline

    构建data_loader原理步骤 # engine/default.py from detectron2.data import ( MetadataCatalog, build_detection ...

  3. CI框架源码阅读笔记4 引导文件CodeIgniter.php

    到了这里,终于进入CI框架的核心了.既然是“引导”文件,那么就是对用户的请求.参数等做相应的导向,让用户请求和数据流按照正确的线路各就各位.例如,用户的请求url: http://you.host.c ...

  4. CI框架源码阅读笔记3 全局函数Common.php

    从本篇开始,将深入CI框架的内部,一步步去探索这个框架的实现.结构和设计. Common.php文件定义了一系列的全局函数(一般来说,全局函数具有最高的加载优先权,因此大多数的框架中BootStrap ...

  5. CI框架源码阅读笔记2 一切的入口 index.php

    上一节(CI框架源码阅读笔记1 - 环境准备.基本术语和框架流程)中,我们提到了CI框架的基本流程,这里再次贴出流程图,以备参考: 作为CI框架的入口文件,源码阅读,自然由此开始.在源码阅读的过程中, ...

  6. Apollo源码阅读笔记(二)

    Apollo源码阅读笔记(二) 前面 分析了apollo配置设置到Spring的environment的过程,此文继续PropertySourcesProcessor.postProcessBeanF ...

  7. Apollo源码阅读笔记(一)

    Apollo源码阅读笔记(一) 先来一张官方客户端设计图,方便我们了解客户端的整体思路. 我们在使用Apollo的时候,需要标记@EnableApolloConfig来告诉程序开启apollo配置,所 ...

  8. CI框架源码阅读笔记5 基准测试 BenchMark.php

    上一篇博客(CI框架源码阅读笔记4 引导文件CodeIgniter.php)中,我们已经看到:CI中核心流程的核心功能都是由不同的组件来完成的.这些组件类似于一个一个单独的模块,不同的模块完成不同的功 ...

  9. 源码阅读笔记 - 1 MSVC2015中的std::sort

    大约寒假开始的时候我就已经把std::sort的源码阅读完毕并理解其中的做法了,到了寒假结尾,姑且把它写出来 这是我的第一篇源码阅读笔记,以后会发更多的,包括算法和库实现,源码会按照我自己的代码风格格 ...

随机推荐

  1. Webdriver get(url)加载时间太长

    运行Selenium脚本时,发现有时候由于网络或性能问题,加载网页时间太长,无法继续执行后续操作,但是实际上元素都已经加载出来了. 解决 # 设置页面加载超时时间 d.set_page_load_ti ...

  2. CDN惹的祸:记一次使用OSS设置跨域资源共享(CORS)不生效的问题

    原文: https://www.lastupdate.net/4669.html 昨天H5组的开发反馈了一个问题,说浏览器收不到跨域的配置,提示:Failed to load https://nnmj ...

  3. ValueError: Graph disconnected: cannot obtain value for tensor Tensor

    一般是Input和下面的变量重名了,导致model里面的input变成了第二次出现的Input变量,而不是最开始模型中作为输入的Input变量 改正方法:给第二个变量赋一个新名字即可

  4. 【IntelliJ IDEA学习之五】IntelliJ IDEA 搭建项目

    版本:IntelliJIDEA2018.1.4 一.同一窗口展示多个应用(弊端:耗内存) idea没有eclipse workspace的概念,如果想在同一窗口显示多个应用,可以按照如下方式来做:1. ...

  5. Sql Server 数据库出现“可疑”的解决办法

    问题:数据库名称出现“可疑”字样 解决: 方法一: 重新还原数据库 方法二: 贴上语句:(DB_CS:你的数据库名) 第一步: ALTER DATABASE DB_CS SET EMERGENCY 第 ...

  6. Java stream 并发应用案例

    在磁盘目录下有几十个txt文件,里面存储着XML格式的数据,每个文件在2-3M左右,现在需要将以上文件解析出来保存到mysql数据库,总数据量大概在30万条左右. (1)首先通过stream并发解析T ...

  7. HashMap源码1

    jdk1.8之前是数组+链表的形式,后面会介绍jdk1.8对hashMap的改动:数组+链表+红黑树 transient是Java语言的关键字,用来表示一个域不是该对象串行化的一部分. 当一个对象被串 ...

  8. 前端与算法 leetcode 344. 反转字符串

    目录 # 前端与算法 leetcode 344. 反转字符串 题目描述 概要 提示 解析 解法一:双指针 解法二:递归 算法 传入测试用例的运行结果 执行结果 GitHub仓库 # 前端与算法 lee ...

  9. JMeter分布式测试环境搭建(禁用SSL)

    JMeter分布式环境,一台Master,一到多台Slave,Master和Slave可以是同一台机器. 前提条件: 所有机器,包括master和slave的机器: 1.运行相同版本的JMeter 2 ...

  10. Linux下c语言TCP多线程聊天室

    开发环境:Linux,GCC 相关知识:TCP(博客:传送门),线程 附加:项目可能还有写不足之处,有些bug没调出来(如:对在线人数的控制),希望大佬赐教. 那么话不多说,放码过来: 码云:传送门, ...