在TensorFlow的优化器中, 都要设置学习率。学习率是在精度和速度之间找到一个平衡:

学习率太大,训练的速度会有提升,但是结果的精度不够,而且还可能导致不能收敛出现震荡的情况。

学习率太小,精度会有所提升,但是训练的速度慢,耗费较多的时间。

因而我们可以使用退化学习率,又称为衰减学习率。它的作用是在训练的过程中,对学习率的值进行衰减,训练到达一定程度后,使用小的学习率来提高精度。

在TensorFlow中的方法如下:tf.train.exponential_decay(),该方法的参数如下:

learning_rate, 初始的学习率的值

global_step, 迭代步数变量

decay_steps, 带迭代多少次进行衰减

decay_rate, 迭代decay_steps次衰减的值

staircase=False, 默认为False,为True则不衰减

例如

tf.train.exponential_decay(initial_learning_rate, global_step=global_step, decay_steps=1000, decay_rate=0.9)表示没经过1000次的迭代,学习率变为原来的0.9。

增大批次处理样本的数量也可以起到退化学习率的作用。

下面我们写了一个例子,每迭代10次,则较小为原来的0.5,代码如下:

import tensorflow as tf
import numpy as np global_step = tf.Variable(0, trainable=False)
initial_learning_rate = 0.1 learning_rate = tf.train.exponential_decay(initial_learning_rate,
global_step=global_step,
decay_steps=10,
decay_rate=0.5) opt = tf.train.GradientDescentOptimizer(learning_rate)
add_global = global_step.assign_add(1) with tf.Session() as sess:
tf.global_variables_initializer().run()
print(sess.run(learning_rate)) for i in range(50):
g, rate = sess.run([add_global, learning_rate])
print(g, rate)

下面是程序的结果,我们发现没10次就变为原来的一般:

随后,又在MNIST上面进行了测试,发现使用学习率衰减使得准确率有较好的提升。代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt mnist = input_data.read_data_sets('MNIST_data', one_hot=True) tf.reset_default_graph() x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10]) w = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10])) pred = tf.matmul(x, w) + b
pred = tf.nn.softmax(pred) cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1)) global_step = tf.Variable(0, trainable=False)
initial_learning_rate = 0.1 learning_rate = tf.train.exponential_decay(initial_learning_rate,
global_step=global_step,
decay_steps=1000,
decay_rate=0.9) opt = tf.train.GradientDescentOptimizer(learning_rate)
add_global = global_step.assign_add(1) optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) training_epochs = 50
batch_size = 100 display_step = 1 with tf.Session() as sess:
sess.run(tf.global_variables_initializer()) for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(mnist.train.num_examples/batch_size)
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
_, c, add, rate = sess.run([optimizer, cost, add_global, learning_rate], feed_dict={x:batch_xs, y:batch_ys})
avg_cost += c / total_batch if (epoch + 1) % display_step == 0:
print('epoch= ', epoch+1, ' cost= ', avg_cost, 'add_global=', add, 'rate=', rate)
print('finished') correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print('accuracy: ', accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))

在使用衰减学习率我们最后的精度达到0.8897,在使用固定的学习率时,精度只有0.8586。

TensorFlow——学习率衰减的使用方法的更多相关文章

  1. TensorFlow之DNN(二):全连接神经网络的加速技巧(Xavier初始化、Adam、Batch Norm、学习率衰减与梯度截断)

    在上一篇博客<TensorFlow之DNN(一):构建“裸机版”全连接神经网络>中,我整理了一个用TensorFlow实现的简单全连接神经网络模型,没有运用加速技巧(小批量梯度下降不算哦) ...

  2. Tensorflow实现学习率衰减

    Tensorflow实现学习率衰减 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 Deeplearning AI Andrew Ng Tensorflow1.2 API 学习率衰减 ...

  3. Adam和学习率衰减(learning learning decay)

    目录 梯度下降法更新参数 Adam 更新参数 Adam + 学习率衰减 Adam 衰减的学习率 References 本文先介绍一般的梯度下降法是如何更新参数的,然后介绍 Adam 如何更新参数,以及 ...

  4. 改善深层神经网络_优化算法_mini-batch梯度下降、指数加权平均、动量梯度下降、RMSprop、Adam优化、学习率衰减

    1.mini-batch梯度下降 在前面学习向量化时,知道了可以将训练样本横向堆叠,形成一个输入矩阵和对应的输出矩阵: 当数据量不是太大时,这样做当然会充分利用向量化的优点,一次训练中就可以将所有训练 ...

  5. Dropout和学习率衰减

         Dropout 在机器学习的模型中,如果模型的参数太多,而训练样本又太少,训练出来的模型很容易产生过拟合的现象.在训练神经网络的时候经常会遇到过拟合的问题,过拟合具体表现在:模型在训练数据上 ...

  6. [深度学习] pytorch学习笔记(3)(visdom可视化、正则化、动量、学习率衰减、BN)

    一.visdom可视化工具 安装:pip install visdom 启动:命令行直接运行visdom 打开WEB:在浏览器使用http://localhost:8097打开visdom界面 二.使 ...

  7. 权重衰减(weight decay)与学习率衰减(learning rate decay)

    本文链接:https://blog.csdn.net/program_developer/article/details/80867468“微信公众号” 1. 权重衰减(weight decay)L2 ...

  8. TensorFlow模型保存和加载方法

    TensorFlow模型保存和加载方法 模型保存 import tensorflow as tf w1 = tf.Variable(tf.constant(2.0, shape=[1]), name= ...

  9. 吴恩达深度学习笔记(五) —— 优化算法:Mini-Batch GD、Momentum、RMSprop、Adam、学习率衰减

    主要内容: 一.Mini-Batch Gradient descent 二.Momentum 四.RMSprop 五.Adam 六.优化算法性能比较 七.学习率衰减 一.Mini-Batch Grad ...

随机推荐

  1. 2018-12-25-dot-net-double-数组转-float-数组

    title author date CreateTime categories dot net double 数组转 float 数组 lindexi 2018-12-25 09:27:46 +080 ...

  2. java 基本数据类型的自动拆箱与装箱

    ——>  -128~127之间的特殊性.为什么要这样设计,好处? ——>  享元模式(Flyweight Pattern):享元模式的特点是,复用我们内存中已存在的对象,降低系统创建对象实 ...

  3. ASP.NET MVC 实现页落网资源分享网站+充值管理+后台管理(5)之业务层

    业务层主要负责定义业务逻辑(规则.工作流.数据完整性等),接收来自表示层的数据请求,逻辑判断后,向数据访问层提交请求,并传递数据访问结果,业务逻辑层实际上是一个中间件,起着承上启下的重要作用. 在我们 ...

  4. 云栖深度干货 | 打造“云边一体化”,时序时空数据库TSDB技术原理深度解密

    本文选自云栖大会下一代云数据库分析专场讲师自修的演讲——<TSDB云边一体化时序时空数据库技术揭秘> 自修 —— 阿里云智能数据库产品事业部高级专家       认识TSDB 第一代时序时 ...

  5. CAS5.3 单点登录/登出/springboot/springmvc

    环境: jdk:1.8 cas server:5.3.14 + tomcat 8.5 cas client:3.5.1 客户端1:springmvc 传统web项目(使用web.xml) 客户端2:s ...

  6. Channel 9视频整理【3】

    Will 保哥 微软mvp https://channel9.msdn.com/Niners/Will_Huang 繁体中文视频 Visual Studio 2017 新功能探索 https://ch ...

  7. 学习Java第六周

    1.内存结构 Java程序在运行时,需要在内存中的分配空间为提高运算效率,空间进行了不同区域的划分,因为每一片区域都有特定的处理数据方式和内存管理方式. 栈内存 ·用于存储局部变量,当数据使用完,所占 ...

  8. hibernate_检索策略

    一.概述 检索策略分三大块,类级别检索策略和关联级别检测策略. 类级别检索策略:get.load. 关联级别检索策略:order.getCustomer().getName() 上面这两种应该是看得懂 ...

  9. Mybase desktop7.3破解

    1.Mybase Desktop 7.3 安装包 百度云链接: 链接:https://pan.baidu.com/s/1mWZ2_Qmkf6aAX9CYgrN12A 提取码:vjw7 2.破解包 百度 ...

  10. 初始Redis与简单使用

    初始Redis: redis是一个key-value存储系统.和Memcached类似,它支持存储的value类型相对更多,包括string(字符串).list(链表).set(集合).zset(so ...