pytorch保存模型等相关参数,利用torch.save(),以及读取保存之后的文件
本文分为两部分,第一部分讲如何保存模型参数,优化器参数等等,第二部分则讲如何读取。
假设网络为model = Net(), optimizer = optim.Adam(model.parameters(), lr=args.lr), 假设在某个epoch,我们要保存模型参数,优化器参数以及epoch
一、
1. 先建立一个字典,保存三个参数:
state = {‘net':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
2.调用torch.save():
torch.save(state, dir)
其中dir表示保存文件的绝对路径+保存文件名,如'/home/qinying/Desktop/modelpara.pth'
二、
当你想恢复某一阶段的训练(或者进行测试)时,那么就可以读取之前保存的网络模型参数等。
checkpoint = torch.load(dir)
model.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch'] + 1
pytorch保存模型等相关参数,利用torch.save(),以及读取保存之后的文件的更多相关文章
- PyTorch保存模型、冻结参数等
此外可以参考PyTorch模型保存.https://zhuanlan.zhihu.com/p/73893187 查看模型每层输出详情 Keras有一个简洁的API来查看模型的每一层输出尺寸,这在调试网 ...
- PyTorch模型读写、参数初始化、Finetune
使用了一段时间PyTorch,感觉爱不释手(0-0),听说现在已经有C++接口.在应用过程中不可避免需要使用Finetune/参数初始化/模型加载等. 模型保存/加载 1.所有模型参数 训练过程中,有 ...
- [Pytorch]Pytorch 保存模型与加载模型(转)
转自:知乎 目录: 保存模型与加载模型 冻结一部分参数,训练另一部分参数 采用不同的学习率进行训练 1.保存模型与加载 简单的保存与加载方法: # 保存整个网络 torch.save(net, PAT ...
- Keras框架下的保存模型和加载模型
在Keras框架下训练深度学习模型时,一般思路是在训练环境下训练出模型,然后拿训练好的模型(即保存模型相应信息的文件)到生产环境下去部署.在训练过程中我们可能会遇到以下情况: 需要运行很长时间的程序在 ...
- 【4】TensorFlow光速入门-保存模型及加载模型并使用
本文地址:https://www.cnblogs.com/tujia/p/13862360.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...
- [深度学习] Pytorch(三)—— 多/单GPU、CPU,训练保存、加载模型参数问题
[深度学习] Pytorch(三)-- 多/单GPU.CPU,训练保存.加载预测模型问题 上一篇实践学习中,遇到了在多/单个GPU.GPU与CPU的不同环境下训练保存.加载使用使用模型的问题,如果保存 ...
- pytorch加载和保存模型
在模型完成训练后,我们需要将训练好的模型保存为一个文件供测试使用,或者因为一些原因我们需要继续之前的状态训练之前保存的模型,那么如何在PyTorch中保存和恢复模型呢? 方法一(推荐): 第一种方法也 ...
- PyTorch保存模型与加载模型+Finetune预训练模型使用
Pytorch 保存模型与加载模型 PyTorch之保存加载模型 参数初始化参 数的初始化其实就是对参数赋值.而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了da ...
- 使用Pytorch在多GPU下保存和加载训练模型参数遇到的问题
最近使用Pytorch在学习一个深度学习项目,在模型保存和加载过程中遇到了问题,最终通过在网卡查找资料得已解决,故以此记之,以备忘却. 首先,是在使用多GPU进行模型训练的过程中,在保存模型参数时,应 ...
随机推荐
- 前端之CSS盒模型介绍
css盒模型 css盒模型是css的基石,盒模型由content(主体内容),padding(补白,填充),border(边框),margin(外间距); 1.content: width:数值+单位 ...
- Flask应用运行流程
当我们运行项目后,Flask内部都经历了什么 1.app.run()启动项目,ctrl点进源码 app.py: 1)执行了run_simple() 2)注意第三个参数,这里是Flask实例化的对象,在 ...
- 查看Windows激活信息
使用 Windows + R组合快捷键打开运行命令框 1.运行: slmgr.vbs -dlv 可以查询到Win10的激活信息,包括:激活ID.安装ID.激活截止日期等信息. 2.运行: slmgr. ...
- (七)maven之阿里云镜像提高jar下载速度
阿里云国内镜像,提高jar包下载速度 镜像 maven默认会从中央仓库下载包,但是下载过几次就知道,下载速度非常慢.镜像就相当于是中央仓库的一个副本,内容和中央仓库完全一样,而且同时也能保证下载速度, ...
- c#写出乘法口诀
显然是显得无聊五分钟写的乘法口诀 static void Main(string[] args) { int dq; int[] array ...
- [BZOJ2938]病毒 (AC自动机+dfs)
题目描述 二进制病毒审查委员会最近发现了如下的规律:某些确定的二进制串是病毒的代码.如果某段代码中不存在任何一段病毒代码,那么我们就称这段代码是安全的.现在委员会已经找出了所有的病毒代码段,试问,是否 ...
- CAS (Compare and Swap)
synchronized是悲观锁 注意:实现了CAS的有原子类(AtomicInteger,AtomicLong,等等原子类) CAS 是乐观锁,一种高效实现线程安全性的方法 1.支持原子更新操作,适 ...
- shell脚本,通过一个shell程序计算n的阶乘。
[root@localhost ~]# cat jiechen.sh #!/bin/bash #设计一个shell程序计算n的阶乘,要求: #.从命令行接收参数n; #.在程序开始后立即判断n的合法性 ...
- CentOS 7 升级gcc/g++编译器
gcc的升级必须要使用源码进行升级,也就说,必须要使用源码进行编译才行.我的7.2的CentOS目前自带的gcc是4.8.5的,gcc从4.8之后开始支持C++11,但是鉴于现在C++14.C++17 ...
- Python爬虫环境常用库安装
1:urllib urllib.request这两个库是python自带的库,不需要重新安装,在python中输入如下代码: import urllibimport urllib.requestres ...