import torch
import matplotlib.pyplot as plt # torch.manual_seed(1) # reproducible # fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1) # The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
# x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) def save():
# save net1
net1 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss() for t in range(100):
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step() # plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title('Net1')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) # 2 ways to save the net
torch.save(net1, 'net.pkl') # save entire net
torch.save(net1.state_dict(), 'net_params.pkl') # save only the parameters def restore_net():
# restore entire net1 to net2
net2 = torch.load('net.pkl')
prediction = net2(x) # plot result
plt.subplot(132)
plt.title('Net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) def restore_params():
# restore only the parameters in net1 to net3
net3 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
) # copy net1's parameters into net3
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x) # plot result
plt.subplot(133)
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.show() # save net1
save() # restore entire net (may slow)
restore_net() # restore only the net parameters
restore_params()

pytorch之 sava_reload_model的更多相关文章

  1. Ubutnu16.04安装pytorch

    1.下载Anaconda3 首先需要去Anaconda官网下载最新版本Anaconda3(https://www.continuum.io/downloads),我下载是是带有python3.6的An ...

  2. 解决运行pytorch程序多线程问题

    当我使用pycharm运行  (https://github.com/Joyce94/cnn-text-classification-pytorch )  pytorch程序的时候,在Linux服务器 ...

  3. 基于pytorch实现word2vec

    一.介绍 word2vec是Google于2013年推出的开源的获取词向量word2vec的工具包.它包括了一组用于word embedding的模型,这些模型通常都是用浅层(两层)神经网络训练词向量 ...

  4. 基于pytorch的CNN、LSTM神经网络模型调参小结

    (Demo) 这是最近两个月来的一个小总结,实现的demo已经上传github,里面包含了CNN.LSTM.BiLSTM.GRU以及CNN与LSTM.BiLSTM的结合还有多层多通道CNN.LSTM. ...

  5. pytorch实现VAE

    一.VAE的具体结构 二.VAE的pytorch实现 1加载并规范化MNIST import相关类: from __future__ import print_function import argp ...

  6. PyTorch教程之Training a classifier

    我们已经了解了如何定义神经网络,计算损失并对网络的权重进行更新. 接下来的问题就是: 一.What about data? 通常处理图像.文本.音频或视频数据时,可以使用标准的python包将数据加载 ...

  7. PyTorch教程之Neural Networks

    我们可以通过torch.nn package构建神经网络. 现在我们已经了解了autograd,nn基于autograd来定义模型并对他们有所区分. 一个 nn.Module模块由如下部分构成:若干层 ...

  8. PyTorch教程之Autograd

    在PyTorch中,autograd是所有神经网络的核心内容,为Tensor所有操作提供自动求导方法. 它是一个按运行方式定义的框架,这意味着backprop是由代码的运行方式定义的. 一.Varia ...

  9. Linux安装pytorch的具体过程以及其中出现问题的解决办法

    1.安装Anaconda 安装步骤参考了官网的说明:https://docs.anaconda.com/anaconda/install/linux.html 具体步骤如下: 首先,在官网下载地址 h ...

随机推荐

  1. VirtualBox扩充磁盘&清空安装包

    1.virtual box 扩充磁盘空间 D:\VirtualBox\VBoxManage.exe modifyhd "E:\virtual box\daisyyun\daisyyun.vd ...

  2. nginx介绍与安装

    1.nginx作用可以配置数十个场景 ​ ​ 2.环境安装 环境确认 ​ 安装环境 yum -y install gcc gcc-c++ autoconf pcre-devel make automa ...

  3. Qt Installer Framework翻译(3-0)

    终端用户使用流程 离线安装和在线安装对终端用户来说是相似的.安装程序将你的应用程序和维护工具一起打包,该工具由包管理器,更新程序和卸载程序组成.用户可以使用维护工具来添加,更新和删除组件.维护工具连接 ...

  4. 高通量计算框架HTCondor(四)——案例准备

    目录 1. 正文 1.1. 任务划分 1.2. 任务程序 2. 相关 1. 正文 1.1. 任务划分 使用高通量计算第一步就是要针对密集运算任务做任务划分.将一个海量的.耗时的.耗资源的任务划分成合适 ...

  5. .net Core Autofac稍微高级一点的方法

    前情摘要 前段时间写了autofac的注入但是每次都需要去修改startup这应该不是大家想要的. 至少不是我想要的. 网上有朋友说可以创建一个基础类来时间. 好了吹牛时间结束我们开始干点正事. 创建 ...

  6. [校内训练20_01_19]ABC

    1.SB题 2.有n个点,m条边,每次加入一条边,你要挑出一些边,使得形成的图每个点度数都为奇数,且最长的边最短. 3.给一个N次多项式,问有多少个质数在任意整数处的点值都是p的倍数,输出它们.$N ...

  7. chrome 安装

    Centos7 yum安装chrome浏览器   跟着这个教程安装的:Centos7安装chrome浏览器 (点击) 1. 配置yum源 在目录 /etc/yum.repos.d/ 下新建文件 goo ...

  8. mysql--->mysql慢查询

    简介 > 开启慢查询日志,可以让MySQL记录下查询超过指定时间的语句,通过定位分析性能的瓶颈,才能更好的优化数据库系统的性能 参数及命令说明 查看慢查询是否开启和日志存储地址 show var ...

  9. 关于基本布局之——Flex布局

    Flex布局 1.Flex为"Flexible Box"的简称,即为弹性布局,可作用于任何容器上.给div这类块状元素元素设置display:flex或者给span这类内联元素设置 ...

  10. x01.auto_input: 自动输入

    单位经常要把 excel 表的数据录入系统中,能够自动录入该多好. 花了几天时间,学习了一下 pandas 操作 excel 数据,利用 pyautogui 完成了一个自动录入的小测试,希望对有此需求 ...