1. 保存整个网络

torch.save(net, PATH)
model = torch.load(PATH)

2. 保存网络中的参数(速度快,占空间小)

torch.save(net.state_dict(),PATH)
model_dict = model.load_state_dict(torch.load(PATH))

model.state_dict函数会以有序字典OrderedDict形式返回模型训练过程中学习的权重weight和偏置bias参数,只有带有可学习参数的层(卷积层、全连接层等),以及注册的缓存(batchnorm的运行平均值)在state_dict 中才有记录。以下面的LeNet为例:

import torch.nn as nn
import torch.nn.functional as F class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(32 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) def forward(self, x):
x = F.relu(self.conv1(x)) # input(3, 32, 32) output(16, 28, 28)
x = self.pool1(x) # output(16, 14, 14)
x = F.relu(self.conv2(x)) # output(32, 10, 10)
x = self.pool2(x) # output(32, 5, 5)
x = x.view(-1, 32 * 5 * 5) # output(32*5*5)
x = F.relu(self.fc1(x)) # output(120)
x = F.relu(self.fc2(x)) # output(84)
x = self.fc3(x) # output(10)
return x net = LeNet()
# 打印可学习层的参数
print(net.state_dict().keys())

上面的模型中,只有卷积层和全连接层具有可学习参数,所以net.state_dict()只会保存这两层的参数,而激活函数层的参数则不会保存。层的名字是上面实例化时确定的,如果是利用nn.Sequential定义多个层时,用层的位置索引表示每个层,如下所示:

示例:用nn.Sequential搭建模型时的state_dict

import torch.nn as nn
import torch.nn.functional as F class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.feature = nn.Sequential(
nn.Conv2d(3, 16, 5),
nn.MaxPool2d(2, 2),
nn.Conv2d(16, 32, 5),
nn.MaxPool2d(2, 2)) self.fc1 = nn.Linear(32 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) def forward(self, x):
x = self.feature(x) # input(3, 32, 32)
x = x.view(-1, 32 * 5 * 5) # output(32*5*5)
x = F.relu(self.fc1(x)) # output(120)
x = F.relu(self.fc2(x)) # output(84)
x = self.fc3(x) # output(10)
return x net = LeNet()
# 打印可学习层的参数
print(net.state_dict().keys()) 

★模型加载

  • 当我们对网络模型结构进行优化改进时,如果改进的部分不包含可学习的层,那么可以直接加载预训练权重。如:如果我们对上述lenet模型进行改进,将激活函数层改为nn.Hardswish(),因为不包含可学习的参数,所以改进的模型的state_dict()没有改变,仍然可以直接加载lenet模型的权重文件。
  • 当我们改进的部分改变了可学习的参数时,如果直接加载预训练权重就会发生不匹配的错误,比如:卷积的维度改变后会报错 size mismatch for conv.weight...(2)新增一些层后会出现 Unexpected key(s) in state_dict等

解决方案:遍历预训练文件的每一层参数,将能够匹配成功的参数提取出来,再进行加载。

import torch
import torch.nn as nn
import torch.nn.functional as F class LeNet_new(nn.Module):
def __init__(self):
super(LeNet_new, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 5)
self.pool2 = nn.MaxPool2d(2, 2) def forward(self, x):
x = F.hardswish(self.conv1(x)) # input(3, 32, 32) output(16, 28, 28)
x = self.pool1(x) # output(16, 14, 14)
x = F.hardswish(self.conv2(x)) # output(32, 10, 10)
x = self.pool2(x) # output(32, 5, 5)
return x def intersect_dicts(da, db):
return {k: v for k, v in da.items() if k in db and v.shape == db[k].shape} net = LeNet_new()
state_dict = torch.load("Lenet.pth") # 加载预训练权重
print(state_dict.keys())
state_dict = intersect_dicts(state_dict, net.state_dict()) # 筛选权重参数
print(state_dict.keys())
net.load_state_dict(state_dict, strict=False) # 模型加载预训练权重中可用的权重

3. 保存网络参数,同时保存优化器参数、损失值等(方便追加训练)

如果还想保存某一次训练采用的优化器、epochs等信息,可将这些信息组合起来构成一个字典,然后将字典保存起来

# 保存
save_file = {"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
"args": args}
torch.save(save_file, "save_weights/model_{}.pth".format(epoch)) # 加载
checkpoint = torch.load(path, map_location='cpu')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1

4. 冻结训练

在加载预训练权重后,可能需要固定一部分模型的参数,只更新另一部分参数。有两种思路实现这个目标,一个是设置不要更新参数的网络层为requires_grad = False,另一个就是在定义优化器时只传入要更新的参数。最优写法时:将不更新的参数的requires_grad设置为False,同时不将该参数传入optimizer

示例:LeNet网络+MNIST手写识别+预训练模型加载+冻结训练

import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_data = datasets.MNIST(root='../dataset', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_data = datasets.MNIST(root='../dataset', train=False, transform=transform, download=True)
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False) class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.feature = nn.Sequential(
nn.Conv2d(1, 16, 5),
nn.MaxPool2d(2, 2),
nn.Conv2d(16, 32, 5),
nn.MaxPool2d(2, 2))
self.fc1 = nn.Linear(32 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) def forward(self, x):
x = self.feature(x)
x = x.view(-1, 32 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x def train(epoch):
loss_runtime = 0.0
for batch, data in enumerate(tqdm(train_loader, 0)):
x, y = data
x = x.to(device)
y = y.to(device)
y_pred = model(x)
loss = criterion(y_pred, y)
loss_runtime += loss.item()
loss_runtime /= x.size(0)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("after %s epochs, loss is %.8f" % (epoch + 1, loss_runtime))
save_file = {"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch}
torch.save(save_file, "model_{}.pth".format(epoch)) def test():
correct, total = 0, 0
with torch.no_grad():
for (x, y) in test_loader:
x = x.to(device)
y = y.to(device)
y_pred = model(x)
_, prediction = torch.max(y_pred.data, dim=1)
correct += (prediction == y).sum().item()
total += y.size(0)
acc = correct / total
print("accuracy on test set is :%5f" % acc) if __name__ == '__main__':
start_epoch = 0
freeze_epoch = 0
resume = "model_5.pth"
freeze = True model = LeNet()
device = ("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) # 加载预训练权重
if resume:
checkpoint = torch.load(resume, map_location='cpu')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch'] # 冻结训练
if freeze:
freeze_epoch = 5
print("冻结前置特征提取网络权重,训练后面的全连接层")
for param in model.feature.parameters():
param.requires_grad = False # 将不更新的参数的requires_grad设置为False,节省了计算这部分参数梯度的时间
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01, momentum=0.5)
for epoch in range(start_epoch, start_epoch + freeze_epoch):
train(epoch)
test()
print("解冻前置特征提取网络权重,接着训练整个网络权重")
for param in model.feature.parameters():
param.requires_grad = True
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01, momentum=0.5) for epoch in range(start_epoch + freeze_epoch, 100):
train(epoch)
test()

  

参考:

1.加载预训练权重

模型权重保存、加载、冻结(pytorch)的更多相关文章

  1. 全面解析Pytorch框架下模型存储,加载以及冻结

    最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...

  2. 超详细的Tensorflow模型的保存和加载(理论与实战详解)

    1.Tensorflow的模型到底是什么样的? Tensorflow模型主要包含网络的设计(图)和训练好的各参数的值等.所以,Tensorflow模型有两个主要的文件: a) Meta graph: ...

  3. tensorflow模型持久化保存和加载

    模型文件的保存 tensorflow将模型保持到本地会生成4个文件: meta文件:保存了网络的图结构,包含变量.op.集合等信息 ckpt文件: 二进制文件,保存了网络中所有权重.偏置等变量数值,分 ...

  4. tensorflow模型持久化保存和加载--深度学习-神经网络

    模型文件的保存 tensorflow将模型保持到本地会生成4个文件: meta文件:保存了网络的图结构,包含变量.op.集合等信息 ckpt文件: 二进制文件,保存了网络中所有权重.偏置等变量数值,分 ...

  5. docker 保存 加载(导入 导出镜像

    tensorflow 的docker镜像很大,pull一次由于墙经常失败.其实docker 可以将镜像导出再导入. 保存加载(tensorflow)镜像 1) 查看镜像 docker images 如 ...

  6. pytorch GPU训练好的模型使用CPU加载

    torch.load('tensors.pt') # 把所有的张量加载到CPU中 torch.load('tensors.pt', map_location=lambda storage, loc: ...

  7. 三、TensorFlow模型的保存和加载

    1.模型的保存: import tensorflow as tf v1 = tf.Variable(1.0,dtype=tf.float32) v2 = tf.Variable(2.0,dtype=t ...

  8. 基于FBX SDK的FBX模型解析与加载 -(三)

    http://blog.csdn.net/bugrunner/article/details/7229416 6. 加载Camera和Light 在FBX模型中除了几何数据外较为常用的信息可能就是Ca ...

  9. 基于FBX SDK的FBX模型解析与加载 -(二)

    http://blog.csdn.net/bugrunner/article/details/7211515 5. 加载材质 Material是一个模型渲染时必不可少的部分,当然,这些信息也被存到了F ...

  10. Python之模型的保存和加载-5.3

    一.模型的保存,主要是我们在训练完成的时候把训练下来的数据保存下来,这个也就是我们后续需要使用的模型算法.模型的加载,在保存好的模型上面我们通过原生保存好的模型,去计算新的数据,这样不用每次都要去训练 ...

随机推荐

  1. [C++基础入门] 7、 指针

    文章目录 7 指针 7.1 指针的基本概念 7.2 指针变量的定义和使用 7.3 指针所占内存空间 7.4 空指针和野指针 7.5 const修饰指针 7.6 指针和数组 7.7 指针和函数 7.8 ...

  2. 如何在SpringBoot项目中兼容Jersey和SpringMVC框架?

    文章目录 Jersey框架介绍 常用的注解: SpringBoot中SpringMVC兼容Jersey 整合Jersey REST(Representational State Transfer)表象 ...

  3. windows10下编译32位和64位webrtc(m77)静态库

    1. windows10下编译32位和64位webrtc(m77)静态库 省略挂代理下载depot_tools以及webrtc代码的过程... 可参考webrtc编译 务必在 cmd 终端环境下进入到 ...

  4. JavaScript 如何判断一个对象中是否有某个属性?

    今天讲讲,JavaScript 如何判断一个对象中是否有某个属性? 我总结了5个方法: 方法1: if(Obj[a]) {} 缺点:对于参数值为 undefined 和 0 的无效. 方法2: if( ...

  5. golang基础面试题,不完整

    启动流程 Q.go的init函数是什么时候执行的? Q.多个init函数执行顺序能保证吗? Q.go init 的执行顺序,注意是不按导入规则的(这里是编译时按文件名的顺序执行的) Q.init函数能 ...

  6. 2021-10-14:被围绕的区域。给你一个 m x n 的矩阵 board ,由若干字符 ‘X‘ 和 ‘O‘ ,找到所有被 ‘X‘ 围绕的区域,并将这些区域里所有的 ‘O‘ 用 ‘X‘ 填充。力扣1

    2021-10-14:被围绕的区域.给你一个 m x n 的矩阵 board ,由若干字符 'X' 和 'O' ,找到所有被 'X' 围绕的区域,并将这些区域里所有的 'O' 用 'X' 填充.力扣1 ...

  7. linux 账户和权限

    目录 一.用户账户管理 二.组账号管理 三.用户账户文件和组账户文件 四.查询账户命令 五.设置目录与文件权限 六.设置命令与文件归属 七.默认文件属性umask 八.修改主机名 一.用户账户管理 u ...

  8. HTML渲染机制

    一直写页面但是很少对一些较深的运行机制的了解,这次趁休假查了一些相关的资料加上个人理解,记录一下关于html渲染的整个过程,也加深一下自己对html渲染的理解 一.先借一张图来看看html的整个加载过 ...

  9. ODOO升级模块后到系统进入不了,报错500

    有时候安装后者升级odoo相关模块后会导致系统进入不了,报错500,此时我们可以通过Odoo命令行卸载相关模块 此方法适用于在安装或升级某个模块后导致崩库,进不去桌面的情况下使用.原理是通过odoo- ...

  10. Python自学指南-第一章-安装运行

    1.1 [环境]快速安装 Python 与PyCharm "工欲善其事,必先利其器",为了自学之路的顺利顺利进行.首先需要搭建项目的开发环境. 1. 下载解释器 进入 Python ...