TensorFlow优化器浅析
本文基于tensorflow-v1.15分支,简单分析下TensorFlow中的优化器。
optimizer = tf.train.GradientDescentOptimizer(learning_rate=.05)
train_op = optimizer.minimize(loss)
当我们调用optimizer.minimize()
时,其内部会调用两个方法compute_gradients()
和apply_gradients()
,分别用来计算梯度和使用梯度更新权重,其核心逻辑如下所示。
def minimize(self, loss, global_step=None, var_list=None,
gate_gradients=GATE_OP, aggregation_method=None,
colocate_gradients_with_ops=False, name=None,
grad_loss=None):
grads_and_vars = self.compute_gradients(
loss, var_list=var_list, gate_gradients=gate_gradients,
aggregation_method=aggregation_method,
colocate_gradients_with_ops=colocate_gradients_with_ops,
grad_loss=grad_loss)
vars_with_grad = [v for g, v in grads_and_vars if g is not None]
return self.apply_gradients(grads_and_vars, global_step=global_step, name=name)
如果我们想在模型更新前对梯度搞一些自定义的操作,TensorFlow中推荐的方式是
- 通过
compute_gradients
计算梯度 - 对梯度进行一些自定义操作
- 通过
apply_gradients
将处理后的梯度更新到模型权重
在optimizer.minimize
的第一阶段,我们首先通过compute_gradients
计算出梯度。在compute_gradients函数中,TensorFlow使用了两种计算梯度的方式,分别是针对静态图的tf.gradients
接口和针对动态图的tf.GradientTape
接口,这两个接口内部分别使用了符号微分和自动微分的方式来计算梯度。下面是compute_gradients的核心执行逻辑,代码中省略了部分异常判断的语句。可以看到,如果传入的loss是一个可调用对象,那么就会调用backprop.GradientTape
相关的接口去求解梯度;否则,就会调用gradients.gradients
接口去求解梯度。
from tensorflow.python.eager import backprop
from tensorflow.python.ops import gradients
def compute_gradients(self, loss, var_list=None,
gate_gradients=GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
grad_loss=None):
if callable(loss):
with backprop.GradientTape() as tape:
if var_list is not None:
tape.watch(var_list)
loss_value = loss()
loss_value = self._scale_loss(loss_value)
if var_list is None:
var_list = tape.watched_variables()
with ops.control_dependencies([loss_value]):
grads = tape.gradient(loss_value, var_list, grad_loss)
return list(zip(grads, var_list))
# Non-callable/Tensor loss case
# Scale loss if using a "mean" loss reduction and multiple replicas.
loss = self._scale_loss(loss)
if var_list is None:
var_list = (
variables.trainable_variables() +
ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
else:
var_list = nest.flatten(var_list)
var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
processors = [_get_processor(v) for v in var_list]
var_refs = [p.target() for p in processors]
grads = gradients.gradients(
loss, var_refs, grad_ys=grad_loss,
gate_gradients=(gate_gradients == Optimizer.GATE_OP),
aggregation_method=aggregation_method,
colocate_gradients_with_ops=colocate_gradients_with_ops)
if gate_gradients == Optimizer.GATE_GRAPH:
grads = control_flow_ops.tuple(grads)
grads_and_vars = list(zip(grads, var_list))
return grads_and_vars
一般来说,loss是一个tensor,因此我们主要关注上述代码的第16-29行。在第16-19行,我们获取需要求解梯度的变量列表。如果没有指定var_list,那么compute_gradient函数会默认获取所有的TRAINABLE_VARIABLES
和TRAINABLE_RESOURCE_VARIABLES
。第20行貌似啥也没做,因为在源代码中2找不到名为_STRAMING_MODEL_PROTS
的变量集合。注意到第23行调用gradients.gradients函数计算梯度,这个函数实现在python/ops/gradient_impl.py文件中,其内部调用了_GradientsHelper
来实现真正的计算。因为_GradientsHelper
这个函数特别长,而且它和gradients函数的参数相同,所以我们这里先介绍几个重要形参的含义。
@tf_export(v1=["gradients"])
def gradients(ys,
xs,
grad_ys=None,
name="gradients",
colocate_gradients_with_ops=False,
gate_gradients=False,
aggregation_method=None,
stop_gradients=None,
unconnected_gradients=UnconnectedGradients.NONE):
with ops.get_default_graph()._mutation_lock():
return gradients_util._GradientsHelper(
ys, xs, grad_ys, name, colocate_gradients_with_ops,
gate_gradients, aggregation_method, stop_gradients,
unconnected_gradients)
ys和xs参数均接收单个tensor或tensor列表,分别对应\(\frac{\partial Y}{\partial X}\)中的\(Y\)和\(X\)。
grad_ys参数接收单个tensor或tensor列表,它的维度必须和ys的维度相同。grad_ys为ys中的每个tensor提供初始值,如果grad_ys为None,那么ys中每个tensor的初始值就被设置为1。
aggregation_method表示梯度聚合的方式,TensorFlow支持的所有聚合方式均定义于tf.AggregationMethods
类中,包括ADD_N
、DEFAULT
、EXPERIMENTAL_N
以及EXPERIMENTAL_ACCUMULATE_N
等方法。
stop_gradients参数接收单个tensor或tensor列表,这些tensor不参与反向传播梯度的计算。注意,tensorflow提供了另一个接口tf.stop_gradients,也可以完成相同的工作。二者的区别在于tf.stop_gradient作用于计算图构建时,而tf.gradients的stop_gradients参数作用于计算图的运行时。
_GradientHelper是构建反向计算图并求解梯度的关键方法,需要仔细阅读。这里暂时给出一个简略的分析。
这个方法会维护两个重要变量:
- 一个队列queue,队列里存放计算图里所有出度为0的Op
- 一个字典grads,字典的键是Op本身,值是该Op每个输出端收到的梯度列表
反向传播求梯度时,每从队列中弹出一个Op,都会把它输出变量的梯度加起来(对应全微分定理)得到out_grads,然后获取对应的梯度计算函数grad_fn。Op本身和out_grads会传递给grad_fn做参数,求出输入的梯度。每当一个Op的梯度被求出来,就会更新所有未经处理的Op的出度和queue。当queue为空时,就表示整个反向计算图处理完毕。
if grad_fn:
in_grads = _MaybeCompile(grad_scope, op, func_call,
lambda: grad_fn(op, *out_grads))
else:
in_grads = _MaybeCompile(grad_scope, op, func_call,
lambda: _SymGrad(op, out_grads, xs))
grad_fn是梯度计算函数,它用来计算给定Op的梯度。在TensorFlow里,每个Op都会定义一个对应的梯度计算函数。例如,下面是平方函数(tf.square)的梯度:
@ops.RegisterGradient("Square")
def _SquareGrad(op, grad):
x = op.inputs[0]
# Added control dependencies to prevent 2*x from being computed too early.
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
y = constant_op.constant(2.0, dtype=x.dtype)
return math_ops.multiply(grad, math_ops.multiply(x, y))
apply_gradients是optimizer.minimize的第二阶段,它将梯度更新应用到变量上。根据所使用的学习算法的不同,apply_gradients内部会调用不同的Optimizer实现。下面的代码展示了apply_gradients的核心执行逻辑。
converted_grads_and_vars = []
for g, v in grads_and_vars:
if g is not None:
g = ops.convert_to_tensor_or_indexed_slices(g)
p = _get_processor(v)
converted_grads_and_vars.append((g, v, p))
converted_grads_and_vars = tuple(converted_grads_and_vars)
var_list = [v for g, v, _ in converted_grads_and_vars if g is not None]
with ops.init_scope():
self._create_slots(var_list)
update_ops = []
with ops.name_scope(name, self._name) as name:
self._prepare()
for grad, var, processor in converted_grads_and_vars:
if grad is None:
continue
else:
scope_name = var.op.name
with ops.name_scope("update_" + scope_name), ops.colocate_with(var):
update_ops.append(processor.update_op(self, grad))
if global_step is None:
apply_updates = self._finish(update_ops, name)
else:
with ops.control_dependencies([self._finish(update_ops, "update")]):
with ops.colocate_with(global_step):
if isinstance(
global_step, resource_variable_ops.BaseResourceVariable):
# TODO(apassos): the implicit read in assign_add is slow; consider
# making it less so.
apply_updates = resource_variable_ops.assign_add_variable_op(
global_step.handle,
ops.convert_to_tensor(1, dtype=global_step.dtype),
name=name)
else:
apply_updates = state_ops.assign_add(global_step, 1, name=name)
if not context.executing_eagerly():
if isinstance(apply_updates, ops.Tensor):
apply_updates = apply_updates.op
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
if apply_updates not in train_op:
train_op.append(apply_updates)
return apply_updates
在1-6行,程序将每个非None的梯度转化成tensor(稠密)或indexedslices(稀疏),根据每个变量存储类型的不同,我们获取到不同的processor(第5行),最终将一个三元组(g, v, p)保存到列表converted_grads_and_vars中,以备后用。这里主要解释下第5行,对于所有可优化的变量OptimizableVariable,根据其类型的不同,我们需要调用不同update_op。
在程序的8-10行,我们首先获取到非None的grad对应的var,然后对相应的var创建slots。_create_slots方法需要Optimzier的子类自己去实现,它的作用是创建学习算法所需要的中间变量。以momentum sgd为例,它的更新公式为:
accumulation = momentum * accumulation + gradient
variable -= learning_rate * accumulation
可以看到,它在更新变量的时候需要用到一个中间变量accumulation。因此,我们需要为每个变量创建一个slot,用来保存这个中间变量,以便在下次进行权重更新时继续使用它:
def _create_slots(self, var_list):
for v in var_list:
self._zeros_slot(v, "momentum", self._name)
在程序的12-20行,我们为每个var添加对应的update_op。第14行调用了_prepare方法,它是用来初始化一些必要的变量(例如学习率、动量),为应用梯度做准备。还是以momentum sgd为例子,它的_prepare函数实现如下:
def _prepare(self):
learning_rate = self._learning_rate
if callable(learning_rate):
learning_rate = learning_rate()
self._learning_rate_tensor = ops.convert_to_tensor(learning_rate,
name="learning_rate")
momentum = self._momentum
if callable(momentum):
momentum = momentum()
self._momentum_tensor = ops.convert_to_tensor(momentum, name="momentum")
可以看到,在这个函数中,它将学习率和动量都转化成了tensor(为什么要转成tensor?)。准备工作完成后,我们就可以给每个非None的grad和var添加相应的update_op。注意第18行,我们使用ops.colocate_with(var)把var对应的update_op放置到var所在的设备上。最后,我们调用_finish函数以完成所有的更新。一般来说_finish函数不需要重写,唯一的例外是Adam算法,它在实现时重写了_finish算法。
前面提到,对于所有的可优化的变量,根据其类型的不同,我们会调用不同的update_op。一般来说,不同的Optimizer需要实现的update_op主要包括四种:_apply_dense、_resource_apply_dense、_apply_sparse和_resource_apply_sparse,其中前两种对应稠密更新,后两种对应稀疏更新。以_resource开头的方法是针对variable handle,而不带_resource的update_op则是针对variable的。这里我们还是以momentum sgd为例,介绍一下它的_apply_dense的实现。下面是对应的代码,可以看到它首先获取了中间变量mom,然后直接调用了training_ops中的apply_momentum方法。
def _apply_dense(self, grad, var):
mom = self.get_slot(var, "momentum")
return training_ops.apply_momentum(
var, mom,
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
grad,
math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
use_locking=self._use_locking,
use_nesterov=self._use_nesterov).op
apply_momentum方法是由bazel构建生成的代码,它会调用op_def_lib中的_apply_op_helper函数,将一个名为ApplyMomentum的Op添加到计算图中:
_, _, _op = _op_def_lib._apply_op_helper(
"ApplyMomentum", var=var, accum=accum, lr=lr, grad=grad,
momentum=momentum, use_locking=use_locking,
use_nesterov=use_nesterov, name=name)
根据gen_training_ops.py中的注释,我们可以找到ApplyMomemtum这个Op的注册信息:
REGISTER_OP("ApplyMomentum")
.Input("var: Ref(T)")
.Input("accum: Ref(T)")
.Input("lr: T")
.Input("grad: T")
.Input("momentum: T")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
.Attr("use_nesterov: bool = false")
.SetShapeFn([](InferenceContext* c) {
return ApplyMomentumShapeFn(c, false /* not sparse */);
});
最终,我们可以在kernel目录下找到ApplyMomentum这个Op的实现。针对不同的设备,ApplyMomentum有不同的特化实现。
template <typename T>
struct ApplyMomentum<CPUDevice, T> {
void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
typename TTypes<T>::Flat accum,
typename TTypes<T>::ConstScalar lr,
typename TTypes<T>::ConstFlat grad,
typename TTypes<T>::ConstScalar momentum, bool use_nesterov) {
accum.device(d) = accum * momentum() + grad;
if (use_nesterov) {
var.device(d) -= grad * lr() + accum * momentum() * lr();
} else {
var.device(d) -= accum * lr();
}
}
};
TensorFlow优化器浅析的更多相关文章
- TensorFlow从0到1之TensorFlow优化器(13)
高中数学学过,函数在一阶导数为零的地方达到其最大值和最小值.梯度下降算法基于相同的原理,即调整系数(权重和偏置)使损失函数的梯度下降. 在回归中,使用梯度下降来优化损失函数并获得系数.本节将介绍如何使 ...
- TensorFlow优化器及用法
TensorFlow优化器及用法 函数在一阶导数为零的地方达到其最大值和最小值.梯度下降算法基于相同的原理,即调整系数(权重和偏置)使损失函数的梯度下降. 在回归中,使用梯度下降来优化损失函数并获得系 ...
- tensorflow优化器-【老鱼学tensorflow】
tensorflow中的优化器主要是各种求解方程的方法,我们知道求解非线性方程有各种方法,比如二分法.牛顿法.割线法等,类似的,tensorflow中的优化器也只是在求解方程时的各种方法. 比较常用的 ...
- DNN网络(三)python下用Tensorflow实现DNN网络以及Adagrad优化器
摘自: https://www.kaggle.com/zoupet/neural-network-model-for-house-prices-tensorflow 一.实现功能简介: 本文摘自Kag ...
- Tensorflow 中的优化器解析
Tensorflow:1.6.0 优化器(reference:https://blog.csdn.net/weixin_40170902/article/details/80092628) I: t ...
- tensorflow的几种优化器
最近自己用CNN跑了下MINIST,准确率很低(迭代过程中),跑了几个epoch,我就直接stop了,感觉哪有问题,随即排查了下,同时查阅了网上其他人的blog,并没有发现什么问题 之后copy了一篇 ...
- 莫烦大大TensorFlow学习笔记(8)----优化器
一.TensorFlow中的优化器 tf.train.GradientDescentOptimizer:梯度下降算法 tf.train.AdadeltaOptimizer tf.train.Adagr ...
- TensorFlow使用记录 (六): 优化器
0. tf.train.Optimizer tensorflow 里提供了丰富的优化器,这些优化器都继承与 Optimizer 这个类.class Optimizer 有一些方法,这里简单介绍下: 0 ...
- Tensorflow 2.0 深度学习实战 —— 详细介绍损失函数、优化器、激活函数、多层感知机的实现原理
前言 AI 人工智能包含了机器学习与深度学习,在前几篇文章曾经介绍过机器学习的基础知识,包括了监督学习和无监督学习,有兴趣的朋友可以阅读< Python 机器学习实战 >.而深度学习开始只 ...
随机推荐
- Windows 重装系统,配置 WSL,美化终端,部署 WebDAV 服务器,并备份系统分区
最新博客文章链接 最近发现我 Windows11 上的 WSL 打不开了,一直提示我虚拟化功能没有打开,但我看了下配置,发现虚拟化功能其实是开着的.然后试了各种方法,重装了好几次系统,我一个软件一个软 ...
- Vulnhub系列:Tomato(文件包含getshell)
这个靶机挺有意思,它是通过文件包含漏洞进行的getshell,主要姿势是将含有一句话木马的内容记录到ssh的登录日志中,然后利用文件包含漏洞进行包含,从而拿到shell 0x01 靶机信息 靶机:To ...
- [.Net]使用ABP 数据库迁移migration遇到的坑及解决方案
问题:在使用Update-Database时,突然出现"数据库中已存在名为 'XXX' 的对象". 检查发现__EFMigrationsHistory表中的MigrationI ...
- Solon Web 开发,六、过滤器、处理、拦截器
Solon Web 开发 一.开始 二.开发知识准备 三.打包与运行 四.请求上下文 五.数据访问.事务与缓存应用 六.过滤器.处理.拦截器 七.视图模板与Mvc注解 八.校验.及定制与扩展 九.跨域 ...
- winform控件拖动
示例代码 using System; using System.Collections.Generic; using System.Drawing; using System.Windows.Form ...
- java抽象类概述特点
1 package face_09; 2 /* 3 * 抽象类: 4 * 抽象:笼统,模糊,看不懂!不具体. 5 * 6 * *特点: 7 * 1,方法只有声明没有实现时,该方法就是抽象方法,需要被a ...
- 理解cpu load
三种命令 1. w 2. uptime 3. top CPU负载和CPU利用率的区别 1)CPU利用率:显示的是程序在运行期间实时占用的CPU百分比 2)CPU负载:显示的是一段时间内正在使用和等待使 ...
- linux文件时间详细说明
目录 一:文件时间信息 2 文件时间详细说明 一:文件时间信息 1 文件时间信息分类: 三种时间信息 文件修改时间: mtime 属性修改时间: ctime 文件访问时间: atime 2 查看文件时 ...
- 怎么重载网卡?ip修改 HHS服务器
目录 一:目录结构知识详述 1.网卡配置文件 2,ip修改 3.重载网卡信息 4.关闭网络管理器(因为已经有了network)所有要关闭NetworkManager不然会发生冲突 5.判断SSH服务是 ...
- How to check in Windows if you are using UEFI
You might be wondering if Windows is using UEFI or the legacy BIOS, it's easy to check. Just fire up ...