摘要:本案例代码是FCOS论文复现的体验案例,此模型为FCOS论文中所提出算法在ModelArts + PyTorch框架下的实现。本代码支持FCOS + ResNet-101在MS-COCO数据集上完整的训练和测试流程

本文分享自华为云社区《通用物体检测算法 FCOS(目标检测/Pytorch)》,作者: HWCloudAI 。

FCOS:Fully Convolutional One-Stage Object Detection

本案例代码是FCOS论文复现的体验案例

此模型为FCOS论文中所提出算法在ModelArts + PyTorch框架下的实现。该算法使用MS-COCO公共数据集进行训练和评估。本代码支持FCOS + ResNet-101在MS-COCO数据集上完整的训练和测试流程

具体的算法介绍:https://marketplace.huaweicloud.com/markets/aihub/modelhub/detail/?id=ce7acc40-0540-45c9-a0c6-e2fda8d1ac7e

注意事项:

1.本案例使用框架: PyTorch1.0.0

2.本案例使用硬件: GPU

3.运行代码方法: 点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码

1.数据和代码下载

import os
import moxing as mox
# 数据代码下载
mox.file.copy_parallel('obs://obs-aigallery-zc/algorithm/FCOS.zip','FCOS.zip')
# 解压缩
os.system('unzip FCOS.zip -d ./')

2.模型训练

2.1依赖库安装及加载

"""
Basic training script for PyTorch
"""
# Set up custom environment before nearly anything else is imported
# NOTE: this should be the first import (no not reorder)
import os
import argparse
import torch
import shutil
src_dir = './FCOS/'
os.chdir(src_dir)
os.system('pip install -r ./pip-requirements.txt')
os.system('python -m pip install ./trained_model/model/framework-2.0-cp36-cp36m-linux_x86_64.whl')
os.system('python setup.py build develop')
from framework.utils.env import setup_environment
from framework.config import cfg
from framework.data import make_data_loader
from framework.solver import make_lr_scheduler
from framework.solver import make_optimizer
from framework.engine.inference import inference
from framework.engine.trainer import do_train
from framework.modeling.detector import build_detection_model
from framework.utils.checkpoint import DetectronCheckpointer
from framework.utils.collect_env import collect_env_info
from framework.utils.comm import synchronize, \
get_rank, is_pytorch_1_1_0_or_later
from framework.utils.logger import setup_logger
from framework.utils.miscellaneous import mkdir

2.2训练函数

def train(cfg, local_rank, distributed, new_iteration=False):
model = build_detection_model(cfg)
device = torch.device(cfg.MODEL.DEVICE)
model.to(device)
if cfg.MODEL.USE_SYNCBN:
assert is_pytorch_1_1_0_or_later(), \
"SyncBatchNorm is only available in pytorch >= 1.1.0"
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
optimizer = make_optimizer(cfg, model)
scheduler = make_lr_scheduler(cfg, optimizer)
if distributed:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[local_rank], output_device=local_rank,
# this should be removed if we update BatchNorm stats
broadcast_buffers=False,
)
arguments = {}
arguments["iteration"] = 0
output_dir = cfg.OUTPUT_DIR
save_to_disk = get_rank() == 0
checkpointer = DetectronCheckpointer(
cfg, model, optimizer, scheduler, output_dir, save_to_disk
)
print(cfg.MODEL.WEIGHT)
extra_checkpoint_data = checkpointer.load_from_file(cfg.MODEL.WEIGHT)
print(extra_checkpoint_data)
arguments.update(extra_checkpoint_data)
if new_iteration:
arguments["iteration"] = 0
data_loader = make_data_loader(
cfg,
is_train=True,
is_distributed=distributed,
start_iter=arguments["iteration"],
)
do_train(
model,
data_loader,
optimizer,
scheduler,
checkpointer,
device,
arguments,
)
return model

2.3设置参数,开始训练

def main():
parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
parser.add_argument(
'--train_url',
default='./outputs',
type=str,
help='the path to save training outputs'
)
parser.add_argument(
"--config-file",
default="./trained_model/model/fcos_resnet_101_fpn_2x.yaml",
metavar="FILE",
help="path to config file",
type=str,
)
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument('--train_iterations', default=0, type=int)
parser.add_argument('--warmup_iterations', default=500, type=int)
parser.add_argument('--train_batch_size', default=8, type=int)
parser.add_argument('--solver_lr', default=0.01, type=float)
parser.add_argument('--decay_steps', default='120000,160000', type=str)
parser.add_argument('--new_iteration',default=False, action='store_true')
args, unknown = parser.parse_known_args()
cfg.merge_from_file(args.config_file)
# load the model trained on MS-COCO
if args.train_iterations > 0:
cfg.SOLVER.MAX_ITER = args.train_iterations
if args.warmup_iterations > 0:
cfg.SOLVER.WARMUP_ITERS = args.warmup_iterations
if args.train_batch_size > 0:
cfg.SOLVER.IMS_PER_BATCH = args.train_batch_size
if args.solver_lr > 0:
cfg.SOLVER.BASE_LR = args.solver_lr
if len(args.decay_steps) > 0:
steps = args.decay_steps.replace(' ', ',')
steps = steps.replace(';', ',')
steps = steps.replace(';', ',')
steps = steps.replace(',', ',')
steps = steps.split(',')
steps = tuple([int(x) for x in steps])
cfg.SOLVER.STEPS = steps
cfg.freeze()
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
args.distributed = num_gpus > 1
if args.distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(
backend="nccl", init_method="env://"
)
synchronize()
output_dir = args.train_url
if output_dir:
mkdir(output_dir)
logger = setup_logger("framework", output_dir, get_rank())
logger.info("Using {} GPUs".format(num_gpus))
logger.info(args)
logger.info("Loaded configuration file {}".format(args.config_file))
train(cfg, args.local_rank, args.distributed, args.new_iteration)
if __name__ == "__main__":
main()

3.模型测试

3.1预测函数

from framework.engine.predictor import Predictor
from PIL import Image,ImageDraw
import numpy as np
import matplotlib.pyplot as plt
def predict(img_path,model_path):
config_file = "./trained_model/model/fcos_resnet_101_fpn_2x.yaml"
cfg.merge_from_file(config_file)
cfg.defrost()
cfg.MODEL.WEIGHT = model_path
cfg.OUTPUT_DIR = None
cfg.freeze()
predictor = Predictor(cfg=cfg, min_image_size=800)
src_img = Image.open(img_path)
img = src_img.convert('RGB')
img = np.array(img)
img = img[:, :, ::-1]
predictions = predictor.compute_prediction(img)
top_predictions = predictor.select_top_predictions(predictions)
bboxes = top_predictions.bbox.int().numpy().tolist()
bboxes = [[x[1], x[0], x[3], x[2]] for x in bboxes]
scores = top_predictions.get_field("scores").numpy().tolist()
scores = [round(x, 4) for x in scores]
labels = top_predictions.get_field("labels").numpy().tolist()
labels = [predictor.CATEGORIES[x] for x in labels]
draw = ImageDraw.Draw(src_img)
for i,bbox in enumerate(bboxes):
draw.text((bbox[1],bbox[0]),labels[i] + ':'+str(scores[i]),fill=(255,0,0))
draw.rectangle([bbox[1],bbox[0],bbox[3],bbox[2]],fill=None,outline=(255,0,0))
return src_img

3.2开始预测

if __name__ == "__main__":
model_path = "./outputs/weights/fcos_resnet_101_fpn_2x/model_final.pth" # 训练得到的模型
image_path = "./trained_model/model/demo_image.jpg" # 预测的图像
img = predict(image_path,model_path)
plt.figure(figsize=(10,10)) #设置窗口大小
plt.imshow(img)
plt.show()
2021-06-09 15:33:15,362 framework.utils.checkpoint INFO: Loading checkpoint from ./outputs/weights/fcos_resnet_101_fpn_2x/model_final.pth

点击关注,第一时间了解华为云新鲜技术~

FCOS论文复现:通用物体检测算法的更多相关文章

  1. 物体检测算法 SSD 的训练和测试

    物体检测算法 SSD 的训练和测试 GitHub:https://github.com/stoneyang/caffe_ssd Paper: https://arxiv.org/abs/1512.02 ...

  2. 论文笔记:目标检测算法(R-CNN,Fast R-CNN,Faster R-CNN,FPN,YOLOv1-v3)

    R-CNN(Region-based CNN) motivation:之前的视觉任务大多数考虑使用SIFT和HOG特征,而近年来CNN和ImageNet的出现使得图像分类问题取得重大突破,那么这方面的 ...

  3. Detectron系统实现了最先进的物体检测算法https://github.com/facebookresearch/Detectron

    ,包括Mask R-CNN. 它是用Python编写的,支持Caffe2深度学习框架. 不久前,FAIR才开源了语音识别的工具wav2letter,戳这里看大数据文摘介绍<快讯 | Facebo ...

  4. 深度学习原理与框架-卷积网络细节-三代物体检测算法 1.R-CNN 2.Fast R-CNN 3.Faster R-CNN

    目标检测的选框操作:第一步:找出一些边缘信息,进行图像合并,获得少量的边框信息 1.R-CNN, 第一步:进行图像的选框,对于选出来的框,使用卷积计算其相似度,选择最相似ROI的选框,即最大值抑制RO ...

  5. 深度学习 目标检测算法 SSD 论文简介

    深度学习 目标检测算法 SSD 论文简介 一.论文简介: ECCV-2016 Paper:https://arxiv.org/pdf/1512.02325v5.pdf  Slides:http://w ...

  6. FCOS : 找到诀窍了,anchor-free的one-stage目标检测算法也可以很准 | ICCV 2019

    论文提出anchor-free和proposal-free的one-stage的目标检测算法FCOS,不再需要anchor相关的的超参数,在目前流行的逐像素(per-pixel)预测方法上进行目标检测 ...

  7. yolo类检测算法解析——yolo v3

    每当听到有人问“如何入门计算机视觉”这个问题时,其实我内心是拒绝的,为什么呢?因为我们说的计算机视觉的发展史可谓很长了,它的分支很多,而且理论那是错综复杂交相辉映,就好像数学一样,如何学习数学?这问题 ...

  8. 物体检测之FPN及Mask R-CNN

    对比目前科研届普遍喜欢把问题搞复杂,通过复杂的算法尽量把审稿人搞蒙从而提高论文的接受率的思想,无论是著名的残差网络还是这篇Mask R-CNN,大神的论文尽量遵循著名的奥卡姆剃刀原理:即在所有能解决问 ...

  9. 利用modelarts和物体检测方式识别验证码

    近来有朋友让老山帮忙识别验证码.在github上查看了下,目前开源社区中主要流行以下几种验证码识别方式: tesseract-ocr模块: 这是HP实验室开发由Google 维护的开源 OCR引擎,内 ...

  10. 转-------基于R-CNN的物体检测

    基于R-CNN的物体检测 原文地址:http://blog.csdn.net/hjimce/article/details/50187029 作者:hjimce 一.相关理论 本篇博文主要讲解2014 ...

随机推荐

  1. Windows 2012 R2 iSCSI server

     Windows 2012 R2可以充当一台简单的SAN,提供iSCSI方式的连接,供客户端使用.不确定是否有人会这么使用,但至少在做实验的时候我觉得挺方便的.不用再像以前专门安装windows ...

  2. Echarts中tooltip解决显示指定数据

    今天开发中遇到一个问题,echarts图表触摸x轴触发tooltip会将x轴上所有的数据展示出来,但是有些场合只需要展示某些数据就可以,并不需要全部展示,如下图: 这里警戒线因为需要开关,所以使用填充 ...

  3. 官方文档----ProxySQL 1.4.2 现在支持原生集群!!!

    官方文档地址:https://proxysql.com/blog/proxysql-cluster/ 前言 ProxySQL 是一个去中心化的代理,建议靠近应用部署.这种方法甚至可以很好地扩展到数百个 ...

  4. 添加 K8S CPU limit 会降低服务性能

    文章转载自:https://mp.weixin.qq.com/s/cR6MpQu-n1cwMbXmVaXqzQ

  5. 各编程语言 + aardio 相互调用示例

    代码简单.复制可用.aardio 快速调用 C,C++,C#,Java,R,V,Python,JavaScript,Node.js,Rust,PHP,Ruby,PowerShell,Fortran,D ...

  6. P7476 苦涩 题解

    Link 一道很好的复杂度均摊题目. 只需要考虑删除操作时的时间复杂度.保证复杂度的重点之一是精确定位到所有包含最大值的区间,即不去碰多余的区间.每次删除操作会删除若干个整个区间,以及至多两个区间被删 ...

  7. Spring的同一个服务为什么会加载多次?

    问题现象 最近在本地调试公司的一个Web项目时,无意中发现日志中出现了两次同一个服务的init记录,项目都是基于Spring来搭建的,按理说服务都是单例的,应该只有一次服务加载日志才对,本着对工作认真 ...

  8. Java I/O(1):模型与流

    在1990年以前,有一帮工程师们认为未来(1990年以后)会有很多小型设备需要得到电脑操控(不得不说,想法非常超前),鉴于当时市面上并没有任何一款编程语言能够跨平台,而且能够在诸如烤面包机这种小型设备 ...

  9. golang中的errgroup

    0.1.索引 https://waterflow.link/articles/1665239900004 1.串行执行 假如我们需要查询一个课件列表,其中有课件的信息,还有课件创建者的信息,和课件的缩 ...

  10. 三、Kubernetes调度

    一.Kubernetes调度 Scheduler 是 kubernetes 的调度器,主要的任务是把定义的 pod 分配到集群的节点上.听起来非常简单,但有很多要考虑的问题: 公平:如何保证每个节点都 ...