各种优化器的比较

莫烦的对各种优化通俗理解的视频

 import torch

 import torch.utils.data as Data

 import torch.nn.functional as F

 from torch.autograd import Variable

 import matplotlib.pyplot as plt

 # 超参数

 LR = 0.01

 BATCH_SIZE = 

 EPOCH = 

 # 生成假数据

 # torch.unsqueeze() 的作用是将一维变二维,torch只能处理二维的数据

 x = torch.unsqueeze(torch.linspace(-, , ), dim=)  # x data (tensor), shape(, )

 # 0.2 * torch.rand(x.size())增加噪点

 y = x.pow() + 0.1 * torch.normal(torch.zeros(*x.size()))

 # 输出数据图

 # plt.scatter(x.numpy(), y.numpy())

 # plt.show()

 torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)

 loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=)

 class Net(torch.nn.Module):

     # 初始化

     def __init__(self):

         super(Net, self).__init__()

         self.hidden = torch.nn.Linear(, )

         self.predict = torch.nn.Linear(, )

     # 前向传递

     def forward(self, x):

         x = F.relu(self.hidden(x))

         x = self.predict(x)

         return x

 net_SGD = Net()

 net_Momentum = Net()

 net_RMSProp = Net()

 net_Adam = Net()

 nets = [net_SGD, net_Momentum, net_RMSProp, net_Adam]

 opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)

 opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)

 opt_RMSProp = torch.optim.RMSprop(net_RMSProp.parameters(), lr=LR, alpha=0.9)

 opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))

 optimizers = [opt_SGD, opt_Momentum, opt_RMSProp, opt_Adam]

 loss_func = torch.nn.MSELoss()

 loss_his = [[], [], [], []]  # 记录损失

 for epoch in range(EPOCH):

     print(epoch)

     for step, (batch_x, batch_y) in enumerate(loader):

         b_x = Variable(batch_x)

         b_y = Variable(batch_y)

         for net, opt,l_his in zip(nets, optimizers, loss_his):

             output = net(b_x)  # get output for every net

             loss = loss_func(output, b_y)  # compute loss for every net

             opt.zero_grad()  # clear gradients for next train

             loss.backward()  # backpropagation, compute gradients

             opt.step()  # apply gradients

             l_his.append(loss.data.numpy())  # loss recoder

 labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']

 for i, l_his in enumerate(loss_his):

     plt.plot(l_his, label=labels[i])

 plt.legend(loc='best')

 plt.xlabel('Steps')

 plt.ylabel('Loss')

 plt.ylim((, 0.2))

 plt.show()

莫烦pytorch学习笔记(七)——Optimizer优化器的更多相关文章

  1. 莫烦 - Pytorch学习笔记 [ 一 ]

    1. Numpy VS Torch #相互转换 np_data = torch_data.numpy() torch_data = torch.from_numpy(np_data) #abs dat ...

  2. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  3. 莫烦PyTorch学习笔记(五)——模型的存取

    import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed() ...

  4. [PyTorch 学习笔记] 4.3 优化器

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson4/optimizer_methods.py https: ...

  5. 莫烦PyTorch学习笔记(五)——分类

    import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...

  6. 莫烦PyTorch学习笔记(四)——回归

    下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...

  7. 莫烦PyTorch学习笔记(六)——批处理

    1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...

  8. 莫烦PyTorch学习笔记(三)——激励函数

    1. sigmod函数 函数公式和图表如下图     在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...

  9. 莫烦pytorch学习笔记(二)——variable

    .简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...

随机推荐

  1. Android Telephony分析(三) ---- RILJ详解

    前言 本文主要讲解RILJ工作原理,以便更好地分析代码,分析业务的流程.这里说的RILJ指的是RIL.java (frameworks\opt\telephony\src\java\com\andro ...

  2. 46张PPT弄懂JVM、GC算法和性能调优!

    来源:cnblogs.com/cyfonly/p/5807121.html 本PPT从JVM体系结构概述.GC算法.Hotspot内存管理.Hotspot垃圾回收器.调优和监控工具六大方面进行讲述. ...

  3. SonarQube代码质量扫描持续集成

    1.安装JDK和配置JAVA_HOME和CLASSPATH 2.安装mysql数据库 3.创建数据库和用户 mysql -u root -p mysql> CREATE DATABASE son ...

  4. swt java 内嵌ActiveX控件

    这里用的是SWT/JFace开发application中SWT自带的org.eclipse.swt.ole.win32 包可以支持内嵌OLE和ActiveX. 具体用法如下: //创建一个OleFra ...

  5. SSH的两种登录方式以及配置

    前言 SSH简介 Secure Shell(SSH) 是由 IETF(The Internet Engineering Task Force) 制定的建立在应用层基础上的安全网络协议.它是专为远程登录 ...

  6. load data local infile

    发财 基本语法:load data [low_priority] [local] infile '文件名称' [replace替换策略 | ignore忽略策略]into table 表名称[fiel ...

  7. Go语言简介以及安装

    http://www.runoob.com/go/go-tutorial.html Go 是一个开源的编程语言,它能让构造简单.可靠且高效的软件变得容易. Go是从2007年末由Robert Grie ...

  8. 单独编译和使用webrtc音频降噪模块(附完整源码+测试音频文件)

    单独编译和使用webrtc音频增益模块(附完整源码+测试音频文件) 单独编译和使用webrtc音频回声消除模块(附完整源码+测试音频文件) webrtc的音频处理模块分为降噪ns,回音消除aec,回声 ...

  9. thinkphp 三元运算

    模板可以支持三元运算符,例如: {$status?'正常':'错误'} {$info['status']?$info['msg']:$info['error']} 注意:三元运算符中暂时不支持点语法. ...

  10. 最大流拆点——poj3281

    /* 因为牛的容量为1,把牛拆点 按照s->f->cow->cow->d->t建图 */ #include<iostream> #include<cst ...