torchvision.models 里包含了许多模型,用于解决不同的视觉任务:图像分类、语义分割、物体检测、实例分割、人体关键点检测和视频分类。

本文将介绍 torchvision 中模型的入门使用,一起来创建 Faster R-CNN 预训练模型,预测图像中有什么物体吧。

import torch
import torchvision
from PIL import Image

创建预训练模型

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

print(model) 可查看其结构:

FasterRCNN(
(transform): GeneralizedRCNNTransform(
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
Resize(min_size=(800,), max_size=1333, mode='bilinear')
)
(backbone): BackboneWithFPN(
...
)
(rpn): RegionProposalNetwork(
(anchor_generator): AnchorGenerator()
(head): RPNHead(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(cls_logits): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))
(bbox_pred): Conv2d(256, 12, kernel_size=(1, 1), stride=(1, 1))
)
)
(roi_heads): RoIHeads(
(box_roi_pool): MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=(7, 7), sampling_ratio=2)
(box_head): TwoMLPHead(
(fc6): Linear(in_features=12544, out_features=1024, bias=True)
(fc7): Linear(in_features=1024, out_features=1024, bias=True)
)
(box_predictor): FastRCNNPredictor(
(cls_score): Linear(in_features=1024, out_features=91, bias=True)
(bbox_pred): Linear(in_features=1024, out_features=364, bias=True)
)
)
)

此预训练模型是于 COCO train2017 上训练的,可预测的分类有:

COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

指定 CPU or GPU

获取支持的 device

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

模型移到 device

model.to(device)

读取输入图像

img = Image.open('data/bicycle.jpg').convert("RGB")
img = torchvision.transforms.ToTensor()(img)

准备模型入参 images

images = [img.to(device)]

例图 data/bicycle.jpg

进行模型推断

模型切为 eval 模式:

# For inference
model.eval()

模型在推断时,只需要给到图像数据,不用标注数据。推断后,会返回每个图像的预测结果 List[Dict[Tensor]]Dict 包含字段有:

  • boxes (FloatTensor[N, 4]): 预测框 [x1, y1, x2, y2], x 范围 [0,W], y 范围 [0,H]
  • labels (Int64Tensor[N]): 预测类别
  • scores (Tensor[N]): 预测评分
predictions = model(images)
pred = predictions[0]
print(pred)

预测结果如下:

{'boxes': tensor([[750.7896,  56.2632, 948.7942, 473.7791],
[ 82.7364, 178.6174, 204.1523, 491.9059],
...
[174.9881, 235.7873, 351.1031, 417.4089],
[631.6036, 278.6971, 664.1542, 353.2548]], device='cuda:0',
grad_fn=<StackBackward>), 'labels': tensor([ 1, 1, 2, 1, 1, 1, 2, 2, 1, 77, 1, 1, 1, 2, 1, 1, 1, 1,
1, 1, 27, 1, 1, 44, 1, 1, 1, 1, 27, 1, 1, 32, 1, 44, 1, 1,
31, 2, 38, 2, 2, 1, 1, 31, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 2,
1, 1, 1, 1, 1, 1, 31, 2, 27, 1, 2, 1, 1, 31, 2, 77, 2, 1,
2, 2, 2, 44, 2, 31, 1, 1, 1, 1], device='cuda:0'), 'scores': tensor([0.9990, 0.9976, 0.9962, 0.9958, 0.9952, 0.9936, 0.9865, 0.9746, 0.9694,
0.9679, 0.9620, 0.9395, 0.8984, 0.8979, 0.8847, 0.8537, 0.8475, 0.7865,
0.7822, 0.6896, 0.6633, 0.6629, 0.6222, 0.6132, 0.6073, 0.5383, 0.5248,
0.4891, 0.4881, 0.4595, 0.4335, 0.4273, 0.4089, 0.4074, 0.3679, 0.3357,
0.3192, 0.3102, 0.2797, 0.2655, 0.2640, 0.2626, 0.2615, 0.2375, 0.2306,
0.2174, 0.2129, 0.1967, 0.1912, 0.1907, 0.1739, 0.1722, 0.1669, 0.1666,
0.1596, 0.1586, 0.1473, 0.1456, 0.1408, 0.1374, 0.1373, 0.1329, 0.1291,
0.1290, 0.1289, 0.1278, 0.1205, 0.1182, 0.1182, 0.1103, 0.1060, 0.1025,
0.1010, 0.0985, 0.0959, 0.0919, 0.0887, 0.0886, 0.0873, 0.0832, 0.0792,
0.0778, 0.0764, 0.0693, 0.0686, 0.0679, 0.0671, 0.0668, 0.0636, 0.0635,
0.0607, 0.0605, 0.0581, 0.0578, 0.0572, 0.0568, 0.0557, 0.0556, 0.0555,
0.0533], device='cuda:0', grad_fn=<IndexBackward>)}

绘制预测结果

获取 score >= 0.9 的预测结果:

scores = pred['scores']
mask = scores >= 0.9 boxes = pred['boxes'][mask]
labels = pred['labels'][mask]
scores = scores[mask]

引入 utils.plots.plot_image 绘制结果:

from utils.colors import golden
from utils.plots import plot_image lb_names = COCO_INSTANCE_CATEGORY_NAMES
lb_colors = golden(len(lb_names), fn=int, scale=0xff, shuffle=True)
lb_infos = [f'{s:.2f}' for s in scores]
plot_image(img, boxes, labels, lb_names, lb_colors, lb_infos,
save_name='result.png')

utils.plots.plot_image 函数实现可见后文源码,注意其要求 torchvision >= 0.9.0/nightly

源码

utils.colors.golden:

import colorsys
import random def golden(n, h=random.random(), s=0.5, v=0.95,
fn=None, scale=None, shuffle=False):
if n <= 0:
return [] coef = (1 + 5**0.5) / 2 colors = []
for _ in range(n):
h += coef
h = h - int(h)
color = colorsys.hsv_to_rgb(h, s, v)
if scale is not None:
color = tuple(scale*v for v in color)
if fn is not None:
color = tuple(fn(v) for v in color)
colors.append(color) if shuffle:
random.shuffle(colors)
return colors

utils.plots.plot_image:

from typing import Union, Optional, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from PIL import Image def plot_image(
image: Union[torch.Tensor, Image.Image, np.ndarray],
boxes: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
lb_names: Optional[List[str]] = None,
lb_colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
lb_infos: Optional[List[str]] = None,
save_name: Optional[str] = None,
show_name: Optional[str] = 'result',
) -> torch.Tensor:
"""
Draws bounding boxes on given image.
Args:
image (Image): `Tensor`, `PIL Image` or `numpy.ndarray`.
boxes (Optional[Tensor]): `FloatTensor[N, 4]`, the boxes in `[x1, y1, x2, y2]` format.
labels (Optional[Tensor]): `Int64Tensor[N]`, the class label index for each box.
lb_names (Optional[List[str]]): All class label names.
lb_colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of all class label names.
lb_infos (Optional[List[str]]): Infos for given labels.
save_name (Optional[str]): Save image name.
show_name (Optional[str]): Show window name.
"""
if not isinstance(image, torch.Tensor):
image = torchvision.transforms.ToTensor()(image) if boxes is not None:
if image.dtype != torch.uint8:
image = torchvision.transforms.ConvertImageDtype(torch.uint8)(image)
draw_labels = None
draw_colors = None
if labels is not None:
draw_labels = [lb_names[i] for i in labels] if lb_names is not None else None
draw_colors = [lb_colors[i] for i in labels] if lb_colors is not None else None
if draw_labels and lb_infos:
draw_labels = [f'{l} {i}' for l, i in zip(draw_labels, lb_infos)]
# torchvision >= 0.9.0/nightly
# https://github.com/pytorch/vision/blob/master/torchvision/utils.py
res = torchvision.utils.draw_bounding_boxes(image, boxes,
labels=draw_labels, colors=draw_colors)
else:
res = image if save_name or show_name:
res = res.permute(1, 2, 0).contiguous().numpy()
if save_name:
Image.fromarray(res).save(save_name)
if show_name:
plt.gcf().canvas.set_window_title(show_name)
plt.imshow(res)
plt.show() return res

参考

GoCoding 个人实践的经验分享,可关注公众号!

TorchVision 预训练模型进行推断的更多相关文章

  1. 【小白学PyTorch】5 torchvision预训练模型与数据集全览

    文章来自:微信公众号[机器学习炼丹术].一个ai专业研究生的个人学习分享公众号 文章目录: 目录 torchvision 1 torchvision.datssets 2 torchvision.mo ...

  2. 【转载】最强NLP预训练模型!谷歌BERT横扫11项NLP任务记录

    本文介绍了一种新的语言表征模型 BERT--来自 Transformer 的双向编码器表征.与最近的语言表征模型不同,BERT 旨在基于所有层的左.右语境来预训练深度双向表征.BERT 是首个在大批句 ...

  3. pytorch预训练模型的下载地址以及解决下载速度慢的方法

    https://github.com/pytorch/vision/tree/master/torchvision/models 几乎所有的常用预训练模型都在这里面 总结下各种模型的下载地址: 1 R ...

  4. PyTorch保存模型与加载模型+Finetune预训练模型使用

    Pytorch 保存模型与加载模型 PyTorch之保存加载模型 参数初始化参 数的初始化其实就是对参数赋值.而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了da ...

  5. [Pytorch]Pytorch加载预训练模型(转)

    转自:https://blog.csdn.net/Vivianyzw/article/details/81061765 东风的地方 1. 直接加载预训练模型 在训练的时候可能需要中断一下,然后继续训练 ...

  6. 【tf.keras】tf.keras加载AlexNet预训练模型

    目录 从 PyTorch 中导出模型参数 第 0 步:配置环境 第 1 步:安装 MMdnn 第 2 步:得到 PyTorch 保存完整结构和参数的模型(pth 文件) 第 3 步:导出 PyTorc ...

  7. BERT预训练模型的演进过程!(附代码)

    1. 什么是BERT BERT的全称是Bidirectional Encoder Representation from Transformers,是Google2018年提出的预训练模型,即双向Tr ...

  8. XLNet预训练模型,看这篇就够了!(代码实现)

    1. 什么是XLNet XLNet 是一个类似 BERT 的模型,而不是完全不同的模型.总之,XLNet是一种通用的自回归预训练方法.它是CMU和Google Brain团队在2019年6月份发布的模 ...

  9. NLP预训练模型-百度ERNIE2.0的效果到底有多好【附用户点评】

    ERNIE是百度自研的持续学习语义理解框架,该框架支持增量引入词汇(lexical).语法 (syntactic) .语义(semantic)等3个层次的自定义预训练任务,能够全面捕捉训练语料中的词法 ...

随机推荐

  1. AtCoder - agc043_a 和 POJ - 2336 dp

    题意: 给你一个n行m列由'#'和'.'构成的矩阵,你需要从(1,1)点走到(n,m)点,你每次只能向右或者向下走,且只能走'.'的位置. 你可以执行操作改变矩阵: 你可以选取两个点,r0,c0;r1 ...

  2. 2019牛客多校 Round2

    Solved:2 Rank:136 A Eddy Walker 题意:T个场景 每个场景是一个长度为n的环 从0开始 每次要么向前走要么向后走 求恰好第一次到m点且其他点都到过的概率 每次的答案是前缀 ...

  3. Codeforces 102394I Interesting Permutation 思维

    题意: 你有一个长度为n的序列a(这个序列只能使用[1,n]区间内的数字,每个数字只能使用一次),通过a序列可以构造出来三个相同长度的序列f.g.h For each 1≤i≤n, fi=max{a1 ...

  4. VJ train1 I-彼岸

    一道递推题(我这个菜鸡刚开始以为是排列组合) 题目: 突破蝙蝠的包围,yifenfei来到一处悬崖面前,悬崖彼岸就是前进的方向,好在现在的yifenfei已经学过御剑术,可御剑轻松飞过悬崖.现在的问题 ...

  5. 2018-2019 ACM-ICPC, Asia Dhaka Regional Contest C.Divisors of the Divisors of An Integer (数论)

    题意:求\(n!\)的每个因子的因子数. 题解:我们可以对\(n!\)进行质因数分解,这里可以直接用推论快速求出:https://5ab-juruo.blog.luogu.org/solution-p ...

  6. 使用开源量子编程框架ProjectQ打印编译后的量子线路与绘制线路图

    技术背景 在量子计算领域,基于量子芯片的算法设计(或简称为量子算法)是基于量子线路来设计的,类似于传统计算中使用的与门和非门之类的逻辑门.因此研究一个量子线路输入后的编译(可以简化为数量更少的量子门组 ...

  7. Python小练习批量爬取下载歌曲

    import requests import os headers={ 'Cookie': '_ga=GA1.2.701818100.1612092981; _gid=GA1.2.748589379. ...

  8. Linux内核4.19.1编译

    linux内核编译 1.1 大致步骤 下载linux内核4.19.1 官网链接: https://www.kernel.org/ 官网下载经常速度太慢,无法下载,提供另一个链接: http://ftp ...

  9. 基于OpenCV全景拼接(Python)SIFT/SURF

    一.实验内容: 利用sift算法,实现全景拼接算法,将给定的两幅图片拼接为一幅. 二.实验环境: 主机配置: CPU :intel core i5-7300 2.50GHZ RAM :8.0GB 运行 ...

  10. SQL优化汇总

    今天面某家公司,然后问我SQL优化,感觉有点忘了,今天特此总结一下: 总结得是分两方面:索引优化和查询优化: 一. 索引优化: 1. 独立的列 在进行查询时,索引列不能是表达式的一部分,也不能是函数的 ...