莫烦PyTorch学习笔记(五)——模型的存取
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt torch.manual_seed() # fake data
x = torch.unsqueeze(torch.linspace(-,,),dim=)
y = x.pow() + 0.2 * torch.rand(x.size())
x, y = Variable(x,requires_grad=False), Variable(y,requires_grad=False) def save():
net1 = torch.nn.Sequential(
torch.nn.Linear(, ),
torch.nn.ReLU(),
torch.nn.Linear(, )
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss() for t in range():
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step() plt.figure(,figsize=(,))
plt.subplot()
plt.title('Net1')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(),'r-',lw=)
torch.save(net1, 'net.pkl') # 保存整个网络,包括整个计算图
torch.save(net1.state_dict(), 'net_params.pkl') # 只保存网络中的参数 (速度快, 占内存少) def restore_net():
net2 = torch.load('net.pkl')
prediction = net2(x)
plt.subplot()
plt.title('Net2')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(),'r-',lw=)
def restore_params():
net3 = torch.nn.Sequential(
torch.nn.Linear(, ),
torch.nn.ReLU(),
torch.nn.Linear(, )
)
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x) plt.subplot()
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=)
# 将保存的参数复制到 net3
plt.show() save()
restore_net()
restore_params()

结果和莫烦的不一样,但是找不到问题的所在,,。。。
莫烦PyTorch学习笔记(五)——模型的存取的更多相关文章
- 莫烦PyTorch学习笔记(五)——分类
import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...
- 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)
莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...
- 莫烦pytorch学习笔记(七)——Optimizer优化器
各种优化器的比较 莫烦的对各种优化通俗理解的视频 import torch import torch.utils.data as Data import torch.nn.functional as ...
- 莫烦PyTorch学习笔记(六)——批处理
1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...
- 莫烦pytorch学习笔记(二)——variable
.简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...
- 莫烦 - Pytorch学习笔记 [ 二 ] CNN ( 1 )
CNN原理和结构 观点提出 关于照片的三种观点引出了CNN的作用. 局部性:某一特征只出现在一张image的局部位置中. 相同性: 同一特征重复出现.例如鸟的羽毛. 不变性:subsampling下图 ...
- 莫烦PyTorch学习笔记(四)——回归
下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...
- 莫烦PyTorch学习笔记(三)——激励函数
1. sigmod函数 函数公式和图表如下图 在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...
- 莫烦pytorch学习笔记(一)——torch or numpy
Q1:什么是神经网络? Q2:torch vs numpy Numpy:NumPy系统是Python的一种开源的数值计算扩展.这种工具可用来存储和处理大型矩阵,比Python自身的嵌套列表(neste ...
随机推荐
- HDU1285-确定比赛名次-拓扑排序板子题
有N个比赛队(1<=N<=500),编号依次为1,2,3,....,N进行比赛,比赛结束后,裁判委员会要将所有参赛队伍从前往后依次排名,但现在裁判委员会不能直接获得每个队的比赛成绩,只知道 ...
- 剑指offer——20删除链表中重复的结点
题目描述 在一个排序的链表中,存在重复的结点,请删除该链表中重复的结点,重复的结点不保留,返回链表头指针. 例如,链表1->2->3->3->4->4->5 处理后 ...
- grep 的一些常用用法
打印匹配到的上下5行 grep -C 5 'root' /etc/passwd 上下5行 grep -A 5 'root' /etc/passwd afte ...
- Git的故事
目录 Git Git的概念 Git的安装 Git的配置 Git的指令 Git Git的概念 首先我们要知道git是什么,最根本的概念是版本控制,顾名思义,就是git可以帮助我们控制自己写的代码或者文档 ...
- OC开发系列-成员变量的作用域
成员变量的作用域 OC中成员变量有四种作用域,同时每一种作用域对应着响应的关键字. * @private:自能在当前类的实现@implementation中访问 * @protected: 可以在当前 ...
- 【第五周读书笔记】我是一只IT小小鸟
读了第一个同学的自述,我印象最深的就是一些高分同学,只是机械性地背诵知识点,然后不停刷题,只是为了拿一个高分,然而他们对学科的一些基本概念都没有掌握牢靠.高分,并不代表学的就好.学得好不仅仅要牢靠掌握 ...
- 随笔记录 MBR扇区故障系统备份与还原 2019.8.7
系统备份: [root@localhost ~]# mkdir /abc [root@localhost ~]# mount /dev/sdb1 /abc [root@localhost ~]# dd ...
- Java 基础 - final vs static
总结 共同点,都可以修饰类,方法,属性.而不同点: static 含义:表示静态或全局,被修饰的属性和方法属于类,可以用类名.静态属性 / 方法名 访问 static 方法:只能被static方法覆盖 ...
- 【JZOJ3424】粉刷匠
description 赫克托是一个魁梧的粉刷匠,而且非常喜欢思考= = 现在,神庙里有N根排列成一直线的石柱,从1到N标号,长老要求用油漆将这些石柱重新粉刷一遍.赫克托有K桶颜色各不相同的油漆,第i ...
- BIO 详解
调用者主动等待调用的结果 简介 早期的jdk中,采用BIO通信模式: 通常有一个acceptor(消费者) 去负责监听客户端的连接. 它接收到客户端的连接请求之后为每个客户端创建一个线程进行链路处理, ...