使用了一段时间PyTorch,感觉爱不释手(0-0),听说现在已经有C++接口。在应用过程中不可避免需要使用Finetune/参数初始化/模型加载等。

模型保存/加载

1.所有模型参数

训练过程中,有时候会由于各种原因停止训练,这时候我们训练过程中就需要注意将每一轮epoch的模型保存(一般保存最好模型与当前轮模型)。一般使用pytorch里面推荐的保存方法。该方法保存的是模型的参数。

#保存模型到checkpoint.pth.tar
torch.save(model.module.state_dict(), ‘checkpoint.pth.tar’)

对应的加载模型方法为(这种方法需要先反序列化模型获取参数字典,因此必须先load模型,再load_state_dict):

mymodel.load_state_dict(torch.load(‘checkpoint.pth.tar’))

有了上面的保存后,现以一个例子说明如何在inference AND/OR resume train使用。

#保存模型的状态,可以设置一些参数,后续可以使用
state = {'epoch': epoch + 1,#保存的当前轮数
'state_dict': mymodel.state_dict(),#训练好的参数
'optimizer': optimizer.state_dict(),#优化器参数,为了后续的resume
'best_pred': best_pred#当前最好的精度
,....,...} #保存模型到checkpoint.pth.tar
torch.save(state, ‘checkpoint.pth.tar’)
#如果是best,则复制过去
if is_best:
shutil.copyfile(filename, directory + 'model_best.pth.tar') checkpoint = torch.load('model_best.pth.tar')
model.load_state_dict(checkpoint['state_dict'])#模型参数
optimizer.load_state_dict(checkpoint['optimizer'])#优化参数
epoch = checkpoint['epoch']#epoch,可以用于更新学习率等 #有了以上的东西,就可以继续重新训练了,也就不需要担心停止程序重新训练。
train/eval
....
....

上面是pytorch建议使用的方法,当然还有第二种方法。这种方法灵活性不高,不推荐。

#保存
torch.save(mymodel,‘checkpoint.pth.tar’) #加载
mymodel = torch.load(‘checkpoint.pth.tar’)

2.部分模型参数

在很多时候,我们加载的是已经训练好的模型,而训练好的模型可能与我们定义的模型不完全一样,而我们只想使用一样的那些层的参数。

有几种解决方法:

(1)直接在训练好的模型开始搭建自己的模型,就是先加载训练好的模型,然后再它基础上定义自己的模型;

model_ft = models.resnet18(pretrained=use_pretrained)
self.conv1 = model_ft.conv1
self.bn = model_ft.bn
... ...

(2) 自己定义好模型,直接加载模型

#第一种方法:
mymodelB = TheModelBClass(*args, **kwargs)
# strict=False,设置为false,只保留键值相同的参数
mymodelB.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False) #第二种方法:
# 加载模型
model_pretrained = models.resnet18(pretrained=use_pretrained) # mymodel's state_dict,
# 如: conv1.weight
# conv1.bias
mymodelB_dict = mymodelB.state_dict() # 将model_pretrained的建与自定义模型的建进行比较,剔除不同的
pretrained_dict = {k: v for k, v in model_pretrained.items() if k in mymodelB_dict}
# 更新现有的model_dict
mymodelB_dict.update(pretrained_dict) # 加载我们真正需要的state_dict
mymodelB.load_state_dict(mymodelB_dict) # 方法2可能更直观一些

参数初始化

第二个问题是参数初始化问题,在很多代码里面都会使用到,毕竟不是所有的都是有预训练参数。这时就需要对不是与预训练参数进行初始化。pytorch里面的每个Tensor其实是对Variabl的封装,其包含data、grad等接口,因此可以用这些接口直接赋值。这里也提供了怎样把其他框架(caffe/tensorflow/mxnet/gluonCV等)训练好的模型参数直接赋值给pytorch.其实就是对data直接赋值。

pytorch提供了初始化参数的方法:

 def weight_init(m):
if isinstance(m,nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0,math.sqrt(2./n))
elif isinstance(m,nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()

但一般如果没有很大需求初始化参数,也没有问题(不确定性能是否有影响的情况下),pytorch内部是有默认初始化参数的。

Fintune

最后就是精调了,我们平时做实验,至少backbone是用预训练的模型,将其用作特征提取器,或者在它上面做精调。

用于特征提取的时候,要求特征提取部分参数不进行学习,而pytorch提供了requires_grad参数用于确定是否进去梯度计算,也即是否更新参数。以下以minist为例,用resnet18作特征提取:

#加载预训练模型
model = torchvision.models.resnet18(pretrained=True) #遍历每一个参数,将其设置为不更新参数,即不学习
for param in model.parameters():
param.requires_grad = False # 将全连接层改为mnist所需的10类,注意:这样更改后requires_grad默认为True
model.fc = nn.Linear(512, 10) # 优化
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

用于全局精调时,我们一般对不同的层需要设置不同的学习率,预训练的层学习率小一点,其他层大一点。这要怎么做呢?

# 加载预训练模型
model = torchvision.models.resnet18(pretrained=True)
model.fc = nn.Linear(512, 10) # 参考:https://blog.csdn.net/u012759136/article/details/65634477
ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters()) # 对不同参数设置不同的学习率
params_list = [{'params': base_params, 'lr': 0.001},]
params_list.append({'params': model.fc.parameters(), 'lr': 0.01}) optimizer = torch.optim.SGD(params_list,
0.001,
momentum=args.momentum,
weight_decay=args.weight_decay)

最后整理一下目前,pytorch预训练的基础模型:

(1)torchvision

torchvision里面已经提供了不同的预训练模型,一般也够用了。

pytorch/visiongithub.com

包含了alexnet/densenet各种版本(densenet121/densenet169/densenet201/densenet161)/inception_v3/resnet各种版本(resnet18', 'resnet34', 'resnet50', 'resnet101','resnet152')/SqueezeNet各种版本( 'squeezenet1_0', 'squeezenet1_1')/VGG各种版本( 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn','vgg19_bn', 'vgg19')

(2)其他预训练好的模型,如,SENet/NASNet等。

Cadene/pretrained-models.pytorchgithub.com

(3)gluonCV转pytorch的模型,包括,分类网络,分割网络等,这里的精度均比其他框架高几个百分点。

zhanghang1989/gluoncv-torchgithub.com

PyTorch模型读写、参数初始化、Finetune的更多相关文章

  1. [PyTorch]PyTorch中模型的参数初始化的几种方法(转)

    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 本文目录 1. xavier初始化 2. kaiming初始化 3. 实际使用中看到的初始化 3.1 ResNeXt,de ...

  2. Pytorch基础(6)----参数初始化

    一.使用Numpy初始化:[直接对Tensor操作] 对Sequential模型的参数进行修改: import numpy as np import torch from torch import n ...

  3. pytorch对模型参数初始化

    1.使用apply() 举例说明: Encoder :设计的编码其模型 weights_init(): 用来初始化模型 model.apply():实现初始化 # coding:utf- from t ...

  4. pytorch和tensorflow的爱恨情仇之参数初始化

    pytorch和tensorflow的爱恨情仇之基本数据类型 pytorch和tensorflow的爱恨情仇之张量 pytorch和tensorflow的爱恨情仇之定义可训练的参数 pytorch版本 ...

  5. PyTorch常用参数初始化方法详解

    1. 均匀分布 torch.nn.init.uniform_(tensor, a=0, b=1) 从均匀分布U(a, b)中采样,初始化张量. 参数: tensor - 需要填充的张量 a - 均匀分 ...

  6. 从零搭建Pytorch模型教程(四)编写训练过程--参数解析

    ​  前言 训练过程主要是指编写train.py文件,其中包括参数的解析.训练日志的配置.设置随机数种子.classdataset的初始化.网络的初始化.学习率的设置.损失函数的设置.优化方式的设置. ...

  7. 【转载】 pytorch自定义网络结构不进行参数初始化会怎样?

    原文地址: https://blog.csdn.net/u011668104/article/details/81670544 ------------------------------------ ...

  8. ubuntu之路——day15.1 只用python的numpy在底层检验参数初始化对模型的影响

    首先感谢这位博主整理的Andrew Ng的deeplearning.ai的相关作业:https://blog.csdn.net/u013733326/article/details/79827273 ...

  9. DEX-6-caffe模型转成pytorch模型办法

    在python2.7环境下 文件下载位置:https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/ 1.可视化模型文件prototxt 1)在线可视化 ...

随机推荐

  1. 演练:创建和使用自己的动态链接库 (C++)

    此分布演练演示如何使用 Visual Studio IDE 通过 Microsoft C++ (MSVC) 编写自己的动态链接库 (DLL). 然后,该演练演示如何从其他 C++ 应用中使用 DLL. ...

  2. js对url进行编码和解码

    编码 只有 0-9[a-Z] $ - _ . + ! * ' ( ) , 以及某些保留字,才能不经过编码直接用于 URL. 例如:搜索的中文关键字,复制网址之后再粘贴就会发现该URL已经被转码. 1. ...

  3. 刷题-力扣-518. 零钱兑换 II

    518. 零钱兑换 II 题目链接 来源:力扣(LeetCode) 链接:https://leetcode-cn.com/problems/coin-change-2/ 著作权归领扣网络所有.商业转载 ...

  4. Jmeter HTML 报告、Jenkins 配置

    目录 Jmeter 生成 HTML 测试报告 Jenkins 配置 Jmeter 生成 HTML 测试报告 JMeter 支持生成 HTML 测试报告, 以便从测试计划中获得图表和统计信息. 以上定义 ...

  5. reids在linux上的安装《四》

    linux 安装redis 完整步骤 红色字体在我的Centos上没有设置,因为我设置了密码 安装: 1.获取redis资源 wget http://download.redis.io/release ...

  6. promise加载图片

    实现一个图片的加载:设置第一张图片加载1s之后加载第二张图片: <!DOCTYPE html> <html> <head> <meta charset=&qu ...

  7. 网络协议之TCP和UDP

    TCP/IP协议: 传输控制协议/因特网互联协议( Transmission Control Protocol/Internet Protocol),是Internet最基本.最广泛的协议.它定义了计 ...

  8. 《通过刷leetcode学习Go语言》之(1):序言

    Author       : Email         : vip_13031075266@163.com Date          : 2021.03.07 Version     : 北京 C ...

  9. 远程桌面无法复制粘贴 rdpclip.exe

    在一些意外情况下,远程桌面无法与桌面共享复制内容,这时候需要杀掉一个进程并重新启动 远程桌面复制之后,无法在本地桌面粘贴   在远程桌面中右键点击,选择启动任务管理器   找到一个进行rdpclip. ...

  10. 使用私有gitlab发布自动生成版本号和标签(version和tag)(骚)

    设置 semantic ,自动生成版本号和标签 FROM node:14-buster-slim LABEL maintainer="wangyunpeng" COPY sourc ...