【小白学PyTorch】5 torchvision预训练模型与数据集全览
文章来自:微信公众号【机器学习炼丹术】。一个ai专业研究生的个人学习分享公众号
文章目录:
torchvision
官网上的介绍(翻墙):The torchvision package consists of popular datasets, model architectures, and common image transformations for computer vision.
翻译过来就是:
torchvision包由流行的数据集、模型体系结构和通用的计算机视觉图像转换组成。简单地说就是常用数据集+常见模型+常见图像增强方法
这个torchvision中主要有包组成:
torchvision.datasetstorchvision.modelstorchvision.transforms
1 torchvision.datssets
包含贼多的数据集,包含下面的:
官方说明了:All the datasets have almost similar API. They all have two common arguments: transform and target_transform to transform the input and target respectively.
翻译过来就是:每一个数据集的API都是基本相同的。他们都有两个相同的参数:transform和target_transform(后面细讲)
我们就用最经典最简单的MNIST手写数字数据集作为例子,先看这个的API:
包含5个参数:
- root:就是你想要保存MNIST数据集的位置,如果download是Flase的话,则会从目标位置读取数据集;
- download:True的话就会自动从网上下载这个数据集,到root的位置;
- train:True的话,数据集下载的是训练数据集;False的话则下载测试数据集(真方便,都不用自己划分了)
- transform:这个是对图像进行处理的transform,比方说旋转平移缩放,输入的是PIL格式的图像(不是tensor矩阵);
- target_transform:这个是对图像标签进行处理的函数(这个我没用过不太确定,也许是做标签平滑那种的处理?)
【下面用代码进一步理解】
import torchvision
mydataset = torchvision.datasets.MNIST(root='./',
train=True,
transform=None,
target_transform=None,
download=True)
运行结果如下,表示下载完毕(我不太确定这个下载数据集是否需要翻墙,我会把这次需要用的代码和数据集放到公众号,后台回复【torchvision】获取,下载出现问题请务必私戳我)
之后我们需要用到上一节课讲到的dataloader的内容:
from torch.utils.data import Dataset,DataLoader
myloader = DataLoader(dataset=mydataset,
batch_size=16)
for i,(data,label) in enumerate(myloader):
print(data.shape)
print(label.shape)
break
这时候会抛出一个错误:
大致看一看,就是pytorch的这个dataloader不是可以把数据集分成batch嘛,这个dataloder只能把tensor或者numpy这样的组合成batch,而现在的数据集的格式是PIL格式。这里验证了之前说到的,transform这个输入是PIL格式的图片,解决方法是:transform不能是None,我们需要将PIL转化成tensor才可以
所以我们把上面的transform稍作修改:
mydataset = torchvision.datasets.MNIST(root='./',
train=True,
transform=torchvision.transforms.ToTensor(),
target_transform=None,
 download=True)
重新运行的时候可以得到结果:
结果中,16表示一个batch有16个样本,1表示这是单通道的灰度图片,28表示MNIST数据集图片是\(28\times 28\)的大小,然后每一个图片有一个label。
想要获取其他的数据集也是一样的,不过这里就用MNIST作为举例,其他的相同。
2 torchvision.models
预训练模型中torchvision提供了很多种,大体分成下面四类:
分别是分类模型,语义模型,目标检测模型和视频分类模型。这里呢因为分类模型比较常见也比较基础,就主要介绍这个好啦。
在torch1.6.0版本中(应该是比较近的版本),主要包含下面的预训练模型:
构建模型可以通过下面的代码:
import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet161()
inception = models.inception_v3()
googlenet = models.googlenet()
shufflenet = models.shufflenet_v2_x1_0()
mobilenet = models.mobilenet_v2()
resnext50_32x4d = models.resnext50_32x4d()
wide_resnet50_2 = models.wide_resnet50_2()
mnasnet = models.mnasnet1_0()
这样构建的模型的权重值是随机的,只有结构是保存的。想要获取预训练的模型,则需要设置参数pretrained:
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet = models.mobilenet_v2(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
mnasnet = models.mnasnet1_0(pretrained=True)
我看官网的英文讲解,提到了一点:似乎这些模型的预训练数据集都是ImageNet的那个数据集,输入图片都是3通道的,并且要求输入图片的宽高不小于224像素,并且要求输入图片像素值的范围在0到1之间,然后做一个normalization标准化。
不知道各位在看一些案例的时候,有没有看到这个标准化:mean = [0.485, 0.456, 0.406] 和 std = [0.229, 0.224, 0.225],这个应该是ImageNet的图片的标准化的参数。
这些预训练的模型参数不确定能不能直接下载,我也就把这些模型存起来一并放在了公众号的后台,依然是回复【torchvision】获取。
得到了.pth文件之后使用torch.load来加载即可。
# torch.save(model, 'model.pth')
model = torch.load('model.pth')
模型比较
最后呢,torchvision官方提供了一个不同模型在Imagenet 1-crop 的一个错误率的比较。可以一起来看看到底哪个模型比较好使。这里我放了一些常见的模型。。像是Wide ResNet这种变种我就不放了。
| 网络 | Top-1 error | Top-5 error |
|---|---|---|
| AlexNet | 43.45 | 20.91 |
| VGG-11 | 30.98 | 11.37 |
| VGG-13 | 30.07 | 10.75 |
| VGG-16 | 28.41 | 9.62 |
| VGG-19 | 27.62 | 9.12 |
| VGG-13 with BN | 28.45 | 9.63 |
| VGG-19 with BN | 25.76 | 8.15 |
| Resnet-18 | 30.24 | 10.92 |
| Resnet-34 | 26.70 | 8.58 |
| Resnet-50 | 23.85 | 7.13 |
| Resnet-101 | 22.63 | 6.44 |
| Resnet-152 | 21.69 | 5.94 |
| SqueezeNet 1.1 | 41.81 | 19.38 |
| Densenet-161 | 22.35 | 6.2 |
整体来看,还是Resnet残差网络效果好。不过EfficientNet效果更好,不过Torchvision中没有预训练,在之后会讲解EfficientNet的预训练模型的代码方便使用(先挖坑)。
【小白学PyTorch】5 torchvision预训练模型与数据集全览的更多相关文章
- 【小白学PyTorch】20 TF2的eager模式与求导
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...
- 使用Huggingface在矩池云快速加载预训练模型和数据集
作为NLP领域的著名框架,Huggingface(HF)为社区提供了众多好用的预训练模型和数据集.本文介绍了如何在矩池云使用Huggingface快速加载预训练模型和数据集. 1.环境 HF支持Pyt ...
- 【小白学PyTorch】7 最新版本torchvision.transforms常用API翻译与讲解
文章来自:微信公众号[机器学习炼丹术].欢迎关注支持原创 也欢迎添加作者微信:cyx645016617. 参考目录: 目录 1 基本函数 1.1 Compose 1.2 RandomChoice 1. ...
- [Pytorch]Pytorch加载预训练模型(转)
转自:https://blog.csdn.net/Vivianyzw/article/details/81061765 东风的地方 1. 直接加载预训练模型 在训练的时候可能需要中断一下,然后继续训练 ...
- 【小白学PyTorch】11 MobileNet详解及PyTorch实现
文章来自微信公众号[机器学习炼丹术].我是炼丹兄,欢迎加我微信好友交流学习:cyx645016617. @ 目录 1 背景 2 深度可分离卷积 2.2 一般卷积计算量 2.2 深度可分离卷积计算量 2 ...
- 【小白学PyTorch】13 EfficientNet详解及PyTorch实现
参考目录: 目录 1 EfficientNet 1.1 概述 1.2 把扩展问题用数学来描述 1.3 实验内容 1.4 compound scaling method 1.5 EfficientNet ...
- TorchVision 预训练模型进行推断
torchvision.models 里包含了许多模型,用于解决不同的视觉任务:图像分类.语义分割.物体检测.实例分割.人体关键点检测和视频分类. 本文将介绍 torchvision 中模型的入门使用 ...
- 【小白学PyTorch】10 pytorch常见运算详解
参考目录: 目录 1 矩阵与标量 2 哈达玛积 3 矩阵乘法 4 幂与开方 5 对数运算 6 近似值运算 7 剪裁运算 这一课主要是讲解PyTorch中的一些运算,加减乘除这些,当然还有矩阵的乘法这些 ...
- 【小白学PyTorch】15 TF2实现一个简单的服装分类任务
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...
随机推荐
- Pr剪辑
目录 Pr剪辑教程 入门基础 创建序列类别 处理非正常序列 导出文件 导出设置 导入各类别素材 简单使用: 剪辑素材常用方法 剃刀工具 选择工具 波纹编辑工具 打入点和出点 剪辑速度 整个素材视频速度 ...
- js的栈内存和堆内存
栈内存和堆内存在了解一门语言底层数据结构上,挺重要的,做了个总结 JS中的栈内存堆内存 JS的内存空间分为栈(stack).堆(heap).池(一般也会归类为栈中). 其中栈存放变量,堆存放复杂对象, ...
- 小波变换检测信号突变点的MATLAB实现
之前在不经意间也有接触过求突变点的问题.在我看来,与其说是求突变点,不如说是我们常常玩的"找不同".给你两幅图像,让你找出两个图像中不同的地方,我认为这其实也是找突变点在生活中的应 ...
- Netty(一):server启动流程解析
netty作为一个被广泛应用的通信框架,有必要我们多了解一点. 实际上netty的几个重要的技术亮点: 1. reactor的线程模型; 2. 安全有效的nio非阻塞io模型应用; 3. pipeli ...
- Redis实现商品热卖榜
Redis系列 redis相关介绍 redis是一个key-value存储系统.和Memcached类似,它支持存储的value类型相对更多,包括string(字符串).list(链表).set(集合 ...
- QueryRunner使用总结
使用JDBC技术是一件繁琐的事情,为了使数据库更加高效,有一种简化jdbc技术的操作--DBUtils.DbUtils(org.apache.commons.dbutils.DbUtils)是Apac ...
- Android 给服务器发送网络请求
今天听得有点蒙,因为服务器的问题,这边建立服务器的话,学长用的是Idea建立的Spring之类的方法去搞服务器. 然后就是用Android去给这个服务器发送请求,大致效果还是懂的,就是像网站发送请求, ...
- 高效c/c++日志工具zlog使用介绍
1. zlog简介 zlog的资料网上很多,这里不在详细说明:zlog是用c写的一个日志工具,非常小,而且高效,可以同时向控制台和文件中输出,日志接口与printf使用基本一样,所以使用起来很简单. ...
- Java集合最全解析,学集合,看这篇就够用了!!!
在看集合类之前, 我们要先明白一下概念: 1.数据结构 (1):线性表 [1]:顺序存储结构(也叫顺序表) 一个线性表是n个具有相同特性的数据元素的有限序列.数据元素是一个抽象的符号,其具体含义在不同 ...
- 发布新版首页“外婆新家”升级版:全新的UI,熟悉的味道
在7月30日我们我们忐忑不安地发布了新版网站首页,发布后迎接我们的不是新颜新风貌的惊喜,而是我们最担心的残酷现实——“让我们等这么多年,等来的就是这个新的丑容颜”,在大家的批评声中我们深深地认识到我们 ...