pytorch之 sava_reload_model
import torch
import matplotlib.pyplot as plt # torch.manual_seed(1) # reproducible # fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1) # The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
# x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) def save():
# save net1
net1 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss() for t in range(100):
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step() # plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title('Net1')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) # 2 ways to save the net
torch.save(net1, 'net.pkl') # save entire net
torch.save(net1.state_dict(), 'net_params.pkl') # save only the parameters def restore_net():
# restore entire net1 to net2
net2 = torch.load('net.pkl')
prediction = net2(x) # plot result
plt.subplot(132)
plt.title('Net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) def restore_params():
# restore only the parameters in net1 to net3
net3 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
) # copy net1's parameters into net3
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x) # plot result
plt.subplot(133)
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.show() # save net1
save() # restore entire net (may slow)
restore_net() # restore only the net parameters
restore_params()
pytorch之 sava_reload_model的更多相关文章
- Ubutnu16.04安装pytorch
1.下载Anaconda3 首先需要去Anaconda官网下载最新版本Anaconda3(https://www.continuum.io/downloads),我下载是是带有python3.6的An ...
- 解决运行pytorch程序多线程问题
当我使用pycharm运行 (https://github.com/Joyce94/cnn-text-classification-pytorch ) pytorch程序的时候,在Linux服务器 ...
- 基于pytorch实现word2vec
一.介绍 word2vec是Google于2013年推出的开源的获取词向量word2vec的工具包.它包括了一组用于word embedding的模型,这些模型通常都是用浅层(两层)神经网络训练词向量 ...
- 基于pytorch的CNN、LSTM神经网络模型调参小结
(Demo) 这是最近两个月来的一个小总结,实现的demo已经上传github,里面包含了CNN.LSTM.BiLSTM.GRU以及CNN与LSTM.BiLSTM的结合还有多层多通道CNN.LSTM. ...
- pytorch实现VAE
一.VAE的具体结构 二.VAE的pytorch实现 1加载并规范化MNIST import相关类: from __future__ import print_function import argp ...
- PyTorch教程之Training a classifier
我们已经了解了如何定义神经网络,计算损失并对网络的权重进行更新. 接下来的问题就是: 一.What about data? 通常处理图像.文本.音频或视频数据时,可以使用标准的python包将数据加载 ...
- PyTorch教程之Neural Networks
我们可以通过torch.nn package构建神经网络. 现在我们已经了解了autograd,nn基于autograd来定义模型并对他们有所区分. 一个 nn.Module模块由如下部分构成:若干层 ...
- PyTorch教程之Autograd
在PyTorch中,autograd是所有神经网络的核心内容,为Tensor所有操作提供自动求导方法. 它是一个按运行方式定义的框架,这意味着backprop是由代码的运行方式定义的. 一.Varia ...
- Linux安装pytorch的具体过程以及其中出现问题的解决办法
1.安装Anaconda 安装步骤参考了官网的说明:https://docs.anaconda.com/anaconda/install/linux.html 具体步骤如下: 首先,在官网下载地址 h ...
随机推荐
- 【Java编程思想阅读笔记】Java数据存储位置
Java数据存储位置 P46页有感 一.前置知识 栈是由系统自动分配的,Java程序员对栈没有直接的操作权限, 堆是所有线程共享的内存区域,栈 是每个线程独享的. 堆是由程序员自己申请的,在使用new ...
- 搭建自己的Online Judge
前言 很多人对于做题有点厌烦,但是,如果让你出题给别人做那么可能会很有意思.可是,出题只能出在一些别人的OJ上,甚至只能在自己的Word文档里出.今天我教大家一个厉害点的,叫做搭建自己的Online ...
- 【UEFI】---基于UEFI编程的基本思路
最近基于UEF在写代码的时候,发现由于粗心总是出现很多问题,而且都是一些小问题.虽然UEFI玩了挺久,但是也没梳理一下思路.借此机会整理一下: UEFI对复杂的BIOS代码做了很好的封装和模块化. ...
- 【ARM】---STM32位带操作总结---浅显易懂
正在准备做毕业设计,配置LED_Config()的时候,又看到了位带操作的宏定义,我又嘀咕了,什么是位带操作,一年前在使用位带操作的时候,就查阅过好多资料,Core-M3也看过,但是对于博主这种“低能 ...
- .net 解析嵌套JSON
JSON格式文件如下:我们是要取出msgJsoncontent里面GeneralReportInfo下serviceData中的totalUseValue数据 { ", "mess ...
- numpy nan和inf
一.nan和inf的简介 nan 不是一个数字 读取本地文件为flaot的时候,有缺失 inf(infinity): 无穷尽 inf: 正无穷 -inf: 负无穷 数据类型:float # 注意: 要 ...
- Pandas中merge和join的区别
可以说merge包含了join的操作,merge支持通过列或索引连表,而join只支持通过索引连表,只是简化了merge的索引连表的参数 示例 定义一个left的DataFrame left=pd.D ...
- [bzoj2326] [洛谷P3216] [HNOI2011] 数学作业
想法 最初的想法就是记录当前 \(%m\) 值为cur,到下一个数时 \(cur=cur \times 10^x + i\) n这么大,那就矩阵乘法呗. 矩阵乘法使用的要点就是有一个转移矩阵会不停的用 ...
- idea实现svn拉分支和合并分支的教程
原文地址:https://blog.csdn.net/qq_27471405/article/details/78498260 今天测试了一下svn拉分支和合并分支的教程,决定分享给大家 拉分支教程: ...
- 编写python程序读入1到100之间的整数,然后计算每个数出现的次数,输入0表示结束输人,输入数据不包括0。如果数出现的大现如果大于1,输出时使用复数times
#-*- coding:UTF-8 -*- #环境:python3 print("Enter the numbers between 1 and 100:") enterList= ...