参数初始化参

数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:

def weight_init(m):
# 使用isinstance来判断m属于什么类型
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
m.weight.data.fill_(1)
m.bias.data.zero_()

Finetune

往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。
局部微调:有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)

# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

全局微调:有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:

ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
                     model.parameters())

optimizer = torch.optim.SGD([
            {'params': base_params},
            {'params': model.fc.parameters(), 'lr': 1e-3}
            ], lr=1e-2, momentum=0.9)
其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。

加载部分预训练模型:其实大多数时候我们需要根据我们的任务调节我们的模型,所以很难保证模型和公开的模型完全一样,但是预训练模型的参数确实有助于提高训练的准确率,为了结合二者的优点,就需要我们加载部分预训练模型。

pretrained_dict = model_zoo.load_url(model_urls['resnet152']) 
model_dict = model.state_dict() # 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是一样的

我们可以通过下列算法进行读取模型

model_dict = model.state_dict()

pretrained_dict = torch.load(model_path)
 # 1. filter out unnecessary keys
    diff = {k: v for k, v in model_dict.items() if \
            k in pretrained_dict and pretrained_dict[k].size() == v.size()}
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
    pretrained_dict.update(diff)
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是一样的

model_dict = model.state_dict()

pretrained_dict = torch.load(model_path)
    keys = []
    for k,v in pretrained_dict.items():
        keys.append(k)
    i = 0
    for k,v in model_dict.items():
        if v.size() == pretrained_dict[keys[i]].size():
            print(k, ',', keys[i])
            model_dict[k]=pretrained_dict[keys[i]]
        i = i + 1
    model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是也不一样的

自己找对应关系,一个key对应一个key的赋值

PyTorch保存模型与加载模型+Finetune预训练模型使用的更多相关文章

  1. [Pytorch]Pytorch 保存模型与加载模型(转)

    转自:知乎 目录: 保存模型与加载模型 冻结一部分参数,训练另一部分参数 采用不同的学习率进行训练 1.保存模型与加载 简单的保存与加载方法: # 保存整个网络 torch.save(net, PAT ...

  2. 【4】TensorFlow光速入门-保存模型及加载模型并使用

    本文地址:https://www.cnblogs.com/tujia/p/13862360.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  3. 如何使用 opencv 加载 darknet yolo 预训练模型?

    如何使用 opencv 加载 darknet yolo 预训练模型? opencv 版本 > 3.4 以上 constexpr const char *image_path = "da ...

  4. 莫烦python教程学习笔记——保存模型、加载模型的两种方法

    # View more python tutorials on my Youtube and Youku channel!!! # Youtube video tutorial: https://ww ...

  5. 深度学习原理与框架-猫狗图像识别-卷积神经网络(代码) 1.cv2.resize(图片压缩) 2..get_shape()[1:4].num_elements(获得最后三维度之和) 3.saver.save(训练参数的保存) 4.tf.train.import_meta_graph(加载模型结构) 5.saver.restore(训练参数载入)

    1.cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) 参数说明:image表示输入图片,image_size表示变 ...

  6. keras模型的保存与重新加载

    # 模型保存JSON文件 model_json = model.to_json() with open('model.json', 'w') as file: file.write(model_jso ...

  7. TensorFlow保存、加载模型参数 | 原理描述及踩坑经验总结

    写在前面 我之前使用的LSTM计算单元是根据其前向传播的计算公式手动实现的,这两天想要和TensorFlow自带的tf.nn.rnn_cell.BasicLSTMCell()比较一下,看看哪个训练速度 ...

  8. MindSpore保存与加载模型

    技术背景 近几年在机器学习和传统搜索算法的结合中,逐渐发展出了一种Search To Optimization的思维,旨在通过构造一个特定的机器学习模型,来替代传统算法中的搜索过程,进而加速经典图论等 ...

  9. NeHe OpenGL教程 第三十一课:加载模型

    转自[翻译]NeHe OpenGL 教程 前言 声明,此 NeHe OpenGL教程系列文章由51博客yarin翻译(2010-08-19),本博客为转载并稍加整理与修改.对NeHe的OpenGL管线 ...

随机推荐

  1. 堆排序之Java实现

    堆排序之Java实现 代码: package cn.com.zfc.lesson21.sort; /** * * @title HeapSort * @describe 堆排序 * @author 张 ...

  2. codevs 1464 装箱问题 2

    题目描述 Description 一个工厂制造的产品形状都是长方体,它们的高度都是h,长和宽都相等,一共有六个型号,他们的长宽分别为1*1, 2*2, 3*3, 4*4, 5*5, 6*6.这些产品通 ...

  3. Gunicorn部署部分的翻译

    部署Gunicorn 文档建议Gunicorn最好是用在代理服务器后面.(等于前面最好加一个反向代理) Nginx Configuration 文档建议用Nginx,当然用其他也可以,但是要确保当你用 ...

  4. 【原】配置MySQL服务器端的字符集

    [简述] 通过直接配置my.cnf方式修改mysql的字符集,这种方式并不复杂,但是,在linux端配置时,特别容易出错,因此,记录之,以待后用. [配置步骤描述]Step 1:关闭当前的MySQL服 ...

  5. [多问几个为什么]为什么匿名内部类中引用的局部变量和参数需要final而成员字段不用?(转)

    昨天有一个比较爱思考的同事和我提起一个问题:为什么匿名内部类使用的局部变量和参数需要final修饰,而外部类的成员变量则不用?对这个问题我一直作为默认的语法了,木有仔细想过为什么(在分析完后有点印象在 ...

  6. 您该选择PRINCE2 还是 PMP认证

    您该选择PRINCE2 还是 PMP认证? 很多人都问我,这是一个非常常见的问题,作为一个项目经理,他们是否应选择PRINCE2或PMP认证,因为这两个认证都是全世界非常流行的. PRINCE2 是一 ...

  7. .net core下的dotnet全局工具

    .net core 2.1后支持了一个全新的部署和扩展命令,可以自己注册全局命令行. dotnet install tool -g dotnetsaydotnetsay 也可以自己构建自己的命令行,一 ...

  8. 开源 java CMS - FreeCMS2.2 菜单管理

    项目地址:http://www.freeteam.cn/ 菜单管理 FreeCMS在设计时定位于面向二次开发友好,所以FreeCMS提供了菜单管理功能.二次开发者能够自由添加新的功能菜单到FreeCM ...

  9. Revit API遍历系统族布置喷头

    系统族可以通过内参遍历,遍历出来是个FamilySymbol喷头属于系统族,但不能通过NewDuct();类似这样的方法布置.必须使用 NewFamilyInstance() );           ...

  10. Redis源代码分析(三十三)--- redis-cli.cclient命令行接口的实现(2)

    今天学习完了命令行client的兴许内容,总体感觉就是环绕着2个东西转,config和mode.为什么我会这么说呢,请继续往下看,client中的配置结构体和之前我们所学习的配置结构体,不是指的同一个 ...