pytorch搭建网络,保存参数,恢复参数
这是看过莫凡python的学习笔记。
搭建网络,两种方式
(1)建立Sequential对象
import torch
net = torch.nn.Sequential(
torch.nn.Linear(2,10),
torch.nn.ReLU(),
torch.nn.Linear(10,2))
输出网络结构
Sequential(
(0): Linear(in_features=2, out_features=10, bias=True)
(1): ReLU()
(2): Linear(in_features=10, out_features=2, bias=True)
)
(2)建立网络类,继承torch.nn.module
class Net(torch.nn.Module):
def __init__(self):
super(Net,self).__init__()
self.hidden = torch.nn.Linear(2,10)
self.predict = torch.nn.Linear(10,2)
def forward(self,x):
x = F.relu(self.hidden(x))
x = self.predict(x)
return x
输出和上面基本一样,略微不同
Net(
(hidden): Linear(in_features=2, out_features=10, bias=True)
(predict): Linear(in_features=10, out_features=2, bias=True)
)
保存模型,两种方式
(1)保存整个网络,及网络参数
torch.save(net,'net.pkl')
(2)只保存网络参数
torch.save(net.state_dict(),'net_params.pkl')
恢复模型,两种方式
(1)加载整个网络,及参数
net2 = torch.load('net.pkl')
(2)加载参数,但需实现网络
net3 = torch.nn.Sequential(
torch.nn.Linear(2,10),
torch.nn.ReLU(),
torch.nn.Linear(10,2))
net3.load_state_dict(torch.load('net_params.pkl'))
pytorch搭建网络,保存参数,恢复参数的更多相关文章
- 一文弄懂pytorch搭建网络流程+多分类评价指标
讲在前面,本来想通过一个简单的多层感知机实验一下不同的优化方法的,结果写着写着就先研究起评价指标来了,之前也写过一篇:https://www.cnblogs.com/xiximayou/p/13700 ...
- TensorFlow进阶(六)---模型保存与恢复、自定义命令行参数
模型保存与恢复.自定义命令行参数. 在我们训练或者测试过程中,总会遇到需要保存训练完成的模型,然后从中恢复继续我们的测试或者其它使用.模型的保存和恢复也是通过tf.train.Saver类去实现,它主 ...
- TensorFlow 训练好模型参数的保存和恢复代码
TensorFlow 训练好模型参数的保存和恢复代码,之前就在想模型不应该每次要个结果都要重新训练一遍吧,应该训练一次就可以一直使用吧. TensorFlow 提供了 Saver 类,可以进行保存和恢 ...
- Pytorch从0开始实现YOLO V3指南 part2——搭建网络结构层
本节翻译自:https://blog.paperspace.com/how-to-implement-a-yolo-v3-object-detector-from-scratch-in-pytorch ...
- pytorch基础-搭建网络
搭建网络的步骤大致为以下: 1.准备数据 2. 定义网络结构model 3. 定义损失函数4. 定义优化算法 optimizer5. 训练 5.1 准备好tensor形式的输入数据和标签(可选) 5. ...
- pytorch autograd backward函数中 retain_graph参数的作用,简单例子分析,以及create_graph参数的作用
retain_graph参数的作用 官方定义: retain_graph (bool, optional) – If False, the graph used to compute the grad ...
- caffe-windows之网络描述文件和参数配置文件注释(mnist例程)
caffe-windows之网络描述文件和参数配置文件注释(mnist例程) lenet_solver.prototxt:在训练和测试时涉及到一些参数配置,训练超参数文件 <-----lenet ...
- react native 网络get请求方式参数不可为undefined或null
react native 网络get请求方式参数不可为undefined(为空的话默认变为)或null 错误写法: export function addToCartAction(isRefreshi ...
- loadrunner 脚本开发-参数化之将内容保存为参数、参数数组及参数值获取Part 2
脚本开发-参数化之将内容保存为参数.参数数组及参数值获取 by:授客 QQ:1033553122 ----------------接 Part 1--------------- 把内容保存到参数数组 ...
随机推荐
- div的作用
<div></div>主要是用来设置涵盖一个区块为主,所谓的区块是包含一行以上的数据,所以在<div></div>的开始之前与结束后,浏览都会自动换行, ...
- #pragma execution_character_set("utf-8")
VC2010增加了“#pragma execution_character_set("utf-8")”,指示char的执行字符集是UTF-8编码. VS2010 设置 字符编码: ...
- 转载-你应该知道的 RPC 原理
在校期间大家都写过不少程序,比如写个hello world服务类,然后本地调用下,如下所示.这些程序的特点是服务消费方和服务提供方是本地调用关系. 而一旦踏入公司尤其是大型互联网公司就会发现,公司的系 ...
- 探究QA职能
测试人员一般是被外界普遍认为是QC,即对产品的质量进行检测,找出质量问题并配合相关人员解决问题,从而管控产品质量,说通俗点就是帮开发找漏洞,给开发擦屁股:如果线上出现bug,就是你没有测试完整,最累的 ...
- 基于C++求两个数的最大公约数最小公倍数
求x,y最大公约数的函数如下: int gys(int x,int y) { int temp; while(x) {temp=x; x=y%x; y=temp;} return y; } x=y的时 ...
- 最短路dijkstra堆优化
demo: #include<bits/stdc++.h> #define max_v 102000 #define inf 0x3f3f3f3f using namespace std; ...
- java格式化数字、货币、金钱
网上摘来的,以后可能会用到 java开发中经常会有数字.货币金钱等格式化需求,货币保留几位小数,货币前端需要加上货币符号等.可以用java.text.NumberFormat和java.text.De ...
- C++输出斐波那契数列的几种方法
定义: 斐波那契数列指的是这样一个数列:0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, ... 这个数列从第三项开始,每一项都等于前两项之和. 以输出斐波那 ...
- ZROI2018提高day1t3
传送门 分析 考场上想到了先枚举p的长度,在枚举这个长度的所有子串,期望得分40~50pts,但是最终只得了20pts,这是因为我写的代码在验证中总是不断删除s'中的第一个p,而这种方式不能解决形如a ...
- 1020C Elections
传送门 题目大意 现在有 n个人,m个党派,第i个人开始想把票投给党派pi,而如果想让他改变他的想法需要花费ci元.你现在是党派1,问你最少花多少钱使得你的党派得票数大于其它任意党派. 分析 我们枚举 ...