关于使用SGD时如何选择初始的学习率(这里SGD是指带动量的SGD,momentum=0.9):

训练一个epoch,把学习率从一个较小的值(10-8)上升到一个较大的值(10),画出学习率(取log)和经过平滑后的loss的曲线,根据曲线来选择合适的初始学习率。

从上图可以看出学习率和loss之间的关系,最曲线的最低点的学习率已经有了使loss上升的趋势,曲线的最低点不选。最低点左边的点都是可供选择的点,但是选择太小的学习率会导致收敛的速度过慢,所以根据上图我们可以选择0.01(10-2)为初始的学习率。

关于学习率的调整策略,在使用SGD时不建议使用指数型连续下降的调节方法,建议使用阶梯式调节学习率的方法。每隔一定数量的epoch学习率调节为之前的0.1倍(根据自己实际任务调节每个阶段迭代epoch的数量)。

如果不想使用上述方法,这里提供几个经验值供选择,fine-tune模型初始学习率可设置为0.01,从头开始训练模型学习率可设置为0.1(仅供参考)。

供参考的寻找初始学习率的pytorch代码(根据自己的任务进行修改):

def find_lr(init_value = 1e-8, final_value=10., beta = 0.98):
num = len(train_loader)-1
mult = (final_value / init_value) ** (1/num)
lr = init_value
optimizer.param_groups[0]['lr'] = lr
avg_loss = 0.
best_loss = 0.
batch_num = 0
losses = []
log_lrs = []
for data in train_loader:
batch_num += 1
#As before, get the loss for this mini-batch of inputs/outputs
inputs,labels = data
inputs, labels = Variable(inputs), Variable(labels)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
#Compute the smoothed loss
avg_loss = beta * avg_loss + (1-beta) *loss.data[0]
smoothed_loss = avg_loss / (1 - beta**batch_num)
#Stop if the loss is exploding
if batch_num > 1 and smoothed_loss > 4 * best_loss:
return log_lrs, losses
#Record the best loss
if smoothed_loss < best_loss or batch_num==1:
best_loss = smoothed_loss
#Store the values
losses.append(smoothed_loss)
log_lrs.append(math.log10(lr))
#Do the SGD step
loss.backward()
optimizer.step()
#Update the lr for the next step
lr *= mult
optimizer.param_groups[0]['lr'] = lr
return log_lrs, losses
参考论文《Cyclical Learning Rates for Training Neural Networks》
和博客https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html

sgd学习率选择问题的更多相关文章

  1. Rich feature hierarchies for accurate object detection and semantic segmentation(理解)

    0 - 背景 该论文是2014年CVPR的经典论文,其提出的模型称为R-CNN(Regions with Convolutional Neural Network Features),曾经是物体检测领 ...

  2. R-CNN阅读笔记

    论文地址:<Rich feature hierarchies for accurate object detection and semantic segmentation> 论文包含两个 ...

  3. 转-------基于R-CNN的物体检测

    基于R-CNN的物体检测 原文地址:http://blog.csdn.net/hjimce/article/details/50187029 作者:hjimce 一.相关理论 本篇博文主要讲解2014 ...

  4. 深度学习笔记之基于R-CNN的物体检测

    不多说,直接上干货! 基于R-CNN的物体检测 原文地址:http://blog.csdn.net/hjimce/article/details/50187029 作者:hjimce 一.相关理论 本 ...

  5. 【神经网络与深度学习】【计算机视觉】RCNN- 将CNN引入目标检测的开山之作

    转自:https://zhuanlan.zhihu.com/p/23006190?refer=xiaoleimlnote 前面一直在写传统机器学习.从本篇开始写一写 深度学习的内容. 可能需要一定的神 ...

  6. 深度学习入门实战(二)-用TensorFlow训练线性回归

    欢迎大家关注腾讯云技术社区-博客园官方主页,我们将持续在博客园为大家推荐技术精品文章哦~ 作者 :董超 上一篇文章我们介绍了 MxNet 的安装,但 MxNet 有个缺点,那就是文档不太全,用起来可能 ...

  7. TensorFlow入门:线性回归

    随机.mini-batch.batch(见最后解释) 在每个 epoch 送入单个数据点.这被称为随机梯度下降(stochastic gradient descent).我们也可以在每个 epoch ...

  8. 小匠_碣 第三周期打卡 Task06~Task08

    Task06:批量归一化和残差网络:凸优化:梯度下降 批量归一化和残差网络 对输入的标准化(浅层模型) 处理后的任意一个特征在数据集中所有样本上的均值为0.标准差为1. 标准化处理输入数据使各个特征的 ...

  9. 一天搞懂深度学习-训练深度神经网络(DNN)的要点

    前言 这是<一天搞懂深度学习>的第二部分 一.选择合适的损失函数 典型的损失函数有平方误差损失函数和交叉熵损失函数. 交叉熵损失函数: 选择不同的损失函数会有不同的训练效果 二.mini- ...

随机推荐

  1. 解决org.apache.jasper.JasperException: org.apache.jasper.JasperException: XML parsing error on file org.apache.tomcat.util.scan.MergedWebXml

    1.解决办法整个项目建立时采用utf-8编码,包括代码.jsp.配置文件 2.并用最新的tomcat7.0.75 相关链接: http://ask.csdn.net/questions/223650

  2. python语法(四)— 文件操作

    前面几天学习了一写python的基础语法,也学习了分支if,循环while和for.由于之前已经做过几年的开发了,所以我们知道,许多数据来源并不是靠键盘输入到程序中去的,而是通过数据库和文件来获取到的 ...

  3. Codeforces.744B.Hongcow's Game(交互 按位统计)

    题目链接 \(Description\) 一个\(n\times n\)的非负整数矩阵\(A\),保证\(A_{i,i}=0\).现在你要对每个\(i\)求\(\min_{j\neq i}A_{i,j ...

  4. hdu 5203 && BC Round #37 1002

    代码参考自:xyz111 题意: 众所周知,萌萌哒六花不擅长数学,所以勇太给了她一些数学问题做练习,其中有一道是这样的:勇太有一根长度为n的木棍,这个木棍是由n个长度为1的小木棍拼接而成,当然由于时间 ...

  5. 【BZOJ】3751: [NOIP2014]解方程【秦九韶公式】【大整数取模技巧】

    3751: [NOIP2014]解方程 Time Limit: 10 Sec  Memory Limit: 128 MBSubmit: 4856  Solved: 983[Submit][Status ...

  6. 喵哈哈村的魔法考试 Round #4 (Div.2) 题解

    有任何疑问,可以加我QQ:475517977进行讨论. A 喵哈哈村的嘟嘟熊魔法(1) 题解 这道题我们只要倒着来做就可以了,因为交换杯子是可逆的,我们倒着去模拟一遍就好了. 有个函数叫做swap(a ...

  7. Codeforces Round #375 (Div. 2) D. Lakes in Berland 贪心

    D. Lakes in Berland 题目连接: http://codeforces.com/contest/723/problem/D Description The map of Berland ...

  8. px 与 dp, sp换算公式?(转)

    PPI = Pixels per inch,每英寸上的像素数,即 "像素密度" xhdpi: 2.0 hdpi: 1.5 mdpi: 1.0 (baseline) ldpi: 0. ...

  9. socket listen参数中的backlog

    服务器监听时,在每次处理一个客户端的连接时是需要一定时间的,这个时间非常的短(也许只有1ms 或者还不到),但这个时间还是存在的.而这个backlog 存在的意义就是:在这段时间里面除了第一个连接请求 ...

  10. Why I don’t read books

    http://dougmccune.com/blog/2007/03/23/why-i-dont-read-books/ I’ve never read a programming book. I r ...