参数初始化参

数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:

def weight_init(m):
# 使用isinstance来判断m属于什么类型
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
m.weight.data.fill_(1)
m.bias.data.zero_()

Finetune

往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。
局部微调:有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)

# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

全局微调:有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:

ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
                     model.parameters())

optimizer = torch.optim.SGD([
            {'params': base_params},
            {'params': model.fc.parameters(), 'lr': 1e-3}
            ], lr=1e-2, momentum=0.9)
其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。

加载部分预训练模型:其实大多数时候我们需要根据我们的任务调节我们的模型,所以很难保证模型和公开的模型完全一样,但是预训练模型的参数确实有助于提高训练的准确率,为了结合二者的优点,就需要我们加载部分预训练模型。

pretrained_dict = model_zoo.load_url(model_urls['resnet152']) 
model_dict = model.state_dict() # 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是一样的

我们可以通过下列算法进行读取模型

model_dict = model.state_dict()

pretrained_dict = torch.load(model_path)
 # 1. filter out unnecessary keys
    diff = {k: v for k, v in model_dict.items() if \
            k in pretrained_dict and pretrained_dict[k].size() == v.size()}
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
    pretrained_dict.update(diff)
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是一样的

model_dict = model.state_dict()

pretrained_dict = torch.load(model_path)
    keys = []
    for k,v in pretrained_dict.items():
        keys.append(k)
    i = 0
    for k,v in model_dict.items():
        if v.size() == pretrained_dict[keys[i]].size():
            print(k, ',', keys[i])
            model_dict[k]=pretrained_dict[keys[i]]
        i = i + 1
    model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是也不一样的

自己找对应关系,一个key对应一个key的赋值

PyTorch保存模型与加载模型+Finetune预训练模型使用的更多相关文章

  1. [Pytorch]Pytorch 保存模型与加载模型(转)

    转自:知乎 目录: 保存模型与加载模型 冻结一部分参数,训练另一部分参数 采用不同的学习率进行训练 1.保存模型与加载 简单的保存与加载方法: # 保存整个网络 torch.save(net, PAT ...

  2. 【4】TensorFlow光速入门-保存模型及加载模型并使用

    本文地址:https://www.cnblogs.com/tujia/p/13862360.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  3. 如何使用 opencv 加载 darknet yolo 预训练模型?

    如何使用 opencv 加载 darknet yolo 预训练模型? opencv 版本 > 3.4 以上 constexpr const char *image_path = "da ...

  4. 莫烦python教程学习笔记——保存模型、加载模型的两种方法

    # View more python tutorials on my Youtube and Youku channel!!! # Youtube video tutorial: https://ww ...

  5. 深度学习原理与框架-猫狗图像识别-卷积神经网络(代码) 1.cv2.resize(图片压缩) 2..get_shape()[1:4].num_elements(获得最后三维度之和) 3.saver.save(训练参数的保存) 4.tf.train.import_meta_graph(加载模型结构) 5.saver.restore(训练参数载入)

    1.cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) 参数说明:image表示输入图片,image_size表示变 ...

  6. keras模型的保存与重新加载

    # 模型保存JSON文件 model_json = model.to_json() with open('model.json', 'w') as file: file.write(model_jso ...

  7. TensorFlow保存、加载模型参数 | 原理描述及踩坑经验总结

    写在前面 我之前使用的LSTM计算单元是根据其前向传播的计算公式手动实现的,这两天想要和TensorFlow自带的tf.nn.rnn_cell.BasicLSTMCell()比较一下,看看哪个训练速度 ...

  8. MindSpore保存与加载模型

    技术背景 近几年在机器学习和传统搜索算法的结合中,逐渐发展出了一种Search To Optimization的思维,旨在通过构造一个特定的机器学习模型,来替代传统算法中的搜索过程,进而加速经典图论等 ...

  9. NeHe OpenGL教程 第三十一课:加载模型

    转自[翻译]NeHe OpenGL 教程 前言 声明,此 NeHe OpenGL教程系列文章由51博客yarin翻译(2010-08-19),本博客为转载并稍加整理与修改.对NeHe的OpenGL管线 ...

随机推荐

  1. mint-ui 关于有时候官网有时候打不开的问题

    难道就我只有这种情况,不知所措的,要不好看看是什么,看到别人说,是你路由的问题,换个路由的可以了,我觉得不是,那试试HTTP协议的问题 换个https试试就可以 这是什么原因呢? http是超文本传输 ...

  2. 【LCA+MST】BZOJ3732-Network

    [题目大意] 给你N个点的无向图 (1 <= N <= 15,000),记为:1…N.图中有M条边 (1<=M<=30,000) ,第j条边的长度:d_j (1<=d_j ...

  3. hdu 4445 37届金华赛区 D题

    题意:给一个坦克的高度,求炮弹能打中最多的数量 枚举角度,作为一名学霸虽然很快推出了公式,但是却没有考虑到,角度可以朝下的情况 #include<cstdio> #include<i ...

  4. hadoop 基础视频1

    hadoop 基础视频1 一, 大致内容: 1, 源起与体系结构2,实施Hadoop 集群3,分布式HDFS, 大数据存储实战4,Map-Reduce 体系架构5,Map-Reduce 数据分析之一 ...

  5. Azure虚机磁盘容量警报(邮件提醒)

    上周有个客户提出这样的需求:根据虚拟机磁盘的实际使用量,当达到某一阈值时设置邮件提醒. 在这个需求中我们只需要解决两点问题: 计算虚拟机磁盘实际使用量 发送邮件 使用VS新建一个名为Calculate ...

  6. UNICODE 区域对照表

    0000-007F Basic Latin 基本拉丁字母 0080-00FF Latin-1 Supplement 拉丁字母補充-1 0100-017F Latin Extended-A 拉丁字母擴充 ...

  7. android 控件: xml 设置 Button 按下背景

    本篇文章讲述了不使用java代码来改变 Button 按下和未按下时的背景. 首先准备两张图片, 分别是按钮按下和按钮未按下的. 在res/drawable 文件夹中创建一个button_select ...

  8. Android adb logcat使用技巧

    前言 新买的笔记本E431装了最新版的Eclipse,搞定了Android开发环境,可是logcat里查看东西居然仅仅显示level,没有错误的具体信息.我本身也不是一个愿意折腾图形界面,更喜欢纯命令 ...

  9. WebLogic使用总结(五)——Web项目使用Sigar在WebLogic服务器部署遇到的问题

    今天在WebLogic 12c服务器上部署Web项目时,碰到了一个问题.项目中使用到了"Sigar.jar"监控Window平台下的cpu使用率.内存使用率和硬盘信息,sigar. ...

  10. Android 数据存储04之Content Provider

    Content Provider 版本 修改内容 日期 修改人 V1.0 原始版本 2013/2/25 skywang 1 URI 通用资源标志符(Universal Resource Identif ...