版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/u013745804/article/details/79589514

————————————————

引子

写这篇文章的原因是今天有人问我,DQN中为什么要对q_target进行stop_gradient啊?
        这个函数在TensorFlow中还是很重要的,所以我们利用DQN的代码实例来说明该函数的作用。我要来的两份DQN代码实例见《DQN的两种实现》,下面我们对

其中的关键代码进行分析:

No stop_gradient

这个版本就是人们写得相对较多的版本了,话不多说,直接上代码:

...
self.q_target = tf.placeholder(tf.float32, [None, self.n_actions], name='Q_target') # for calculating loss
...
with tf.variable_scope('loss'):
self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval))
with tf.variable_scope('train'):
self._train_op = tf.train.RMSPropOptimizer(self.lr).minimize(self.loss)
...

上面这一小段代码就是DQN的常规写法了。我们知道,在DQN中会维持两个网络,一个eval net,一个target net。我们对eval net的参数更新是通过MSE + GD来更新的,而MSE的计算将用到target net对下一状态的估值,通常的做法是对eval net设置一个placeholder,也即引入一个输入,用这个placeholder计算loss。

stop_gradient

如果我们使用stop_gradient的话,又是如何解决的呢?

...
with tf.variable_scope('q_target'):
q_target = self.r + self.gamma * tf.reduce_max(self.q_next, axis=1, name='Qmax_s_') # shape=(None, )
self.q_target = tf.stop_gradient(q_target)
...
with tf.variable_scope('loss'):
self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval_wrt_a, name='TD_error'))

这段代码中,我们使用tf.stop_gradient对q_target的反传进行截断,得到self.q_target这个op(运行时就是Tensor了),然后利用通过截断反传得到的self.q_target来计算loss,并没有使用feed_dict。

What’s the difference?

这两者究竟有什么内在区别?我们知道,在TensorFlow中,维持着一些opop在被执行之后将变为常量Tensor(指的不是Variable意义的Tensor),这些计算(eval/run)得到的常量Tensor可以看作是我们自己给出的输入数据。

第一种方法中placeholder输入的本身就是计算好了的q_target,也就是说我们通过feed_dict,将对target net进行计算得到的一个q_target Tensor传入placeholder中,当做常量来对待,我们可以把一次计算(eval/run)看作是一次截图,得到当时各个op的值。这样的话,我们对于eval net中loss的反传就不会影响到target net了。

第二种方法中直接拿target net中的q_target这个op来计算eval net中的loss显然是不妥的,因为我们对loss进行反传时将会影响到target net,这不是我们想看到的结果。所以,这里引入stop_gradient来对从loss到target net的反传进行截断,换句话说,通过self.q_target = tf.stop_gradient(q_target),将原本为TensorFlow计算图中的一个op(节点)转为一个常量self.q_target,这时候对于loss的求导反传就不会传到 target net 去了。
        有没有对如何使用tf.stop_gradient这一方法清楚一些呢?

【转载】 关于tf.stop_gradient的使用及理解的更多相关文章

  1. [转载]Pytorch中nn.Linear module的理解

    [转载]Pytorch中nn.Linear module的理解 本文转载并援引全文纯粹是为了构建和分类自己的知识,方便自己未来的查找,没啥其他意思. 这个模块要实现的公式是:y=xAT+*b 来源:h ...

  2. tf.metrics.sparse_average_precision_at_k 和 tf.metrics.precision_at_k的自己理解

    tensorflow最大的问题就是大家都讲算法,不讲解用法,API文档又全是英文的,看起来好吃力,理解又不到位.当然给数学博士看的话,就没问题的. 最近看了一系列非常不错的文章,做一下记录: http ...

  3. 【转载】PHP运行模式的深入理解

    PHP运行模式的深入理解 作者: 字体:[增加 减小] 类型:转载 时间:2013-06-03我要评论 本篇文章是对PHP运行模式进行了详细的分析介绍,需要的朋友参考下   PHP运行模式有4钟:1) ...

  4. 【转载】 tf.Print() (------------ tensorflow中的print函数)

    原文地址: https://blog.csdn.net/weixin_36670529/article/details/100191674 ------------------------------ ...

  5. 【转载】 tf.ConfigProto和tf.GPUOptions用法总结

    原文地址: https://blog.csdn.net/C_chuxin/article/details/84990176 -------------------------------------- ...

  6. 转载:Java多线程中join方法的理解

    转载自:http://uule.iteye.com/blog/1101994 thread.Join把指定的线程加入到当前线程,可以将两个交替执行的线程合并为顺序执行的线程.比如在线程B中调用了线程A ...

  7. 【转载】C/C++杂记:深入理解数据成员指针、函数成员指针

    原文:C/C++杂记:深入理解数据成员指针.函数成员指针 1. 数据成员指针 对于普通指针变量来说,其值是它所指向的地址,0表示空指针.而对于数据成员指针变量来说,其值是数据成员所在地址相对于对象起始 ...

  8. 【转载】Raft 为什么是更易理解的分布式一致性算法

    一致性问题可以算是分布式领域的一个圣殿级问题了,关于它的研究可以回溯到几十年前. 拜占庭将军问题 Leslie Lamport 在三十多年前发表的论文<拜占庭将军问题>(参考[1]). 拜 ...

  9. 【转载】 tf.cond() ----------------------(tensorflow 条件判断语句 if.......else....... )

    原文地址: https://cloud.tencent.com/developer/article/1486441 ------------------------------------------ ...

  10. 【转载】 tf.train.slice_input_producer()和tf.train.batch()

    原文地址: https://www.jianshu.com/p/8ba9cfc738c2 ------------------------------------------------------- ...

随机推荐

  1. 任意树遍历,可以使用 goto 跳记号标注的

    先顺序进入到最后一个根的根部,完后扫描同级 同级扫描完用 goto跳代码改层数到倒数地二层 之后操作就是倒着往上搜索的,有难度,但是还是能做到的嘛 用 lisit 好像不需要别的,全用 list 连接 ...

  2. .NET5 IIS ASP.NET CORE 部署时 HTTP Error 502.5 - ANCM Out-Of-Process Startup Failure

    .NET5 IIS ASP.NET CORE 部署时 HTTP Error 502.5 - ANCM Out-Of-Process Startup Failure 部署机器只安装了dotnet-hos ...

  3. 如何基于Perl实现批量蛋白名转换为基因名?以做后续GO与KEGG分析

    众所周知,在完成蛋白组学组间差异蛋白筛选后,往往要做GO与KEGG功能富集分析,这就需要我们首先将蛋白名转换为基因名,或者找出基因ID.将蛋白名转化为基因名可能涉及不同的转换工具或数据库,这里有几种常 ...

  4. Cannot set properties of undefined (setting 'dataIndex')""

    前端写桑基图的时候碰到以上bug 原因是: 桑基图中的name值有重复的,把重复的name值去掉就好了

  5. Mirror多人联网发布阿里云

    Mirror多人联网发布阿里云 新建模板小书匠 将mirror网络地址和端口选为你阿里云服务器上开放的公网地址和端口 IP与端口 2. 在阿里云服务器安全组中开放你所制定的端口 开放阿里云端口 3. ...

  6. CodeForces 1935A

    题目链接:Entertainment in MAC 思路 当当前操作次数n为偶数时,若原字符串大于反转字符串则可以将原字符串反转n - 2次,则得到的还是原字符串,此时反转一次,并将其再次反转的字符串 ...

  7. java 8 stream toMap问题

    最近使用java的stream功能有点多,理由有2: 1)少写了不少代码 2)在性能可以接受的范围内 在巨大的collection基础上使用stream,没有什么经验.而非关键业务上,乐于使用stre ...

  8. java多线程-3-使用多线程的时机

    许多人对于计算机的运行原理不了解,甚至根本不了解. 不幸的是,此类中的一部分人也参与了计算机的编码工作.可想而知,编写的效率和结果.听者伤心,闻者流泪. 此类同学的常见的误解: 并发就能加快任务完成 ...

  9. gitlab角色与权限

    用户在项目中的角色 Guest:访客.可以创建issue.发表评论,不能读写版本库.(就是看不了代码-) Reporter:Git项目测试人员.可以克隆代码,不能提交.QA.PM可以赋予这个权限. D ...

  10. python重拾第八天-Socket网络编程

    本节内容 Socket介绍 Socket参数介绍 基本Socket实例 Socket实现多连接处理 通过Socket实现简单SSH 通过Socket实现文件传送 作业:开发一个支持多用户在线的FTP程 ...