pytorch之 optimizer comparison
import torch
import torch.utils.data as Data
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.optim
# torch.manual_seed(1) # reproducible LR = 0.01
BATCH_SIZE = 32
EPOCH = 12 # fake dataset
x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1)
y = x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size())) # plot dataset
plt.scatter(x.numpy(), y.numpy())
plt.show() # put dateset into torch dataset
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2,) # default network
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(1, 20) # hidden layer
self.predict = torch.nn.Linear(20, 1) # output layer def forward(self, x):
x = F.relu(self.hidden(x)) # activation function for hidden layer
x = self.predict(x) # linear output
return x if __name__ == '__main__':
# different nets
net_SGD = Net()
net_Momentum = Net()
net_RMSprop = Net()
net_Adam = Net()
nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam] # different optimizers
opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)
opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)
opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9)
opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))
optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam] loss_func = torch.nn.MSELoss()
losses_his = [[], [], [], []] # record loss # training
for epoch in range(EPOCH):
print('Epoch: ', epoch)
for step, (b_x, b_y) in enumerate(loader): # for each training step
for net, opt, l_his in zip(nets, optimizers, losses_his):
output = net(b_x) # get output for every net
loss = loss_func(output, b_y) # compute loss for every net
opt.zero_grad() # clear gradients for next train
loss.backward() # backpropagation, compute gradients
opt.step() # apply gradients
l_his.append(loss.data.numpy()) # loss recoder labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']
for i, l_his in enumerate(losses_his):
plt.plot(l_his, label=labels[i])
plt.legend(loc='best')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.ylim((0, 0.2))
plt.show()
pytorch之 optimizer comparison的更多相关文章
- pytorch 7 optimizer 优化器 加速训练
import torch import torch.utils.data as Data import torch.nn.functional as F import matplotlib.pyplo ...
- [源码解析] 深度学习分布式训练框架 horovod (21) --- 之如何恢复训练
[源码解析] 深度学习分布式训练框架 horovod (21) --- 之如何恢复训练 目录 [源码解析] 深度学习分布式训练框架 horovod (21) --- 之如何恢复训练 0x00 摘要 0 ...
- [源码解析] PyTorch 分布式(14) --使用 Distributed Autograd 和 Distributed Optimizer
[源码解析] PyTorch 分布式(14) --使用 Distributed Autograd 和 Distributed Optimizer 目录 [源码解析] PyTorch 分布式(14) - ...
- pytorch adam 源码 关于优化函数的调整 optimizer 调参 重点
关于优化函数的调整拆下包:https://ptorch.com/docs/1/optim class torch.optim.Optimizer(params, defaults)所有优化的基类. 参 ...
- 莫烦pytorch学习笔记(七)——Optimizer优化器
各种优化器的比较 莫烦的对各种优化通俗理解的视频 import torch import torch.utils.data as Data import torch.nn.functional as ...
- Pytorch学习笔记08----优化器算法Optimizer详解(SGD、Adam)
1.优化器算法简述 首先来看一下梯度下降最常见的三种变形 BGD,SGD,MBGD,这三种形式的区别就是取决于我们用多少数据来计算目标函数的梯度,这样的话自然就涉及到一个 trade-off,即参数更 ...
- pytorch bert 源码解读
https://daiwk.github.io/posts/nlp-bert.html 目录 概述 BERT 模型架构 Input Representation Pre-training Tasks ...
- Comparison of B-Tree and Hash Indexes
Understanding the B-tree and hash data structures can help predict how different queries perform on ...
- 基于pytorch的CNN、LSTM神经网络模型调参小结
(Demo) 这是最近两个月来的一个小总结,实现的demo已经上传github,里面包含了CNN.LSTM.BiLSTM.GRU以及CNN与LSTM.BiLSTM的结合还有多层多通道CNN.LSTM. ...
随机推荐
- 一款精美的Toast第三方库的简单使用
以前一直用的安卓原生Toast,个人感觉Toast这东西,没必要花功夫,知道看到了Toasty这东西,立刻被圈粉了,真的非常好看. 项目地址 我们都知道,安卓原生Toast的用法是 Toast.mak ...
- 七牛云上传视频并截取第一帧为图片(js实现)
本文出自APICloud官方论坛, 感谢论坛版主 东冥羽的分享. 七牛云上传视频并截取第一帧作为视频的封面图. 使用js上传,模块videoPlayer截取第一帧(有专门的截图模块,但是我使用的有点问 ...
- APICloud发布低代码开发平台
云原生的出现,致使传统IT模式正在集中向云架构.云开发转型,其中在企业业务的互联网化.数字化进程中尤为突出,并衍生出“敏捷开发”.“快速迭代”的刚性需求.面对双模IT,如何打造全新的IT团队与模式?并 ...
- numpy 数组的计算
一.数组和数的计算 数组和数计算,数组中的每个元素和数进行计算 1.加 import numpy as np arr1 = np.arange(12).reshape(3, 4) print(arr1 ...
- 异数OS谈发展国产操作系统的问题
异数OS谈发展国产操作系统的问题 为什么写本文 最近中兴被美制裁的问题以及红芯使用开源技术宣称国产自主技术引发了舆论不少对国产CPU以及国产操作系统自主技术的讨论,为什么我们国家有BAT,有原子弹,能 ...
- 嵩天老师python网课爬虫实例1的问题和解决方法
一,AttributeError: 'NoneType' object has no attribute 'children', 网页'tbody'没有子类 很明显,报错的意思是说tbody下面没有c ...
- 啥叫ORM
名字: object / relation map 对象关系映射 定义: 通过(描述对象和数据库之间映射的)元数据把对象自动转为关系数据 一般都是作为中间件 优缺: 优点是自动化,屏蔽了SQL语句,而 ...
- django cms 5月第一弹
官方文档: ##http://django-cms.readthedocs.io/en/latest/index.html #截图 #生存的项目结构
- 基于swoole+Redis的消息实时推送通知
swoole+Redis将实时数据的推送 一 实现功能 设计师订单如果设计师未抢单,超时(5分钟)设计订单时时给设计师派送, 设计师公众号中收到派单信息 设计发布者收到派单成功信息 环境 centos ...
- Iptables和Firewall-selinux
一.Iptables防火墙 ---------- **三表五链:**三表: filter过滤表 nat转换表 mangle表五链: PREROUTING--->在进行路由选择前处理数据包 INP ...