搭建好网络后,常使用梯度下降类优化算法进行模型参数求解,模型越复杂我们在训练神经网络的过程上花的时间就越多,为了解决这一问题,我们就需要找一些优化算法来提高训练速度,TF的tf.train模块中提供了丰富的优化算法,这一节对这些优化器做下简单的对比。

Stochastic Gradient Descent(SGD)

最基础的方法就是GD了,将整个数据集放入模型中,不断的迭代得到模型的参数,当然这样的方法计算资源占用的比较大,那么有没有什么好的解决方法呢?就是把整个数据集分成小批(mini-batch),然后再进行上述操作这就是SGD了,这种方法虽然不能反应整体的数据情况,不过能够很大程度上加快了模型的训练速度,并且也不会丢失太多的准确率

参数的迭代公式

\(w:=w-\alpha*dw\)

Momentum

传统的GD可能会让学习过程十分的曲折,这里我们引入了惯性这一分量,在朝着最优点移动的过程中由于惯性走的弯路会变少

\(m=\beta*m-\alpha*dw\)

\(w:=w-m\)

AdaGrad

这个方法主要是在学习率上面动手脚,每个参数的更新都会有不同的学习率

\(s=s+dw^2\)

\(w:=w-\alpha*dw/\sqrt{s}\)

RMSProp

AdaGrad收敛速度快,但不一定是全局最优,为了解决这一点,加入了Momentum部分

\(s=\beta*s+(1-\beta)dw^2\)

\(w:=w-\alpha*dw/\sqrt{s}\)

Adam

adam是目前比较好的方法,它融合了Momentum和RMSProp方法

代码示例

下面部分使用TF来比较一下这些方法的效果

# -*- coding: utf-8 -*-
"""
@author: VasiliShi
"""
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
def reset_graph(seed=42):
tf.reset_default_graph()
tf.set_random_seed(seed)
np.random.seed(seed)
reset_graph()
plt.figure(1,figsize=(10,8))
x = np.linspace(-1,1,100)[:,np.newaxis] #<==>x=x.reshape(100,1)
noise = np.random.normal(0,0.1,size = x.shape)
y=np.power(x,2) + x +noise #y=x^2 + x+噪音
plt.scatter(x,y)
plt.show()
learning_rate = 0.01
batch_size = 10 #mini-batch的大小
class Network(object):
def __init__(self,func,**kwarg):
self.x = tf.placeholder(tf.float32,[None,1])
self.y = tf.placeholder(tf.float32,[None,1])
hidden = tf.layers.dense(self.x,20,tf.nn.relu)
output = tf.layers.dense(hidden,1)
self.loss = tf.losses.mean_squared_error(self.y,output)
self.train = func(learning_rate,**kwarg).minimize(self.loss)
SGD = Network(tf.train.GradientDescentOptimizer)
Momentum = Network(tf.train.MomentumOptimizer,momentum=0.5)
AdaGrad = Network(tf.train.AdagradOptimizer)
RMSprop = Network(tf.train.RMSPropOptimizer)
Adam = Network(tf.train.AdamOptimizer)
networks = [SGD,Momentum,AdaGrad,RMSprop,Adam]
record_loss = [[], [], [], [], []] #踩的坑不能使用[[]]*5
plt.figure(2,figsize=(10,8))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for stp in range(200):
index = np.random.randint(0,x.shape[0],batch_size)#模拟batch
batch_x = x[index]
batch_y = y[index]
for net,loss in zip(networks,record_loss):
_,l = sess.run([net.train,net.loss],feed_dict={net.x:batch_x,net.y:batch_y})
loss.append(l)#保存每一batch的loss
labels = ['SGD','Momentum','AdaGrad','RMSprop','Adam']
for i,loss in enumerate(record_loss):
plt.plot(loss,label=labels[i])
plt.legend(loc="best")
plt.xlabel("steps")
plt.ylabel("loss")
plt.show()

下图是batch_size=10的结果

下图是batch_size=30的结果

可以看的出Adam方法收敛速度最快,并且波动最小。

TensorFlow中的优化算法的更多相关文章

  1. Tensorflow 中的优化器解析

    Tensorflow:1.6.0 优化器(reference:https://blog.csdn.net/weixin_40170902/article/details/80092628) I:  t ...

  2. optim.py-使用tensorflow实现一般优化算法

    optim.py Project URL:https://github.com/Codsir/optim.git Based on: tensorflow, numpy, copy, inspect ...

  3. TensorFlow中设置学习率的方式

    目录 1. 指数衰减 2. 分段常数衰减 3. 自然指数衰减 4. 多项式衰减 5. 倒数衰减 6. 余弦衰减 6.1 标准余弦衰减 6.2 重启余弦衰减 6.3 线性余弦噪声 6.4 噪声余弦衰减 ...

  4. 分别使用 Python 和 Math.Net 调用优化算法

    1. Rosenbrock 函数 在数学最优化中,Rosenbrock 函数是一个用来测试最优化算法性能的非凸函数,由Howard Harry Rosenbrock 在 1960 年提出 .也称为 R ...

  5. 梯度优化算法总结以及solver及train.prototxt中相关参数解释

    参考链接:http://sebastianruder.com/optimizing-gradient-descent/ 如果熟悉英文的话,强烈推荐阅读原文,毕竟翻译过程中因为个人理解有限,可能会有谬误 ...

  6. 机器学习中几种优化算法的比较(SGD、Momentum、RMSProp、Adam)

    有关各种优化算法的详细算法流程和公式可以参考[这篇blog],讲解比较清晰,这里说一下自己对他们之间关系的理解. BGD 与 SGD 首先,最简单的 BGD 以整个训练集的梯度和作为更新方向,缺点是速 ...

  7. TensorFlow实现与优化深度神经网络

    TensorFlow实现与优化深度神经网络 转载请注明作者:梦里风林Github工程地址:https://github.com/ahangchen/GDLnotes欢迎star,有问题可以到Issue ...

  8. TensorFlow中的通信机制——Rendezvous(二)gRPC传输

    背景 [作者:DeepLearningStack,阿里巴巴算法工程师,开源TensorFlow Contributor] 本篇是TensorFlow通信机制系列的第二篇文章,主要梳理使用gRPC网络传 ...

  9. TensorFlow中的并行执行引擎——StreamExecutor框架

    背景 [作者:DeepLearningStack,阿里巴巴算法工程师,开源TensorFlow Contributor] 在前一篇文章中,我们梳理了TensorFlow中各种异构Device的添加和注 ...

随机推荐

  1. ssh linux免密登录。。。。生产共钥到另一台主机

    一.第一种方式: 1.ssh-keygen -t rsa -t : 加密方式 默认为rsa 可以省略不写 加密方式选 rsa|dsa 2.将 .pub 文件复制到目标机器的 .ssh 目录, 并 ca ...

  2. linux如何查看端口被哪个进程占用

    1.lsof -i:端口号 2.netstat -tunlp|grep 端口号 都可以查看指定端口被哪个进程占用的情况 工具/原料   linux,windows xshell 方法/步骤     [ ...

  3. 【转】python中的一维卷积conv1d和二维卷积conv2d

    转自:https://blog.csdn.net/qq_26552071/article/details/81178932 二维卷积conv2d 给定4维的输入张量和滤波器张量来进行2维的卷积计算.即 ...

  4. 对 String 字符串的理解

    1.通过构造方法创建的字符串对象和直接赋值方式创建的字符串对象区别? 通过构造方法创建字符串对象是在堆内存. 直接赋值方式创建对象是在方法区的常量池. ==: 基本数据类型,比较的是基本数据类型的值是 ...

  5. python inspect.stack() 的简单使用

    1. #python # -*- encoding: utf-8 -*- #获取函数的名字 import inspect def debug(): callnamer = inspect.stack( ...

  6. 随笔一个dom节点绑定事件

    以下利用jquery说明: js中,给一个dom节点绑定事件再平常不过了.这里说下,如果dom经常发生变化的话,给这个dom绑定事件的情况. 比如代码如下: li的节点,绑定了事件:点击会打出来里头的 ...

  7. 读取Excel的记录并导入SQL数据库

    准备一下,近段时间,需要把Excel的数据导入数据库中. 引用命名空间: using System.Configuration; using System.Data; using System.Dat ...

  8. 将WinForm程序(含多个非托管Dll)合并成一个exe的方法

    原文:将WinForm程序(含多个非托管Dll)合并成一个exe的方法 开发程序的时候经常会引用一些第三方的DLL,然后编译生成的exe文件就不能脱离这些DLL独立运行了. ILMerge能把托管dl ...

  9. C# 实现表单的自动化测试<通过程序控制一个网页>

    学历代表你的过去,能力代表你的现在,学习代表你的将来 十年河东,十年河西,莫欺少年穷 学无止境,精益求精 C# 实现表单的自动化测试,这标题看着就来劲!那么,如何通过C#程序控制一个网页呢? 在此,以 ...

  10. Express中间件,看这篇文章就够了(#^.^#)

    底层:http模块 express目前是最流行的基于Node.js的web开发框架,express框架建立在内置的http模块上, var http = require('http') var app ...