1. lr_scheduler综述

torch.optim.lr_scheduler模块提供了一些根据epoch训练次数来调整学习率(learning rate)的方法。一般情况下我们会设置随着epoch的增大而逐渐减小学习率从而达到更好的训练效果。

学习率的调整应该放在optimizer更新之后,下面是一个参考:

from torch.optim.lr_scheduler import LinearLR
model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = LinearLR(optimizer=optimizer,start_factor=1.0/3,end_factor=1.0,total_iters=15) for epoch in range(20):
  for input, target in dataset:
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()
  scheduler.step()

注意: 在PyTorch 1.1.0之前的版本,学习率的调整应该被放在optimizer更新之前的。如果我们在 1.1.0 及之后的版本仍然将学习率的调整(即 scheduler.step())放在 optimizer’s update(即 optimizer.step())之前,那么 learning rate schedule 的第一个值将会被跳过。所以如果某个代码是在 1.1.0 之前的版本下开发,但现在移植到 1.1.0及之后的版本运行,发现效果变差,需要检查一下是否将scheduler.step()放在了optimizer.step()之前。

2. optimizer综述

torch.optim 是一个实现各种优化算法的包。

要使用torch.optim,必须构造一个optimizer对象,该对象将保持当前状态,并将根据计算出的梯度更新参数。

要构造一个optimizer,我们必须给它一个包含要优化的参数(所有参数应该是Variable类型)的迭代器。然后,可以指定特定于optimizer的选项,如学习率、权重衰减等。

例如:

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.05)

optimizer也支持指定每个参数的选项。要做到这一点,不是传递一个Variable的迭代,而是传递一个可迭代的dict。每个变量都将定义一个单独的parameter group,并且应该包含一个 params 键,包含一个属于它的参数列表。其他键应该与优化器接受的关键字参数匹配,并将用作该组的优化选项。

我们仍然可以将选项作为关键字参数传递。它们将作为默认值,在没有覆盖它们的组中使用。如果只希望改变单个选项,同时保持参数组之间所有其他选项的一致性时,这种方法非常有用。

例如,想要指定每层的学习速度时:

optim.SGD([
{'params': model.base.parameters()},
{'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)

则:model.base 的参数将使用默认的学习率1e-2,model.classifier 的参数将使用学习率1e-3,并且所有参数的momentum为0.9。

所有的优化器都实现了一个 step()方法来更新参数:optimizer.step()。当使用如backward()等方法计算出梯度后,就可以调用step()更新参数。

例如:

for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step() torch.optim.optimizer(params, defaults)

参数:

  • params(iterable) - 可迭代的torch.Tensor或dict,用来指定需要优化的张量。
  • defaults(dict) - dict,包含优化选项的默认值(当参数组没有指定它们时生效)。

方法:

  • Optimizer.add_param_group - 添加一个参数组到优化器的参数组
  • Optimizer.load_state_dict - 加载优化器状态
  • Optimizer.state_dict - 以字典形式返回优化器的状态
  • Optimizer.step - 执行单个优化步骤(参数更新)
  • Optimizer.zero_grad - 所有需优化张量的梯度清零

3. linearLR综述

  • 在epoch数达到total_iters数值之前,使用线性改变乘法因子衰减学习率。
  • 计算公式和pytorch计算代码如下:

def _get_closed_form_lr(self):
return [base_lr * (self.start_factor +
(self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters)
for base_lr in self.base_lrs]
  • pytorch调用及相关参数:
torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.3333333333333333, end_factor=1.0, total_iters=5, last_epoch=- 1, verbose=False)
 

参数

  • optimizer(Optimizer) -包装优化器。

  • start_factor(float) -我们在第一个 epoch 中乘以学习率(base_lr)的数字。在接下来的 epoch 中,乘法因子向 end_factor 变化。默认值:1./3。

  • end_factor(float) -我们在线性变化过程结束时乘以学习率(base_lr)的数字。默认值:1.0。

  • total_iters(int) -乘法因子达到1的迭代次数。默认值:5。

  • last_epoch(int) -最后一个纪元的索引。默认值:-1。

  • verbose(bool) -如果 True ,每次更新都会向标准输出打印一条消息。默认值:False

 

通过线性改变小的乘法因子来衰减每个参数组的学习率,直到 epoch 的数量达到预定义的里程碑:total_iters。

请注意,这种衰减可能与此调度程序外部对学习率的其他更改同时发生。当last_epoch=-1 时,设置初始 lr 为 lr。

  • 举例说明:
lr_scheduler = LinearLR(optimizer, start_factor=0.5, total_iters=4)
base_lr=0.05
# epoch == 0→lr = base_lr * start_factor = 0.05 * 0.5=0.025;
# epoch == 1→lr = 0.05 * (0.5 + 0.5 * 0.25) = 0.3125;
......
# epoch ≥ 4→lr = base_lr * end_factor = 0.05(当epoch数等于total_iters时,min(self.total_iters, self.last_epoch) / self.total_iters = 1)

关于学习率-----linearLR的更多相关文章

  1. 深度学习训练过程中的学习率衰减策略及pytorch实现

    学习率是深度学习中的一个重要超参数,选择合适的学习率能够帮助模型更好地收敛. 本文主要介绍深度学习训练过程中的6种学习率衰减策略以及相应的Pytorch实现. 1. StepLR 按固定的训练epoc ...

  2. 史上最全学习率调整策略lr_scheduler

    学习率是深度学习训练中至关重要的参数,很多时候一个合适的学习率才能发挥出模型的较大潜力.所以学习率调整策略同样至关重要,这篇博客介绍一下Pytorch中常见的学习率调整方法. import torch ...

  3. 自适应学习率调整:AdaDelta

    Reference:ADADELTA: An Adaptive Learning Rate Method 超参数 超参数(Hyper-Parameter)是困扰神经网络训练的问题之一,因为这些参数不可 ...

  4. 使用CNN(convolutional neural nets)关键的一点是检测到的面部教程(四):学习率,学习潜能,dropout

    第七部分 让 学习率 和 学习潜能 随时间的变化 光训练就花了一个小时的时间.等结果并非一个令人心情愉快的事情.这一部分.我们将讨论将两个技巧结合让网络训练的更快! 直觉上的解决的方法是,開始训练时取 ...

  5. 机器学习:Python实现lms中的学习率的退火算法

    ''' 算法:lms学习率的退火算法 解决的问题:学习率不变化,收敛速度较慢的情况 思路:由初始解和控制参数初值开始,对当前解重复进行"产生新解-->计算目标函数差--> 接受或 ...

  6. 学习率 Learning Rate

    本文从梯度学习算法的角度中看学习率对于学习算法性能的影响,以及介绍如何调整学习率的一般经验和技巧. 在机器学习中,监督式学习(Supervised Learning)通过定义一个模型,并根据训练集上的 ...

  7. Python ---------- Tensorflow (二)学习率

    假设最小化函数 y = x2 , 选择初始点 x0= 5 1. 学习率为1的时候,x在5和-5之间震荡. #学习率为1 import tensorflow as tf training_steps = ...

  8. 调参过程中的参数 学习率,权重衰减,冲量(learning_rate , weight_decay , momentum)

    无论是深度学习还是机器学习,大多情况下训练中都会遇到这几个参数,今天依据我自己的理解具体的总结一下,可能会存在错误,还请指正. learning_rate , weight_decay , momen ...

  9. TensorFlow之DNN(二):全连接神经网络的加速技巧(Xavier初始化、Adam、Batch Norm、学习率衰减与梯度截断)

    在上一篇博客<TensorFlow之DNN(一):构建“裸机版”全连接神经网络>中,我整理了一个用TensorFlow实现的简单全连接神经网络模型,没有运用加速技巧(小批量梯度下降不算哦) ...

  10. pytorch识别CIFAR10:训练ResNet-34(自定义transform,动态调整学习率,准确率提升到94.33%)

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 前面通过数据增强,ResNet-34残差网络识别CIFAR10,准确率达到了92.6. 这里对训练过程 ...

随机推荐

  1. jquery的radio的change事件

    一.用的jquery的radio的change事件:当元素的值发生改变时,会发生 change 事件,radio选择不同name值选项的时候恰巧是值发生改变 表单单选框 <input type= ...

  2. Netty源码—7.ByteBuf原理二

    大纲 9.Netty的内存规格 10.缓存数据结构 11.命中缓存的分配流程 12.Netty里有关内存分配的重要概念 13.Page级别的内存分配 14.SubPage级别的内存分配 15.Byte ...

  3. Netty源码—9.性能优化和设计模式

    大纲 1.Netty的两大性能优化工具 2.FastThreadLocal的实现之构造方法 3.FastThreadLocal的实现之get()方法 4.FastThreadLocal的实现之set( ...

  4. spring项目使用EMQX,使用@Autowired注入失败报错空指针问题记录

    目录 java客户端使用MQTT订阅消息大致流程 MQTTConnect部分代码 MQTTListener部分代码 问题分析 问题原因 解决方法 总结 参考 java客户端使用MQTT订阅消息大致流程 ...

  5. Spring 的 resolveBeforeInstantiation 方法作用详解

    一.定义 resolveBeforeInstantiation 是 Spring 框架中 AbstractAutowireCapableBeanFactory 类的核心方法之一,它在 Bean 的实例 ...

  6. 深入理解Java虚拟机-JAVA内存模型与线程

    Java内存模型(JMM) JMM 的核心概念 主内存与工作内存: 主内存(Main Memory)是所有线程共享的内存区域,存放着所有变量的值 每个线程都有自己的 工作内存(Working Memo ...

  7. RMQ学习笔记

    RMQ学习笔记 前言:这个算法无论是从适配性还是长度来说都很有实力... 关于 RMQ RMQ 是英文 Range Maximum/Minimum Query 的缩写,表示区间最大(最小)值. 详细信 ...

  8. DIY钢铁侠方舟反应堆第二期—第一代电路板展示

    经历一个周的时间,终于把方舟反应堆的电路画了出来,简单画了一个USB口加LED灯的电路,先简单测试一下 原理图展示 PCB展示 实物如下 这里出了一点意外,LED被发错了,本来计划的是蓝灯,但是发来的 ...

  9. hadoop部署安装(六)hive

    5.配置hive 5.1 hive下载地址 http://mirror.bit.edu.cn/apache/hive/ 解压缩 [root@master ~]# tar xf apache-hive- ...

  10. 网鼎杯-phpweb

    找了一些php读取文件的函数尝试读取源码,试了一个readfile就成功了 <?php $disable_fun = array("exec","shell_exe ...