tf.trainable_variables可以得到整个模型中所有trainable=True的Variable,也是自由处理梯度的基础

基础梯度操作方法:

tf.gradients 
用来计算导数。该函数的定义如下所示

def gradients(ys,
xs,
grad_ys=None,
name="gradients",
colocate_gradients_with_ops=False,
gate_gradients=False,
aggregation_method=None):

虽然可选参数很多,但是最常使用的还是ys和xs。根据说明得知,ys和xs都可以是一个tensor或者tensor列表。而计算完成以后,该函数会返回一个长为len(xs)的tensor列表,列表中的每个tensor是ys中每个值对xs[i]求导之和。如果用数学公式表示的话,那么 g = tf.gradients(y,x)可以表示成 ,

『cs231n』通过代码理解风格迁移

tf.gradients(loss, model.input_tensor)  # 计算梯度,并非使用optimizer类实现

tf.clip_by_global_norm

修正梯度值,用于控制梯度爆炸的问题。梯度爆炸和梯度弥散的原因一样,都是因为链式法则求导的关系,导致梯度的指数级衰减。为了避免梯度爆炸,需要对梯度进行修剪。 
先来看这个函数的定义:

def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):

输入参数中:t_list为待修剪的张量, clip_norm 表示修剪比例(clipping ratio).

函数返回2个参数: list_clipped,修剪后的张量,以及global_norm,一个中间计算量。当然如果你之前已经计算出了global_norm值,你可以在use_norm选项直接指定global_norm的值。

那么具体如何计算呢?根据源码中的说明,可以得到

list_clipped[i]=t_list[i] * clip_norm / max(global_norm, clip_norm),

其中 global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))

可以写作

其中, 
Lic和Lig代表t_list[i]和list_clipped[i], 
Nc和Ng代表clip_norm 和 global_norm的值。 
其实也可以看到其实Ng就是t_list的L2模。上式也可以进一步写作

也就是说,当t_list的L2模大于指定的Nc时,就会对t_list做等比例缩放。

这里讲解一下具体应用于优化器的方法,

self._lr = tf.Variable(0.0, trainable=False)  # lr 指的是 learning_rate
tvars = tf.trainable_variables() grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars),
config.max_grad_norm) # 梯度下降优化,指定学习速率
optimizer = tf.train.GradientDescentOptimizer(self._lr)
# optimizer = tf.train.AdamOptimizer()
# optimizer = tf.train.GradientDescentOptimizer(0.5)
self._train_op = optimizer.apply_gradients(zip(grads, tvars)) # 将梯度应用于变量
# self._train_op = optimizer.minimize(grads)

优化器类处理法:

『TensorFlow』网络操作API_下

提取梯度,使用梯度优化变量,效果和上面的例子相同,

# 创建一个optimizer.
opt = GradientDescentOptimizer(learning_rate=0.1) # 计算<list of variables>相关的梯度
grads_and_vars = opt.compute_gradients(loss, <list of variables>) # grads_and_vars为tuples (gradient, variable)组成的列表。
#对梯度进行想要的处理,比如cap处理
capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars] # 令optimizer运用capped的梯度(gradients)
opt.apply_gradients(capped_grads_and_vars)

『TensorFlow』梯度优化相关的更多相关文章

  1. 『TensorFlow』专题汇总

    TensorFlow:官方文档 TensorFlow:项目地址 本篇列出文章对于全零新手不太合适,可以尝试TensorFlow入门系列博客,搭配其他资料进行学习. Keras使用tf.Session训 ...

  2. 『TensorFlow』分布式训练_其三_多机分布式

    本节中的代码大量使用『TensorFlow』分布式训练_其一_逻辑梳理中介绍的概念,是成熟的多机分布式训练样例 一.基本概念 Cluster.Job.task概念:三者可以简单的看成是层次关系,tas ...

  3. 『TensorFlow』读书笔记_降噪自编码器

    『TensorFlow』降噪自编码器设计  之前学习过的代码,又敲了一遍,新的收获也还是有的,因为这次注释写的比较详尽,所以再次记录一下,具体的相关知识查阅之前写的文章即可(见上面链接). # Aut ...

  4. 『TensorFlow』SSD源码学习_其一:论文及开源项目文档介绍

    一.论文介绍 读论文系列:Object Detection ECCV2016 SSD 一句话概括:SSD就是关于类别的多尺度RPN网络 基本思路: 基础网络后接多层feature map 多层feat ...

  5. 『TensorFlow』DCGAN生成动漫人物头像_下

    『TensorFlow』以GAN为例的神经网络类范式 『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 『TensorFlow』通过代码理解gan网络_中 一.计算 ...

  6. 『TensorFlow』滑动平均

    滑动平均会为目标变量维护一个影子变量,影子变量不影响原变量的更新维护,但是在测试或者实际预测过程中(非训练时),使用影子变量代替原变量. 1.滑动平均求解对象初始化 ema = tf.train.Ex ...

  7. 『TensorFlow』流程控制

    『PyTorch』第六弹_最小二乘法对比PyTorch和TensorFlow TensorFlow 控制流程操作 TensorFlow 提供了几个操作和类,您可以使用它们来控制操作的执行并向图中添加条 ...

  8. 『TensorFlow』模型保存和载入方法汇总

    『TensorFlow』第七弹_保存&载入会话_霸王回马 一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 ...

  9. 『TensorFlow』命令行参数解析

    argparse很强大,但是我们未必需要使用这么繁杂的东西,TensorFlow自己封装了一个简化版本的解析方式,实际上是对argparse的封装 脚本化调用tensorflow的标准范式: impo ...

随机推荐

  1. 使用hashlib进行登录校验

    注册登录和密码验证 用户注册时,文件中保存用户名,和密码的密文 登录时,密码与文件中的密文进行比较,如果相同就同意登录 import hashlib # 导入模块 def md5(username,p ...

  2. Cocos 2dx项目lua调用OC出现卡死但不闪退的坑

    最近新上线的一个游戏,发现线上游戏有部分功能在点击的时候出现了没有反应的情况.通过调试源码,发现是原生OC的代码出现了崩溃,但是比较奇怪的是线上的Bugly没有任何记录,这个功能属于高频高能,而且又是 ...

  3. PowerBI新功能: PowerBI多报表共享一个数据集

    在PowerBI里面建模,来来回回摸了一遍之后,肯定不想在另外一个报表的时候重复一次建模--这样不利于复用和维护. 四月份的更新版提供了一个预览功能(如下),可以让多报表(pbix)共享同一个数据集. ...

  4. 20175211 2018-2019-2 《Java程序设计》第四周学习总结

    目录 教材学习内容总结 第五章 子类与继承 教材学习中的问题和解决过程 代码调试中的问题和解决过程 代码托管 上周考试错题总结 其他(感悟.思考等,可选) 学习进度条 参考资料 教材学习内容总结 第五 ...

  5. GTK# on Ubuntu DllMap

    修改配置:/etc/mono/config 新增以下代码 <dllmap dll="libglib-2.0-0.dll" target="libglib-2.0.s ...

  6. js页面路径拼接字符串进行参数传递

    页面路径拼接字符串进行参数传递: 参数传递页面: <style> input,button{ border: 1px solid red; } body { font-size:24px; ...

  7. React 最简单的入门教程

      一看就懂的ReactJs入门教程(精华版)   现在最热门的前端框架有AngularJS.React.Bootstrap等.自从接触了ReactJS,ReactJs的虚拟DOM(Virtual D ...

  8. Mysql BLOB、BLOB与TEXT区别及性能影响、将BLOB类型转换成VARCHAR类型

    在排查公司项目业务逻辑的时候,见到了陌生的字眼,如下图 顺着关键字BLOB搜索,原来是Mysql存储的一种类型,从很多文章下了解到如下信息 了解 MySQL中,BLOB字段用于存储二进制数据,是一个可 ...

  9. Centos下安装cassandra

    一.环境准备 环境 Centos6.5  .安装有Java JDK(https://www.cnblogs.com/wt645631686/p/8267239.html这篇文章里有安装JDK1.8的教 ...

  10. Python if条件判断

    if else 判断交流 1.判断用户名密码对不对. import getpass _username = 'devin' _password = 'abc123' username = input( ...