实践案例丨CenterNet-Hourglass论文复现
摘要:本案例是CenterNet-Hourglass论文复现的体验案例,此模型是对Objects as Points 中提出的CenterNet进行结果复现。
本文分享自华为云社区《CenterNet-Hourglass (物体检测/Pytorch)》,作者:HWCloudAI。
目标检测常采用Anchor的方法来获取物体可能存在的位置,再对该位置进行分类,这样的做法耗时、低效,同时需要后处理(比如NMS)。CenterNet将目标看成一个点,即目标bounding box的中心点,整个问题转变成了关键点估计问题,其他目标属性,比如尺寸、3D位置、方向和姿态等都以估计的中心点为基准进行参数回归。
本案例是CenterNet-Hourglass论文复现的体验案例,此模型是对Objects as Points 中提出的CenterNet进行结果复现(原论文Table 2 最后一行)。本模型是以Hourglass网络架构作为backbone,以ExtremNet 作为预训练模型,在COCO数据集上进行50epochs的训练后得到的。本项目是基于原论文的官方代码进行针对ModelArts平台的修改来实现ModelArts上的训练与部署。
注意事项:
1.本案例使用框架:PyTorch1.4.0
2.本案例使用硬件:GPU: 1*NVIDIA-V100NV32(32GB) | CPU: 8 核 64GB
3.运行代码方法: 点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码
4.JupyterLab的详细用法: 请参考《ModelAtrs JupyterLab使用指导》
5.碰到问题的解决办法:请参考《ModelAtrs JupyterLab常见问题解决办法》
1.下载数据和代码
运行下面代码,进行数据和代码的下载和解压
本案例使用COCO数据集。
import os
#数据代码下载
!wget https://obs-aigallery-zc.obs.cn-north-4.myhuaweicloud.com/algorithm/CenterNet.zip
# 解压缩
os.system('unzip CenterNet.zip -d ./')
--2021-06-25 17:50:11-- https://obs-aigallery-zc.obs.cn-north-4.myhuaweicloud.com/algorithm/CenterNet.zip
Resolving proxy-notebook.modelarts.com (proxy-notebook.modelarts.com)... 192.168.6.62
Connecting to proxy-notebook.modelarts.com (proxy-notebook.modelarts.com)|192.168.6.62|:8083... connected.
Proxy request sent, awaiting response... 200 OK
Length: 1529663572 (1.4G) [application/zip]
Saving to: ‘CenterNet.zip’
CenterNet.zip 100%[===================>] 1.42G 279MB/s in 5.6s
2021-06-25 17:50:16 (261 MB/s) - ‘CenterNet.zip’ saved [1529663572/1529663572]
0
2.训练
2.1依赖库加载和安装
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
root_path = './CenterNet/'
os.chdir(root_path)
os.system('pip install pycocotools')
import _init_paths
import torch
import torch.utils.data
from opts import opts
from models.model import create_model, load_model, save_model
from models.data_parallel import DataParallel
from logger import Logger
from datasets.dataset_factory import get_dataset
from trains.train_factory import train_factory
from evaluation import test, prefetch_test, image_infer
USE_MODELARTS = True
INFO:root:Using MoXing-v2.0.0.rc0-19e4d3ab
INFO:root:Using OBS-Python-SDK-3.20.9.1
NMS not imported! If you need it, do
cd $CenterNet_ROOT/src/lib/external
make
2.2训练函数
def main(opt):
torch.manual_seed(opt.seed)
torch.backends.cudnn.benchmark = not opt.not_cuda_benchmark and not opt.test
Dataset = get_dataset(opt.dataset, opt.task)
opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
logger = Logger(opt)
os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str
opt.device = torch.device('cuda' if opt.gpus[0] >= 0 else 'cpu')
print('Creating model...')
model = create_model(opt.arch, opt.heads, opt.head_conv)
optimizer = torch.optim.Adam(model.parameters(), opt.lr)
start_epoch = 0
if opt.load_model != '':
model, optimizer, start_epoch = load_model(
model, opt.load_model, optimizer, opt.resume, opt.lr, opt.lr_step)
Trainer = train_factory[opt.task]
trainer = Trainer(opt, model, optimizer)
trainer.set_device(opt.gpus, opt.chunk_sizes, opt.device)
print('Setting up data...')
train_loader = torch.utils.data.DataLoader(
Dataset(opt, 'train'),
batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.num_workers,
pin_memory=True,
drop_last=True
)
print('Starting training...')
best = 1e10
for epoch in range(start_epoch + 1, opt.num_epochs + 1):
mark = epoch if opt.save_all else 'last'
log_dict_train, _ = trainer.train(epoch, train_loader)
logger.write('epoch: {} |'.format(epoch))
for k, v in log_dict_train.items():
logger.scalar_summary('train_{}'.format(k), v, epoch)
logger.write('{} {:8f} | '.format(k, v))
save_model(os.path.join(opt.save_dir, 'model_last.pth'),
epoch, model)
logger.write('\n')
if epoch in opt.lr_step:
save_model(os.path.join(opt.save_dir, 'model_{}.pth'.format(epoch)),
epoch, model, optimizer)
lr = opt.lr * (0.1 ** (opt.lr_step.index(epoch) + 1))
print('Drop LR to', lr)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
logger.close()
2.3开始训练
训练需要一点时间,请耐心等待
if __name__ == '__main__':
opt = opts().parse()
if USE_MODELARTS:
pwd = os.getcwd()
print('Copying dataset to work space...')
print('Listing directory: ')
print(os.listdir())
if not os.path.exists(opt.save_dir):
os.makedirs(opt.save_dir)
main(opt)
if USE_MODELARTS:
print("Processing model checkpoints & service config for deployment...")
if not opt.eval:
infer_dir = os.path.join(opt.save_dir, 'model')
os.makedirs(infer_dir)
os.system(f'mv ./trained_model/* {infer_dir}')
pretrained_pth = os.path.join(infer_dir, '*.pth')
ckpt_dir = os.path.join(opt.save_dir, 'checkpoints')
os.makedirs(ckpt_dir)
os.system(f'mv {pretrained_pth} {ckpt_dir}')
pth_files = os.path.join(opt.save_dir, '*.pth')
infer_pth = os.path.join(ckpt_dir, f'{opt.model_deploy}.pth')
os.system(f'mv {pth_files} {ckpt_dir}')
os.system(f'mv {infer_pth} {infer_dir}')
print(os.listdir(opt.save_dir))
print("ModelArts post-training work is done!")
Fix size testing.
training chunk_sizes: [8]
The output will be saved to ./output/exp/ctdet/default
Copying dataset to work space...
Listing directory:
['pre-trained_weights', '.ipynb_checkpoints', 'coco_eval.py', 'train.py', 'coco', 'output', 'training_logs', 'trained_model', '_init_paths.py', '__pycache__', 'coco_classes.py', 'lib', 'evaluation.py']
heads {'hm': 80, 'wh': 2, 'reg': 2}
Creating model...
loaded ./trained_model/epoch_50_mAP_42.7.pth, epoch 50
Setting up data...
==> initializing coco 2017 train data.
loading annotations into memory...
Done (t=0.54s)
creating index...
index created!
Loaded train 5000 samples
Starting training...
/home/ma-user/anaconda3/envs/Pytorch-1.4.0/lib/python3.6/site-packages/torch/nn/_reduction.py:43: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
warnings.warn(warning.format(ret))
ctdet/default| train: [1][0/625] |loss 1.7568 |hm_loss 1.3771 |wh_loss 1.9394 |off_loss 0.1857 |Data 0.384s (0.384s) |Net 5.019s (5.019s)
ctdet/default| train: [1][200/625] |loss 1.9275 |hm_loss 1.4429 |wh_loss 2.7269 |off_loss 0.2119 |Data 0.001s (0.003s) |Net 0.759s (0.779s)
ctdet/default| train: [1][400/625] |loss 1.9290 |hm_loss 1.4430 |wh_loss 2.7423 |off_loss 0.2118 |Data 0.001s (0.002s) |Net 0.760s (0.770s)
ctdet/default| train: [1][600/625] |loss 1.9276 |hm_loss 1.4397 |wh_loss 2.7623 |off_loss 0.2117 |Data 0.001s (0.002s) |Net 0.765s (0.767s)
Processing model checkpoints & service config for deployment...
['model', 'logs_2021-06-25-17-51', 'opt.txt', 'checkpoints']
ModelArts post-training work is done!
3.模型测试
3.1推理函数
# -*- coding: utf-8 -*-
# TODO 添加模型运行需要导入的模块
import os
import torch
import numpy as np
from PIL import Image
from io import BytesIO
from collections import OrderedDict
import cv2
import sys
sys.path.insert(0, './lib')
from opts import opts
from coco_classes import coco_class_map
from detectors.detector_factory import detector_factory
class ModelClass():
def __init__(self, model_path):
self.model_path = model_path # 本行代码必须保留,且无需修改
self.opt = opts().parse()
self.opt.num_classes = 80
self.opt.resume = True
self.opt.keep_res = True
self.opt.fix_res = False
self.opt.heads = {'hm': 80, 'wh': 2, 'reg': 2}
self.opt.load_model = model_path
self.opt.mean = np.array([0.40789654, 0.44719302, 0.47026115],
dtype=np.float32).reshape(1, 1, 3)
self.opt.std = np.array([0.28863828, 0.27408164, 0.27809835],
dtype=np.float32).reshape(1, 1, 3)
self.opt.batch_infer = False
# configurable varibales:
if 'BATCH_INFER' in os.environ:
print('Batch inference mode!')
self.opt.batch_infer = True
if 'FLIP_TEST' in os.environ:
print('Flip test!')
self.opt.flip_test = True
if 'MULTI_SCALE' in os.environ:
print('Multi scale!')
self.opt.test_scales = [0.5,0.75,1,1.25,1.5]
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
self.opt.gpus = [-1]
self.class_map = coco_class_map()
torch.set_grad_enabled(False)
Detector = detector_factory[self.opt.task]
self.detector = Detector(self.opt)
print('load model success')
def predict(self, file_name):
image = Image.open(file_name).convert('RGB')
img = np.array(image)
img = img[:, :, ::-1]
results = self.detector.run(img)['results']
image = cv2.cvtColor(np.asarray(image),cv2.COLOR_RGB2BGR)
if not self.opt.batch_infer:
for c_id, dets in results.items():
for det in dets:
if det[4] > self.opt.vis_thresh:
scores = str(round(float(det[4]), 4))
classes = self.class_map[c_id]
image = cv2.rectangle(image,(int(det[0]),int(det[1])),(int(det[2]),int(det[3])),(0,255,0),2)
image = cv2.putText(image,classes+':'+scores,(int(det[0]),int(det[1])),cv2.FONT_HERSHEY_SIMPLEX,0.7,(0,0,255),2)
else:
for c_id, dets in results.items():
for det in dets:
scores = str(round(float(det[4]), 4))
classes = self.class_map[c_id]
image = cv2.rectangle(image,(int(det[0]),int(det[1])),(int(det[2]),int(det[3])),(0,255,0),2)
image = cv2.putText(image,classes+':'+scores,(int(det[0]),int(det[1])),cv2.FONT_HERSHEY_SIMPLEX,0.5,(0,0,255),2)
return image
3.2开始推理
可以自行修改预测的图像路径
if __name__ == '__main__':
import matplotlib.pyplot as plt
img_path = './coco/train/000000021903.jpg'
model_path = './output/exp/ctdet/default/model/model_last.pth' #模型的保存路径,你可以自己找一下
# 以下代码无需修改
my_model = ModelClass(model_path)
result = my_model.predict(img_path)
result = Image.fromarray(cv2.cvtColor(result,cv2.COLOR_BGR2RGB))
plt.figure(figsize=(10,10)) #设置窗口大小
plt.imshow(result)
plt.show()
Fix size testing.
training chunk_sizes: [8]
The output will be saved to ./output/exp/ctdet/default
Creating model...
loaded ./output/exp/ctdet/default/model/model_last.pth, epoch 1
load model success

实践案例丨CenterNet-Hourglass论文复现的更多相关文章
- 实践案例丨基于ModelArts AI市场算法MobileNet_v2实现花卉分类
概述 MobileNetsV2是基于一个流线型的架构,它使用深度可分离的卷积来构建轻量级的深层神经网,此模型基于 MobileNetV2: Inverted Residuals and Linear ...
- 实践案例丨基于 Raft 协议的分布式数据库系统应用
摘要:简单介绍Raft协议的原理.以及存储节点(Pinetree)如何应用 Raft实现复制的一些工程实践经验. 1.引言 在华为分布式数据库的工程实践过程中,我们实现了一个计算存储分离. 底层存储基 ...
- 实践案例丨教你一键构建部署发布前端和Node.js服务
如何使用华为云服务一键构建部署发布前端和Node.js服务 构建部署,一直是一个很繁琐的过程 作为开发,最害怕遇到版本发布,特别是前.后端一起上线发布,项目又特别多的时候. 例如你有10个项目,前后端 ...
- 实践案例丨利用小熊派开发板获取土壤湿度传感器的ADC值
摘要:一文带你用小熊派开发板动手做土壤湿度传感器. 一.实验准备 1.实验环境 一块stm32开发板(推荐使用小熊派),以及数据线 已经安装STM32CubeMX 已经安装KeilMDK,并导入stm ...
- FCOS论文复现:通用物体检测算法
摘要:本案例代码是FCOS论文复现的体验案例,此模型为FCOS论文中所提出算法在ModelArts + PyTorch框架下的实现.本代码支持FCOS + ResNet-101在MS-COCO数据集上 ...
- DDD实践案例:引入事件驱动与中间件机制来实现后台管理功能
DDD实践案例:引入事件驱动与中间件机制来实现后台管理功能 一.引言 在当前的电子商务平台中,用户下完订单之后,然后店家会在后台看到客户下的订单,然后店家可以对客户的订单进行发货操作.此时客户会在自己 ...
- 微服务实战(四):服务发现的可行方案以及实践案例 - DockOne.io
原文:微服务实战(四):服务发现的可行方案以及实践案例 - DockOne.io 这是关于使用微服务架构创建应用系列的第四篇文章.第一篇介绍了微服务架构的模式,讨论了使用微服务架构的优缺点.第二和第三 ...
- - 反编译 AndroidKiller 逆向 实践案例 MD
目录 目录 反编译 AndroidKiller 逆向 实践案例 MD AndroidKiller 简介 插件升级 基本使用 实践案例 修改清单文件 打印 debug 级别的日志 方式一:直接代理 Lo ...
- 《SaltStack技术入门与实践》—— 实践案例 <中小型Web架构>3 Memcached配置管理
实践案例 <中小型Web架构>3 Memcached配置管理 本章节参考<SaltStack技术入门与实践>,感谢该书作者: 刘继伟.沈灿.赵舜东 Memcached介绍 Me ...
- Visualizing and Understanding Convolutional Networks论文复现笔记
目录 Visualizing and Understanding Convolutional Networks 论文复现笔记 Abstract Introduction Approach Visual ...
随机推荐
- 空地一体化网络综述_Space-Air-Ground Integrated Network: A Survey
摘要 空地一体化网络(SAGIN)主要解决的是单一网络下的局限性问题,此综述文章从网络设计.资源分配.到性能的优化,对近几年SAGIN的总结. 引言 受限于网络容量和覆盖范围,仅依靠地面通信系统无法在 ...
- 解决IDEA中.properties文件中文变问号(???)的问题(已解决)
问题背景 构建SpringBoot项目时,项目结构中有一个application.properties文件.这个项目是Spring Boot一个特有的配置文件.内容如下(我写了一些日志的配置): 写到 ...
- ALSA Compress-Offload API
概述 从 ALSA API 的早期开始,它就被定义为支持 PCM,或考虑到了 IEC61937 等固定比特率的载荷.参数和返回值以帧计算是常态,这使得扩展已有的 API 以支持压缩数据流充满挑战. 最 ...
- python 执行脚本,并将输出打印到文件
转载请注明出处: 在使用 python 直接执行脚本时,执行的相关输出会打印到当前的控制台,如果想输出到指定的文件,可以采用以下几种方式: 1.在启动时,使用 > 操作符,该操作符会将输出写入到 ...
- 推荐一个Node.js多版本管理的可视化工具
关于Node.js的开发者来说,在开发机器上管理多个不同版本的Node.js是一个常见痛点.之前在开发者安全大全专栏中,提到过解决方法:使用nvm,如果对于nvm还不了解的话,可以前往了解. 对于TJ ...
- Istio 网格的出口定义者:深入了解 Egress Gateway
本文分享自华为云社区<Istio Egress 出口网关使用>,作者:k8s技术圈. 前面我们了解了位于服务网格内部的应用应如何访问网格外部的 HTTP 和 HTTPS 服务,知道如何通过 ...
- IDEA提示java_ 程序包org.apache.ibatis.session不存在
一.解决方案 1.问题原因: 这是因为配置Java的程序包这块出现了错误,同时可能你还没有设置让IDEA自动加载Jar包,才会报出这种错误的. 2.解决方案: 解决方式如下: File->Set ...
- serdes与PCIE的区别
serdes和PCIE是两种非常常见的总线.因为PCIE也是差分信号传输,所以做硬件时比较难区别PCIE和serdes的具体差异点. 两者之间的区别主要表现在以下几点: 1.PCIE使用了SERDES ...
- 探秘扫雷游戏的C语言实现
1 引言 1.1 为什么写这篇文章? 项目仓库地址:基于 C 语言实现的扫雷游戏 我决定写这篇文章的初衷是想分享我在使用C语言开发扫雷游戏的经验和心得.通过这篇文章,我希望能够向读者展示我是如何利用C ...
- 数据智慧:C#中编程实现自定义计算的Excel数据透视表
前言 数据透视表(Pivot Table)是一种数据分析工具,通常用于对大量数据进行汇总.分析和展示.它可以帮助用户从原始数据中提取关键信息.发现模式和趋势,并以可视化的方式呈现. 在数据透视表中,数 ...