exponential_decay(learning_rate,  global_steps, decay_steps, decay_rate, staircase=False, name=None)

使用方式:

tf.tf.train.exponential_decay()

例子:

tf.train.exponential_decay(self.config.e_lr, self.e_global_steps,self.config.decay_steps, self.config.decay_rate, staircase=True)

在 Tensorflow 中,exponential_decay()是应用于学习率的指数衰减函数(实现指数衰减学习率)。

在训练模型时,通常建议随着训练的进行逐步降低学习率。该函数需要`global_step`值来计算衰减的学习速率。

该函数返回衰减后的学习率。该函数的计算方程式如下

参数:

  • learning_rate - 初始学习率
  • global_step - 用于衰减计算的全局步骤。 一定不为负数。喂入一次 BACTH_SIZE 计为一次 global_step
  • decay_steps - 衰减速度,一定不能为负数,每间隔decay_steps次更新一次learning_rate值
  • decay_rate - 衰减系数,衰减速率,其具体意义参看函数计算方程(对应α^t中的α)。
  • staircase - 若 ‘ True ’ ,则学习率衰减呈 ‘ 离散间隔 ’ (discrete intervals),具体地讲,`global_step / decay_steps`是整数除法,衰减学习率( the decayed learning rate )遵循阶梯函数;若为 ’ False ‘ ,则更新学习率的值是一个连续的过程,每步都会更新学习率。

返回值:

  • 与初始学习率 ‘ learning_rate ’ 相同的标量 ’ Tensor ‘ 。

优点:

  • 训练伊始可以使用较大学习率,以快速得到比较优的解。
  • 后期通过逐步衰减后的学习率进行迭代训练,以使模型在训练后期更加稳定。

示例代码:

import tensorflow as tf
import matplotlib.pyplot as plt

learning_rate = 0.1
decay_rate = 0.96
global_steps = 1000
decay_steps = 100

global_step = tf.Variable(0, trainable = Fasle)
c = tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase=True)
d = tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase=False)

T_C = []
F_D = []

with tf.Session() as sess:
for i in range(global_steps):
T_c = sess.run(c, feed_dict={global_step: i})
T_C.append(T_c)
F_d = sess.run(d, feed_dict={global_step: i})
F_D.append(F_d)

plt.figure(1)
plt.plot(range(global_steps), F_D, 'r-')
plt.plot(range(global_steps), T_C, 'b-')

plt.show()

实操:

运行结果:

备注:

(1)

台阶形状的蓝色线是 staircase = True

线条形状的红色线是 staircase = Fasle

(2)

初始学习率 learning_rate 为0.1,总训练次数 global_setps 为 1000 次;staircase=True时,每隔 decay_steps = 100 次更新一次 学习率 learning_rate,而staircase=True时,每一步均会更新一次学习率 learning_rate ,

(3)

训练过程中,decay_rate的数值保持步不变。

参考文献:https://www.cnblogs.com/gengyi/p/9898960.html

tensorflow之tf.train.exponential_decay()指数衰减法的更多相关文章

  1. TensorFlow 中的 tf.train.exponential_decay() 指数衰减法

    exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase=False, name=None) 使 ...

  2. tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数(转)

    tensorflow数据读取机制 tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数 ...

  3. TensorFlow:tf.train.Saver()模型保存与恢复

    1.保存 将训练好的模型参数保存起来,以便以后进行验证或测试.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf.train.S ...

  4. tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数

    tensorflow数据读取机制 tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数 ...

  5. tensorflow的tf.train.Saver()模型保存与恢复

    将训练好的模型参数保存起来,以便以后进行验证或测试.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf.train.Saver( ...

  6. 【转载】 tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数

    原文地址: https://blog.csdn.net/dcrmg/article/details/79776876 ----------------------------------------- ...

  7. [Tensorflow] 使用 tf.train.Checkpoint() 保存 / 加载 keras subclassed model

    在 subclassed_model.py 中,通过对 tf.keras.Model 进行子类化,设计了两个自定义模型. import tensorflow as tf tf.enable_eager ...

  8. TensorFlow 实战(二)—— tf.train(优化算法)

    Training | TensorFlow tf 下以大写字母开头的含义为名词的一般表示一个类(class) 1. 优化器(optimizer) 优化器的基类(Optimizer base class ...

  9. tensorflow API _ 3 (tf.train.polynomial_decay)

    学习率的三种调整方式:固定的,指数的,多项式的 def _configure_learning_rate(num_samples_per_epoch, global_step): "&quo ...

随机推荐

  1. @atcoder - AGC035F@ Two Histograms

    目录 @description@ @solution@ @accepted code@ @details@ @description@ 给定一个 N*M 的方格,我们通过以下步骤往里面填数: (1)将 ...

  2. BKDRhash

     哈希: 字符串(数字同理): 例如有100000个字符串,现在要插入一些字符串,插入前比较是否已经存在避免含有重复数据 用暴力计较的话会比较慢,在某字符串插入时,最好的情况是在第一个位置就遇见该字符 ...

  3. oracle函数 INITCAP(c1)

    [功能]返回字符串并将字符串的第一个字母变为大写,其它字母小写; [参数]c1字符型表达式 [返回]字符型 [示例] SQL> select initcap('smith abc aBC') u ...

  4. HZOJ 方程的解

    乍一看还以为是道水题,没想到这玩意这么难搞. 看题显然是exgcd,然而exgcd求的是一个解而不是解的个数(考试的时候不记得通解的式子然后挂了). 对于40%的数据,直接枚举计数即可. 对于另为20 ...

  5. hdu 4347 The Closest M Points(KD树)

    Problem - 4347 一道KNN的题.直接用kd树加上一个暴力更新就撸过去了.写的时候有一个错误就是搜索一边子树的时候返回有当前层数会被改变了,然后就直接判断搜索另一边子树,搞到wa了半天. ...

  6. VS Code导入已存在的Vue.js工程

    打开vscode------->文件--------->打开文件夹--------->选择工程文件夹-------->确定 查看---->终端或者使用"Ctrl ...

  7. Java容易搞错的知识点

    一.关于Switch 代码: Java代码 1         public class TestSwitch { 2             public static void main(Stri ...

  8. hdu 5734 Acperience(2016多校第二场)

    Acperience Time Limit: 4000/2000 MS (Java/Others)    Memory Limit: 65536/65536 K (Java/Others)Total ...

  9. HTML让文字在图片上显示的几种方法

    第一种方式是image 作为背景图片,即:background:url("......."); 第二种方式是将img块与文字块(文字块采用span标签显示)放在同一个div 中,然 ...

  10. 如何让索引只能被一个SQL使用

    有个徒弟问我,要创建一个索引,去优化一个SQL,但是创建了索引之后其他 SQL 也要用 这个索引,其他SQL慢死了,要优化的SQL又快.遇到这种问题咋搞? 一般遇到这种问题还是很少的.处理的方法很多. ...