import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt torch.manual_seed() # fake data
x = torch.unsqueeze(torch.linspace(-,,),dim=)
y = x.pow() + 0.2 * torch.rand(x.size())
x, y = Variable(x,requires_grad=False), Variable(y,requires_grad=False) def save():
net1 = torch.nn.Sequential(
torch.nn.Linear(, ),
torch.nn.ReLU(),
torch.nn.Linear(, )
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss() for t in range():
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step() plt.figure(,figsize=(,))
plt.subplot()
plt.title('Net1')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(),'r-',lw=)
torch.save(net1, 'net.pkl') # 保存整个网络,包括整个计算图
torch.save(net1.state_dict(), 'net_params.pkl') # 只保存网络中的参数 (速度快, 占内存少) def restore_net():
net2 = torch.load('net.pkl')
prediction = net2(x)
plt.subplot()
plt.title('Net2')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(),'r-',lw=)
def restore_params():
net3 = torch.nn.Sequential(
torch.nn.Linear(, ),
torch.nn.ReLU(),
torch.nn.Linear(, )
)
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x) plt.subplot()
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=)
# 将保存的参数复制到 net3
plt.show() save()
restore_net()
restore_params()

结果和莫烦的不一样,但是找不到问题的所在,,。。。

莫烦PyTorch学习笔记(五)——模型的存取的更多相关文章

  1. 莫烦PyTorch学习笔记(五)——分类

    import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...

  2. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  3. 莫烦pytorch学习笔记(七)——Optimizer优化器

    各种优化器的比较 莫烦的对各种优化通俗理解的视频 import torch import torch.utils.data as Data import torch.nn.functional as ...

  4. 莫烦PyTorch学习笔记(六)——批处理

    1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...

  5. 莫烦pytorch学习笔记(二)——variable

    .简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...

  6. 莫烦 - Pytorch学习笔记 [ 二 ] CNN ( 1 )

    CNN原理和结构 观点提出 关于照片的三种观点引出了CNN的作用. 局部性:某一特征只出现在一张image的局部位置中. 相同性: 同一特征重复出现.例如鸟的羽毛. 不变性:subsampling下图 ...

  7. 莫烦PyTorch学习笔记(四)——回归

    下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...

  8. 莫烦PyTorch学习笔记(三)——激励函数

    1. sigmod函数 函数公式和图表如下图     在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...

  9. 莫烦pytorch学习笔记(一)——torch or numpy

    Q1:什么是神经网络? Q2:torch vs numpy Numpy:NumPy系统是Python的一种开源的数值计算扩展.这种工具可用来存储和处理大型矩阵,比Python自身的嵌套列表(neste ...

随机推荐

  1. 转载:jQuery 获取屏幕高度、宽度

    做手机Web开发做浏览器兼容用到了,所以在网上找了些汇总下. alert($(window).height()); //浏览器当前窗口可视区域高度 alert($(document).height() ...

  2. 为什么 TCP 建立连接是三次握手,关闭连接确是四次挥手呢?

    Java技术栈 www.javastack.cn 优秀的Java技术公众号 作者:小书go https://blog.csdn.net/qzcsu/article/details/72861891 背 ...

  3. mssql查询表在哪个数据库中

    mssql查询表在哪个数据库中 EXEC sp_MSforeachdb @command1='IF object_id(''?'' + ''..表名'') IS NOT NULL PRINT ''?' ...

  4. python学习3—数据类型之整型、字符串和布尔值

    python学习3-数据类型之整型.字符串和布尔值 数据类型 python3支持的数据类型共有6种: 1 Number 2 String 3 List 4 Tuple 5 Set 6 Dictiona ...

  5. 批量调用百度地图API获取地址经纬度坐标

    1 申请密匙 注册百度地图API:http://lbsyun.baidu.com/index.php?title=webapi 点击左侧 “获取密匙” ,经过填写个人信息.邮箱注册等,成功之后在开放平 ...

  6. Winform 窗体闪烁 & 任务栏提示

    准备: [DllImport("user32.dll")] static extern bool FlashWindowEx(ref FLASHWINFO pwfi); [DllI ...

  7. 原来腾迅的QQ号竟然是个int变量

    今天有个人加我好友,我一点开申请界面 我惊异了.... 我擦,号码竟然是个负数,但是人物资料里面却是个正数 有编程经验的人,一眼就看得出来原因.而且一眼就看得出来,它们是什么 1857918296 + ...

  8. DIV+CSS网页布局常用的一些基础知识

    CSS命名规范 一.文件命名规范 全局样式:global.css:框架布局:layout.css:字体样式:font.css:链接样式:link.css:打印样式:print.css: 二.常用类/I ...

  9. eclipse导出说明文档

    选中项目--右键--Export--Java--Javadoc—Finish 1.为程序添加文档注释 2.选中项目--右键Export--Java--Javadoc--next, 3.next--在V ...

  10. ios 中倒计时计算,时间戳为NaN

    // 倒计时 daojishi(params) { let _this = this; let datetemp = this.servertimes; let lasttime = Date.par ...