本章代码:

这篇文章主要介绍了序列化与反序列化,以及 PyTorch 中的模型保存于加载的两种方式,模型的断点续训练。

序列化与反序列化

模型在内存中是以对象的逻辑结构保存的,但是在硬盘中是以二进制流的方式保存的。

  • 序列化是指将内存中的数据以二进制序列的方式保存到硬盘中。PyTorch 的模型保存就是序列化。

  • 反序列化是指将硬盘中的二进制序列加载到内存中,得到模型的对象。PyTorch 的模型加载就是反序列化。

PyTorch 中的模型保存与加载

torch.save

torch.save(obj, f, pickle_module, pickle_protocol=2, _use_new_zipfile_serialization=False)

主要参数:

  • obj:保存的对象,可以是模型。也可以是 dict。因为一般在保存模型时,不仅要保存模型,还需要保存优化器、此时对应的 epoch 等参数。这时就可以用 dict 包装起来。
  • f:输出路径

其中模型保存还有两种方式:

保存整个 Module

这种方法比较耗时,保存的文件大

torch.savev(net, path)

只保存模型的参数

推荐这种方法,运行比较快,保存的文件比较小

state_sict = net.state_dict()
torch.savev(state_sict, path)

下面是保存 LeNet 的例子。在网络初始化中,把权值都设置为 2020,然后保存模型。

import torch
import numpy as np
import torch.nn as nn
from common_tools import set_seed class LeNet2(nn.Module):
def __init__(self, classes):
super(LeNet2, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 6, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.classifier = nn.Sequential(
nn.Linear(16*5*5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, classes)
) def forward(self, x):
x = self.features(x)
x = x.view(x.size()[0], -1)
x = self.classifier(x)
return x def initialize(self):
for p in self.parameters():
p.data.fill_(2020) net = LeNet2(classes=2019) # "训练"
print("训练前: ", net.features[0].weight[0, ...])
net.initialize()
print("训练后: ", net.features[0].weight[0, ...]) path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl" # 保存整个模型
torch.save(net, path_model) # 保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)

运行完之后,文件夹中生成了``model.pklmodel_state_dict.pkl`,分别保存了整个网络和网络的参数

torch.load

torch.load(f, map_location=None, pickle_module, **pickle_load_args)

主要参数:

  • f:文件路径
  • map_location:指定存在 CPU 或者 GPU。

加载模型也有两种方式

加载整个 Module

如果保存的时候,保存的是整个模型,那么加载时就加载整个模型。这种方法不需要事先创建一个模型对象,也不用知道模型的结构,代码如下:

path_model = "./model.pkl"
net_load = torch.load(path_model) print(net_load)

输出如下:

LeNet2(
(features): Sequential(
(0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(4): ReLU()
(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(classifier): Sequential(
(0): Linear(in_features=400, out_features=120, bias=True)
(1): ReLU()
(2): Linear(in_features=120, out_features=84, bias=True)
(3): ReLU()
(4): Linear(in_features=84, out_features=2019, bias=True)
)
)

只加载模型的参数

如果保存的时候,保存的是模型的参数,那么加载时就参数。这种方法需要事先创建一个模型对象,再使用模型的load_state_dict()方法把参数加载到模型中,代码如下:

path_state_dict = "./model_state_dict.pkl"
state_dict_load = torch.load(path_state_dict)
net_new = LeNet2(classes=2019) print("加载前: ", net_new.features[0].weight[0, ...])
net_new.load_state_dict(state_dict_load)
print("加载后: ", net_new.features[0].weight[0, ...])

模型的断点续训练

在训练过程中,可能由于某种意外原因如断点等导致训练终止,这时需要重新开始训练。断点续练是在训练过程中每隔一定次数的 epoch 就保存模型的参数和优化器的参数,这样如果意外终止训练了,下次就可以重新加载最新的模型参数和优化器的参数,在这个基础上继续训练。

下面的代码中,每隔 5 个 epoch 就保存一次,保存的是一个 dict,包括模型参数、优化器的参数、epoch。然后在 epoch 大于 5 时,就break模拟训练意外终止。关键代码如下:

    if (epoch+1) % checkpoint_interval == 0:

        checkpoint = {"model_state_dict": net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch}
path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
torch.save(checkpoint, path_checkpoint)

在 epoch 大于 5 时,就break模拟训练意外终止

    if epoch > 5:
print("训练意外中断...")
break

断点续训练的恢复代码如下:

path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint) net.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] scheduler.last_epoch = start_epoch

需要注意的是,还要设置scheduler.last_epoch参数为保存的 epoch。模型训练的起始 epoch 也要修改为保存的 epoch。

参考资料

如果你觉得这篇文章对你有帮助,不妨点个赞,让我有更多动力写出好文章。

[PyTorch 学习笔记] 7.1 模型保存与加载的更多相关文章

  1. 驱动开发学习笔记. 0.07 Uboot链接地址 加载地址 和 链接脚本地址

    驱动开发学习笔记. 0.07 Uboot链接地址 加载地址 和 链接脚本地址 最近重新看了乾龙_Heron的<ARM 上电启动及 Uboot 代码分析>(下简称<代码分析>) ...

  2. tensorflow 模型保存与加载 和TensorFlow serving + grpc + docker项目部署

    TensorFlow 模型保存与加载 TensorFlow中总共有两种保存和加载模型的方法.第一种是利用 tf.train.Saver() 来保存,第二种就是利用 SavedModel 来保存模型,接 ...

  3. tensorflow实现线性回归、以及模型保存与加载

    内容:包含tensorflow变量作用域.tensorboard收集.模型保存与加载.自定义命令行参数 1.知识点 """ 1.训练过程: 1.准备好特征和目标值 2.建 ...

  4. sklearn模型保存与加载

    sklearn模型保存与加载 sklearn模型的保存和加载API 线性回归的模型保存加载案例 保存模型 sklearn模型的保存和加载API from sklearn.externals impor ...

  5. TensorFlow构建卷积神经网络/模型保存与加载/正则化

    TensorFlow 官方文档:https://www.tensorflow.org/api_guides/python/math_ops # Arithmetic Operators import ...

  6. Tensorflow模型保存与加载

    在使用Tensorflow时,我们经常要将以训练好的模型保存到本地或者使用别人已训练好的模型,因此,作此笔记记录下来. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提 ...

  7. 转 tensorflow模型保存 与 加载

    使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我们可能也需要用到别人训练好的模型,并在这个基础上再次训练.这时候我们需要掌握如何操作这些模型数据.看完本文,相信你一定会有收获 ...

  8. TensorFlow的模型保存与加载

    import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf #tensorboard --logdir=&qu ...

  9. Entity Framework学习笔记(五)----Linq查询(2)---贪婪加载

    请注明转载地址:http://www.cnblogs.com/arhat 在上一章中,我们使用了Linq对Entity Framework进行了一个查询,但是通过学习我们却发现了懒加载给我来的性能上的 ...

随机推荐

  1. itest(爱测试) 开源一站式敏捷测试管理平台&极简项目管理,重大升级(接口测试)6.0.0 发布

    itest 简介 itest 开源敏捷测试管理,testOps 践行者,极简的任务管理,测试管理,缺陷管理,测试环境管理,接口测试5合1,又有丰富的统计分析.可按测试包分配测试用例执行,也可建测试迭代 ...

  2. C#/VB.NET 比较两个Word文档差异

    本文以C#和VB.NET代码为例,来介绍如何对比两个Word文档差异.程序中使用最新版的Spire.Doc for .NET 版本8.8.2.编辑代码前,先在VS程序中添加引用Spire.Doc.dl ...

  3. JavaScript 用七种方式教你判断一个变量是否为数组类型

    JavaScript 如何判断一个变量是否为数组类型 引言 正文 方法一 方法二 方法三 方法四 方法五 方法六 方法七 结束语 引言 我们如何判断一个变量是否为数组类型呢? 今天来给大家介绍七种方式 ...

  4. puppeteer去掉同源策略及请求拦截

    puppeteer是一个功能强大的工具,在自动化测试和爬虫方面应用广泛,这里谈一下如何在puppeteer中关掉同源策略和进行请求拦截. 同源策略 同源策略为web 安全提供了有力的保障,但是有时候我 ...

  5. js的事件循环和任务队列

    js 异步.栈.事件循环.任务队列 在开发中经常遇到js的异步问题,为了方便理解,记录下来,随时回顾. 以下的所有代码都是在浏览器环境下运行 在浏览器中js的运行是依赖浏览器js引擎来解析的,并且是在 ...

  6. 12c RAC 用Rman 恢复到异机单实例

    准备工作 原服务器软件部署:Redhat 6.6 + Oracle 12.2.0.1 rac Oracle12c单实例安装 1.创建恢复服务器,设置大于原库数据大小的磁盘容量.设置相同的服务器主机名参 ...

  7. src rpm 下载地址

    drbd: http://mirror.rackspace.com/elrepo/elrepo/el7/SRPMS/ rabbitmq: https://dl.bintray.com/rabbitmq ...

  8. linux tmpfs及消耗内存脚本

    一.tmpfs介绍 tmpfs是一种虚拟内存文件系统,正如这个定义它最大的特点就是它的存储空间在VM里面VM是由linux内核里面的vm子系统管理的东西,现在大多数操作系统都采用了虚拟内存管理机制VM ...

  9. 兼容低版本IE浏览器的一些心得体会(持续更新)

    前言: 近期工作中,突然被要求改别人的代码,其中有一项就是兼容IE低版本浏览器,所以优雅降级吧. 我相信兼容低版本IE是许多前端开发的噩梦,尤其是改别人写的代码,更是痛不欲生. 本文将介绍一些本人兼容 ...

  10. seo如何发外链

    http://www.wocaoseo.com/thread-228-1-1.html 在做外链方面博主并没有什么太多的经验,做为一位seo,下面武汉seo把自己做外链的大条列出来,都是经过本身实践并 ...