各种优化器的比较

莫烦的对各种优化通俗理解的视频

 import torch

 import torch.utils.data as Data

 import torch.nn.functional as F

 from torch.autograd import Variable

 import matplotlib.pyplot as plt

 # 超参数

 LR = 0.01

 BATCH_SIZE = 

 EPOCH = 

 # 生成假数据

 # torch.unsqueeze() 的作用是将一维变二维,torch只能处理二维的数据

 x = torch.unsqueeze(torch.linspace(-, , ), dim=)  # x data (tensor), shape(, )

 # 0.2 * torch.rand(x.size())增加噪点

 y = x.pow() + 0.1 * torch.normal(torch.zeros(*x.size()))

 # 输出数据图

 # plt.scatter(x.numpy(), y.numpy())

 # plt.show()

 torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)

 loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=)

 class Net(torch.nn.Module):

     # 初始化

     def __init__(self):

         super(Net, self).__init__()

         self.hidden = torch.nn.Linear(, )

         self.predict = torch.nn.Linear(, )

     # 前向传递

     def forward(self, x):

         x = F.relu(self.hidden(x))

         x = self.predict(x)

         return x

 net_SGD = Net()

 net_Momentum = Net()

 net_RMSProp = Net()

 net_Adam = Net()

 nets = [net_SGD, net_Momentum, net_RMSProp, net_Adam]

 opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)

 opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)

 opt_RMSProp = torch.optim.RMSprop(net_RMSProp.parameters(), lr=LR, alpha=0.9)

 opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))

 optimizers = [opt_SGD, opt_Momentum, opt_RMSProp, opt_Adam]

 loss_func = torch.nn.MSELoss()

 loss_his = [[], [], [], []]  # 记录损失

 for epoch in range(EPOCH):

     print(epoch)

     for step, (batch_x, batch_y) in enumerate(loader):

         b_x = Variable(batch_x)

         b_y = Variable(batch_y)

         for net, opt,l_his in zip(nets, optimizers, loss_his):

             output = net(b_x)  # get output for every net

             loss = loss_func(output, b_y)  # compute loss for every net

             opt.zero_grad()  # clear gradients for next train

             loss.backward()  # backpropagation, compute gradients

             opt.step()  # apply gradients

             l_his.append(loss.data.numpy())  # loss recoder

 labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']

 for i, l_his in enumerate(loss_his):

     plt.plot(l_his, label=labels[i])

 plt.legend(loc='best')

 plt.xlabel('Steps')

 plt.ylabel('Loss')

 plt.ylim((, 0.2))

 plt.show()

莫烦pytorch学习笔记(七)——Optimizer优化器的更多相关文章

  1. 莫烦 - Pytorch学习笔记 [ 一 ]

    1. Numpy VS Torch #相互转换 np_data = torch_data.numpy() torch_data = torch.from_numpy(np_data) #abs dat ...

  2. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  3. 莫烦PyTorch学习笔记(五)——模型的存取

    import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed() ...

  4. [PyTorch 学习笔记] 4.3 优化器

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson4/optimizer_methods.py https: ...

  5. 莫烦PyTorch学习笔记(五)——分类

    import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...

  6. 莫烦PyTorch学习笔记(四)——回归

    下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...

  7. 莫烦PyTorch学习笔记(六)——批处理

    1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...

  8. 莫烦PyTorch学习笔记(三)——激励函数

    1. sigmod函数 函数公式和图表如下图     在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...

  9. 莫烦pytorch学习笔记(二)——variable

    .简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...

随机推荐

  1. 自动化测试工具2-testcomplete

    今天来说说testcomplete的使用 录了一个简单案例视频,网址如下:https://v.qq.com/x/page/f05116a062y.html 第一步是创建一个工程: 输入工程名,和选择工 ...

  2. Python3数据分析与挖掘建模实战✍✍✍

    Python3数据分析与挖掘建模实战 Python数据分析简介 Python入门 运行:cmd下"python hello.py" 基本命令: 第三方库 安装 Windows中 p ...

  3. js 实现多选

    效果: html: <!DOCTYPE html> <html> <head> <meta charset="UTF-8"> < ...

  4. 使用<script>标签在HTML网页中插入JavaScript代码

    新朋友你在哪里(如何插入JS) 我们来看看如何写入JS代码?你只需一步操作,使用<script>标签在HTML网页中插入JavaScript代码.注意, <script>标签要 ...

  5. 【笔记篇】(理论向)快速傅里叶变换(FFT)学习笔记w

    现在真是一碰电脑就很颓废啊... 于是早晨把电脑锁上然后在旁边啃了一节课多的算导, 把FFT的基本原理整明白了.. 但是我并不觉得自己能讲明白... Fast Fourier Transformati ...

  6. python ORM框架:SqlAlchemy

    ORM,对象关系映射,即Object Relational Mapping的简称,通过ORM框架将编程语言中的对象模型与数据库的关系模型建立映射关系,这样做的目的:简化sql语言操作数据库的繁琐过程( ...

  7. struts使用

    下载文件 <action name="download" class="thirdIssueAction" method="getDownloa ...

  8. CSIC_716_20191119【常用模块的用法 subprocess、re、logging、防止自动测试、包的理论】

    subprocess模块 可以通过python代码给操作系统终端发送命令,并可以得到返回结果. import subprocess str = input('>>>请输入命令') # ...

  9. 微信小程序为什么看不到所有的console.log()的日志信息

    记录一个巨傻无比的问题 1.在首页的onLoad()函数里面,加了地理位置的加载,并打印到控制台上,可是今天就是没出现 2.然后纳闷的很久,各种google,发现没有人遇到这个问题 3.再然后,我就看 ...

  10. leetcode-241-为运算表达式设置优先级*

    题目描述: 方法:分治* class Solution: def diffWaysToCompute(self, input: str) -> List[int]: if input.isdig ...