1. import torch
  2. import matplotlib.pyplot as plt
  3.  
  4. # torch.manual_seed(1) # reproducible
  5.  
  6. # fake data
  7. x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
  8. y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)
  9.  
  10. # The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
  11. # x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)
  12.  
  13. def save():
  14. # save net1
  15. net1 = torch.nn.Sequential(
  16. torch.nn.Linear(1, 10),
  17. torch.nn.ReLU(),
  18. torch.nn.Linear(10, 1)
  19. )
  20. optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
  21. loss_func = torch.nn.MSELoss()
  22.  
  23. for t in range(100):
  24. prediction = net1(x)
  25. loss = loss_func(prediction, y)
  26. optimizer.zero_grad()
  27. loss.backward()
  28. optimizer.step()
  29.  
  30. # plot result
  31. plt.figure(1, figsize=(10, 3))
  32. plt.subplot(131)
  33. plt.title('Net1')
  34. plt.scatter(x.data.numpy(), y.data.numpy())
  35. plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  36.  
  37. # 2 ways to save the net
  38. torch.save(net1, 'net.pkl') # save entire net
  39. torch.save(net1.state_dict(), 'net_params.pkl') # save only the parameters
  40.  
  41. def restore_net():
  42. # restore entire net1 to net2
  43. net2 = torch.load('net.pkl')
  44. prediction = net2(x)
  45.  
  46. # plot result
  47. plt.subplot(132)
  48. plt.title('Net2')
  49. plt.scatter(x.data.numpy(), y.data.numpy())
  50. plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  51.  
  52. def restore_params():
  53. # restore only the parameters in net1 to net3
  54. net3 = torch.nn.Sequential(
  55. torch.nn.Linear(1, 10),
  56. torch.nn.ReLU(),
  57. torch.nn.Linear(10, 1)
  58. )
  59.  
  60. # copy net1's parameters into net3
  61. net3.load_state_dict(torch.load('net_params.pkl'))
  62. prediction = net3(x)
  63.  
  64. # plot result
  65. plt.subplot(133)
  66. plt.title('Net3')
  67. plt.scatter(x.data.numpy(), y.data.numpy())
  68. plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  69. plt.show()
  70.  
  71. # save net1
  72. save()
  73.  
  74. # restore entire net (may slow)
  75. restore_net()
  76.  
  77. # restore only the net parameters
  78. restore_params()

pytorch之 sava_reload_model的更多相关文章

  1. Ubutnu16.04安装pytorch

    1.下载Anaconda3 首先需要去Anaconda官网下载最新版本Anaconda3(https://www.continuum.io/downloads),我下载是是带有python3.6的An ...

  2. 解决运行pytorch程序多线程问题

    当我使用pycharm运行  (https://github.com/Joyce94/cnn-text-classification-pytorch )  pytorch程序的时候,在Linux服务器 ...

  3. 基于pytorch实现word2vec

    一.介绍 word2vec是Google于2013年推出的开源的获取词向量word2vec的工具包.它包括了一组用于word embedding的模型,这些模型通常都是用浅层(两层)神经网络训练词向量 ...

  4. 基于pytorch的CNN、LSTM神经网络模型调参小结

    (Demo) 这是最近两个月来的一个小总结,实现的demo已经上传github,里面包含了CNN.LSTM.BiLSTM.GRU以及CNN与LSTM.BiLSTM的结合还有多层多通道CNN.LSTM. ...

  5. pytorch实现VAE

    一.VAE的具体结构 二.VAE的pytorch实现 1加载并规范化MNIST import相关类: from __future__ import print_function import argp ...

  6. PyTorch教程之Training a classifier

    我们已经了解了如何定义神经网络,计算损失并对网络的权重进行更新. 接下来的问题就是: 一.What about data? 通常处理图像.文本.音频或视频数据时,可以使用标准的python包将数据加载 ...

  7. PyTorch教程之Neural Networks

    我们可以通过torch.nn package构建神经网络. 现在我们已经了解了autograd,nn基于autograd来定义模型并对他们有所区分. 一个 nn.Module模块由如下部分构成:若干层 ...

  8. PyTorch教程之Autograd

    在PyTorch中,autograd是所有神经网络的核心内容,为Tensor所有操作提供自动求导方法. 它是一个按运行方式定义的框架,这意味着backprop是由代码的运行方式定义的. 一.Varia ...

  9. Linux安装pytorch的具体过程以及其中出现问题的解决办法

    1.安装Anaconda 安装步骤参考了官网的说明:https://docs.anaconda.com/anaconda/install/linux.html 具体步骤如下: 首先,在官网下载地址 h ...

随机推荐

  1. Numpy的介绍与基本使用方法

    1.什么是Numpy numpy官方文档:https://docs.scipy.org/doc/numpy/reference/?v=20190307135750 NumPy是一个功能强大的Pytho ...

  2. python对象的初始化

    效果图: 代码: # 对象的初始化 class Person: # 在类中可以定义一些特殊方法(魔术方法) # 特殊方法都是以__开头,__结尾的方法 前后都是两个下划线 # 特殊方法会在特殊的时刻自 ...

  3. Spring Boot2 系列教程 (十五) | 服务端参数校验之一

    估计很多朋友都认为参数校验是客户端的职责,不关服务端的事.其实这是错误的,学过 Web 安全的都知道,客户端的验证只是第一道关卡.它的参数验证并不是安全的,一旦被有心人抓到可乘之机,他就可以有各种方法 ...

  4. Jenkins 应用

    一.Jenkins Linux shell集成 新建任务 shell-freestyle-job,选择Freestyle project,点击[确定] ​ 添加描述,This is my first ...

  5. Nginx代理服务——常用的配置语法

    可以到官方查看所有代理的配置语法http://nginx.org/en/docs/http/ngx_http_proxy_module.html 缓存区 Syntax:proxy_buffering ...

  6. shell 条件测试

    1.文件相关 -e 判断文件或者文件夹是否存在 -d 判断目录是否存在 -f 判断文件是否存在 -r 判断是否有读权限 -w 判断是否有写权限 -x 判断是否有执行权限 1.1命令行使用 [root@ ...

  7. spring cloud 与 docker 读书笔记 1

    Eureka Server 的高可用

  8. centos7+ docker 实践部署docker及配置direct_lvm

    转载于博客园:http://www.cnblogs.com/Andrew-XinFei/p/6245330.html 前言 Docker现在在后端是那么的火热..尤其当笔者了解了docker是什么.能 ...

  9. docker 简单使用

    1.docker 命令 docker start nginx https://www.w3cschool.cn/docker/windows-docker-install.html // docker ...

  10. Docker(一):理解Docker镜像与容器

    一.镜像的概念 1.广泛镜像概念: 镜像是一种文件存储形式,是冗余的一种类型,一个磁盘上的数据在另一个磁盘上存在完全相同的副本即为镜像. 2.Docker镜像概念: 在Docker中镜像同样是一种完全 ...