优化器Optimizer
目前最流行的5种优化器:Momentum(动量优化)、NAG(Nesterov梯度加速)、AdaGrad、RMSProp、Adam,所有的优化算法都是在原始梯度下降算法的基础上增加惯性和环境感知因素进行持续优化
Momentum优化
momentum优化的一个简单思想:考虑物体运动惯性,想象一个保龄球在光滑表面滚下一个平缓的坡度,最开始会很慢,但是会迅速地恢复动力,直到达到最终速度(假设又一定的摩擦力核空气阻力)
momentum优化关注以前的梯度是多少,公式:
\((1)m \leftarrow \beta m + \eta \nabla _\theta J(\theta)\)
\((2)\theta \leftarrow \theta - m\)
超参数\(\beta\)称为动量,其必须设置在0(高摩擦)和1(零摩擦)之间,默认值为0.9
可以很容易地验证当梯度保持一个常量,最终速度(即权重的最大值)就等于梯度乘以学习率乘以\(\frac{1}{1-\beta}\),当\(\beta = 0.9\)时,那么最终速度等于10倍梯度乘以学习率,所有momentum优化最终会比梯度下降快10倍,在不适用批量归一化的深度神经网络中,高层最终常会产生不同尺寸的输入,因此使用momentum优化会很有帮助,同时还会帮助跨过局部最优
由于又动量,优化器可能会超调一点,然后返回,再超调,来回震荡多次后,最后稳定在最小值,这也是系统中要有一些摩擦的原因之一,它可以帮助摆脱震荡,从而加速收敛
optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,momentum=0.9)
Nesterov梯度加速
公式:
\((1)m \leftarrow \beta m + \eta \nabla _\theta J(\theta + \beta m)\)
\((2)\theta \leftarrow \theta - m\)
与momentum唯一不同的是用\(\theta + \beta m\)来测量梯度,这个小调整有效是因为在通常情况下,动量矢量会指向正确的方向,所以在该方向相对远的地方使用梯度会比在原有地方更准确一些
optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,momentum=0.9,use_nesterov=True)
AdaGrad
AdaGrad对于简单的二次问题一般表现都不错,但是在训练神经网络时却经常很早就停滞了,学习速率缩小得很多,在到达全局最优前算法就停止了,所以尽管tensorflow又AdagradOptimizer,也不要用它来训练深度神经网络
公式:
\((1)s \leftarrow s + \nabla _\theta J(\theta) \otimes \nabla _\theta J(\theta)\)
\((2)\theta \leftarrow \theta - \eta \nabla _\theta J(\theta) \oslash \sqrt{s+\varepsilon}\)
RMSProp
AdaGrad降速太快而且没有办法收敛到全局最优,RMSProp算法却通过仅积累最近迭代中得梯度(而非从训练开始得梯度)解决这个问题,它通在第一步使用指数衰减开实现
公式:
\((1)s \leftarrow \beta s + (1-\beta)\nabla _\theta J(\theta) \otimes \nabla _\theta J(\theta)\)
\((2)\theta \leftarrow \theta - \eta \nabla _\theta J(\theta) \oslash \sqrt{s+\varepsilon}\)
衰减率\(\eta\)通常为0.9
optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate,momentum=0.9,decay=0.9,epsilon=0.9)
除去非常简单得问题,这个优化器得表现几乎全部优于AdaGrad,同时表现也基本都优于Momentum优化和NAG,事实上在Adam优化出现之前,它是众多研究者所推荐得优化算法
Adam优化
Adam代表了自适应力矩估计,集合了Momentum优化和RmsProp的想法,类似Momentum优化,它会跟踪过去梯度的指数衰减平均值,同时也类似RMSProp,它会跟踪过去梯度平方的指数衰减平均值,
Adam算法:
\((1)m \leftarrow \beta_1 m + (1-\beta_i) \nabla _\theta J(\theta)\)
\((2)s \leftarrow \beta_2s +(1-\beta_2)\nabla _\theta J(\theta) \otimes \nabla _\theta J(\theta)\)
\((3)m \leftarrow \frac{m}{1-\beta_1^T}\)
\((4)s \leftarrow \frac{s}{1-\beta_2^T}\)
\((5)\theta \leftarrow \theta - \eta m\oslash \sqrt{s+\varepsilon}\)
注:T表示迭代次数(从1开始)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
使用Adam优化器对mnist进行测试
import tensorflow as tf
from tensorflow.contrib.layers import fully_connected,batch_norm
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)
tf.reset_default_graph()
n_input = 784
n_hidden1 = 300
n_hidden2 = 100
n_output = 10
X = tf.placeholder(tf.float32,shape=(None,n_input),name='X')
Y = tf.placeholder(tf.int64,shape=(None,10),name='Y')
#归一化参数
is_training = tf.placeholder(tf.bool,shape=(),name='is_training')
bn_params = {'is_training':is_training,'decay':0.99,'updates_collections':None}
with tf.name_scope('dnn'):
with tf.contrib.framework.arg_scope([fully_connected],normalizer_fn=batch_norm,normalizer_params=bn_params):
hidden1 = fully_connected(X,n_hidden1,activation_fn=tf.nn.elu,scope='hidden1')
hidden2 = fully_connected(hidden1,n_hidden2,activation_fn=tf.nn.elu,scope='hidden2')
y_prab = fully_connected(hidden2,n_output,activation_fn=tf.nn.softmax,scope='output')
with tf.name_scope('train'):
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=Y,logits=y_prab))
learning_rate = tf.placeholder(tf.float32,shape=(),name='learning_rate')
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss)
with tf.name_scope('accuracy'):
prab_bool = tf.equal(tf.argmax(y_prab,1),tf.argmax(Y,1))
accuracy = tf.reduce_mean(tf.cast(prab_bool,tf.float32))
with tf.name_scope('tensorboard_mnist'):
file_writer = tf.summary.FileWriter('./tensorboard/',tf.get_default_graph())
accuracy_summary = tf.summary.scalar('accuracy',accuracy)
with tf.name_scope('saver'):
saver = tf.train.Saver()
with tf.name_scope('collection'):
tf.add_to_collection('logits',y_prab)
epoches = 20
batch_size = 100
n_batches = mnist.train.num_examples // batch_size
rate = 0.1
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for epoch in range(epoches):
for batch in range(n_batches):
x_batch,y_batch = mnist.train.next_batch(batch_size)
sess.run(optimizer,feed_dict={X:x_batch,Y:y_batch,learning_rate:rate,is_training:True})
result = sess.run([accuracy,accuracy_summary],feed_dict={X:mnist.test.images,Y:mnist.test.labels,
learning_rate:rate,is_training:False})
file_writer.add_summary(result[1],epoch)
print('epoch:{},accuracy:{}'.format(epoch,result[0]))
saver.save(sess,'./model/model_final.ckpt',global_step=5)
print('stop')
Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
epoch:0,accuracy:0.945900022983551
epoch:1,accuracy:0.9574999809265137
epoch:2,accuracy:0.9635000228881836
epoch:3,accuracy:0.9693999886512756
epoch:4,accuracy:0.970300018787384
epoch:5,accuracy:0.9704999923706055
epoch:6,accuracy:0.9758999943733215
epoch:7,accuracy:0.9757999777793884
epoch:8,accuracy:0.9768999814987183
epoch:9,accuracy:0.9783999919891357
epoch:10,accuracy:0.9783999919891357
epoch:11,accuracy:0.9642999768257141
epoch:12,accuracy:0.9779999852180481
epoch:13,accuracy:0.9799000024795532
epoch:14,accuracy:0.9760000109672546
epoch:15,accuracy:0.977400004863739
epoch:16,accuracy:0.9819999933242798
epoch:17,accuracy:0.9781000018119812
epoch:18,accuracy:0.9661999940872192
epoch:19,accuracy:0.9779000282287598
stop
优化器Optimizer的更多相关文章
- 【深度学习】深入理解优化器Optimizer算法(BGD、SGD、MBGD、Momentum、NAG、Adagrad、Adadelta、RMSprop、Adam)
在机器学习.深度学习中使用的优化算法除了常见的梯度下降,还有 Adadelta,Adagrad,RMSProp 等几种优化器,都是什么呢,又该怎么选择呢? 在 Sebastian Ruder 的这篇论 ...
- scipy优化器optimizer
#optimazer优化器 from scipy.optimize import minimize def rosem(x): return sum(100.0*(x[1:]-x[:-1])**2.0 ...
- 深度学习优化器 optimizer 的选择
网址:https://blog.csdn.net/g11d111/article/details/76639460
- Oracle 课程五之优化器和执行计划
课程目标 完成本课程的学习后,您应该能够: •优化器的作用 •优化器的类型 •优化器的优化步骤 •扫描的基本类型 •表连接的执行计划 •其他运算方式的执行计划 •如何看执行计划顺序 •如何获取执行计划 ...
- Oracle SQL优化器简介
目录 一.Oracle的优化器 1.1 优化器简介 1.2 SQL执行过程 二.优化器优化方式 2.1 优化器的优化方式 2.2 基于规则的优化器 2.3 基于成本的优化器 三.优化器优化模式 3.1 ...
- oracle-sql优化器
优化器optimizer Oracle 执行计划(Explain Plan) 说明 http://langgufu.iteye.com/blog/2158163 explain plan是一个dml语 ...
- Oracle 优化器
http://blog.csdn.net/it_man/article/details/8185370一.优化器基本知识 Oracle在执行一个SQL之前,首先要分析一下语句的执行计划,然后再按执 ...
- 采用梯度下降优化器(Gradient Descent optimizer)结合禁忌搜索(Tabu Search)求解矩阵的全部特征值和特征向量
[前言] 对于矩阵(Matrix)的特征值(Eigens)求解,采用数值分析(Number Analysis)的方法有一些,我熟知的是针对实对称矩阵(Real Symmetric Matrix)的特征 ...
- pytorch1.0进行Optimizer 优化器对比
pytorch1.0进行Optimizer 优化器对比 import torch import torch.utils.data as Data # Torch 中提供了一种帮助整理数据结构的工具, ...
随机推荐
- webservice的测试案例
1.服务器端 服务器接口Test_service.java package com.xiaostudy; /** * @desc 服务器接口 * @author xiaostudy * */ publ ...
- ubuntu16.04 python3.5 opencv的安装与卸载(转载)
转载https://blog.csdn.net/qq_37541097/article/details/79045595 Ubuntu16.04 自带python2.7和python3.5两个版本,默 ...
- u-boot-2015.07 make xxx_config 分析
1.u-boot编译脚本:mk.sh #! /bin/sh export PATH=$PATH:/opt/ti-sdk-am335x-evm-08.00.00.00/linux-devkit/sysr ...
- Dive into Spring framework -- 搭建spring 源码的开发环境
spring是一个类之间依赖的管理容器,大家都知道,但我们中很多人都仅仅停留在使用的层面,但spring本身具有极大的研究价值,所以在使用了几年spring之后,还是想深入的探究一下其根源.记录于此, ...
- 常见HTTP状态(304,)
一.1XX(临时响应) 表示临时响应并需要请求者继续执行操作的状态码. 100(继续) 请求者应当继续提出请求.服务器返回此代码表示:已经收到请求的第一部分,正在等待其余部分. 101(切换协议) 请 ...
- Spark 基于物品的协同过滤算法实现
J由于 Spark MLlib 中协同过滤算法只提供了基于模型的协同过滤算法,在网上也没有找到有很好的实现,所以尝试自己实现基于物品的协同过滤算法(使用余弦相似度距离) 算法介绍 基于物品的协同过滤算 ...
- 三十二 Python分布式爬虫打造搜索引擎Scrapy精讲—scrapy的暂停与重启
scrapy的每一个爬虫,暂停时可以记录暂停状态以及爬取了哪些url,重启时可以从暂停状态开始爬取过的URL不在爬取 实现暂停与重启记录状态 1.首先cd进入到scrapy项目里 2.在scrapy项 ...
- HUST 1010 The Minimum Length (字符串最小循环节)
题意 有一个字符串A,一次次的重写A,会得到一个新的字符串AAAAAAAA.....,现在将这个字符串从中切去一部分得到一个字符串B.例如有一个字符串A="abcdefg".,复制 ...
- nmcli 使用记录---fatt
安装nmcli工具 yum install NetworkManager 使用语法 Usage: nmcli [OPTIONS] OBJECT { COMMAND | help } OBJECT g[ ...
- iOS自动化探索(八)Mac上的Jenkins安装
安装Jenkins 首先检查是否有Jenkins依赖的java环境 java -version 出现java version "1.8.xx"说明已经安装了java Jackeys ...