[Pytorch]Pytorch加载预训练模型(转)
转自:https://blog.csdn.net/Vivianyzw/article/details/81061765
东风的地方
1. 直接加载预训练模型
在训练的时候可能需要中断一下,然后继续训练,也就是简单的从保存的模型中加载参数权重:
-
net = SNet()
-
net.load_state_dict(torch.load("model_1599.pkl"))
这种方式是针对于之前保存模型时以保存参数的格式使用的:
torch.save(net.state_dict(), "model/model_1599.pkl")
pytorch官网更推荐上述模型保存方法,也据说这种方式比下一种更快一点。
下面介绍第二种模型保存和加载的方式:
-
net = SNet()
-
torch.save(net, "model_1599.pkl")
-
-
snet = torch.load("model_1599.pkl")
这种方式会将整个网络保存下来,数据量会更大,会消耗更多的时间,占用内存也更高。
2. 加载一部分预训练模型
模型可能是一些经典的模型改掉一部分,比如一般算法中提取特征的网络常见的会直接使用vgg16的features extraction部分,也就是在训练的时候可以直接加载已经在imagenet上训练好的预训练参数,这种方式实现如下:
-
net = SNet()
-
model_dict = net.state_dict()
-
-
vgg16 = models.vgg16(pretrained=True)
-
pretrained_dict = vgg16.state_dict()
-
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
-
-
model_dict.update(pretrained_dict)
-
net.load_state_dict(model_dict)
也就是在网络中state_dict部分,属于vgg16的,替换成vgg16预训练模型里的参数(代码里的k:v for k,v in pretrained_dict.items() if k in model_dict),其他保持不变。
3. 微调经典网络
因为pytorch中的torchvision给出了很多经典常用模型,并附加了预训练模型。利用好这些训练好的基础网络可以加快不少自己的训练速度。
首先比如加载vgg16(带有预训练参数的形式):
-
import torchvision.models as models
-
vgg16 = models.vgg16(pretrained=True)
比如,网络第一层本来是Conv2d(3, 64, 3, 1, 1),想修改成Conv2d(4, 64, 3, 1 ,1),那直接赋值就可以了:
-
import torch.nn as nn
-
vgg16.features[0]=nn.Conv2d(4, 64, 3, 1, 1)
4. 修改经典网络
这个比上面微调修改的地方要多一些,但是想介绍一下这样的修改方式。
先简单介绍一下我需要需改的部分,在vgg16的基础模型下,每一个卷积都要加一个dropout层,并将ReLU激活函数换成PReLU,最后两层的Pooling层stride改成1。直接上代码:
-
def feature_layer():
-
layers = []
-
pool1 = ['4', '9', '16']
-
pool2 = ['23', '30']
-
vgg16 = models.vgg16(pretrained=True).features
-
for name, layer in vgg16._modules.items():
-
if isinstance(layer, nn.Conv2d):
-
layers += [layer, nn.Dropout2d(0.5), nn.PReLU()]
-
elif name in pool1:
-
layers += [layer]
-
elif name == pool2[0]:
-
layers += [nn.MaxPool2d(2, 1, 1)]
-
elif name == pool2[1]:
-
layers += [nn.MaxPool2d(2, 1, 0)]
-
else:
-
continue
-
features = nn.Sequential(*layers)
-
#feat3 = features[0:24]
-
return features
大概的思路就是,创建一个新的网络(layers列表), 遍历vgg16里每一层,如果遇到卷积层(if isinstance(layer, nn.Conv2d)就先把该层(Conv2d)保持原样加进去,随后增加一个dropout层,再加一个PReLU层。然后如果遇到最后两层pool,就修改响应参数加进去,其他的pool正常加载。 最后将这个layers列表转成网络的nn.Sequential的形式,最后返回features。然后再你的新的网络层就可以用以下方式来加载:
-
class SNet(nn.Module):
-
def __init__(self):
-
super(SNet, self).__init__()
-
self.features = feature_layer()
-
def forward(self, x):
-
x = self.features(x)
-
return x
[Pytorch]Pytorch加载预训练模型(转)的更多相关文章
- pytorch中修改后的模型如何加载预训练模型
问题描述 简单来说,比如你要加载一个vgg16模型,但是你自己需要的网络结构并不是原本的vgg16网络,可能你删掉某些层,可能你改掉某些层,这时你去加载预训练模型,就会报错,错误原因就是你的模型和原本 ...
- 使用Huggingface在矩池云快速加载预训练模型和数据集
作为NLP领域的著名框架,Huggingface(HF)为社区提供了众多好用的预训练模型和数据集.本文介绍了如何在矩池云使用Huggingface快速加载预训练模型和数据集. 1.环境 HF支持Pyt ...
- pytorch加载预训练模型参数的方式
1.直接使用默认程序里的下载方式,往往比较慢: 2.通过修改源代码,使得模型加载已经下载好的参数,修改地方如下: 通过查找自己代码里所调用网络的类,使用pycharm自带的函数查找功能(ctrl+鼠标 ...
- Tensorflow加载预训练模型和保存模型(ckpt文件)以及迁移学习finetuning
转载自:https://blog.csdn.net/huachao1001/article/details/78501928 使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我 ...
- Tensorflow加载预训练模型和保存模型
转载自:https://blog.csdn.net/huachao1001/article/details/78501928 使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我 ...
- PyTorch模型加载与保存的最佳实践
一般来说PyTorch有两种保存和读取模型参数的方法.但这篇文章我记录了一种最佳实践,可以在加载模型时避免掉一些问题. 第一种方案是保存整个模型: 1 torch.save(model_object, ...
- PyTorch数据加载处理
PyTorch数据加载处理 PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 scikit-image:用于图像的IO和变换 pandas:用于更容易地进行csv解 ...
- 【小白学PyTorch】5 torchvision预训练模型与数据集全览
文章来自:微信公众号[机器学习炼丹术].一个ai专业研究生的个人学习分享公众号 文章目录: 目录 torchvision 1 torchvision.datssets 2 torchvision.mo ...
- pytorch数据加载器
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, ...
随机推荐
- WEB安全第三篇--控制请求的艺术:CSRF和SSRF
零.前言 最近做专心web安全有一段时间了,但是目测后面的活会有些复杂,涉及到更多的中间件.底层安全.漏洞研究与安全建设等越来越复杂的东东,所以在这里想写一个系列关于web安全基础以及一些讨巧的pay ...
- 【BZOJ3932】[CQOI2015]任务查询系统 主席树
[BZOJ3932][CQOI2015]任务查询系统 Description 最近实验室正在为其管理的超级计算机编制一套任务管理系统,而你被安排完成其中的查询部分.超级计算机中的 任务用三元组(Si, ...
- PMP 笔记
项目: 为创造独特的产品.服务或结果而进行的临时性工作. 项目特征: 独特性:Unique.临时性:Temporary.渐进明细. 渐进明细:预算越来越精细.比如三峡工程中,预算从10亿级的误差到1亿 ...
- mysql load data导入脚本
# !/bin/bash ############中文说明###################### #本程序的一些提示需要中文支持,如系统没有安装中文包,请先安装:yum -y groupinst ...
- 树链剖分+线段树+离线(广州网选赛第八题hdu5029)
http://acm.hdu.edu.cn/showproblem.php?pid=5029 Relief grain Time Limit: 10000/5000 MS (Java/Others) ...
- Yii2框架添加API Modules
原文链接:http://www.itnose.net/detail/6459353.html : 一.环境部署 1. read fucking Yii Documents. http://www.yi ...
- poj1584 A round peg in a ground hole【计算几何】
含[判断凸包],[判断点在多边形内],[判断圆在多边形内]模板 凸包:即凸多边形 用不严谨的话来讲,给定二维平面上的点集,凸包就是将最外层的点连接起来构成的凸多边形,它能包含点集中所有的点. The ...
- 系统中同时有 python2和 python3,怎么让 ipython 选择不同的版本启动?
已经安装的情况下: > which ipython /usr/local/bin/ipython > cat /usr/local/bin/ipython #!/usr/local/op ...
- PHP关于函数的参数问题
可能是自己以前写程序太规范了,今天发现个PHP函数参数个数的问题,定义的函数有三个参数,但是使用函数的时候竟然传了四个参数,更意外的是程序运行没有错误,甚至没有警告.于是依靠搜索引擎和PHP文档仔细查 ...
- stark - 分页、search、actions
一.分页 效果图 知识点 1.分页 {{ showlist.pagination.page_html|safe }} 2.page.py class Pagination(object): def _ ...