各种优化器的比较

莫烦的对各种优化通俗理解的视频

 import torch

 import torch.utils.data as Data

 import torch.nn.functional as F

 from torch.autograd import Variable

 import matplotlib.pyplot as plt

 # 超参数

 LR = 0.01

 BATCH_SIZE = 

 EPOCH = 

 # 生成假数据

 # torch.unsqueeze() 的作用是将一维变二维,torch只能处理二维的数据

 x = torch.unsqueeze(torch.linspace(-, , ), dim=)  # x data (tensor), shape(, )

 # 0.2 * torch.rand(x.size())增加噪点

 y = x.pow() + 0.1 * torch.normal(torch.zeros(*x.size()))

 # 输出数据图

 # plt.scatter(x.numpy(), y.numpy())

 # plt.show()

 torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)

 loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=)

 class Net(torch.nn.Module):

     # 初始化

     def __init__(self):

         super(Net, self).__init__()

         self.hidden = torch.nn.Linear(, )

         self.predict = torch.nn.Linear(, )

     # 前向传递

     def forward(self, x):

         x = F.relu(self.hidden(x))

         x = self.predict(x)

         return x

 net_SGD = Net()

 net_Momentum = Net()

 net_RMSProp = Net()

 net_Adam = Net()

 nets = [net_SGD, net_Momentum, net_RMSProp, net_Adam]

 opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)

 opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)

 opt_RMSProp = torch.optim.RMSprop(net_RMSProp.parameters(), lr=LR, alpha=0.9)

 opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))

 optimizers = [opt_SGD, opt_Momentum, opt_RMSProp, opt_Adam]

 loss_func = torch.nn.MSELoss()

 loss_his = [[], [], [], []]  # 记录损失

 for epoch in range(EPOCH):

     print(epoch)

     for step, (batch_x, batch_y) in enumerate(loader):

         b_x = Variable(batch_x)

         b_y = Variable(batch_y)

         for net, opt,l_his in zip(nets, optimizers, loss_his):

             output = net(b_x)  # get output for every net

             loss = loss_func(output, b_y)  # compute loss for every net

             opt.zero_grad()  # clear gradients for next train

             loss.backward()  # backpropagation, compute gradients

             opt.step()  # apply gradients

             l_his.append(loss.data.numpy())  # loss recoder

 labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']

 for i, l_his in enumerate(loss_his):

     plt.plot(l_his, label=labels[i])

 plt.legend(loc='best')

 plt.xlabel('Steps')

 plt.ylabel('Loss')

 plt.ylim((, 0.2))

 plt.show()

莫烦pytorch学习笔记(七)——Optimizer优化器的更多相关文章

  1. 莫烦 - Pytorch学习笔记 [ 一 ]

    1. Numpy VS Torch #相互转换 np_data = torch_data.numpy() torch_data = torch.from_numpy(np_data) #abs dat ...

  2. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  3. 莫烦PyTorch学习笔记(五)——模型的存取

    import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed() ...

  4. [PyTorch 学习笔记] 4.3 优化器

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson4/optimizer_methods.py https: ...

  5. 莫烦PyTorch学习笔记(五)——分类

    import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...

  6. 莫烦PyTorch学习笔记(四)——回归

    下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...

  7. 莫烦PyTorch学习笔记(六)——批处理

    1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...

  8. 莫烦PyTorch学习笔记(三)——激励函数

    1. sigmod函数 函数公式和图表如下图     在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...

  9. 莫烦pytorch学习笔记(二)——variable

    .简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...

随机推荐

  1. linux就该这么学--资料整理--持续更新

    基础命令 服务管理 systemctl redhat7 systemctl start foo.service 启动服务 systemctl restart foo.service 重启服务 syst ...

  2. spring MVC <mvc:annotation-driven>

    研究SpringMvc 3.2的<mvc:annotation-driven>默认干了什么 如果不配置其他参数,大致相当于以下的配置文件(参考自org.springframework.we ...

  3. django 项目分析

    项目要点 一.功能制定 1.用户功能 #.登陆 #.权限组功能 2.数据展示功能 #.列表展示 #.详细信息展示 #.图标展示 3.资源管理功能 #远程管理 #对远程服务器上的进程具有 #开启 #关闭 ...

  4. Algo: maxSubArray vs. maxProduct

    这两个问题类似,都可利用动态规划思想求解. 一.最大连续子序列和 https://leetcode.com/problems/maximum-subarray/description/ https:/ ...

  5. 从别人git下载项目下来然后运行

    1点击clone or  download 2.自由选择 3.拉到你想放的位置,我是放到桌面上的 4. cmd 打开,进入你 的下载到桌面的项目 5. # install dependencies n ...

  6. 阿里云应用上边缘云解决方案助力互联网All in Cloud

    九月末的杭州因为一场云栖大会变得格外火热. 9月25日,吸引全球目光的2019杭州云栖大会如期开幕.20000平米的展区集结数百家企业,为数万名开发者带来了一场前沿科技的饕餮盛宴. 如同往年一样,位于 ...

  7. thinkphp 链接数据库

    ThinkPHP内置了抽象数据库访问层,把不同的数据库操作封装起来,我们只需要使用公共的Db类进行操作,而无需针对不同的数据库写不同的代码和底层实现,Db类会自动调用相应的数据库驱动来处理.目前的数据 ...

  8. (转)Android开发把项目打包成apk

    转:http://blog.csdn.net/luoyin22/article/details/7862742 做完一个Android项目之后,如何才能把项目发布到Internet上供别人使用呢?我们 ...

  9. 什么是存根类 Stub

    转:http://www.cnblogs.com/cy163/archive/2009/08/04/1539077.html 存根类是一个类,它实现了一个接口,但是实现后的每个方法都是空的.      ...

  10. Vuex听说很难?

    Vuex 是什么? Vuex 是一个专为 Vue.js 应用程序开发的状态管理模式.它采用集中式存储管理应用的所有组件状态,并以相应的规则保证状态以一种可预测的方式发生变化.   什么鬼东西 看完这段 ...