pytorch 动态调整学习率 重点
深度炼丹如同炖排骨一般,需要先大火全局加热,紧接着中火炖出营养,最后转小火收汁。
本文给出炼丹中的 “火候控制器”-- 学习率的几种调节方法,框架基于 pytorch
1. 自定义根据 epoch 改变学习率。
这种方法在开源代码中常见,此处引用 pytorch 官方实例中的代码 adjust_lr
def adjust_learning_rate(optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = args.lr * (0.1 ** (epoch // 30))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
注释:在调用此函数时需要输入所用的 optimizer 以及对应的 epoch ,并且 args.lr 作为初始化的学习率也需要给出。
使用代码示例:
optimizer = torch.optim.SGD(model.parameters(),lr = args.lr,momentum = 0.9)
for epoch in range(10):
adjust_learning_rate(optimizer,epoch)
train(...)
validate(...)
2. 针对模型的不同层设置不同的学习率
当我们在使用预训练的模型时,需要对分类层进行单独修改并进行初始化,其他层的参数采用预训练的模型参数进行初始化,这个时候我们希望在进行训练过程中,除分类层以外的层只进行微调,不需要过多改变参数,因此需要设置较小的学习率。而改正后的分类层则需要以较大的步子去收敛,学习率往往要设置大一点以 resnet101 为例,分层设置学习率。
model = torchvision.models.resnet101(pretrained=True)
large_lr_layers = list(map(id,model.fc.parameters()))
small_lr_layers = filter(lambda p:id(p) not in large_lr_layers,model.parameters())
optimizer = torch.optim.SGD([
{"params":large_lr_layers},
{"params":small_lr_layers,"lr":1e-4}
],lr = 1e-2,momenum=0.9)
注:large_lr_layers 学习率为 1e-2,small_lr_layers 学习率为 1e-4,两部分参数共用一个 momenum
3. 根据具体需要改变 lr
以前使用 keras 的时候比较喜欢 ReduceLROnPlateau 可以根据 损失或者 准确度的变化来改变 lr。最近发现 pytorch 也实现了这一个功能。
class torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)
以 acc 为例,当 mode 设置为 “max” 时,如果 acc 在给定 patience 内没有提升,则以 factor 的倍率降低 lr。
使用方法示例:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'max',verbose=1,patience=3)
for epoch in range(10):
train(...)
val_acc = validate(...)
# 降低学习率需要在给出 val_acc 之后
scheduler.step(val_acc)
4. 手动设置 lr 衰减区间
使用方法示例
def adjust_learning_rate(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
for epoch in range(60):
lr = 30e-5
if epoch > 25:
lr = 15e-5
if epoch > 30:
lr = 7.5e-5
if epoch > 35:
lr = 3e-5
if epoch > 40:
lr = 1e-5
adjust_learning_rate(optimizer, lr)
5. 余弦退火
论文: SGDR: Stochastic Gradient Descent with Warm Restarts
使用方法示例
epochs = 60
optimizer = optim.SGD(model.parameters(),lr = config.lr,momentum=0.9,weight_decay=1e-4)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer,T_max = (epochs // 9) + 1)
for epoch in range(epochs):
scheduler.step(epoch)
目前最常用的也就这么多了,当然也有很多其他类别,详情见 how-to-adjust-learning-rate
参考文献
pytorch 动态调整学习率 重点的更多相关文章
- pytorch识别CIFAR10:训练ResNet-34(自定义transform,动态调整学习率,准确率提升到94.33%)
版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 前面通过数据增强,ResNet-34残差网络识别CIFAR10,准确率达到了92.6. 这里对训练过程 ...
- pytorch中调整学习率的lr_scheduler机制
有的时候需要我们通过一定机制来调整学习率,这个时候可以借助于torch.optim.lr_scheduler类来进行调整:一般地有下面两种调整策略:(通过两个例子来展示一下) 两种机制:LambdaL ...
- [pytorch笔记] 调整网络学习率
1. 为网络的不同部分指定不同的学习率 class LeNet(t.nn.Module): def __init__(self): super(LeNet, self).__init__() self ...
- 【转载】 Pytorch中的学习率调整lr_scheduler,ReduceLROnPlateau
原文地址: https://blog.csdn.net/happyday_d/article/details/85267561 ------------------------------------ ...
- pytorch中的学习率调整函数
参考:https://pytorch.org/docs/master/optim.html#how-to-adjust-learning-rate torch.optim.lr_scheduler提供 ...
- Pytorch调整学习率
每隔一定的epoch调整学习率 def adjust_learning_rate(optimizer, epoch): """Sets the learning rate ...
- 在 Web 级集群中动态调整 Pod 资源限制
作者阿里云容器平台技术专家 王程阿里云容器平台技术专家 张晓宇(衷源) ## 引子 不知道大家有没有过这样的经历,当我们拥有了一套 Kubernetes 集群,然后开始部署应用的时候,我们应该给容器分 ...
- 动态线程池(DynamicTp)之动态调整Tomcat、Jetty、Undertow线程池参数篇
大家好,这篇文章我们来介绍下动态线程池框架(DynamicTp)的adapter模块,上篇文章也大概介绍过了,该模块主要是用来适配一些第三方组件的线程池管理,让第三方组件内置的线程池也能享受到动态参数 ...
- 如何实现可动态调整隐藏header的listview
(转自:http://blog.sina.com.cn/s/blog_70b9730f01014sgm.html) 需求:根据某种需要,可能需要动态调整listview的页眉页脚,譬如将header作 ...
随机推荐
- NOIP模拟 9.09
AK300分 果实计数 (count.pas/.c/.cpp) 时间限制:1s,空间限制32MB 题目描述: 淘淘家有棵奇怪的苹果树,这棵树共有n+1层,标号为0~n.这棵树第0层只有一个节点,为根节 ...
- scala的插值器
Scala 为我们提供了三种字符串插值的方式,分别是 s, f 和 raw.它们都是定义在 StringContext 中的方法. s 字符串插值器 val a = 2println(s"小 ...
- 运行Jmeter时,响应数据中文乱码问题解决办法
需要修改jmeter中的配置,在Jmeter安装目录/bin/jmeter.properties文件中进行修改: sampleresult.default.encoding默认为ISO-8859-1, ...
- 从0开始学习 GitHub 系列之「06.团队合作利器 Branch」
Git 相比于 SVN 最强大的一个地方就在于「分支」,Git 的分支操作简直不要太方便,而实际项目开发中团队合作最依赖的莫过于分支了,关于分支前面的系列也提到过,但是本篇会详细讲述什么是分支.分支的 ...
- 为什么DW的可视化下看到的效果与浏览器的效果有所区别?
可视区不是调用外面浏览器,Dreamweav 可视化区是为用户编辑而设计. 支持最基本的 HTML 与 CSS ,对 CSS 而言,我写入样式时如果你使用最基本的样式时它显示与你浏览器中看的效果相差不 ...
- Python 正则表达式解析HTML
- JavaScript--返回顶部方法:锚链接、行内式js写法、外链式、内嵌式
返回网页顶部方法 一.锚链接 simpleDemo: <!DOCTYPE html> <html lang="en"> <head> <m ...
- span元素和div元素的浮动效果
首先看一段代码: <style> #right {margin: 10px;float:right;color:red;} #left {float:left;color:blue;} & ...
- AS2.2使用CMake方式进行JNI/NDK开发
之前写过一篇比较水的文章Android手机控制电脑撸出HelloWorld 里面用到了JNI/NDK技术. 这篇文章给大家介绍下JNI/NDK开发.采用的是Android Studio2.2开发环境, ...
- 在Eclipse中添加Tomcat
在Eclipse中开发web或开启web服务需要Tomcat的支持,在添加Tomcat之前要清楚你的Eclipse版本,如果你的Eclipse是javvEE版的就可以直接安装Tomcat,如果不是就需 ...