学习率是深度学习训练中至关重要的参数,很多时候一个合适的学习率才能发挥出模型的较大潜力。所以学习率调整策略同样至关重要,这篇博客介绍一下Pytorch中常见的学习率调整方法。

  1. import torch
  2. import numpy as np
  3. from torch.optim import SGD
  4. from torch.optim import lr_scheduler
  5. from torch.nn.parameter import Parameter
  6. model = [Parameter(torch.randn(2, 2, requires_grad=True))]
  7. optimizer = SGD(model, lr=0.1)

以上是一段通用代码,这里将基础学习率设置为0.1。接下来仅仅展示学习率调节器的代码,以及对应的学习率曲线。

1. StepLR

这是最简单常用的学习率调整方法,每过step_size轮,将此前的学习率乘以gamma。

  1. scheduler=lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

2. MultiStepLR

MultiStepLR同样也是一个非常常见的学习率调整策略,它会在每个milestone时,将此前学习率乘以gamma。

  1. scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[30,80], gamma=0.5)

3. ExponentialLR

ExponentialLR是指数型下降的学习率调节器,每一轮会将学习率乘以gamma,所以这里千万注意gamma不要设置的太小,不然几轮之后学习率就会降到0。

  1. scheduler=lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

4. LinearLR

LinearLR是线性学习率,给定起始factor和最终的factor,LinearLR会在中间阶段做线性插值,比如学习率为0.1,起始factor为1,最终的factor为0.1,那么第0次迭代,学习率将为0.1,最终轮学习率为0.01。下面设置的总轮数total_iters为80,所以超过80时,学习率恒为0.01。

  1. scheduler=lr_scheduler.LinearLR(optimizer,start_factor=1,end_factor=0.1,total_iters=80)

5. CyclicLR

  1. scheduler=lr_scheduler.CyclicLR(optimizer,base_lr=0.1,max_lr=0.2,step_size_up=30,step_size_down=10)

CyclicLR的参数要更多一些,它的曲线看起来就像是不断的上坡与下坡,base_lr为谷底的学习率,max_lr为顶峰的学习率,step_size_up是从谷底到顶峰需要的轮数,step_size_down时从顶峰到谷底的轮数。至于为啥这样设置,可以参见论文,简单来说最佳学习率会在base_lr和max_lr,CyclicLR不是一味衰减而是出现增大的过程是为了避免陷入鞍点。

  1. scheduler=lr_scheduler.CyclicLR(optimizer,base_lr=0.1,max_lr=0.2,step_size_up=30,step_size_down=10)

6. OneCycleLR

OneCycleLR顾名思义就像是CyclicLR的一周期版本,它也有多个参数,max_lr就是最大学习率,pct_start是学习率上升部分所占比例,一开始的学习率为max_lr/div_factor,最终的学习率为max_lr/final_div_factor,总的迭代次数为total_steps。

  1. scheduler=lr_scheduler.OneCycleLR(optimizer,max_lr=0.1,pct_start=0.5,total_steps=120,div_factor=10,final_div_factor=10)

7. CosineAnnealingLR

CosineAnnealingLR是余弦退火学习率,T_max是周期的一半,最大学习率在optimizer中指定,最小学习率为eta_min。这里同样能够帮助逃离鞍点。值得注意的是最大学习率不宜太大,否则loss可能出现和学习率相似周期的上下剧烈波动。

  1. scheduler=lr_scheduler.CosineAnnealingLR(optimizer,T_max=20,eta_min=0.05)

7. CosineAnnealingWarmRestarts

这里相对负责一些,公式如下,其中T_0是第一个周期,会从optimizer中的学习率下降至eta_min,之后的每个周期变成了前一周期乘以T_mult。

\(eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
\cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)\)

  1. scheduler=lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2, eta_min=0.01)

8. LambdaLR

LambdaLR其实没有固定的学习率曲线,名字中的lambda指的是可以将学习率自定义为一个有关epoch的lambda函数,比如下面我们定义了一个指数函数,实现了ExponentialLR的功能。

  1. scheduler=lr_scheduler.LambdaLR(optimizer,lr_lambda=lambda epoch:0.9**epoch)

9.SequentialLR

SequentialLR可以将多个学习率调整策略按照顺序串联起来,在milestone时切换到下一个学习率调整策略。下面就是将一个指数衰减的学习率和线性衰减的学习率结合起来。

  1. scheduler=lr_scheduler.SequentialLR(optimizer,schedulers=[lr_scheduler.ExponentialLR(optimizer, gamma=0.9),lr_scheduler.LinearLR(optimizer,start_factor=1,end_factor=0.1,total_iters=80)],milestones=[50])

10.ChainedScheduler

ChainedScheduler和SequentialLR类似,也是按照顺序调用多个串联起来的学习率调整策略,不同的是ChainedScheduler里面的学习率变化是连续的。

  1. scheduler=lr_scheduler.ChainedScheduler([lr_scheduler.LinearLR(optimizer,start_factor=1,end_factor=0.5,total_iters=10),lr_scheduler.ExponentialLR(optimizer, gamma=0.95)])

11.ConstantLR

ConstantLRConstantLR非常简单,在total_iters轮内将optimizer里面指定的学习率乘以factor,total_iters轮外恢复原学习率。

  1. scheduler=lr_scheduler.ConstantLRConstantLR(optimizer,factor=0.5,total_iters=80)

12.ReduceLROnPlateau

ReduceLROnPlateau参数非常多,其功能是自适应调节学习率,它在step的时候会观察验证集上的loss或者准确率情况,loss当然是越低越好,准确率则是越高越好,所以使用loss作为step的参数时,mode为min,使用准确率作为参数时,mode为max。factor是每次学习率下降的比例,新的学习率等于老的学习率乘以factor。patience是能够容忍的次数,当patience次后,网络性能仍未提升,则会降低学习率。threshold是测量最佳值的阈值,一般只关注相对大的性能提升。min_lr是最小学习率,eps指最小的学习率变化,当新旧学习率差别小于eps时,维持学习率不变。

因为参数相对复杂,这里可以看一份完整的代码实操

  1. scheduler=lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',factor=0.5,patience=5,threshold=1e-4,threshold_mode='abs',cooldown=0,min_lr=0.001,eps=1e-8)
  2. scheduler.step(val_score)

史上最全学习率调整策略lr_scheduler的更多相关文章

  1. 优秀后端架构师必会知识:史上最全MySQL大表优化方案总结

    本文原作者“ manong”,原创发表于segmentfault,原文链接:segmentfault.com/a/1190000006158186 1.引言   MySQL作为开源技术的代表作之一,是 ...

  2. 了解iOS消息推送一文就够:史上最全iOS Push技术详解

    本文作者:陈裕发, 腾讯系统测试工程师,由腾讯WeTest整理发表. 1.引言 开发iOS系统中的Push推送,通常有以下3种情况: 1)在线Push:比如QQ.微信等IM界面处于前台时,聊天消息和指 ...

  3. 移动端IM开发者必读(二):史上最全移动弱网络优化方法总结

    1.前言 本文接上篇<移动端IM开发者必读(一):通俗易懂,理解移动网络的“弱”和“慢”>,关于移动网络的主要特性,在上篇中已进行过详细地阐述,本文将针对上篇中提到的特性,结合我们的实践经 ...

  4. 【转载】 PyTorch学习之六个学习率调整策略

    原文地址: https://blog.csdn.net/shanglianlm/article/details/85143614 ----------------------------------- ...

  5. 史上最全的maven的pom.xml文件详解(转载)

    此文出处:史上最全的maven的pom.xml文件详解——阿豪聊干货 <project xmlns="http://maven.apache.org/POM/4.0.0" x ...

  6. JVM史上最全实践优化没有之一

    JVM史上最全优化没有之一 1.jvm的运行参数 1.1 三种参数类型 1.1.1 -server与-clinet参数 2.1 -X参数 2.1.1 -Xint.-Xcomp.-Xmixed 3.1 ...

  7. Springcloud 配置 | 史上最全,一文全懂

    Springcloud 高并发 配置 (一文全懂) 疯狂创客圈 Java 高并发[ 亿级流量聊天室实战]实战系列之15 [博客园总入口 ] 前言 疯狂创客圈(笔者尼恩创建的高并发研习社群)Spring ...

  8. Tomcat8史上最全优化实践

    Tomcat8史上最全优化实践 1.Tomcat8优化 1.1.Tomcat配置优化 1.1.1.部署安装tomcat8 1.1.2 禁用AJP连接 1.1.3.执行器(线程池) 1.1.4 3种运行 ...

  9. 史上最全的音视频SDK包分享给大家

    史上最全的音视频SDK包分享给大家 概述一下SDK功能: 项目 详情视频通信  支持多种分辨率的视频通信语音通信  提供语音通信,可支持高清宽带语音动态创建房间  可以根据需要,随时创建房间H5 支持 ...

随机推荐

  1. 云平台短信验证码通知短信java/php/.net开发实现

    一.本文目的 大部分平台都有一个接入发送短信验证码.通知短信的需求.虽然市场上大部分平台的接口都只是一个非常普通的HTTP-GET请求,但终归有需要学习和借鉴使用的朋友. 本文的初衷是主要提供学习便利 ...

  2. [linux tips] puppet client ssl 证书过期

    问题: [root@control-01 .ssh]# puppet agent -tv Warning: Unable to fetch my node definition, but the ag ...

  3. 运行npm install命令的时候会发生什么?

    摘要:我们日常在下载第三方依赖的时候,都会用到一个命令npm install,那么你知道,在运行这个命令的时候都会发生什么吗? 本文分享自华为云社区<运行npm install命令的时候会发生什 ...

  4. IDEA打包javaFX及踩坑解决

    开门见山的说,先打包,再说坑. File-->Project Structure --> Artifacts-->(此处点加号)JAR-->From modules with ...

  5. 【学习笔记】CDQ分治(等待填坑)

    因为我对CDQ分治理解不深,所以这篇博客只是我现在的浅显理解有任何不对的,希望大佬指出. 首先就是CDQ分治适用的题型: (1)带修改,但修改互相独立 (2)必须允许离线 (3)解决数据结构的题,能把 ...

  6. 树莓派开发笔记(十三):入手研华ADVANTECH工控树莓派UNO-220套件(二):安装rtc等驱动

    前言   前面运行了系统,本篇是安装对应套装的驱动,使rtc等外设生效,树莓派本身是不带rtc外设的.   UNO-220-P4N1AE 驱动下载     官方下载:https://www.advan ...

  7. day02 真正的高并发还得看IO多路复用

    教程说明 C++高性能网络服务保姆级教程 首发地址 day02 真正的高并发还得看IO多路复用 本节目的 使用epoll实现一个高并发的服务器 从单进程讲起 上节从一个基础的socket服务说起我们实 ...

  8. Json序列化与反序列化导致多线程运行速度和单线程运行速度一致问题

    紧跟上篇文章 十个进程开启十个bash后一致写入命令执行完毕之后产生了很多很多的文件,博主需要对这些文件同意处理,也就是说对几十万个文件进行处理,想了又想,单线程处理那么多数据肯定不行,于是乎想到了使 ...

  9. 1903021121—刘明伟—Java第四周作业—java分支语句学习

    项目 内容 课程班级博客链接 19信计班(本) 作业要求链接 第四周作业 要求 每道题要有题目,代码(使用插入代码,不会插入代码的自己查资料解决,不要直接截图代码!!),截图(只截运行结果). 扩展阅 ...

  10. mysql5.7介绍和安装

    环境准备: 1.关闭防火墙和selinux systemctl stop firewalldsystemctl stop SElinux 2. 如果安装过mariadb需要停止且卸载服务 system ...