exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase=False, name=None)

使用方式为

tf.train.exponential_decay( )

在 Tensorflow 中,exponential_decay()是应用于学习率的指数衰减函数(实现指数衰减学习率)。

在训练模型时,通常建议随着训练的进行逐步降低学习率。该函数需要`global_step`值来计算衰减的学习速率。

该函数返回衰减后的学习率。该函数的计算方程式如下

参数:

  • learning_rate - 初始学习率
  • global_step - 用于衰减计算的全局步骤。 一定不为负数。喂入一次 BACTH_SIZE 计为一次 global_step
  • decay_steps - 衰减速度,一定不能为负数,每间隔decay_steps次更新一次learning_rate值
  • decay_rate - 衰减系数,衰减速率,其具体意义参看函数计算方程(对应α^t中的α)。
  • staircase - 若 ‘ True ’ ,则学习率衰减呈 ‘ 离散间隔 ’ (discrete intervals),具体地讲,`global_step / decay_steps`是整数除法,衰减学习率( the decayed learning rate )遵循阶梯函数;若为 ’ False ‘ ,则更新学习率的值是一个连续的过程,每步都会更新学习率。

返回值:

  • 与初始学习率 ‘ learning_rate ’ 相同的标量 ’ Tensor ‘ 。

优点:

  • 训练伊始可以使用较大学习率,以快速得到比较优的解。
  • 后期通过逐步衰减后的学习率进行迭代训练,以使模型在训练后期更加稳定。

示例:

import tensorflow as tf
import matplotlib.pyplot as plt

learning_rate = 0.1
decay_rate = 0.96
global_steps = 1000
decay_steps = 100

global_step = tf.Variable(0, trainable = Fasle)
c = tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase=True)
d = tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase=False)

T_C = []
F_D = []

with tf.Session() as sess:
    for i in range(global_steps):
        T_c = sess.run(c, feed_dict={global_step: i})
        T_C.append(T_c)
        F_d = sess.run(d, feed_dict={global_step: i})
        F_D.append(F_d)

plt.figure(1)
plt.plot(range(global_steps), F_D, 'r-')
plt.plot(range(global_steps), T_C, 'b-')

plt.show()

运行

备注:

(1)

台阶形状的蓝色线是 staircase = True

线条形状的红色线是 staircase = Fasle

(2)

初始学习率 learning_rate 为0.1,总训练次数 global_setps 为 1000 次;staircase=True时,每隔 decay_steps = 100 次更新一次 学习率 learning_rate,而staircase=True时,每一步均会更新一次学习率 learning_rate ,

(3)

训练过程中,decay_rate的数值保持步不变。

TensorFlow 中的 tf.train.exponential_decay() 指数衰减法的更多相关文章

  1. tensorflow之tf.train.exponential_decay()指数衰减法

    exponential_decay(learning_rate,  global_steps, decay_steps, decay_rate, staircase=False, name=None) ...

  2. Tensorflow中的tf.argmax()函数

    转载请注明出处:http://www.cnblogs.com/willnote/p/6758953.html 官方API定义 tf.argmax(input, axis=None, name=None ...

  3. tensorflow中使用tf.variable_scope和tf.get_variable的ValueError

    ValueError: Variable conv1/weights1 already exists, disallowed. Did you mean to set reuse=True in Va ...

  4. [转载]tensorflow中使用tf.ConfigProto()配置Session运行参数&&GPU设备指定

    tf.ConfigProto()函数用在创建session的时候,用来对session进行参数配置: config = tf.ConfigProto(allow_soft_placement=True ...

  5. tensorflow中共享变量 tf.get_variable 和命名空间 tf.variable_scope

    tensorflow中有很多需要变量共享的场合,比如在多个GPU上训练网络时网络参数和训练数据就需要共享. tf通过 tf.get_variable() 可以建立或者获取一个共享的变量. tf.get ...

  6. tensorflow中使用tf.ConfigProto()配置Session运行参数&&GPU设备指定

    tf.ConfigProto()函数用在创建session的时候,用来对session进行参数配置: config = tf.ConfigProto(allow_soft_placement=True ...

  7. tf.Session()函数的参数应用(tensorflow中使用tf.ConfigProto()配置Session运行参数&&GPU设备指定)

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明.本文链接:https://blog.csdn.net/dcrmg/article/details ...

  8. tensorflow中的tf.app.run()的使用

    指明函数的入口,即从哪里执行函数. 如果你的代码中的入口函数不叫main(),而是一个其他名字的函数,如test(),则你应该这样写入口tf.app.run(test()) 如果你的代码中的入口函数叫 ...

  9. tensorflow中协调器 tf.train.Coordinator 和入队线程启动器 tf.train.start_queue_runners

    TensorFlow的Session对象是支持多线程的,可以在同一个会话(Session)中创建多个线程,并行执行.在Session中的所有线程都必须能被同步终止,异常必须能被正确捕获并报告,会话终止 ...

随机推荐

  1. 推荐一款jQuery ColorPicked 颜色拾取器插件

    先看实现的效果图, 本文底部有完整demo 不想看我墨迹的可以跳过了^_^. 官网地址:http://www.eyecon.ro/colorpicker/#about 代码SVN 地址:https:/ ...

  2. iOS中文API之UITouch详解

    UITouch 对象用于位置. 大小. 运动和一根手指在屏幕上为某一特定事件的力度.触摸的力度是从开始在 iOS 9 支持 3D 的触摸的设备上可用.你可以通过UIEvent对象传递给响应者对象访问. ...

  3. ImportError DLL load failed: %1 不是有效的 Win32 应用程序

    操作系统:win7 64位,安装mysqldb 后提示:ImportError DLL load failed: %1 不是有效的 Win32 应用程序,是由于安装的32位的 MySQL-Python ...

  4. 比特币 Bitcoin 是什么,我勒个去,哈耶克果然超前——货币的非国有化,容我思量一下【转载+整理】

    原文地址 比特币矿业史(上):故事的开始,CPU 时代 比特币矿业史(中):群众的觉醒 ,GPU 时代 比特币矿业史(下):巨头的诞生 ,ASIC 时代 本文内容 引子 0 序 1 故事的开始 : C ...

  5. LintCode: Count and Say

    C++ class Solution { public: /** * @param n the nth * @return the nth sequence */ string countAndSay ...

  6. excel自定义数据验证

    1. 判断必须为5位或者9位的数字 2. 自定义限制级别和提示消息

  7. MySQL auto_increment初始值设置

    http://blog.csdn.net/u011439289/article/details/48055917 DROP TABLE IF EXISTS zan1; CREATE TABLE zan ...

  8. java.lang.ClassNotFoundException: SparkPi$$anonfun$1

    出现这个错误可能有两种情况,Jar文件没有传上去,或者Build Path里面包含的Jar文件和Spark的运行环境有冲突. 对于第一种情况,需要在SparkConf语句后面加上Jar文件的路径: v ...

  9. @Service注解的使用

    首先,在applicationContext.xml文件中加一行: <context:component-scan base-package="com.hzhi.clas"/ ...

  10. ScrollView嵌套EditText联带滑动的解决的方法

    本篇文章的相关内容需结合上文:从ScrollView嵌套EditText的滑动事件冲突分析触摸事件的分发机制以及TextView的简要实现和冲突的解决的方法 在说完了怎样解决ScrollView嵌套E ...