转自:知乎

目录:

  • 保存模型与加载模型
  • 冻结一部分参数,训练另一部分参数
  • 采用不同的学习率进行训练

1.保存模型与加载

简单的保存与加载方法:

# 保存整个网络
torch.save(net, PATH)
# 保存网络中的参数, 速度快,占空间少
torch.save(net.state_dict(),PATH)
#--------------------------------------------------
#针对上面一般的保存方法,加载的方法分别是:
model_dict=torch.load(PATH)
model_dict=model.load_state_dict(torch.load(PATH))

然而,在实验中往往需要保存更多的信息,比如优化器的参数,那么可以采取下面的方法保存:

torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN,
'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma},
checkpoint_path + '/m-' + launchTimestamp + '-' + str("%.4f" % lossMIN) + '.pth.tar')

以上包含的信息有,epochID, state_dict, min loss, optimizer, 自定义损失函数的两个参数;格式以字典的格式存储。

加载的方式:

def load_checkpoint(model, checkpoint_PATH, optimizer):
if checkpoint != None:
model_CKPT = torch.load(checkpoint_PATH)
model.load_state_dict(model_CKPT['state_dict'])
print('loading checkpoint!')
optimizer.load_state_dict(model_CKPT['optimizer'])
return model, optimizer

其他的参数可以通过以字典的方式获得

但是,但是,我们可能修改了一部分网络,比如加了一些,删除一些,等等,那么需要过滤这些参数,加载方式:

def load_checkpoint(model, checkpoint, optimizer, loadOptimizer):
if checkpoint != 'No':
print("loading checkpoint...")
model_dict = model.state_dict()
modelCheckpoint = torch.load(checkpoint)
pretrained_dict = modelCheckpoint['state_dict']
# 过滤操作
new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
model_dict.update(new_dict)
# 打印出来,更新了多少的参数
print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))
model.load_state_dict(model_dict)
print("loaded finished!")
# 如果不需要更新优化器那么设置为false
if loadOptimizer == True:
optimizer.load_state_dict(modelCheckpoint['optimizer'])
print('loaded! optimizer')
else:
print('not loaded optimizer')
else:
print('No checkpoint is included')
return model, optimizer

2.冻结部分参数,训练另一部分参数

1)添加下面一句话到模型中

for p in self.parameters():
p.requires_grad = False

比如加载了resnet预训练模型之后,在resenet的基础上连接了新的模快,resenet模块那部分可以先暂时冻结不更新,只更新其他部分的参数,那么可以在下面加入上面那句话

class RESNET_MF(nn.Module):
def __init__(self, model, pretrained):
super(RESNET_MF, self).__init__()
self.resnet = model(pretrained)
for p in self.parameters():
p.requires_grad = False
self.f = SpectralNorm(nn.Conv2d(2048, 512, 1))
self.g = SpectralNorm(nn.Conv2d(2048, 512, 1))
self.h = SpectralNorm(nn.Conv2d(2048, 2048, 1))
...

同时在优化器中添加:filter(lambda p: p.requires_grad, model.parameters())

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001, betas=(0.9, 0.999),
eps=1e-08, weight_decay=1e-5)

2) 参数保存在有序的字典中,那么可以通过查找参数的名字对应的id值,进行冻结

查找的代码:

    model_dict = torch.load('net.pth.tar').state_dict()
dict_name = list(model_dict)
for i, p in enumerate(dict_name):
print(i, p)

保存一下这个文件,可以看到大致是这个样子的:

0 gamma
1 resnet.conv1.weight
2 resnet.bn1.weight
3 resnet.bn1.bias
4 resnet.bn1.running_mean
5 resnet.bn1.running_var
6 resnet.layer1.0.conv1.weight
7 resnet.layer1.0.bn1.weight
8 resnet.layer1.0.bn1.bias
9 resnet.layer1.0.bn1.running_mean
....

同样在模型中添加这样的代码:

for i,p in enumerate(net.parameters()):
if i < 165:
p.requires_grad = False

在优化器中添加上面的那句话可以实现参数的屏蔽

[Pytorch]Pytorch 保存模型与加载模型(转)的更多相关文章

  1. PyTorch保存模型与加载模型+Finetune预训练模型使用

    Pytorch 保存模型与加载模型 PyTorch之保存加载模型 参数初始化参 数的初始化其实就是对参数赋值.而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了da ...

  2. 【4】TensorFlow光速入门-保存模型及加载模型并使用

    本文地址:https://www.cnblogs.com/tujia/p/13862360.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  3. 莫烦python教程学习笔记——保存模型、加载模型的两种方法

    # View more python tutorials on my Youtube and Youku channel!!! # Youtube video tutorial: https://ww ...

  4. 深度学习原理与框架-猫狗图像识别-卷积神经网络(代码) 1.cv2.resize(图片压缩) 2..get_shape()[1:4].num_elements(获得最后三维度之和) 3.saver.save(训练参数的保存) 4.tf.train.import_meta_graph(加载模型结构) 5.saver.restore(训练参数载入)

    1.cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) 参数说明:image表示输入图片,image_size表示变 ...

  5. keras模型的保存与重新加载

    # 模型保存JSON文件 model_json = model.to_json() with open('model.json', 'w') as file: file.write(model_jso ...

  6. TensorFlow保存、加载模型参数 | 原理描述及踩坑经验总结

    写在前面 我之前使用的LSTM计算单元是根据其前向传播的计算公式手动实现的,这两天想要和TensorFlow自带的tf.nn.rnn_cell.BasicLSTMCell()比较一下,看看哪个训练速度 ...

  7. MindSpore保存与加载模型

    技术背景 近几年在机器学习和传统搜索算法的结合中,逐渐发展出了一种Search To Optimization的思维,旨在通过构造一个特定的机器学习模型,来替代传统算法中的搜索过程,进而加速经典图论等 ...

  8. NeHe OpenGL教程 第三十一课:加载模型

    转自[翻译]NeHe OpenGL 教程 前言 声明,此 NeHe OpenGL教程系列文章由51博客yarin翻译(2010-08-19),本博客为转载并稍加整理与修改.对NeHe的OpenGL管线 ...

  9. 学习笔记TF049:TensorFlow 模型存储加载、队列线程、加载数据、自定义操作

    生成检查点文件(chekpoint file),扩展名.ckpt,tf.train.Saver对象调用Saver.save()生成.包含权重和其他程序定义变量,不包含图结构.另一程序使用,需要重新创建 ...

随机推荐

  1. DEV中gridview常用属性

    1.隐藏最上面的GroupPanel: gridView1.OptionsView.ShowGroupPanel=false; 2.得到当前选定记录某字段的值: sValue=Table.Rows[g ...

  2. ntpdate同步更新时间

    Linux服务器运行久时,系统时间就会存在一定的误差,一般情况下可以使用date命令进行时间设置,但在做数据库集群分片等操作时对多台机器的时间差是有要求的,此时就需要使用ntpdate进行时间同步 1 ...

  3. Spring整合JUnit4进行AOP单元测试的时候,报:"C:\Program Files\Java\jdk1.8.0_191\bin\java.exe" -ea -Didea.test.cyclic.buffer.size=1048576 "-javaagent:C:\Program Files\JetBrains\IntelliJ IDEA 2018.3\lib\idea_rt.jar=64

    错误代码 "C:\Program Files\Java\jdk1.8.0_191\bin\java.exe" -ea -Didea.test.cyclic.buffer.size= ...

  4. (3.10)常用知识-T-SQL优化

    关键字:SQL优化 总结: 1.书写问题 2.表连接方式 3.索引的抉择 4.执行计划之参数嗅探,使用提示强制执行计划 5.子查询与表连接的效率 6.临时表.CTE.表变量的选择 7.常用sp与sel ...

  5. lamp环境的搭建和安装

    最近,部门有些系统需要迁移到新的机器上,因此需要在新的机器上安装lamp和lnmp的环境,因此在这里总结一下: 一. 安装lamp环境的步骤:  (1).因为是新的机器,因此需要安装gcc的各种环境: ...

  6. struct初始化

    C语言中struct初始化 • 普通结构体的初始化 假设我们有如下的一段代码,其中已有Student结构体,要求实例化一个Student对象并将其初始化. #include <stdio.h&g ...

  7. Mock Server 之 moco-runner 使用指南一

    文章出处http://ju.outofmemory.cn/entry/96866 用以下命令可以启动moco-runner 服务 java -jar moco-runner-<version&g ...

  8. 无线路由和无线AP的区别

    一.答疑解惑 1.什么是无线AP? AP:Access Point 接入点 无线AP:无线接入点是一个无线网络的接入点,俗称“热点”.主要有路由交换接入一体设备和纯接入点设备,一体设备执行接入和路由工 ...

  9. let 与 var

    前言let与var最大的区别就是var会变量提升.var会被覆盖.var变量没有块级作用域,而let都将弥补这些bug.传统语言都不会有‘变量提升.重复声明被覆盖.变量没有块级作用’这些问题,这是js ...

  10. [acm/icpc2016ChinaFinal][CodeforcesGym101194] Mr. Panda and Fantastic Beasts

    地址:http://codeforces.com/gym/101194 题目:略 思路: 这题做法挺多的,可以sam也可以后缀数组,我用sam做的. 1.我自己yy的思路(瞎bb的) 把第一个串建立s ...