[Pytorch]Pytorch加载预训练模型(转)
转自:https://blog.csdn.net/Vivianyzw/article/details/81061765
东风的地方
1. 直接加载预训练模型
在训练的时候可能需要中断一下,然后继续训练,也就是简单的从保存的模型中加载参数权重:
-
net = SNet()
-
net.load_state_dict(torch.load("model_1599.pkl"))
这种方式是针对于之前保存模型时以保存参数的格式使用的:
torch.save(net.state_dict(), "model/model_1599.pkl")
pytorch官网更推荐上述模型保存方法,也据说这种方式比下一种更快一点。
下面介绍第二种模型保存和加载的方式:
-
net = SNet()
-
torch.save(net, "model_1599.pkl")
-
-
snet = torch.load("model_1599.pkl")
这种方式会将整个网络保存下来,数据量会更大,会消耗更多的时间,占用内存也更高。
2. 加载一部分预训练模型
模型可能是一些经典的模型改掉一部分,比如一般算法中提取特征的网络常见的会直接使用vgg16的features extraction部分,也就是在训练的时候可以直接加载已经在imagenet上训练好的预训练参数,这种方式实现如下:
-
net = SNet()
-
model_dict = net.state_dict()
-
-
vgg16 = models.vgg16(pretrained=True)
-
pretrained_dict = vgg16.state_dict()
-
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
-
-
model_dict.update(pretrained_dict)
-
net.load_state_dict(model_dict)
也就是在网络中state_dict部分,属于vgg16的,替换成vgg16预训练模型里的参数(代码里的k:v for k,v in pretrained_dict.items() if k in model_dict),其他保持不变。
3. 微调经典网络
因为pytorch中的torchvision给出了很多经典常用模型,并附加了预训练模型。利用好这些训练好的基础网络可以加快不少自己的训练速度。
首先比如加载vgg16(带有预训练参数的形式):
-
import torchvision.models as models
-
vgg16 = models.vgg16(pretrained=True)
比如,网络第一层本来是Conv2d(3, 64, 3, 1, 1),想修改成Conv2d(4, 64, 3, 1 ,1),那直接赋值就可以了:
-
import torch.nn as nn
-
vgg16.features[0]=nn.Conv2d(4, 64, 3, 1, 1)
4. 修改经典网络
这个比上面微调修改的地方要多一些,但是想介绍一下这样的修改方式。
先简单介绍一下我需要需改的部分,在vgg16的基础模型下,每一个卷积都要加一个dropout层,并将ReLU激活函数换成PReLU,最后两层的Pooling层stride改成1。直接上代码:
-
def feature_layer():
-
layers = []
-
pool1 = ['4', '9', '16']
-
pool2 = ['23', '30']
-
vgg16 = models.vgg16(pretrained=True).features
-
for name, layer in vgg16._modules.items():
-
if isinstance(layer, nn.Conv2d):
-
layers += [layer, nn.Dropout2d(0.5), nn.PReLU()]
-
elif name in pool1:
-
layers += [layer]
-
elif name == pool2[0]:
-
layers += [nn.MaxPool2d(2, 1, 1)]
-
elif name == pool2[1]:
-
layers += [nn.MaxPool2d(2, 1, 0)]
-
else:
-
continue
-
features = nn.Sequential(*layers)
-
#feat3 = features[0:24]
-
return features
大概的思路就是,创建一个新的网络(layers列表), 遍历vgg16里每一层,如果遇到卷积层(if isinstance(layer, nn.Conv2d)就先把该层(Conv2d)保持原样加进去,随后增加一个dropout层,再加一个PReLU层。然后如果遇到最后两层pool,就修改响应参数加进去,其他的pool正常加载。 最后将这个layers列表转成网络的nn.Sequential的形式,最后返回features。然后再你的新的网络层就可以用以下方式来加载:
-
class SNet(nn.Module):
-
def __init__(self):
-
super(SNet, self).__init__()
-
self.features = feature_layer()
-
def forward(self, x):
-
x = self.features(x)
-
return x
[Pytorch]Pytorch加载预训练模型(转)的更多相关文章
- pytorch中修改后的模型如何加载预训练模型
问题描述 简单来说,比如你要加载一个vgg16模型,但是你自己需要的网络结构并不是原本的vgg16网络,可能你删掉某些层,可能你改掉某些层,这时你去加载预训练模型,就会报错,错误原因就是你的模型和原本 ...
- 使用Huggingface在矩池云快速加载预训练模型和数据集
作为NLP领域的著名框架,Huggingface(HF)为社区提供了众多好用的预训练模型和数据集.本文介绍了如何在矩池云使用Huggingface快速加载预训练模型和数据集. 1.环境 HF支持Pyt ...
- pytorch加载预训练模型参数的方式
1.直接使用默认程序里的下载方式,往往比较慢: 2.通过修改源代码,使得模型加载已经下载好的参数,修改地方如下: 通过查找自己代码里所调用网络的类,使用pycharm自带的函数查找功能(ctrl+鼠标 ...
- Tensorflow加载预训练模型和保存模型(ckpt文件)以及迁移学习finetuning
转载自:https://blog.csdn.net/huachao1001/article/details/78501928 使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我 ...
- Tensorflow加载预训练模型和保存模型
转载自:https://blog.csdn.net/huachao1001/article/details/78501928 使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我 ...
- PyTorch模型加载与保存的最佳实践
一般来说PyTorch有两种保存和读取模型参数的方法.但这篇文章我记录了一种最佳实践,可以在加载模型时避免掉一些问题. 第一种方案是保存整个模型: 1 torch.save(model_object, ...
- PyTorch数据加载处理
PyTorch数据加载处理 PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 scikit-image:用于图像的IO和变换 pandas:用于更容易地进行csv解 ...
- 【小白学PyTorch】5 torchvision预训练模型与数据集全览
文章来自:微信公众号[机器学习炼丹术].一个ai专业研究生的个人学习分享公众号 文章目录: 目录 torchvision 1 torchvision.datssets 2 torchvision.mo ...
- pytorch数据加载器
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, ...
随机推荐
- 自行颁发不受浏览器信任的SSL证书
ssh登陆到服务器上,终端输入以下命令,使用openssl生成RSA密钥及证书. # 生成一个RSA密钥 $ openssl genrsa -des3 -out 33iq.key 1024 # 拷贝一 ...
- oneThink的ArticleController控制,详看
本人新手小白,看下 onethink 的 ArticleController , 它里面写的方法,和一些自己以后改进的方向: <?php namespace Home\Controller; c ...
- jquery插件方式实现table查询功能
1.写插件部分,如下: ;(function($){ $.fn.plugin = function(options){ var defaults = { //各种属性,各种参数 } var optio ...
- Spring.Net的使用
1.Spring.Net的简单介绍 spring.net 框架是微软效仿java中的spring框架而推出的一种在.net中使用的框架,它使用配置的方式实现逻辑的解耦,它的主要功能集成在Spring. ...
- JS/Java/Python格式化金额
//java代码 public static void main(String[] args) { DecimalFormat myformat = new DecimalFormat(); ...
- CentOS7.2升级默认yum安装的php版本
CentOS7.2yum安装php默认版本为5.4,可以升级通过yum安装更高版本 设置yum源 rpm -Uvh https://mirror.webtatic.com/yum/el7/webtat ...
- Centos6.5安装JDK环境
1,系统版本查看 2,下载jdk1.8 wget http://download.oracle.com/otn-pub/java/jdk/8u144-b01/090f390dda5b47b9b721c ...
- Ora-1157 ora-1110错误解决案例一枚
1.数据库打开报错如下: SQL> alter database open; alter database open * ERROR at line 1: ORA-01157: cannot i ...
- Fast R-CNN论文详解 - CSDN博客
废话不多说,上车吧,少年 paper链接:Fast R-CNN &创新点 规避R-CNN中冗余的特征提取操作,只对整张图像全区域进行一次特征提取: 用RoI pooling层取代最后一层max ...
- Python开发【Django】:Admin配置管理
Admin创建登录用户 数据库结构表: from django.db import models # Create your models here. class UserProfile(models ...