pytorch 7 save_reload 保存和提取神经网络
import torch
import matplotlib.pyplot as plt
# torch.manual_seed(1) # reproducible
# fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)
# The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
# x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)
训练网络,并用两种方式保存网络
def save():
# save net1
net1 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss()
for t in range(100): # 训练网络
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title('Net1')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
# 2 ways to save the net 保存网络的两种方式,第二种只保存参数更快一点
torch.save(net1, 'net.pkl') # save entire net
torch.save(net1.state_dict(), 'net_params.pkl') # save only the parameters
加载第一种(含所有信息的)网络:torch.load('net.pkl')
def restore_net():
# restore entire net1 to net2
net2 = torch.load('net.pkl')
prediction = net2(x)
# plot result
plt.subplot(132)
plt.title('Net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
加载第二种(只含有参数的)网络:net3.load_state_dict(torch.load('net_params.pkl'))
def restore_params():
# restore only the parameters in net1 to net3
net3 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
# copy net1's parameters into net3
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x)
# plot result
plt.subplot(133)
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.show()
运行上面的函数
# save net1 训练网络,并用两种方式保存网络
save()
# restore entire net (may slow) 所有信息的网络
restore_net()
# restore only the net parameters 只含参数信息的网络,加载前需要重新构造与之前一模一样的网络
restore_params()
三个网络绘制的图片
END
pytorch 7 save_reload 保存和提取神经网络的更多相关文章
- pytorch1.0神经网络保存、提取、加载
pytorch1.0网络保存.提取.加载 import torch import torch.nn.functional as F # 包含激励函数 import matplotlib.pyplot ...
- Pytorch学习笔记(二)---- 神经网络搭建
记录如何用Pytorch搭建LeNet-5,大体步骤包括:网络的搭建->前向传播->定义Loss和Optimizer->训练 # -*- coding: utf-8 -*- # Al ...
- tensorflow学习之路----保存和提取数据
#保存数据注意他只能保存变量,不能保存神经网络的框架.#保存数据的作用:保存权重有利于下一次的训练,或者可以用这个数据进行识别#np.arange():arange函数用于创建等差数组,使用频率非常高 ...
- Keras(六)Autoencoder 自编码 原理及实例 Save&reload 模型的保存和提取
Autoencoder 自编码 压缩与解压 原来有时神经网络要接受大量的输入信息, 比如输入信息是高清图片时, 输入信息量可能达到上千万, 让神经网络直接从上千万个信息源中学习是一件很吃力的工作. 所 ...
- cookie的保存与提取
爬虫过程中,cookie可以保留用户与服务器之间的交互信息,使服务器与用户相互能够识别.由于HTTP协议是无状态协议,即不能够识别客户端身份,即使客户端多次请求同一个url服务器仍然响应.这种协议导致 ...
- Tensorflow学习教程------参数保存和提取重利用
#coding:utf-8 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mni ...
- 【莫烦Pytorch】【P1】人工神经网络VS. 生物神经网络
滴:转载引用请注明哦[握爪] https://www.cnblogs.com/zyrb/p/9700343.html 莫烦教程是一个免费的机器学习(不限于)的学习教程,幽默风俗的语言让我们这些刚刚起步 ...
- pytorch实战(7)-----卷积神经网络
一.卷积: 卷积在 pytorch 中有两种方式: [实际使用中基本都使用 nn.Conv2d() 这种形式] 一种是 torch.nn.Conv2d(), 一种是 torch.nn.function ...
- delphi保存和提取ini文件信息
procedure TLoginForm.FormShow(Sender: TObject);var ini:TIniFile; name:string;begin //实现动态提取数据库的登录用户名 ...
随机推荐
- SwipeRefreshLayout与ViewPager滑动事件冲突解决
问题描写叙述: 开发中发现,SwipeRefreshLayout的下拉刷新,与ViewPager开发的banner的左右滑动事件有一点冲突,导致banner的左右滑动不够顺畅. 非常easy在bann ...
- Tom和Jerry来了,Tom和Jerry走了——北漂18年(38)
上次讲到跟我同一时候入职的女销售走了. 回忆起来,她的问题多半是技巧足够,脑子不足够,走了之后再没联系.不久之后,在老板的要求之下.LilyG又招聘了两位男销售,英文名字非常登对一个叫Tom,一个叫J ...
- nyist oj 115 城市平乱 (最短路径)
城市平乱 时间限制:1000 ms | 内存限制:65535 KB 难度:4 描写叙述 南将军统领着N个部队.这N个部队分别驻扎在N个不同的城市. 他在用这N个部队维护着M个城市的治安.这M个城市 ...
- 【Swift】学习笔记(六)——函数
函数 懂编程语言的来说这个是最主要的了,不论什么语言都有函数这个概念.函数就是完毕特定任务的独立代码块. 函数怎么创建: 1.创建一个无參无返回值的函数(实际上全部的函数都有返回值,这个函数返回vo ...
- XCODE插件 之 Code Pilot 无鼠标化
什么是Code Pilot? Code Pilot 是一个 Xcode 5 插件.同意你不许使用鼠标就能高速地查找项目内的文件.方法和标识符. 它使用模糊查询匹配(fuzzy query matchi ...
- TNS-12555 / TNS-12560 / TNS-00525 Error listening on: (DESCRIPTION=(ADDRESS=(PROTOCOL=IPC)(KEY=EXTPR
TNS-12555 / TNS-12560 / TNS-00525 Error listening on: (DESCRIPTION=(ADDRESS=(PROTOCOL=IPC)(KEY=EXTPR ...
- elasticsearch源码分析之search模块(server端)
elasticsearch源码分析之search模块(server端) 继续接着上一篇的来说啊,当client端将search的请求发送到某一个node之后,剩下的事情就是server端来处理了,具体 ...
- hdoj--1171--Number Sequence(KMP)
Number Sequence Time Limit: 10000/5000 MS (Java/Others) Memory Limit: 32768/32768 K (Java/Others) ...
- Oracle动态性能表-V$SESSION_WAIT,V$SESSION_EVENT
(1)-V$SESSION_WAIT 这是一个寻找性能瓶颈的关键视图.它提供了任何情况下session在数据库中当前正在等待什么(如果session当前什么也没在做,则显示它最后的等待事件).当系统存 ...
- Java基础——GridBagLayout布局
1.GridBagLayout布局管理器非常灵活,每个 GridBagLayout 对象维持一个动态的矩形单元网格: 2.需要和它的约束类(GridBagConstraints类)一起使用: 3.Gr ...