转自:https://www.tensorflow.org/api_docs/python/tf/train/GradientDescentOptimizer

1.tf.train.GradientDescentOptimizer

其中有函数:

1.1apply_gradients

apply_gradients(
grads_and_vars,
global_step=None,
name=None
)

Apply gradients to variables.

This is the second part of minimize(). It returns an Operation that applies gradients.

将梯度应用到变量上。它是minimize函数的第二部分。

1.2compute_gradients

compute_gradients(
loss,
var_list=None,
gate_gradients=GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
grad_loss=None
)

Compute gradients of loss for the variables in var_list.

This is the first part of minimize(). It returns a list of (gradient, variable) pairs where "gradient" is the gradient for "variable". Note that "gradient" can be a Tensor, an IndexedSlices, or None if there is no gradient for the given variable.

计算var-list的梯度,它是minimize函数的第一部分,返回的是一个list,对应每个变量都有梯度。准备使用apply_gradient函数更新。

下面重点来了:

参数:

  • loss: A Tensor containing the value to minimize or a callable taking no arguments which returns the value to minimize. When eager execution is enabled it must be a callable.
  • var_list: Optional list or tuple of tf.Variable to update to minimize loss. Defaults to the list of variables collected in the graph under the key GraphKeys.TRAINABLE_VARIABLES.

loss就是损失函数,没啥了。

这个第二个参数变量列表通常是不传入的,那么计算谁的梯度呢?上面说,默认的参数列表是计算图中的 GraphKeys.TRAINABLE_VARIABLES.

去看这个的API发现:

tf.GraphKeys

The following standard keys are defined:

找到TRAINABLE_VARIABLES是:

  • TRAINABLE_VARIABLES: the subset of Variable objects that will be trained by an optimizer. Seetf.trainable_variables for more details.

然后再去看:

tf.trainable_variables

tf.trainable_variables(scope=None)

Returns all variables created with trainable=True.

When passed trainable=True, the Variable() constructor automatically adds new variables to the graph collectionGraphKeys.TRAINABLE_VARIABLES.

This convenience function returns the contents of that collection.

Returns:

A list of Variable objects.

然后再去看一下tf.Variable函数:

tf.Variable

__init__(
initial_value=None,
trainable=True,
collections=None,
validate_shape=True,
caching_device=None,
name=None,
variable_def=None,
dtype=None,
expected_shape=None,
import_scope=None,
constraint=None,
use_resource=None,
synchronization=tf.VariableSynchronization.AUTO,
aggregation=tf.VariableAggregation.NONE
)

并且:

  • trainable: If True, the default, also adds the variable to the graph collection GraphKeys.TRAINABLE_VARIABLES. This collection is used as the default list of variables to use by the Optimizer classes.

默认为真,并且加入可训练变量集中,所以:

在word2vec实现中,

with tf.device('/cpu:0'):
# Look up embeddings for inputs.
with tf.name_scope('embeddings'):
embeddings = tf.Variable(
tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
embed = tf.nn.embedding_lookup(embeddings, train_inputs)

定义的embeddings应该是可以更新的。怎么更新?:

with tf.name_scope('loss'):
loss = tf.reduce_mean(
tf.nn.nce_loss(
weights=nce_weights,
biases=nce_biases,
labels=train_labels,
inputs=embed,
num_sampled=num_sampled,
num_classes=vocabulary_size)) # Add the loss value as a scalar to summary.
tf.summary.scalar('loss', loss) # Construct the SGD optimizer using a learning rate of 1.0.
with tf.name_scope('optimizer'):
optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(loss)

使用SGD随机梯度下降,在minimize损失函数中,应该是会对所有的可训练变量求导,对的,没错一定是这样,所以nec_weights,nce_biases,embeddings都是可更新变量。

都是通过先计算损失函数,求导然后更新变量,在迭代数据计算损失函数,求导更新,

这样来更新的。

Tf中的SGDOptimizer学习【转载】的更多相关文章

  1. R中双表操作学习[转载]

    转自:https://www.jianshu.com/p/a7af4f6e50c3 1.原始数据 以上是原有的一个,再生成一个新的: > gene_exp_tidy2 <- data.fr ...

  2. Tf中的NCE-loss实现学习【转载】

    转自:http://www.jianshu.com/p/fab82fa53e16 1.tf中的nce_loss的API def nce_loss(weights, biases, inputs, la ...

  3. tf中计算图 执行流程学习【转载】

    转自:https://blog.csdn.net/dcrmg/article/details/79028003 https://blog.csdn.net/qian99/article/details ...

  4. Java多线程学习(转载)

    Java多线程学习(转载) 时间:2015-03-14 13:53:14      阅读:137413      评论:4      收藏:3      [点我收藏+] 转载 :http://blog ...

  5. CNCC2017中的深度学习与跨媒体智能

    CNCC2017中的深度学习与跨媒体智能 转载请注明作者:梦里茶 目录 机器学习与跨媒体智能 传统方法与深度学习 图像分割 小数据集下的深度学习 语音前沿技术 生成模型 基于贝叶斯的视觉信息编解码 珠 ...

  6. SqlServer中的merge操作(转载)

    SqlServer中的merge操作(转载)   今天在一个存储过程中看见了merge这个关键字,第一个想法是,这个是配置管理中的概念吗,把相邻两次的更改合并到一起.后来在technet上搜索发现别有 ...

  7. PHP中的Libevent学习

    wangbin@2012,1,3 目录 Libevent在php中的应用学习 1.      Libevent介绍 2.      为什么要学习libevent 3.      Php libeven ...

  8. spring中context:property-placeholder/元素 转载

    spring中context:property-placeholder/元素  转载 1.有些参数在某些阶段中是常量 比如 :a.在开发阶段我们连接数据库时的连接url,username,passwo ...

  9. JS中childNodes深入学习

    原文:JS中childNodes深入学习 <html xmlns="http://www.w3.org/1999/xhtml"> <head> <ti ...

随机推荐

  1. I - 迷宫问题

    定义一个二维数组: int maze[5][5] = { 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, ...

  2. webpack 配置

    https://segmentfault.com/a/1190000009454172

  3. Linux关闭IPV6

    Linux关闭IPV6的方法 修改配置文件/etc/sysctl.conf添加以下1行 net.ipv6.conf.all.disable_ipv6 = 1 设置生效 sysctl -p 查看没有IP ...

  4. Missing artifact com.h2database:h2:jar:1.4.197

    之前OK的项目再次打开pom上报错: 一起出现的现象: maven库中这个包和H2数据库的包每次项目右键→maven→update project都会产生.lastupdate文件.原来是以前从mav ...

  5. 1.7Oob 构造方法

    1)构造方法 在创建对象后不用调用会自动执行,如无自定义构造会默认执行没有参数没有,且方法体中没有任何语句的, 2)构造方法在main入口开始后就执行

  6. Eclipse各个版本区别

    1.eclipse下载地址: 最新版:http://www.eclipse.org/downloads/ 历史版:http://archive.eclipse.org/eclipse/download ...

  7. 关于Java程序流程控制的整理(未完善)

  8. Xml文件删除节点总是留有空标签

    ---恢复内容开始--- 在删除Xml文件时,删除成功后还有标签,让我百思不得其解,因为xml文档中留着这空标签会对后续的操作带来很多麻烦,会取出空值,人后导致程序中止. 导致这种情况的原因是删除xm ...

  9. AndroidStudio_RecyclerView

    在这里回顾一下RecyclerView的用法 RecyclerView的用法与Button的用法很类似,只是要增加一个Adapter.java文件和item.xml文件 具体用法: 1.在page1. ...

  10. 最全的MonkeyRunner自动化测试从入门到精通(1)

    一.环境变量的配置 1.JDK环境变量的配置 步骤一:在官网上面下载jdk,JDK官网网址: http://www.oracle.com/technetwork/java/javase/downloa ...