莫烦PyTorch学习笔记(五)——模型的存取
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt torch.manual_seed() # fake data
x = torch.unsqueeze(torch.linspace(-,,),dim=)
y = x.pow() + 0.2 * torch.rand(x.size())
x, y = Variable(x,requires_grad=False), Variable(y,requires_grad=False) def save():
net1 = torch.nn.Sequential(
torch.nn.Linear(, ),
torch.nn.ReLU(),
torch.nn.Linear(, )
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss() for t in range():
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step() plt.figure(,figsize=(,))
plt.subplot()
plt.title('Net1')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(),'r-',lw=)
torch.save(net1, 'net.pkl') # 保存整个网络,包括整个计算图
torch.save(net1.state_dict(), 'net_params.pkl') # 只保存网络中的参数 (速度快, 占内存少) def restore_net():
net2 = torch.load('net.pkl')
prediction = net2(x)
plt.subplot()
plt.title('Net2')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(),'r-',lw=)
def restore_params():
net3 = torch.nn.Sequential(
torch.nn.Linear(, ),
torch.nn.ReLU(),
torch.nn.Linear(, )
)
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x) plt.subplot()
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=)
# 将保存的参数复制到 net3
plt.show() save()
restore_net()
restore_params()

结果和莫烦的不一样,但是找不到问题的所在,,。。。
莫烦PyTorch学习笔记(五)——模型的存取的更多相关文章
- 莫烦PyTorch学习笔记(五)——分类
import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...
- 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)
莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...
- 莫烦pytorch学习笔记(七)——Optimizer优化器
各种优化器的比较 莫烦的对各种优化通俗理解的视频 import torch import torch.utils.data as Data import torch.nn.functional as ...
- 莫烦PyTorch学习笔记(六)——批处理
1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...
- 莫烦pytorch学习笔记(二)——variable
.简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...
- 莫烦 - Pytorch学习笔记 [ 二 ] CNN ( 1 )
CNN原理和结构 观点提出 关于照片的三种观点引出了CNN的作用. 局部性:某一特征只出现在一张image的局部位置中. 相同性: 同一特征重复出现.例如鸟的羽毛. 不变性:subsampling下图 ...
- 莫烦PyTorch学习笔记(四)——回归
下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...
- 莫烦PyTorch学习笔记(三)——激励函数
1. sigmod函数 函数公式和图表如下图 在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...
- 莫烦pytorch学习笔记(一)——torch or numpy
Q1:什么是神经网络? Q2:torch vs numpy Numpy:NumPy系统是Python的一种开源的数值计算扩展.这种工具可用来存储和处理大型矩阵,比Python自身的嵌套列表(neste ...
随机推荐
- 01、requests 基本使用
requests模块的基本使用 基于网络请求的模块. 环境的安装:pip install requests 作用:模拟浏览器发起请求 分析requests的编码流程: 1.指定url 2.发起了请求 ...
- codeforces round#524 D - Olya and magical square /// 大概算是数学规律题?
题目大意: t 个测试用例 (1≤t≤103) 给定n k (1≤n≤10^9,1≤k≤10^18) 表示有一个边长为2^n的正方形格子 每次操作只能将一个格子切割为左上左下右上右下的四等分格子 ...
- axios 基本运用
axios是专门对ajax请求进行封装的一个插件,其返回一个promise对象,用法跟ES6的promise很相似 一.安装axios插件npm install axios 二.引入axios插件 在 ...
- JS函数 函数调用 函数定义好后,是不能自动执行的,需要调用它,直接在需要的位置写函数名。
函数调用 函数定义好后,是不能自动执行的,需要调用它,直接在需要的位置写函数名. 第一种情况:在<script>标签内调用. <script type="text/java ...
- 5、Docker数据管理
为了能够存储持久化数据以及共享容器间的数据,Docker提出了Volume的概念.让我们通过类似mount的方式将宿主机的文件或者目录挂载到容器中. 在容器中管理数据主要有两种方式: 数据卷(Data ...
- 1分钟k线图能反映什么?(转)
对于投资者特别是短线操作者来讲,应该重视1分钟K线图,但是并不是所有的股票都能通过1分钟K线图看出名堂来,比如一些小盘股,盘子较轻,很容易上蹿下跳.仅用1分钟K线图分析其上证指数,很难研判大盘当日的高 ...
- Redis 常用的数据结构
String 字符串 set get 使用场景: 可以用来作为缓存使用(缓存更新策略和缓存雪崩如何处理) List lpop rpop lpush rpush 使用场景: set 无序集合 使用场景: ...
- thinkphp 获取内容
如果需要获取渲染模板的输出内容而不是直接输出,可以使用fetch方法. fetch方法的用法和display基本一致(只是不需要指定输出编码和输出类型): 大理石平台规格 fetch('模板文件') ...
- luoguP1062 数列 [数学]
题目描述 给定一个正整数k(3≤k≤15),把所有k的方幂及所有有限个互不相等的k的方幂之和构成一个递增的序列,例如,当k=3时,这个序列是: 1,3,4,9,10,12,13,… (该序列实际上就是 ...
- 带撤销贪心——cf1148F好题
自己不会做,看了题解懂得 从最高位依次往低位遍历,因为偶数个1是不改变符号的,所以带个贪心即可(可以看成是带撤销的..) 每轮循环用sum记录该位选择1可以减少的值 如果是负数,就不要改成1 如果是正 ...