各种优化器的比较

莫烦的对各种优化通俗理解的视频

 import torch

 import torch.utils.data as Data

 import torch.nn.functional as F

 from torch.autograd import Variable

 import matplotlib.pyplot as plt

 # 超参数

 LR = 0.01

 BATCH_SIZE = 

 EPOCH = 

 # 生成假数据

 # torch.unsqueeze() 的作用是将一维变二维,torch只能处理二维的数据

 x = torch.unsqueeze(torch.linspace(-, , ), dim=)  # x data (tensor), shape(, )

 # 0.2 * torch.rand(x.size())增加噪点

 y = x.pow() + 0.1 * torch.normal(torch.zeros(*x.size()))

 # 输出数据图

 # plt.scatter(x.numpy(), y.numpy())

 # plt.show()

 torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)

 loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=)

 class Net(torch.nn.Module):

     # 初始化

     def __init__(self):

         super(Net, self).__init__()

         self.hidden = torch.nn.Linear(, )

         self.predict = torch.nn.Linear(, )

     # 前向传递

     def forward(self, x):

         x = F.relu(self.hidden(x))

         x = self.predict(x)

         return x

 net_SGD = Net()

 net_Momentum = Net()

 net_RMSProp = Net()

 net_Adam = Net()

 nets = [net_SGD, net_Momentum, net_RMSProp, net_Adam]

 opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)

 opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)

 opt_RMSProp = torch.optim.RMSprop(net_RMSProp.parameters(), lr=LR, alpha=0.9)

 opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))

 optimizers = [opt_SGD, opt_Momentum, opt_RMSProp, opt_Adam]

 loss_func = torch.nn.MSELoss()

 loss_his = [[], [], [], []]  # 记录损失

 for epoch in range(EPOCH):

     print(epoch)

     for step, (batch_x, batch_y) in enumerate(loader):

         b_x = Variable(batch_x)

         b_y = Variable(batch_y)

         for net, opt,l_his in zip(nets, optimizers, loss_his):

             output = net(b_x)  # get output for every net

             loss = loss_func(output, b_y)  # compute loss for every net

             opt.zero_grad()  # clear gradients for next train

             loss.backward()  # backpropagation, compute gradients

             opt.step()  # apply gradients

             l_his.append(loss.data.numpy())  # loss recoder

 labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']

 for i, l_his in enumerate(loss_his):

     plt.plot(l_his, label=labels[i])

 plt.legend(loc='best')

 plt.xlabel('Steps')

 plt.ylabel('Loss')

 plt.ylim((, 0.2))

 plt.show()

莫烦pytorch学习笔记(七)——Optimizer优化器的更多相关文章

  1. 莫烦 - Pytorch学习笔记 [ 一 ]

    1. Numpy VS Torch #相互转换 np_data = torch_data.numpy() torch_data = torch.from_numpy(np_data) #abs dat ...

  2. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  3. 莫烦PyTorch学习笔记(五)——模型的存取

    import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed() ...

  4. [PyTorch 学习笔记] 4.3 优化器

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson4/optimizer_methods.py https: ...

  5. 莫烦PyTorch学习笔记(五)——分类

    import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...

  6. 莫烦PyTorch学习笔记(四)——回归

    下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...

  7. 莫烦PyTorch学习笔记(六)——批处理

    1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...

  8. 莫烦PyTorch学习笔记(三)——激励函数

    1. sigmod函数 函数公式和图表如下图     在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...

  9. 莫烦pytorch学习笔记(二)——variable

    .简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...

随机推荐

  1. Mysql命令增加、修改、删除表字段

    alter add 命令用来增加表的字段: alter add命令格式:alter table 表名 add字段 类型 其他:如下所示: ) comment '单位' alter drop 命令删除表 ...

  2. Vue的项目搭建及请求生命周期

    目录 Vue的项目搭建及请求生命周期 Vue-CLI的项目搭建 环境搭建 项目创建 pycharm运行Vue项目 Vue项目的大体结构 Vue的请求生命周期 两个小用法 Vue的项目搭建及请求生命周期 ...

  3. jmeter在windows环境下系统参数设置

    在windows环境下搭建jmeter的压测实验环境,需要对操作系统默认的一些个参数进行设置,以提高并发能力.特别是作为压力机的时候. Socket 编程时,单机最多可以建立多少个 TCP 连接,受到 ...

  4. Activiti学习笔记4 — 流程实例化

    1.创建流程引擎对象 private ProcessEngine processEngine = ProcessEngines.getDefaultProcessEngine(); 2.启动流程 流程 ...

  5. JS事件 鼠标单击事件( onclick )通常与按钮一起使用。onclick是鼠标单击事件,当在网页上单击鼠标时,就会发生该事件。同时onclick事件调用的程序块就会被执行

    鼠标单击事件( onclick ) onclick是鼠标单击事件,当在网页上单击鼠标时,就会发生该事件.同时onclick事件调用的程序块就会被执行,通常与按钮一起使用. 比如,我们单击按钮时,触发  ...

  6. Python学习笔记(五)——异常处理

    Python 异常总结 异常名称 解释 AssertionError 断言语句(assert)失败:当assert关键字后边的条件为假时,程序将抛出该异常,一般用于在代码中置入检查点 OSError ...

  7. ubuntu 权限不够,解决办法,无法获得锁 /var/lib/dpkg/lock - open (11: 资源暂时不可用)

    终端执行  sudo passwd root输入root 新密码执行命令  nano /usr/share/lightdm/lightdm.conf.d/50-ubuntu.conf末行添加   gr ...

  8. Android开发 View的UI刷新Invalidate和postInvalidate

    Invalidate 正常刷新 /** * 使整个视图无效.如果视图可见, * {@link #onDraw(android.graphics.Canvas)} 调用此方法后将在后续的UI刷新里调用o ...

  9. virtualbox manager命令小记

    virtualbox 控制虚拟机 VBoxManage list runningvms 列出运行的虚拟机 (返回名称和UUID): VBoxManage list runningvms Stop ru ...

  10. hdu多校第三场 1006 (hdu6608) Fansblog Miller-Rabin素性检测

    题意: 给你一个1e9-1e14的质数P,让你找出这个质数的前一个质数Q,然后计算Q!mod P 题解: 1e14的数据范围pass掉一切素数筛法,考虑Miller-Rabin算法. 米勒拉宾算法是一 ...