import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt torch.manual_seed() # fake data
x = torch.unsqueeze(torch.linspace(-,,),dim=)
y = x.pow() + 0.2 * torch.rand(x.size())
x, y = Variable(x,requires_grad=False), Variable(y,requires_grad=False) def save():
net1 = torch.nn.Sequential(
torch.nn.Linear(, ),
torch.nn.ReLU(),
torch.nn.Linear(, )
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss() for t in range():
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step() plt.figure(,figsize=(,))
plt.subplot()
plt.title('Net1')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(),'r-',lw=)
torch.save(net1, 'net.pkl') # 保存整个网络,包括整个计算图
torch.save(net1.state_dict(), 'net_params.pkl') # 只保存网络中的参数 (速度快, 占内存少) def restore_net():
net2 = torch.load('net.pkl')
prediction = net2(x)
plt.subplot()
plt.title('Net2')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(),'r-',lw=)
def restore_params():
net3 = torch.nn.Sequential(
torch.nn.Linear(, ),
torch.nn.ReLU(),
torch.nn.Linear(, )
)
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x) plt.subplot()
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=)
# 将保存的参数复制到 net3
plt.show() save()
restore_net()
restore_params()

结果和莫烦的不一样,但是找不到问题的所在,,。。。

莫烦PyTorch学习笔记(五)——模型的存取的更多相关文章

  1. 莫烦PyTorch学习笔记(五)——分类

    import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...

  2. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  3. 莫烦pytorch学习笔记(七)——Optimizer优化器

    各种优化器的比较 莫烦的对各种优化通俗理解的视频 import torch import torch.utils.data as Data import torch.nn.functional as ...

  4. 莫烦PyTorch学习笔记(六)——批处理

    1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...

  5. 莫烦pytorch学习笔记(二)——variable

    .简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...

  6. 莫烦 - Pytorch学习笔记 [ 二 ] CNN ( 1 )

    CNN原理和结构 观点提出 关于照片的三种观点引出了CNN的作用. 局部性:某一特征只出现在一张image的局部位置中. 相同性: 同一特征重复出现.例如鸟的羽毛. 不变性:subsampling下图 ...

  7. 莫烦PyTorch学习笔记(四)——回归

    下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...

  8. 莫烦PyTorch学习笔记(三)——激励函数

    1. sigmod函数 函数公式和图表如下图     在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...

  9. 莫烦pytorch学习笔记(一)——torch or numpy

    Q1:什么是神经网络? Q2:torch vs numpy Numpy:NumPy系统是Python的一种开源的数值计算扩展.这种工具可用来存储和处理大型矩阵,比Python自身的嵌套列表(neste ...

随机推荐

  1. jquery控件的学习

    <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/ ...

  2. 如何用Maven创建web项目(具体步骤)转载

    使用eclipse插件创建一个web project 首先创建一个Maven的Project如下图 本文转载于(http://blog.csdn.net/chuyuqing/article/detai ...

  3. 【ARC072E】Alice in linear land

    题目 瑟瑟发抖,这竟然只是个蓝题 题意大概就是初始在\(0\),要到坐标为\(D\)的地方去,有\(n\)条指令,第\(i\)条为\(d_i\).当收到一条指令\(x\)后,如果向\(D\)方向走\( ...

  4. YARN设计思路

  5. CF431E Chemistry Experiment

    题意:有n个试管,有高度为hi的水银.操作1:将试管x中的水银高度改成y.操作2:将体积为v的水注入试管,求水位的高度?n,q<=1e5. 标程: #include<bits/stdc++ ...

  6. iptbales无法正常重启

    新主机iptables无法启动关闭和重启 一般是由于没有配文件导致 解决办法 直接touch /etc/sysconfig/iptables 然后就可以正常启动. 备注:一般存在于centos6系列中

  7. 【JZOJ1667】【BZOJ1801】【luoguP2051】中国象棋

    description 在N行M列的棋盘上,放若干个炮可以是0个,使得没有任何一个炮可以攻击另一个炮.请问有多少种放置方法?中国象棋中炮的行走方式大家应该很清楚吧. analysis \(DP\),容 ...

  8. day28 import,from * import *,__name__

    Python之路,Day16 = Python基础16 一 module通常模块为一个文件,直接使用import来导入就好了.可以作为module的文件类型有".py"." ...

  9. csp-s模拟测试87

    csp-s模拟测试87 考场状态还可以$T1$我当时以为我秒切,$T2$确认自己思路不对后毅然决然码上,$T3$暴力挂了太可惜了. 03:01:28 03:16:07 03:11:38 140 03: ...

  10. csv文件格式

    弱渣今天第一次读Kaggle入门文章,知道train data,test data以及提供的result文件大都是以csv文件格式给出的. csv,全称 Comma-Separated Values, ...