import torch
import matplotlib.pyplot as plt # torch.manual_seed(1) # reproducible # fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1) # The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
# x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)

训练网络,并用两种方式保存网络

def save():
# save net1
net1 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss() for t in range(100): # 训练网络
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step() # plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title('Net1')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) # 2 ways to save the net 保存网络的两种方式,第二种只保存参数更快一点
torch.save(net1, 'net.pkl') # save entire net
torch.save(net1.state_dict(), 'net_params.pkl') # save only the parameters

加载第一种(含所有信息的)网络:torch.load('net.pkl')

def restore_net():
# restore entire net1 to net2
net2 = torch.load('net.pkl')
prediction = net2(x) # plot result
plt.subplot(132)
plt.title('Net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)

加载第二种(只含有参数的)网络:net3.load_state_dict(torch.load('net_params.pkl'))

def restore_params():
# restore only the parameters in net1 to net3
net3 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
) # copy net1's parameters into net3
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x) # plot result
plt.subplot(133)
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.show()

运行上面的函数

# save net1  训练网络,并用两种方式保存网络
save() # restore entire net (may slow) 所有信息的网络
restore_net() # restore only the net parameters 只含参数信息的网络,加载前需要重新构造与之前一模一样的网络
restore_params()

三个网络绘制的图片

END

pytorch 7 save_reload 保存和提取神经网络的更多相关文章

  1. pytorch1.0神经网络保存、提取、加载

    pytorch1.0网络保存.提取.加载 import torch import torch.nn.functional as F # 包含激励函数 import matplotlib.pyplot ...

  2. Pytorch学习笔记(二)---- 神经网络搭建

    记录如何用Pytorch搭建LeNet-5,大体步骤包括:网络的搭建->前向传播->定义Loss和Optimizer->训练 # -*- coding: utf-8 -*- # Al ...

  3. tensorflow学习之路----保存和提取数据

    #保存数据注意他只能保存变量,不能保存神经网络的框架.#保存数据的作用:保存权重有利于下一次的训练,或者可以用这个数据进行识别#np.arange():arange函数用于创建等差数组,使用频率非常高 ...

  4. Keras(六)Autoencoder 自编码 原理及实例 Save&reload 模型的保存和提取

    Autoencoder 自编码 压缩与解压 原来有时神经网络要接受大量的输入信息, 比如输入信息是高清图片时, 输入信息量可能达到上千万, 让神经网络直接从上千万个信息源中学习是一件很吃力的工作. 所 ...

  5. cookie的保存与提取

    爬虫过程中,cookie可以保留用户与服务器之间的交互信息,使服务器与用户相互能够识别.由于HTTP协议是无状态协议,即不能够识别客户端身份,即使客户端多次请求同一个url服务器仍然响应.这种协议导致 ...

  6. Tensorflow学习教程------参数保存和提取重利用

    #coding:utf-8 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mni ...

  7. 【莫烦Pytorch】【P1】人工神经网络VS. 生物神经网络

    滴:转载引用请注明哦[握爪] https://www.cnblogs.com/zyrb/p/9700343.html 莫烦教程是一个免费的机器学习(不限于)的学习教程,幽默风俗的语言让我们这些刚刚起步 ...

  8. pytorch实战(7)-----卷积神经网络

    一.卷积: 卷积在 pytorch 中有两种方式: [实际使用中基本都使用 nn.Conv2d() 这种形式] 一种是 torch.nn.Conv2d(), 一种是 torch.nn.function ...

  9. delphi保存和提取ini文件信息

    procedure TLoginForm.FormShow(Sender: TObject);var ini:TIniFile; name:string;begin //实现动态提取数据库的登录用户名 ...

随机推荐

  1. Orcale-利用闪回恢复数据方法

    一.delete误删 方法1:如果表结构没有改变,直接闪回整个表,具体步骤: --首先需要表闪回权限,开启行移动功能 alter table 表名 enable row movement; --执行闪 ...

  2. Linux磁盘分区--GPT分区

    MBR分区表有一定的局限性,最大支持2.1tb硬盘,单块硬盘最多4个主分区. 这里就要引入GPT分区表,可以支持最大18EB的卷,最多支持128个主分区,所以如果使用大于2tb的卷,就必须使用GTP分 ...

  3. NOIP2018提高组省一冲奖班模测训练(六)

    NOIP2018提高组省一冲奖班模测训练(六) https://www.51nod.com/Contest/ContestDescription.html#!#contestId=80 20分钟AC掉 ...

  4. 深入了解Spring中的容器

    1.创建Bean的3种方式 1.1使用构造器创建bean实例 这是最常见的方式,如果不采用构造注入,bean类需要有默认构造函数.如果采用构造注入,则需要配置xml文件的<constructor ...

  5. 【ACM-ICPC 2018 南京赛区网络预赛 A】An Olympian Math Problem

    [链接] 我是链接,点我呀:) [题意] 在这里输入题意 [题解] 估计试几个就会发现答案总是n-1吧. 队友给的证明 [代码] #include <bits/stdc++.h> #def ...

  6. HDU 2439 The Mussels

    The Mussels Time Limit: 1000ms Memory Limit: 32768KB This problem will be judged on HDU. Original ID ...

  7. 洛谷—— P1204 [USACO1.2]挤牛奶Milking Cows

    https://www.luogu.org/problem/show?pid=1204 题目描述 三个农民每天清晨5点起床,然后去牛棚给3头牛挤奶.第一个农民在300秒(从5点开始计时)给他的牛挤奶, ...

  8. xml解析之----DOM解析

    DOM模型(documentobject model) •DOM解析器在解析XML文档时,会把文档中的全部元素.依照其出现的层次关系.解析成一个个Node对象(节点). •在dom中.节点之间关系例如 ...

  9. jsp不通过form和Ajax提交

    在页面里面我们一般都通过form表单和Ajax向后台提交请求,但是我如今页面没有form表单,也不想通过ajax异步提交. 解决方式例如以下:location.href="${rootPat ...

  10. 2015.03.16,外语,读书笔记-《Word Power Made Easy》 00 “如何最大限度的利用本书”学习笔记

    备注:蓝色表明是自己学习或笔记的部分,红色表明特别的地方,例如自己不理解或需要重点关注的地方.加粗单词表明是要加入生词库学习的词语.单词后面括号中的蓝色部分,是单词的解释和音标. 1.this is ...