转自:https://blog.csdn.net/Vivianyzw/article/details/81061765

东风的地方

1. 直接加载预训练模型

在训练的时候可能需要中断一下,然后继续训练,也就是简单的从保存的模型中加载参数权重:


  1. net = SNet()
  2. net.load_state_dict(torch.load("model_1599.pkl"))

这种方式是针对于之前保存模型时以保存参数的格式使用的:

torch.save(net.state_dict(), "model/model_1599.pkl")

pytorch官网更推荐上述模型保存方法,也据说这种方式比下一种更快一点。

下面介绍第二种模型保存和加载的方式:


  1. net = SNet()
  2. torch.save(net, "model_1599.pkl")
  3. snet = torch.load("model_1599.pkl")

这种方式会将整个网络保存下来,数据量会更大,会消耗更多的时间,占用内存也更高。

2. 加载一部分预训练模型

模型可能是一些经典的模型改掉一部分,比如一般算法中提取特征的网络常见的会直接使用vgg16的features extraction部分,也就是在训练的时候可以直接加载已经在imagenet上训练好的预训练参数,这种方式实现如下:


  1. net = SNet()
  2. model_dict = net.state_dict()
  3. vgg16 = models.vgg16(pretrained=True)
  4. pretrained_dict = vgg16.state_dict()
  5. pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
  6. model_dict.update(pretrained_dict)
  7. net.load_state_dict(model_dict)

也就是在网络中state_dict部分,属于vgg16的,替换成vgg16预训练模型里的参数(代码里的k:v for k,v in pretrained_dict.items() if k in model_dict),其他保持不变。

3. 微调经典网络

因为pytorch中的torchvision给出了很多经典常用模型,并附加了预训练模型。利用好这些训练好的基础网络可以加快不少自己的训练速度。

首先比如加载vgg16(带有预训练参数的形式):


  1. import torchvision.models as models
  2. vgg16 = models.vgg16(pretrained=True)

比如,网络第一层本来是Conv2d(3, 64, 3, 1, 1),想修改成Conv2d(4, 64, 3, 1 ,1),那直接赋值就可以了:


  1. import torch.nn as nn
  2. vgg16.features[0]=nn.Conv2d(4, 64, 3, 1, 1)

4. 修改经典网络

这个比上面微调修改的地方要多一些,但是想介绍一下这样的修改方式。

先简单介绍一下我需要需改的部分,在vgg16的基础模型下,每一个卷积都要加一个dropout层,并将ReLU激活函数换成PReLU,最后两层的Pooling层stride改成1。直接上代码:


  1. def feature_layer():
  2. layers = []
  3. pool1 = ['4', '9', '16']
  4. pool2 = ['23', '30']
  5. vgg16 = models.vgg16(pretrained=True).features
  6. for name, layer in vgg16._modules.items():
  7. if isinstance(layer, nn.Conv2d):
  8. layers += [layer, nn.Dropout2d(0.5), nn.PReLU()]
  9. elif name in pool1:
  10. layers += [layer]
  11. elif name == pool2[0]:
  12. layers += [nn.MaxPool2d(2, 1, 1)]
  13. elif name == pool2[1]:
  14. layers += [nn.MaxPool2d(2, 1, 0)]
  15. else:
  16. continue
  17. features = nn.Sequential(*layers)
  18. #feat3 = features[0:24]
  19. return features

大概的思路就是,创建一个新的网络(layers列表), 遍历vgg16里每一层,如果遇到卷积层(if isinstance(layer, nn.Conv2d)就先把该层(Conv2d)保持原样加进去,随后增加一个dropout层,再加一个PReLU层。然后如果遇到最后两层pool,就修改响应参数加进去,其他的pool正常加载。 最后将这个layers列表转成网络的nn.Sequential的形式,最后返回features。然后再你的新的网络层就可以用以下方式来加载:


  1. class SNet(nn.Module):
  2. def __init__(self):
  3. super(SNet, self).__init__()
  4. self.features = feature_layer()
  5. def forward(self, x):
  6. x = self.features(x)
  7. return x

[Pytorch]Pytorch加载预训练模型(转)的更多相关文章

  1. pytorch中修改后的模型如何加载预训练模型

    问题描述 简单来说,比如你要加载一个vgg16模型,但是你自己需要的网络结构并不是原本的vgg16网络,可能你删掉某些层,可能你改掉某些层,这时你去加载预训练模型,就会报错,错误原因就是你的模型和原本 ...

  2. 使用Huggingface在矩池云快速加载预训练模型和数据集

    作为NLP领域的著名框架,Huggingface(HF)为社区提供了众多好用的预训练模型和数据集.本文介绍了如何在矩池云使用Huggingface快速加载预训练模型和数据集. 1.环境 HF支持Pyt ...

  3. pytorch加载预训练模型参数的方式

    1.直接使用默认程序里的下载方式,往往比较慢: 2.通过修改源代码,使得模型加载已经下载好的参数,修改地方如下: 通过查找自己代码里所调用网络的类,使用pycharm自带的函数查找功能(ctrl+鼠标 ...

  4. Tensorflow加载预训练模型和保存模型(ckpt文件)以及迁移学习finetuning

    转载自:https://blog.csdn.net/huachao1001/article/details/78501928 使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我 ...

  5. Tensorflow加载预训练模型和保存模型

    转载自:https://blog.csdn.net/huachao1001/article/details/78501928 使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我 ...

  6. PyTorch模型加载与保存的最佳实践

    一般来说PyTorch有两种保存和读取模型参数的方法.但这篇文章我记录了一种最佳实践,可以在加载模型时避免掉一些问题. 第一种方案是保存整个模型: 1 torch.save(model_object, ...

  7. PyTorch数据加载处理

    PyTorch数据加载处理 PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 scikit-image:用于图像的IO和变换 pandas:用于更容易地进行csv解 ...

  8. 【小白学PyTorch】5 torchvision预训练模型与数据集全览

    文章来自:微信公众号[机器学习炼丹术].一个ai专业研究生的个人学习分享公众号 文章目录: 目录 torchvision 1 torchvision.datssets 2 torchvision.mo ...

  9. pytorch数据加载器

    class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, ...

随机推荐

  1. jquery类似方法的比较(二)

    (1)append()&appendTo()&prepend()$prependTo() (2)after()&before()&insertAfter()&i ...

  2. 12.文件系统fs

    文件系统fs ==> 提供文件的读取, 写入, 更名, 删除, 遍历目录, 链接等POSIX文件系统操作 1.fs.readFile(filename, [encoding], [callbac ...

  3. 【BZOJ4922】[Lydsy六月月赛]Karp-de-Chant Number 贪心+动态规划

    [BZOJ4922][Lydsy六月月赛]Karp-de-Chant Number Description 卡常数被称为计算机算法竞赛之中最神奇的一类数字,主要特点集中于令人捉摸不透,有时候会让水平很 ...

  4. 【BZOJ1040】[ZJOI2008]骑士 树形DP

    [BZOJ1040][ZJOI2008]骑士 Description Z国的骑士团是一个很有势力的组织,帮会中汇聚了来自各地的精英.他们劫富济贫,惩恶扬善,受到社会各界的赞扬.最近发生了一件可怕的事情 ...

  5. CentOS7使用yum安装nginx

    CentOS默认没有nginx的yum源需要yum安装nginx可以使用一下方法 一,环境检测 二,设置yum源 rpm -Uvh http://nginx.org/packages/centos/7 ...

  6. 向Docx4j生成的word文档中添加布局--第二部分

    原文标题:Adding layout to your Docx4j-generated word documents, part 2 原文链接:http://blog.iprofs.nl/2012/1 ...

  7. [SQL] 让特定的数据 排在最前

    MYSQL目前常用的两种方法,如下: 让值为"张三" 的数据排在最前. -- 方法一 end asc -- 方法二 select * from tableName where co ...

  8. (2.8)Mysql之SQL基础——索引的分类与使用

    (2.8)Mysql之SQL基础——索引的分类与使用 关键字:mysql索引,mysql增加索引,mysql修改索引,mysql删除索引 按逻辑分类: 1.主键索引(聚集索引)(也是唯一索引,不允许有 ...

  9. centos DNS服务搭建 DNS原理 使用bind搭建DNS服务器 配置DNS转发 配置主从 安装dig工具 DHCP dhclient 各种域名解析记录 mydns DNS动态更新 第三十节课

    centos  DNS服务搭建  DNS原理  使用bind搭建DNS服务器 配置DNS转发 配置主从  安装dig工具  DHCP  dhclient  各种域名解析记录  mydns DNS动态更 ...

  10. 205-react SyntheticEvent 事件

    参看地址:https://reactjs.org/docs/events.html