tensorflow中的学习率调整策略
通常为了模型能更好的收敛,随着训练的进行,希望能够减小学习率,以使得模型能够更好地收敛,找到loss最低的那个点.
tensorflow中提供了多种学习率的调整方式.在https://www.tensorflow.org/api_docs/python/tf/compat/v1/train搜索decay.可以看到有多种学习率的衰减策略.
- cosine_decay
- exponential_decay
- inverse_time_decay
- linear_cosine_decay
- natural_exp_decay
- noisy_linear_cosine_decay
- polynomial_decay
本文介绍两种学习率衰减策略,指数衰减和多项式衰减.
tf.compat.v1.train.exponential_decay(
learning_rate,
global_step,
decay_steps,
decay_rate,
staircase=False,
name=None
)
learning_rate 初始学习率
global_step 当前总共训练多少个迭代
decay_steps 每xxx steps后变更一次学习率
decay_rate 用以计算变更后的学习率
staircase: global_step/decay_steps的结果是float型还是向下取整
学习率的计算公式为:decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
我们用一段测试代码来绘制一下学习率的变化情况.
#coding=utf-8
import matplotlib.pyplot as plt
import tensorflow as tf
x=[]
y=[]
N = 200 #总共训练200个迭代
num_epoch = tf.Variable(0, name='global_step', trainable=False)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for num_epoch in range(N):
##初始学习率0.5,每10个迭代更新一次学习率.
learing_rate_decay = tf.train.exponential_decay(learning_rate=0.5, global_step=num_epoch, decay_steps=10, decay_rate=0.9, staircase=False)
learning_rate = sess.run([learing_rate_decay])
y.append(learning_rate)
#print(y)
x = range(N)
fig = plt.figure()
ax.set_xlabel('step')
ax.set_ylabel('learing rate')
plt.plot(x, y, 'r', linewidth=2)
plt.show()
结果如图:
- 多项式衰减
tf.compat.v1.train.polynomial_decay(
learning_rate,
global_step,
decay_steps,
end_learning_rate=0.0001,
power=1.0,
cycle=False,
name=None
)
设定一个初始学习率,一个终止学习率,然后线性衰减.cycle控制衰减到end_learning_rate后是否保持这个最小学习率不变,还是循环往复. 过小的学习率会导致收敛到局部最优解,循环往复可以一定程度上避免这个问题.
根据cycle是否为true,其计算方式不同,如下:
#coding=utf-8
import matplotlib.pyplot as plt
import tensorflow as tf
x=[]
y=[]
z=[]
N = 200 #总共训练200个迭代
num_epoch = tf.Variable(0, name='global_step', trainable=False)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for num_epoch in range(N):
##初始学习率0.5,每10个迭代更新一次学习率.
learing_rate_decay = tf.train.polynomial_decay(learning_rate=0.5, global_step=num_epoch, decay_steps=10, end_learning_rate=0.0001, cycle=False)
learning_rate = sess.run([learing_rate_decay])
y.append(learning_rate)
learing_rate_decay2 = tf.train.polynomial_decay(learning_rate=0.5, global_step=num_epoch, decay_steps=10, end_learning_rate=0.0001, cycle=True)
learning_rate2 = sess.run([learing_rate_decay2])
z.append(learning_rate2)
#print(y)
x = range(N)
fig = plt.figure()
ax.set_xlabel('step')
ax.set_ylabel('learing rate')
plt.plot(x, y, 'r', linewidth=2)
plt.plot(x, z, 'g', linewidth=2)
plt.show()
绘图结果如下:
cycle为false时对应红线,学习率下降到0.0001后不再下降. cycle=true时,下降到0.0001后再突变到一个更大的值,在继续衰减,循环往复.
在代码里,通常通过参数去控制不同的学习率策略,例如
def _configure_learning_rate(num_samples_per_epoch, global_step):
"""Configures the learning rate.
Args:
num_samples_per_epoch: The number of samples in each epoch of training.
global_step: The global_step tensor.
Returns:
A `Tensor` representing the learning rate.
Raises:
ValueError: if
"""
# Note: when num_clones is > 1, this will actually have each clone to go
# over each epoch FLAGS.num_epochs_per_decay times. This is different
# behavior from sync replicas and is expected to produce different results.
decay_steps = int(num_samples_per_epoch * FLAGS.num_epochs_per_decay /
FLAGS.batch_size)
if FLAGS.sync_replicas:
decay_steps /= FLAGS.replicas_to_aggregate
if FLAGS.learning_rate_decay_type == 'exponential':
return tf.train.exponential_decay(FLAGS.learning_rate,
global_step,
decay_steps,
FLAGS.learning_rate_decay_factor,
staircase=True,
name='exponential_decay_learning_rate')
elif FLAGS.learning_rate_decay_type == 'fixed':
return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate')
elif FLAGS.learning_rate_decay_type == 'polynomial':
return tf.train.polynomial_decay(FLAGS.learning_rate,
global_step,
decay_steps,
FLAGS.end_learning_rate,
power=1.0,
cycle=False,
name='polynomial_decay_learning_rate')
else:
raise ValueError('learning_rate_decay_type [%s] was not recognized' %
FLAGS.learning_rate_decay_type)
推荐一篇:https://blog.csdn.net/dcrmg/article/details/80017200 对各种学习率衰减策略描述的很详细.并且都有配图,可以很直观地看到各种衰减策略下学习率变换情况.
tensorflow中的学习率调整策略的更多相关文章
- tensorflow中常用学习率更新策略
神经网络训练过程中,根据每batch训练数据前向传播的结果,计算损失函数,再由损失函数根据梯度下降法更新每一个网络参数,在参数更新过程中使用到一个学习率(learning rate),用来定义每次参数 ...
- 【转载】 PyTorch学习之六个学习率调整策略
原文地址: https://blog.csdn.net/shanglianlm/article/details/85143614 ----------------------------------- ...
- 深度学习训练过程中的学习率衰减策略及pytorch实现
学习率是深度学习中的一个重要超参数,选择合适的学习率能够帮助模型更好地收敛. 本文主要介绍深度学习训练过程中的6种学习率衰减策略以及相应的Pytorch实现. 1. StepLR 按固定的训练epoc ...
- 史上最全学习率调整策略lr_scheduler
学习率是深度学习训练中至关重要的参数,很多时候一个合适的学习率才能发挥出模型的较大潜力.所以学习率调整策略同样至关重要,这篇博客介绍一下Pytorch中常见的学习率调整方法. import torch ...
- 【转载】 Pytorch中的学习率调整lr_scheduler,ReduceLROnPlateau
原文地址: https://blog.csdn.net/happyday_d/article/details/85267561 ------------------------------------ ...
- PyTorch学习之六个学习率调整策略
PyTorch学习率调整策略通过torch.optim.lr_scheduler接口实现.PyTorch提供的学习率调整策略分为三大类,分别是 有序调整:等间隔调整(Step),按需调整学习率(Mul ...
- TensorFlow中设置学习率的方式
目录 1. 指数衰减 2. 分段常数衰减 3. 自然指数衰减 4. 多项式衰减 5. 倒数衰减 6. 余弦衰减 6.1 标准余弦衰减 6.2 重启余弦衰减 6.3 线性余弦噪声 6.4 噪声余弦衰减 ...
- pytorch中的学习率调整函数
参考:https://pytorch.org/docs/master/optim.html#how-to-adjust-learning-rate torch.optim.lr_scheduler提供 ...
- 深度学习---1cycle策略:实践中的学习率设定应该是先增再降
深度学习---1cycle策略:实践中的学习率设定应该是先增再降 本文转载自机器之心Pro,以作为该段时间的学习记录 深度模型中的学习率及其相关参数是最重要也是最难控制的超参数,本文将介绍 Lesli ...
随机推荐
- https协议分析
一:什么是HTTPS https全称是超文本传输安全协议,https利用SSL/TLS加密数据包来进行http通信.https开发的主要目的,是提供对网站服务器的身份认证,保护交换数据的隐私与完整性. ...
- [LUOGU1122] 最大子树和 - 树形动规
题目描述 小明对数学饱有兴趣,并且是个勤奋好学的学生,总是在课后留在教室向老师请教一些问题.一天他早晨骑车去上课,路上见到一个老伯正在修剪花花草草,顿时想到了一个有关修剪花卉的问题.于是当日课后,小明 ...
- OSI七层模型和五层TCP/IP协议
1.查公网ip的方法: windows,打开浏览器,访问百度,搜IP即可 linux:curl ifconfig.me 2.OSI七层模型 ==网络工程师:== 物理层 1层,通信介质的信号到数字信号 ...
- 玩转OneNET物联网平台之MQTT服务⑦ —— 远程控制LED(数量无限制)+ Android App控制 优化第一版
授人以鱼不如授人以渔,目的不是为了教会你具体项目开发,而是学会学习的能力.希望大家分享给你周边需要的朋友或者同学,说不定大神成长之路有博哥的奠基石... QQ技术互动交流群:ESP8266&3 ...
- NetworkManager网络通讯_Example(一)
---恢复内容开始--- 用户手册,范例精讲. 用户手册上给出了一个简单的范例,并指出可以以此为基础进行相开发,再次对范例进行精讲.(NetworkManager对使用unity的轻量级游戏开发有很大 ...
- 网络数据请求request
关于网络数据请求的类很多,httpwebrequest,webrequest,webclient以及httpclient,具体差别在此不在赘述,在应用方面介绍webclient与httpclient则 ...
- 如何判断float值有效
// 一个浮点数是否有效,首先要看其是否是一个数字(_isnan为0),其次还要看其是否超出了表示范围(_finite为0) // 注意_finite是有限的意思 #include <float ...
- github 下载子目录内容 亲测可用!
下载我的LYBTouchID项目的Kit目录内容 (1)在github上点开这个目录,浏览器地址栏可以得到这个地址 https://github.com/Liuyubao/LYBTouchID/tre ...
- 关于RocketMQ消息消费与重平衡的一些问题探讨
其实最好的学习方式就是互相交流,最近也有跟网友讨论了一些关于 RocketMQ 消息拉取与重平衡的问题,我姑且在这里写下我的一些总结. ## 关于 push 模式下的消息循环拉取问题 之前发表了一篇关 ...
- 机器学习笔记(一)· 感知机算法 · 原理篇
这篇学习笔记强调几何直觉,同时也注重感知机算法内部的动机.限于篇幅,这里仅仅讨论了感知机的一般情形.损失函数的引入.工作原理.关于感知机的对偶形式和核感知机,会专门写另外一篇文章.关于感知机的实现代码 ...