莫烦PyTorch学习笔记(五)——模型的存取
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt torch.manual_seed() # fake data
x = torch.unsqueeze(torch.linspace(-,,),dim=)
y = x.pow() + 0.2 * torch.rand(x.size())
x, y = Variable(x,requires_grad=False), Variable(y,requires_grad=False) def save():
net1 = torch.nn.Sequential(
torch.nn.Linear(, ),
torch.nn.ReLU(),
torch.nn.Linear(, )
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss() for t in range():
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step() plt.figure(,figsize=(,))
plt.subplot()
plt.title('Net1')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(),'r-',lw=)
torch.save(net1, 'net.pkl') # 保存整个网络,包括整个计算图
torch.save(net1.state_dict(), 'net_params.pkl') # 只保存网络中的参数 (速度快, 占内存少) def restore_net():
net2 = torch.load('net.pkl')
prediction = net2(x)
plt.subplot()
plt.title('Net2')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(),'r-',lw=)
def restore_params():
net3 = torch.nn.Sequential(
torch.nn.Linear(, ),
torch.nn.ReLU(),
torch.nn.Linear(, )
)
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x) plt.subplot()
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=)
# 将保存的参数复制到 net3
plt.show() save()
restore_net()
restore_params()

结果和莫烦的不一样,但是找不到问题的所在,,。。。
莫烦PyTorch学习笔记(五)——模型的存取的更多相关文章
- 莫烦PyTorch学习笔记(五)——分类
import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...
- 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)
莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...
- 莫烦pytorch学习笔记(七)——Optimizer优化器
各种优化器的比较 莫烦的对各种优化通俗理解的视频 import torch import torch.utils.data as Data import torch.nn.functional as ...
- 莫烦PyTorch学习笔记(六)——批处理
1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...
- 莫烦pytorch学习笔记(二)——variable
.简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...
- 莫烦 - Pytorch学习笔记 [ 二 ] CNN ( 1 )
CNN原理和结构 观点提出 关于照片的三种观点引出了CNN的作用. 局部性:某一特征只出现在一张image的局部位置中. 相同性: 同一特征重复出现.例如鸟的羽毛. 不变性:subsampling下图 ...
- 莫烦PyTorch学习笔记(四)——回归
下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...
- 莫烦PyTorch学习笔记(三)——激励函数
1. sigmod函数 函数公式和图表如下图 在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...
- 莫烦pytorch学习笔记(一)——torch or numpy
Q1:什么是神经网络? Q2:torch vs numpy Numpy:NumPy系统是Python的一种开源的数值计算扩展.这种工具可用来存储和处理大型矩阵,比Python自身的嵌套列表(neste ...
随机推荐
- Linux上 安装Sorl4.7 中间件用tomcat
最近需要用到solr,公司内部搭建了一个solr测试环境. 版本:solr4.7.2 ,tomcat 7.0.55 jdk:1.7_051 解压 solr 和tomcat 这里就不详说. 1.启动t ...
- Python中and_Or
自 http://www.cnblogs.com/BeginMan/p/3197123.html 一.and: 在Python 中,and 和 or 执行布尔逻辑演算,如你所期待的一样,但是它们并不返 ...
- 《转》python(7)列表
转自 http://www.cnblogs.com/BeginMan/p/3153842.html 一.序列类型操作符 1.切片[]和[:] 2.成员关系操作符(in ,not in ) 1: s1 ...
- MVC 传递数据 从前台到后台,包括单个对象,多个对象,集合
MVC 传递数据 从前台到后台,包括单个对象,多个对象,集合 1.基本数据类型 我们常见有传递 int, string, bool, double, decimal 等类型. 需要注意的是前台传递的参 ...
- wish - 简单的窗口式(windowing) shell
总览 wish [filename] [arg] [arg ...] 选项 -colormap new 指定窗口使用一个新的私有的调色板(colormap)而不使用给屏幕的缺省的调色板. -displ ...
- GetOpenFilename的基本用法
GetOpenFilename '一.概述基本语法 Application.GetOpenFilename 方法 显示标准的“打开”对话框,并获取用户文件名,而不必真正打开任何文件,只是把打开文件名称 ...
- 无法解析的外部符号 jpeg_std_error
1>dlib.lib(png_loader.obj) : error LNK2001: 无法解析的外部符号 png_set_sig_bytes 1>dlib.lib(png_loader. ...
- 解决vagrant上使用Homestead很慢(响应速度10s+)
说明: 使用vagrant和Homestead 在vBox上面跑laravel, 响应速度非常缓慢(大概在10+s), 尝试过增加虚拟机配置, 但是没有任何效果, 经验证也不是数据库的原因 . 通过网 ...
- [转]springmvc+mybatis需要的jar包与详解
1.antlr-2.7.6.jar: 项目中没有添加,hibernate不会执行hql语句 2.Aopalliance.jar: 这个包是AOP联盟的API包,里面包含了针对面向切面的接口,通常Sp ...
- [VS2008] Debug版本程序发布后 由于应用程序的配置不正确,应用程序未能启动,重新安装应用程序可能会纠正这个问题
转自VC错误:http://www.vcerror.com/?p=59 问题描述: [VS2008] 版本程序发布后,运行程序弹出错误框: 由于应用程序的配置不正确,应用程序未能启动,重新安装应用程序 ...