转载请注明出处:

http://www.cnblogs.com/darkknightzh/p/8108466.html

参考网址:

http://pytorch.org/docs/master/notes/serialization.html

https://github.com/clcarwin/sphereface_pytorch

有两种方式保存和载入模型

1. 只保存和载入模型参数

保存:

torch.save(the_model.state_dict(), PATH)

载入:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

当model使用gpu训练时,可以将数据转换到cpu中,并保存(载入时,还是上面的方法。需要使用gpu时,加上.cuda()):

def save_model(model, filename):
state = model.state_dict()
for key in state: state[key] = state[key].clone().cpu()
torch.save(state, filename)

2. 保存和载入整个模型

保存:

torch.save(the_model, PATH)

载入:

the_model = torch.load(PATH)

However in this case, the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.

第二种方式,序列化后的数据使用特殊的结构,缺点就是当在其他工程中使用时,可能会碰到各种问题。

因而,官方更建议使用第一种方式。

(原+译)pytorch中保存和载入模型的更多相关文章

  1. (原)pytorch中使用TensorRT

    转载请注明出处: https://www.cnblogs.com/darkknightzh/p/11332155.html 代码网址: https://github.com/darkknightzh/ ...

  2. TensorFlow保存和载入模型

    首先定义一个tf.train.Saver类: saver = tf.train.Saver(max_to_keep=1) 其中,max_to_keep参数设定只保存最后一个参数,默认值是5,即保存最后 ...

  3. (原)PyTorch中使用指定的GPU

    转载请注明出处: http://www.cnblogs.com/darkknightzh/p/6836568.html PyTorch默认使用从0开始的GPU,如果GPU0正在运行程序,需要指定其他G ...

  4. pytorch中修改后的模型如何加载预训练模型

    问题描述 简单来说,比如你要加载一个vgg16模型,但是你自己需要的网络结构并不是原本的vgg16网络,可能你删掉某些层,可能你改掉某些层,这时你去加载预训练模型,就会报错,错误原因就是你的模型和原本 ...

  5. 第六节,TensorFlow编程基础案例-保存和恢复模型(中)

    在我们使用TensorFlow的时候,有时候需要训练一个比较复杂的网络,比如后面的AlexNet,ResNet,GoogleNet等等,由于训练这些网络花费的时间比较长,因此我们需要保存模型的参数. ...

  6. TensorFlow学习笔记:保存和读取模型

    TensorFlow 更新频率实在太快,从 1.0 版本正式发布后,很多 API 接口就发生了改变.今天用 TF 训练了一个 CNN 模型,结果在保存模型的时候居然遇到各种问题.Google 搜出来的 ...

  7. pytorch 中模型的保存与加载,增量训练

     让模型接着上次保存好的模型训练,模型加载 #实例化模型.优化器.损失函数 model = MnistModel().to(config.device) optimizer = optim.Adam( ...

  8. 『TensorFlow』模型保存和载入方法汇总

    『TensorFlow』第七弹_保存&载入会话_霸王回马 一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 ...

  9. 详解Pytorch中的网络构造,模型save和load,.pth权重文件解析

    转载:https://zhuanlan.zhihu.com/p/53927068 https://blog.csdn.net/wangdongwei0/article/details/88956527 ...

随机推荐

  1. PPPOE数据包转换及SharpPcap应用

    在最近写的一个程序中需要用到Sniffer功能,但由于通过.net自身的Socket做出来的Sniffer不能达到实际应用的要求(如不能监听WIFI数据包)所以找到了WinPCAP的.NET库Shar ...

  2. 用css3实现风车效果

    前面讲过css3可以替代很多js实现的效果,其实很多时候纯css3甚至可以替代图片,直接用css3就可以画出一些简单的图片.虽然css3画出来的图片效果可能不如直接用图片的好,实现起来也比较复杂,最麻 ...

  3. Datetimepicker配置参数

    jquery的datetimepicker时间控件除了样式有点不太美观,功能性还是相当强大的. 在正常情况下input的type应该设置为"text",可点击又可输入(mask,e ...

  4. 大数据开发实战:Stream SQL实时开发一

    1.流计算SQL原理和架构 流计算SQL通常是一个类SQL的声明式语言,主要用于对流式数据(Streams)的持续性查询,目的是在常见流计算平台和框架(如Storm.Spark Streaming.F ...

  5. javascript学习笔记——怎样改动<a href="#">url name</a>

    0.前言     使用了一段时间javascript,再花了点时间学习了jquery.可是总是感觉自己非常"迷糊",比如<a href="#">ur ...

  6. Python 各种测试框架简介

    转载:https://blog.csdn.net/yockie/article/details/47415265 一.doctest doctest 是一个 Python 发行版自带的标准模块.本篇将 ...

  7. mysql zerofill 的使用

    转自:http://www.jquerycn.cn/blog/mysql/ 那这个int[M]中M是什么意义喃,在定义数值型数据类型的时候,可以在关键字括号内指定整数值(如:int(M),M的最大值为 ...

  8. Solidworks公司电脑图纸被加密之后如何解密输出

    第一步:打开总装配的组件(该组件需要包含你所有需要的零件),比如打开其中一个:   第二步:Solidworks的菜单中依次:"文件"→"打包"(有的版本是pa ...

  9. asp.net集合类

    1.返回IEnumerable类型 protected void Page_Load(object sender, EventArgs e) { IEnumerable ie = AllGet(); ...

  10. PAT 1065 1066 1067 1068

    pat 1065 A+B and C                                          主要是注意一下加法溢出的情况,不要试图使用double,因为它的精度是15~16 ...