转自:https://www.tensorflow.org/api_docs/python/tf/train/GradientDescentOptimizer

1.tf.train.GradientDescentOptimizer

其中有函数:

1.1apply_gradients

apply_gradients(
grads_and_vars,
global_step=None,
name=None
)

Apply gradients to variables.

This is the second part of minimize(). It returns an Operation that applies gradients.

将梯度应用到变量上。它是minimize函数的第二部分。

1.2compute_gradients

compute_gradients(
loss,
var_list=None,
gate_gradients=GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
grad_loss=None
)

Compute gradients of loss for the variables in var_list.

This is the first part of minimize(). It returns a list of (gradient, variable) pairs where "gradient" is the gradient for "variable". Note that "gradient" can be a Tensor, an IndexedSlices, or None if there is no gradient for the given variable.

计算var-list的梯度,它是minimize函数的第一部分,返回的是一个list,对应每个变量都有梯度。准备使用apply_gradient函数更新。

下面重点来了:

参数:

  • loss: A Tensor containing the value to minimize or a callable taking no arguments which returns the value to minimize. When eager execution is enabled it must be a callable.
  • var_list: Optional list or tuple of tf.Variable to update to minimize loss. Defaults to the list of variables collected in the graph under the key GraphKeys.TRAINABLE_VARIABLES.

loss就是损失函数,没啥了。

这个第二个参数变量列表通常是不传入的,那么计算谁的梯度呢?上面说,默认的参数列表是计算图中的 GraphKeys.TRAINABLE_VARIABLES.

去看这个的API发现:

tf.GraphKeys

The following standard keys are defined:

找到TRAINABLE_VARIABLES是:

  • TRAINABLE_VARIABLES: the subset of Variable objects that will be trained by an optimizer. Seetf.trainable_variables for more details.

然后再去看:

tf.trainable_variables

tf.trainable_variables(scope=None)

Returns all variables created with trainable=True.

When passed trainable=True, the Variable() constructor automatically adds new variables to the graph collectionGraphKeys.TRAINABLE_VARIABLES.

This convenience function returns the contents of that collection.

Returns:

A list of Variable objects.

然后再去看一下tf.Variable函数:

tf.Variable

__init__(
initial_value=None,
trainable=True,
collections=None,
validate_shape=True,
caching_device=None,
name=None,
variable_def=None,
dtype=None,
expected_shape=None,
import_scope=None,
constraint=None,
use_resource=None,
synchronization=tf.VariableSynchronization.AUTO,
aggregation=tf.VariableAggregation.NONE
)

并且:

  • trainable: If True, the default, also adds the variable to the graph collection GraphKeys.TRAINABLE_VARIABLES. This collection is used as the default list of variables to use by the Optimizer classes.

默认为真,并且加入可训练变量集中,所以:

在word2vec实现中,

with tf.device('/cpu:0'):
# Look up embeddings for inputs.
with tf.name_scope('embeddings'):
embeddings = tf.Variable(
tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
embed = tf.nn.embedding_lookup(embeddings, train_inputs)

定义的embeddings应该是可以更新的。怎么更新?:

with tf.name_scope('loss'):
loss = tf.reduce_mean(
tf.nn.nce_loss(
weights=nce_weights,
biases=nce_biases,
labels=train_labels,
inputs=embed,
num_sampled=num_sampled,
num_classes=vocabulary_size)) # Add the loss value as a scalar to summary.
tf.summary.scalar('loss', loss) # Construct the SGD optimizer using a learning rate of 1.0.
with tf.name_scope('optimizer'):
optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(loss)

使用SGD随机梯度下降,在minimize损失函数中,应该是会对所有的可训练变量求导,对的,没错一定是这样,所以nec_weights,nce_biases,embeddings都是可更新变量。

都是通过先计算损失函数,求导然后更新变量,在迭代数据计算损失函数,求导更新,

这样来更新的。

Tf中的SGDOptimizer学习【转载】的更多相关文章

  1. R中双表操作学习[转载]

    转自:https://www.jianshu.com/p/a7af4f6e50c3 1.原始数据 以上是原有的一个,再生成一个新的: > gene_exp_tidy2 <- data.fr ...

  2. Tf中的NCE-loss实现学习【转载】

    转自:http://www.jianshu.com/p/fab82fa53e16 1.tf中的nce_loss的API def nce_loss(weights, biases, inputs, la ...

  3. tf中计算图 执行流程学习【转载】

    转自:https://blog.csdn.net/dcrmg/article/details/79028003 https://blog.csdn.net/qian99/article/details ...

  4. Java多线程学习(转载)

    Java多线程学习(转载) 时间:2015-03-14 13:53:14      阅读:137413      评论:4      收藏:3      [点我收藏+] 转载 :http://blog ...

  5. CNCC2017中的深度学习与跨媒体智能

    CNCC2017中的深度学习与跨媒体智能 转载请注明作者:梦里茶 目录 机器学习与跨媒体智能 传统方法与深度学习 图像分割 小数据集下的深度学习 语音前沿技术 生成模型 基于贝叶斯的视觉信息编解码 珠 ...

  6. SqlServer中的merge操作(转载)

    SqlServer中的merge操作(转载)   今天在一个存储过程中看见了merge这个关键字,第一个想法是,这个是配置管理中的概念吗,把相邻两次的更改合并到一起.后来在technet上搜索发现别有 ...

  7. PHP中的Libevent学习

    wangbin@2012,1,3 目录 Libevent在php中的应用学习 1.      Libevent介绍 2.      为什么要学习libevent 3.      Php libeven ...

  8. spring中context:property-placeholder/元素 转载

    spring中context:property-placeholder/元素  转载 1.有些参数在某些阶段中是常量 比如 :a.在开发阶段我们连接数据库时的连接url,username,passwo ...

  9. JS中childNodes深入学习

    原文:JS中childNodes深入学习 <html xmlns="http://www.w3.org/1999/xhtml"> <head> <ti ...

随机推荐

  1. mysqldump命令的安装

    author:headsen   chen date:2019-03-14  11:31:00 安装:yum -y install mysql-client / apt-get install mys ...

  2. ssl---阿里云的public.crt和chain.crt的证书怎么弄

    由于项目需要,网站需要https服务,服务器是阿里云的,装的是宝塔的面板,下面是详细的配置ssl证书的方法: 如何在阿里云的后台申请ssl证书就不说了,下载下来的证书有三个:.key   chain. ...

  3. css---计算页面的的宽度和长度

    我们在写前端页面的时候,会遇到这样的情况,就是一个div设置宽度100%,设置左右边距10像素,这样的布局,在里面嵌套的div的宽度设置100%,这样写的话,里面的宽度是和外面的宽度一致的,同样是10 ...

  4. js dom 观察者属性 MutationObserver

    MDN上说的很清楚 MutationObserver给开发者们提供了一种能在某个范围内的DOM树发生变化时作出适当反应的能力.该API设计用来替换掉在DOM3事件规范中引入的Mutation事件 co ...

  5. dts的pci模块中bus-range和ranges

    bus-range = <2 3>;       该设备(一般为RC)下的pci总线号范围 ranges = <0x2000000 0x0 0xc0000000 0 0xc00000 ...

  6. SQL 2017 远程连接被拒绝

    1.防火墙端口 2.数据库要能帐号登录 可是还是不行 打开:SQL Server 2017 配置管理器->SQL Server 服务 ->SQLServer(你的实例名)-> 右键- ...

  7. stm32 硬件错误

    进入该模式,程序死机. 一般来说都是内存错误 1. 数组越界,装入数据溢出, 2. 堆和栈设置不当,这里面硬件的堆和栈在汇编文件中,如果有freertos等,重点检查,任务堆栈使用情况,一般任务堆栈溢 ...

  8. STM FLASH在线编程 升级

    注意字节到 stm flash 顺序是反的 例如 12 34 56 78 世纪写入内存 应该是 78 56 34 12

  9. ArcEngine二次开发,TOCControl控件上使用contextMenuStrip

    右键菜单,在二次开发中很实用,以前没用过,最近通过一本书了解到,一直想找这么一个控件来用. 一般的控件,将contextMenuStrip控件拖到所依托的控件上,然后输入自己想要的几个功能.  在所依 ...

  10. 分区实践 注意分区名 p2018-01 p2018-02 被解释为同一分区名

    # https://dev.mysql.com/doc/refman/5.6/en/partitioning-columns-range.html'''CREATE TABLE employees ( ...