Tf中的SGDOptimizer学习【转载】
转自:https://www.tensorflow.org/api_docs/python/tf/train/GradientDescentOptimizer
1.tf.train.GradientDescentOptimizer
其中有函数:

1.1apply_gradients
apply_gradients(
grads_and_vars,
global_step=None,
name=None
)
Apply gradients to variables.
This is the second part of minimize(). It returns an Operation that applies gradients.
将梯度应用到变量上。它是minimize函数的第二部分。
1.2compute_gradients
compute_gradients(
loss,
var_list=None,
gate_gradients=GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
grad_loss=None
)
Compute gradients of loss for the variables in var_list.
This is the first part of minimize(). It returns a list of (gradient, variable) pairs where "gradient" is the gradient for "variable". Note that "gradient" can be a Tensor, an IndexedSlices, or None if there is no gradient for the given variable.
计算var-list的梯度,它是minimize函数的第一部分,返回的是一个list,对应每个变量都有梯度。准备使用apply_gradient函数更新。
下面重点来了:
参数:
loss: A Tensor containing the value to minimize or a callable taking no arguments which returns the value to minimize. When eager execution is enabled it must be a callable.var_list: Optional list or tuple oftf.Variableto update to minimizeloss. Defaults to the list of variables collected in the graph under the keyGraphKeys.TRAINABLE_VARIABLES.
loss就是损失函数,没啥了。
这个第二个参数变量列表通常是不传入的,那么计算谁的梯度呢?上面说,默认的参数列表是计算图中的 GraphKeys.TRAINABLE_VARIABLES.
去看这个的API发现:
tf.GraphKeys
The following standard keys are defined:
找到TRAINABLE_VARIABLES是:
TRAINABLE_VARIABLES: the subset ofVariableobjects that will be trained by an optimizer. Seetf.trainable_variablesfor more details.
然后再去看:
tf.trainable_variables
tf.trainable_variables(scope=None)
Returns all variables created with trainable=True.
When passed trainable=True, the Variable() constructor automatically adds new variables to the graph collectionGraphKeys.TRAINABLE_VARIABLES.
This convenience function returns the contents of that collection.
Returns:
A list of Variable objects.
然后再去看一下tf.Variable函数:
tf.Variable
__init__(
initial_value=None,
trainable=True,
collections=None,
validate_shape=True,
caching_device=None,
name=None,
variable_def=None,
dtype=None,
expected_shape=None,
import_scope=None,
constraint=None,
use_resource=None,
synchronization=tf.VariableSynchronization.AUTO,
aggregation=tf.VariableAggregation.NONE
)
并且:
trainable: IfTrue, the default, also adds the variable to the graph collectionGraphKeys.TRAINABLE_VARIABLES. This collection is used as the default list of variables to use by theOptimizerclasses.
默认为真,并且加入可训练变量集中,所以:
在word2vec实现中,
with tf.device('/cpu:0'):
# Look up embeddings for inputs.
with tf.name_scope('embeddings'):
embeddings = tf.Variable(
tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
embed = tf.nn.embedding_lookup(embeddings, train_inputs)
定义的embeddings应该是可以更新的。怎么更新?:
with tf.name_scope('loss'):
loss = tf.reduce_mean(
tf.nn.nce_loss(
weights=nce_weights,
biases=nce_biases,
labels=train_labels,
inputs=embed,
num_sampled=num_sampled,
num_classes=vocabulary_size))
# Add the loss value as a scalar to summary.
tf.summary.scalar('loss', loss)
# Construct the SGD optimizer using a learning rate of 1.0.
with tf.name_scope('optimizer'):
optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(loss)
使用SGD随机梯度下降,在minimize损失函数中,应该是会对所有的可训练变量求导,对的,没错一定是这样,所以nec_weights,nce_biases,embeddings都是可更新变量。
都是通过先计算损失函数,求导然后更新变量,在迭代数据计算损失函数,求导更新,
这样来更新的。
Tf中的SGDOptimizer学习【转载】的更多相关文章
- R中双表操作学习[转载]
转自:https://www.jianshu.com/p/a7af4f6e50c3 1.原始数据 以上是原有的一个,再生成一个新的: > gene_exp_tidy2 <- data.fr ...
- Tf中的NCE-loss实现学习【转载】
转自:http://www.jianshu.com/p/fab82fa53e16 1.tf中的nce_loss的API def nce_loss(weights, biases, inputs, la ...
- tf中计算图 执行流程学习【转载】
转自:https://blog.csdn.net/dcrmg/article/details/79028003 https://blog.csdn.net/qian99/article/details ...
- Java多线程学习(转载)
Java多线程学习(转载) 时间:2015-03-14 13:53:14 阅读:137413 评论:4 收藏:3 [点我收藏+] 转载 :http://blog ...
- CNCC2017中的深度学习与跨媒体智能
CNCC2017中的深度学习与跨媒体智能 转载请注明作者:梦里茶 目录 机器学习与跨媒体智能 传统方法与深度学习 图像分割 小数据集下的深度学习 语音前沿技术 生成模型 基于贝叶斯的视觉信息编解码 珠 ...
- SqlServer中的merge操作(转载)
SqlServer中的merge操作(转载) 今天在一个存储过程中看见了merge这个关键字,第一个想法是,这个是配置管理中的概念吗,把相邻两次的更改合并到一起.后来在technet上搜索发现别有 ...
- PHP中的Libevent学习
wangbin@2012,1,3 目录 Libevent在php中的应用学习 1. Libevent介绍 2. 为什么要学习libevent 3. Php libeven ...
- spring中context:property-placeholder/元素 转载
spring中context:property-placeholder/元素 转载 1.有些参数在某些阶段中是常量 比如 :a.在开发阶段我们连接数据库时的连接url,username,passwo ...
- JS中childNodes深入学习
原文:JS中childNodes深入学习 <html xmlns="http://www.w3.org/1999/xhtml"> <head> <ti ...
随机推荐
- mysqldump命令的安装
author:headsen chen date:2019-03-14 11:31:00 安装:yum -y install mysql-client / apt-get install mys ...
- ssl---阿里云的public.crt和chain.crt的证书怎么弄
由于项目需要,网站需要https服务,服务器是阿里云的,装的是宝塔的面板,下面是详细的配置ssl证书的方法: 如何在阿里云的后台申请ssl证书就不说了,下载下来的证书有三个:.key chain. ...
- css---计算页面的的宽度和长度
我们在写前端页面的时候,会遇到这样的情况,就是一个div设置宽度100%,设置左右边距10像素,这样的布局,在里面嵌套的div的宽度设置100%,这样写的话,里面的宽度是和外面的宽度一致的,同样是10 ...
- js dom 观察者属性 MutationObserver
MDN上说的很清楚 MutationObserver给开发者们提供了一种能在某个范围内的DOM树发生变化时作出适当反应的能力.该API设计用来替换掉在DOM3事件规范中引入的Mutation事件 co ...
- dts的pci模块中bus-range和ranges
bus-range = <2 3>; 该设备(一般为RC)下的pci总线号范围 ranges = <0x2000000 0x0 0xc0000000 0 0xc00000 ...
- SQL 2017 远程连接被拒绝
1.防火墙端口 2.数据库要能帐号登录 可是还是不行 打开:SQL Server 2017 配置管理器->SQL Server 服务 ->SQLServer(你的实例名)-> 右键- ...
- stm32 硬件错误
进入该模式,程序死机. 一般来说都是内存错误 1. 数组越界,装入数据溢出, 2. 堆和栈设置不当,这里面硬件的堆和栈在汇编文件中,如果有freertos等,重点检查,任务堆栈使用情况,一般任务堆栈溢 ...
- STM FLASH在线编程 升级
注意字节到 stm flash 顺序是反的 例如 12 34 56 78 世纪写入内存 应该是 78 56 34 12
- ArcEngine二次开发,TOCControl控件上使用contextMenuStrip
右键菜单,在二次开发中很实用,以前没用过,最近通过一本书了解到,一直想找这么一个控件来用. 一般的控件,将contextMenuStrip控件拖到所依托的控件上,然后输入自己想要的几个功能. 在所依 ...
- 分区实践 注意分区名 p2018-01 p2018-02 被解释为同一分区名
# https://dev.mysql.com/doc/refman/5.6/en/partitioning-columns-range.html'''CREATE TABLE employees ( ...