pytorch之 sava_reload_model
import torch
import matplotlib.pyplot as plt # torch.manual_seed(1) # reproducible # fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1) # The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
# x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) def save():
# save net1
net1 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss() for t in range(100):
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step() # plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title('Net1')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) # 2 ways to save the net
torch.save(net1, 'net.pkl') # save entire net
torch.save(net1.state_dict(), 'net_params.pkl') # save only the parameters def restore_net():
# restore entire net1 to net2
net2 = torch.load('net.pkl')
prediction = net2(x) # plot result
plt.subplot(132)
plt.title('Net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) def restore_params():
# restore only the parameters in net1 to net3
net3 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
) # copy net1's parameters into net3
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x) # plot result
plt.subplot(133)
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.show() # save net1
save() # restore entire net (may slow)
restore_net() # restore only the net parameters
restore_params()
pytorch之 sava_reload_model的更多相关文章
- Ubutnu16.04安装pytorch
1.下载Anaconda3 首先需要去Anaconda官网下载最新版本Anaconda3(https://www.continuum.io/downloads),我下载是是带有python3.6的An ...
- 解决运行pytorch程序多线程问题
当我使用pycharm运行 (https://github.com/Joyce94/cnn-text-classification-pytorch ) pytorch程序的时候,在Linux服务器 ...
- 基于pytorch实现word2vec
一.介绍 word2vec是Google于2013年推出的开源的获取词向量word2vec的工具包.它包括了一组用于word embedding的模型,这些模型通常都是用浅层(两层)神经网络训练词向量 ...
- 基于pytorch的CNN、LSTM神经网络模型调参小结
(Demo) 这是最近两个月来的一个小总结,实现的demo已经上传github,里面包含了CNN.LSTM.BiLSTM.GRU以及CNN与LSTM.BiLSTM的结合还有多层多通道CNN.LSTM. ...
- pytorch实现VAE
一.VAE的具体结构 二.VAE的pytorch实现 1加载并规范化MNIST import相关类: from __future__ import print_function import argp ...
- PyTorch教程之Training a classifier
我们已经了解了如何定义神经网络,计算损失并对网络的权重进行更新. 接下来的问题就是: 一.What about data? 通常处理图像.文本.音频或视频数据时,可以使用标准的python包将数据加载 ...
- PyTorch教程之Neural Networks
我们可以通过torch.nn package构建神经网络. 现在我们已经了解了autograd,nn基于autograd来定义模型并对他们有所区分. 一个 nn.Module模块由如下部分构成:若干层 ...
- PyTorch教程之Autograd
在PyTorch中,autograd是所有神经网络的核心内容,为Tensor所有操作提供自动求导方法. 它是一个按运行方式定义的框架,这意味着backprop是由代码的运行方式定义的. 一.Varia ...
- Linux安装pytorch的具体过程以及其中出现问题的解决办法
1.安装Anaconda 安装步骤参考了官网的说明:https://docs.anaconda.com/anaconda/install/linux.html 具体步骤如下: 首先,在官网下载地址 h ...
随机推荐
- Spring Boot 2.X(十九):集成 mybatis-plus 高效开发
前言 之前介绍了 SpringBoot 整合 Mybatis 实现数据库的增删改查操作,分别给出了 xml 和注解两种实现 mapper 接口的方式:虽然注解方式干掉了 xml 文件,但是使用起来并不 ...
- K8S基于ingress-nginx实现灰度发布
之前介绍过使用ambassador实现灰度发布,今天介绍如何使用ingre-nginx实现. 介绍 Ingress-Nginx 是一个K8S ingress工具,支持配置 Ingress Annota ...
- JPA基本注解的使用
一:JPA基本注解 使用: 使用: 使用: 查看表: 二:用table来生成主键 使用: allocationSize:每次增加多少 tablel:指定使用那张表 执行两次main方法后查看表: jp ...
- Mysql梳理-关于索引/引擎与锁
前言 最近突发新型肺炎,本来只有七天的春节假期也因为各种封锁延长到了正月十五,在家实在闲的蛋疼便重新研究了一下Mysql数据库的相关知识,特此总结梳理一下.本文主要围绕以下几点进行: 1.Mysql的 ...
- 龙芯2f 8089D 笔记本 Debian 系统安装配置
版权声明:原创文章,未经博主允许不得转载 正文主要讲述安装社区版Debian6镜像(也有7和8,方法大同小异) 最后简单介绍了网络安装原版Debian 小记 非网络安装,没网也没事,再也不用担心网速度 ...
- 在windows中python安装sit-packages路径位置 在Pycharm中导入opencv不能自动代码补全问题
在Pycharm中导入opencv不能自动代码补全问题 近期学习到计算机视觉库的相关知识,经过几个小时的探讨,终于解决了opencv不能自动补全代码的困惑, 我们使用pycharm安装配置可能会添加多 ...
- BERT模型总结
BERT模型总结 前言 BERT是在Google论文<BERT: Pre-training of Deep Bidirectional Transformers for Language U ...
- http轮询,长轮询
轮询,长轮询 轮询 轮询:客户端定时向服务器发送Ajax请求,服务器接到请求后马上返回响应信息并关闭连接. 优点:后端程序编写比较容易. 缺点:请求中有大半是无用,浪费带宽和服务器资源. 实例:适于小 ...
- redis--->微博小项目
redis 微博小项目 centos6.9+lnmp+redis 写的微博小项目,梳理了redis在项目中kes的设计,redis各种数据结构在不同业务场景下的应用等知识点. 这里用的php框架是自己 ...
- postman的简单介绍及运用
postman下载地址 https://www.getpostman.com/downloads/ postman的工作原理:发送请求给服务器,服务器处理postman发送的数据然后返回给postma ...