PyTorch保存模型与加载模型+Finetune预训练模型使用
参数初始化参
数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:
def weight_init(m):
# 使用isinstance来判断m属于什么类型
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
m.weight.data.fill_(1)
m.bias.data.zero_()
Finetune
往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。
局部微调:有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)
# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)
全局微调:有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:
ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
model.parameters())
optimizer = torch.optim.SGD([
{'params': base_params},
{'params': model.fc.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)
其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。
加载部分预训练模型:其实大多数时候我们需要根据我们的任务调节我们的模型,所以很难保证模型和公开的模型完全一样,但是预训练模型的参数确实有助于提高训练的准确率,为了结合二者的优点,就需要我们加载部分预训练模型。
pretrained_dict = model_zoo.load_url(model_urls['resnet152'])
model_dict = model.state_dict() # 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)
如果模型的key值和在大数据集上训练时的key值是一样的
我们可以通过下列算法进行读取模型
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path)
# 1. filter out unnecessary keys
diff = {k: v for k, v in model_dict.items() if \
k in pretrained_dict and pretrained_dict[k].size() == v.size()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
pretrained_dict.update(diff)
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是一样的
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path)
keys = []
for k,v in pretrained_dict.items():
keys.append(k)
i = 0
for k,v in model_dict.items():
if v.size() == pretrained_dict[keys[i]].size():
print(k, ',', keys[i])
model_dict[k]=pretrained_dict[keys[i]]
i = i + 1
model.load_state_dict(model_dict)
如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是也不一样的
自己找对应关系,一个key对应一个key的赋值
PyTorch保存模型与加载模型+Finetune预训练模型使用的更多相关文章
- [Pytorch]Pytorch 保存模型与加载模型(转)
转自:知乎 目录: 保存模型与加载模型 冻结一部分参数,训练另一部分参数 采用不同的学习率进行训练 1.保存模型与加载 简单的保存与加载方法: # 保存整个网络 torch.save(net, PAT ...
- 【4】TensorFlow光速入门-保存模型及加载模型并使用
本文地址:https://www.cnblogs.com/tujia/p/13862360.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...
- 如何使用 opencv 加载 darknet yolo 预训练模型?
如何使用 opencv 加载 darknet yolo 预训练模型? opencv 版本 > 3.4 以上 constexpr const char *image_path = "da ...
- 莫烦python教程学习笔记——保存模型、加载模型的两种方法
# View more python tutorials on my Youtube and Youku channel!!! # Youtube video tutorial: https://ww ...
- 深度学习原理与框架-猫狗图像识别-卷积神经网络(代码) 1.cv2.resize(图片压缩) 2..get_shape()[1:4].num_elements(获得最后三维度之和) 3.saver.save(训练参数的保存) 4.tf.train.import_meta_graph(加载模型结构) 5.saver.restore(训练参数载入)
1.cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) 参数说明:image表示输入图片,image_size表示变 ...
- keras模型的保存与重新加载
# 模型保存JSON文件 model_json = model.to_json() with open('model.json', 'w') as file: file.write(model_jso ...
- TensorFlow保存、加载模型参数 | 原理描述及踩坑经验总结
写在前面 我之前使用的LSTM计算单元是根据其前向传播的计算公式手动实现的,这两天想要和TensorFlow自带的tf.nn.rnn_cell.BasicLSTMCell()比较一下,看看哪个训练速度 ...
- MindSpore保存与加载模型
技术背景 近几年在机器学习和传统搜索算法的结合中,逐渐发展出了一种Search To Optimization的思维,旨在通过构造一个特定的机器学习模型,来替代传统算法中的搜索过程,进而加速经典图论等 ...
- NeHe OpenGL教程 第三十一课:加载模型
转自[翻译]NeHe OpenGL 教程 前言 声明,此 NeHe OpenGL教程系列文章由51博客yarin翻译(2010-08-19),本博客为转载并稍加整理与修改.对NeHe的OpenGL管线 ...
随机推荐
- mint-ui 关于有时候官网有时候打不开的问题
难道就我只有这种情况,不知所措的,要不好看看是什么,看到别人说,是你路由的问题,换个路由的可以了,我觉得不是,那试试HTTP协议的问题 换个https试试就可以 这是什么原因呢? http是超文本传输 ...
- 【LCA+MST】BZOJ3732-Network
[题目大意] 给你N个点的无向图 (1 <= N <= 15,000),记为:1…N.图中有M条边 (1<=M<=30,000) ,第j条边的长度:d_j (1<=d_j ...
- hdu 4445 37届金华赛区 D题
题意:给一个坦克的高度,求炮弹能打中最多的数量 枚举角度,作为一名学霸虽然很快推出了公式,但是却没有考虑到,角度可以朝下的情况 #include<cstdio> #include<i ...
- hadoop 基础视频1
hadoop 基础视频1 一, 大致内容: 1, 源起与体系结构2,实施Hadoop 集群3,分布式HDFS, 大数据存储实战4,Map-Reduce 体系架构5,Map-Reduce 数据分析之一 ...
- Azure虚机磁盘容量警报(邮件提醒)
上周有个客户提出这样的需求:根据虚拟机磁盘的实际使用量,当达到某一阈值时设置邮件提醒. 在这个需求中我们只需要解决两点问题: 计算虚拟机磁盘实际使用量 发送邮件 使用VS新建一个名为Calculate ...
- UNICODE 区域对照表
0000-007F Basic Latin 基本拉丁字母 0080-00FF Latin-1 Supplement 拉丁字母補充-1 0100-017F Latin Extended-A 拉丁字母擴充 ...
- android 控件: xml 设置 Button 按下背景
本篇文章讲述了不使用java代码来改变 Button 按下和未按下时的背景. 首先准备两张图片, 分别是按钮按下和按钮未按下的. 在res/drawable 文件夹中创建一个button_select ...
- Android adb logcat使用技巧
前言 新买的笔记本E431装了最新版的Eclipse,搞定了Android开发环境,可是logcat里查看东西居然仅仅显示level,没有错误的具体信息.我本身也不是一个愿意折腾图形界面,更喜欢纯命令 ...
- WebLogic使用总结(五)——Web项目使用Sigar在WebLogic服务器部署遇到的问题
今天在WebLogic 12c服务器上部署Web项目时,碰到了一个问题.项目中使用到了"Sigar.jar"监控Window平台下的cpu使用率.内存使用率和硬盘信息,sigar. ...
- Android 数据存储04之Content Provider
Content Provider 版本 修改内容 日期 修改人 V1.0 原始版本 2013/2/25 skywang 1 URI 通用资源标志符(Universal Resource Identif ...