import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt torch.manual_seed() # fake data
x = torch.unsqueeze(torch.linspace(-,,),dim=)
y = x.pow() + 0.2 * torch.rand(x.size())
x, y = Variable(x,requires_grad=False), Variable(y,requires_grad=False) def save():
net1 = torch.nn.Sequential(
torch.nn.Linear(, ),
torch.nn.ReLU(),
torch.nn.Linear(, )
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss() for t in range():
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step() plt.figure(,figsize=(,))
plt.subplot()
plt.title('Net1')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(),'r-',lw=)
torch.save(net1, 'net.pkl') # 保存整个网络,包括整个计算图
torch.save(net1.state_dict(), 'net_params.pkl') # 只保存网络中的参数 (速度快, 占内存少) def restore_net():
net2 = torch.load('net.pkl')
prediction = net2(x)
plt.subplot()
plt.title('Net2')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(),'r-',lw=)
def restore_params():
net3 = torch.nn.Sequential(
torch.nn.Linear(, ),
torch.nn.ReLU(),
torch.nn.Linear(, )
)
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x) plt.subplot()
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=)
# 将保存的参数复制到 net3
plt.show() save()
restore_net()
restore_params()

结果和莫烦的不一样,但是找不到问题的所在,,。。。

莫烦PyTorch学习笔记(五)——模型的存取的更多相关文章

  1. 莫烦PyTorch学习笔记(五)——分类

    import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...

  2. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  3. 莫烦pytorch学习笔记(七)——Optimizer优化器

    各种优化器的比较 莫烦的对各种优化通俗理解的视频 import torch import torch.utils.data as Data import torch.nn.functional as ...

  4. 莫烦PyTorch学习笔记(六)——批处理

    1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...

  5. 莫烦pytorch学习笔记(二)——variable

    .简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...

  6. 莫烦 - Pytorch学习笔记 [ 二 ] CNN ( 1 )

    CNN原理和结构 观点提出 关于照片的三种观点引出了CNN的作用. 局部性:某一特征只出现在一张image的局部位置中. 相同性: 同一特征重复出现.例如鸟的羽毛. 不变性:subsampling下图 ...

  7. 莫烦PyTorch学习笔记(四)——回归

    下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...

  8. 莫烦PyTorch学习笔记(三)——激励函数

    1. sigmod函数 函数公式和图表如下图     在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...

  9. 莫烦pytorch学习笔记(一)——torch or numpy

    Q1:什么是神经网络? Q2:torch vs numpy Numpy:NumPy系统是Python的一种开源的数值计算扩展.这种工具可用来存储和处理大型矩阵,比Python自身的嵌套列表(neste ...

随机推荐

  1. ActiveMQ任意文件写入漏洞(CVE-2016-3088)

    上传webshell 容器用vulhub的 PUT一个jsp文件 MOVE到api目录 默认的ActiveMQ账号密码均为admin,首先访问http://your-ip:8161/admin/tes ...

  2. nutch2.2.1+mysql抓取数据

    基本环境:linux centos6.5 nutch2.2.1 源码包, mysql 5.5 ,elasticsearch1.1.1, jdk1.7 1.下载地址http://mirror.bjtu. ...

  3. Spring AOP源码分析(二):AOP的三种配置方式与内部解析实现

    AOP配置 在应用代码中,可以通过在spring的XML配置文件applicationContext.xml或者基于注解方式来配置AOP.AOP配置的核心元素为:pointcut,advisor,as ...

  4. 用python, PIL在图像上添加文字(可以控制,调节为水印等)

    最近想在图像上,添加想要的文字,首先想到的是matplotlib,但是这个更加倾向于画图(柱状图,折线图之类) opencv这个库肯定也行,但是为了和我现有程序连接在一起,我选择了PIL 其中字体的设 ...

  5. 3-Ubuntu下python3安装pymysql模块(1)

    命令:sudo pip3 install pymysql

  6. 网页开发人员收藏的16款HTML5工具

    本文收集的20款优秀的 HTML5 Web 应用程序,值得添加到您的 HTML5 的工具箱中,他们能够帮助你开发前端项目更快.更容易. Initializr Initializr 是一个可以让你创建 ...

  7. element-ui表格点击一行展开

    转载:https://www.cnblogs.com/xiaochongchong/p/8127282.html <template> <el-table :data="t ...

  8. gradle配置全局仓库

    1.在系统环境变量中配置: GRADLE_USER_HOME=D:\gradleRepository 2.在配置的路径中,增加文件init.gradle allprojects{ repositori ...

  9. [SNOI 2017] 炸弹

    题目描述: 给定炸弹和爆炸范围,求对于每个炸弹连锁爆炸的炸弹总和对\(1e9+7\)取膜 思路: 为啥都是线段树+TS+tarjan呢? 实在是搞不懂~~ 线性\(O(n)\)递推即可. #inclu ...

  10. VS2010-MFC(图形图像:CDC类及其屏幕绘图函数)

    转自:http://www.jizhuomi.com/software/244.html 上一节讲了文本输出的知识,本节的主要内容是CDC类及其屏幕绘图函数. CDC类简介 CDC类是一个设备上下文类 ...