tensorflow语法【shape、tf.trainable_variables()、Optimizer.minimize()】
相关文章:
【一】tensorflow安装、常用python镜像源、tensorflow 深度学习强化学习教学
【二】tensorflow调试报错、tensorflow 深度学习强化学习教学
trick1---实现tensorflow和pytorch迁移环境教学
张量shape参数理解
shape参数的个数应为维度数,每一个参数的值代表该维度上的长度
shape=(100,784)
代表该张量有两个维度,第一个维度长度为100,第二个维度长度为784,二维数组100行784列
shape=(2,)
代表该张量有一个维度,第一个维度长度为2,一维数组1行2列
第几个维度的长度,就是左数第几个中括号组之间的元素总数量
# 例:
[[[1,2,3],[4,5,6]]]
# 第一个维度中只有一个元素[[1,2,3][4,5,6]],所以第一个维度长度为1
# 第二个维度中有两个元素[1,2,3][4,5,6],所以第二个维度长度为2
# 第三个维度中有三个元素“1,2,3”或“4,5,6”,所以第三个维度长度为3
# 那么它的shape参数就是[1,2,3]
tf.trainable_variables(), tf.global_variables()的使用
tf.trainable_variables():
这个函数可以查看可训练的变量,参数trainable,其默认为True
__init__(
initial_value=None,
trainable=True,
collections=None,
validate_shape=True,
...
)
对于一些我们不需要训练的变量,将trainable设置为False,这时tf.trainable_variables() 就不会打印这些变量。
举个简单的例子,在下图中共定义了4个变量,分别是一个权重矩阵,一个偏置向量,一个学习率和计步器,其中前两项是需要训练的而后两项则不需要。
w1 = tf. Variable (tf. randon_normal ([256, 2000]),'w1' )
b1 = tf.get_ variable('b1', [2000])
learning_ rate = tf. Variable(0.5, trainable=False)
global_ step = tf. Variable(0, trainable=False)
trainable_ params = tf. trainable_ variables()
trainable_ params
[<tf. Variable’Variable:0' shape= (256,2000) dtype=float32_ ref>,
<tf. Variable’ b1:0”shape= (2000,) dtype=float32_ ref>]
另一个问题就是,如果变量定义在scope域中,是否会有不同。实际上,tf.trainable_variables()是可以通过参数选定域名的,如下图所示:
vith tf. variable_ scope(' var' ):
w2 = tf.get. variable('w2' , [3, 3])
w3 = tf.get. variable(' w3',[3, 3])
我们重新声明了两个新变量,其中w2是在‘var’中的,如果我们直接使用tf.trainable_variables(),结果如下
trainable. params = tf.trainable.variables ()
trainable_ params
[<tf. Variable’ vrar/w2:0’shape=(3, 3) dtype=float32_ ref>,
<tf. Variable’w3:0' shape=(3, 3) dtype=float32_ ref>]
如果我们只希望查看‘var’域中的变量,我们可以通过加入scope参数的方式实现:
scope_ parans = tf. trainable_ variables (scope-' var' )
scope par ains
[<tf. Variable ’var/w2:0' shape=(3, 3) dtype=float32_ ref>]
tf.global_variables()
如果我希望查看全部变量,包括我的学习率等信息,可以通过tf.global_variables()来实现。效果如下:
global parans = tf. global variables()
global_ params
[<tf. Variable,Variable:0' shape=(256, 2000) dtype=float32_ ref>,
<tf. Variable ' b1:0' shape= (2000,) dtype-float32_ ref>,
<tf. Variable。Variable_ 1:0’shape=0 dtype=float32_ ref>,
<tf. Variable' Variable_ 2:0’ shape=() dtype=int32_ ref>]
这时候打印出来了4个变量,其中后两个即为trainable=False的学习率和计步器。与tf.trainable_variables()一样,tf.global_variables()也可以通过scope的参数来选定域中的变量。
Optimizer.minimize()与Optimizer.compute_gradients()和Optimizer.apply_gradients()的用法
Optimizer.minimize()
minimize()就是compute_gradients()和apply_gradients()这两个方法的简单组合,minimize()的源码如下:
def minimize(self, loss, global_step=None, var_list=None,
gate_gradients=GATE_OP, aggregation_method=None,
colocate_gradients_with_ops=False, name=None,
grad_loss=None):
grads_and_vars = self.compute_gradients(
loss, var_list=var_list, gate_gradients=gate_gradients,
aggregation_method=aggregation_method,
colocate_gradients_with_ops=colocate_gradients_with_ops,
grad_loss=grad_loss)
vars_with_grad = [v for g, v in grads_and_vars if g is not None]
if not vars_with_grad:
raise ValueError(
"No gradients provided for any variable, check your graph for ops"
" that do not support gradients, between variables %s and loss %s." %
([str(v) for _, v in grads_and_vars], loss))
return self.apply_gradients(grads_and_vars, global_step=global_step,
name=name)
主要的参数说明:
loss: `Tensor` ,需要优化的损失;
var_list: 需要更新的变量(tf.Varialble)组成的列表或者元组,默认值为`GraphKeys.TRAINABLE_VARIABLES`,即tf.trainable_variables()
注意:
1、Optimizer.minimize(loss, var_list)中,计算loss所涉及的变量(假设为var(loss))包含在var_list中,也就是var_list中含有多余的变量,并不 影响程序的运行,而且优化过程中不改变var_list里多出变量的值;
2、若var_list中的变量个数少于var(loss),则优化过程中只会更新var_list中的那些变量的值,var(loss)里多出的变量值 并不会改变,相当于固定了网络的某一部分的参数值。
compute_gradients()和apply_gradients()
compute_gradients(self, loss, var_list=None,
gate_gradients=GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
grad_loss=None):
里面参数的定义与minimizer()函数里面的一致,var_list的默认值也一样。需要特殊说明的是,如果var_list里所包含的变量多于var(loss),则程序会报错。其返回值是(gradient, variable)对所组成的列表,返回的数据格式也都是“tf.Tensor”。我们可以通过变量名称的管理来过滤出里面的部分变量,以及对应的梯度。
apply_gradients()的源码如下:
apply_gradients(self, grads_and_vars, global_step=None, name=None)
grads_and_vars的格式就是compute_gradients()所返回的(gradient, variable)对,当然数据类型也是“tf.Tensor”,作用是,更新grads_and_vars中variable的梯度,不在里面的变量的梯度不变。
tensorflow语法【shape、tf.trainable_variables()、Optimizer.minimize()】的更多相关文章
- tf.trainable_variables()
https://blog.csdn.net/shwan_ma/article/details/78879620 一般来说,打印tensorflow变量的函数有两个:tf.trainable_varia ...
- tensorflow 生成随机数 tf.random_normal 和 tf.random_uniform 和 tf.truncated_normal 和 tf.random_shuffle
____tz_zs tf.random_normal 从正态分布中输出随机值. . <span style="font-size:16px;">random_norma ...
- tensorflow 基本函数(1.tf.split, 2.tf.concat,3.tf.squeeze, 4.tf.less_equal, 5.tf.where, 6.tf.gather, 7.tf.cast, 8.tf.expand_dims, 9.tf.argmax, 10.tf.reshape, 11.tf.stack, 12tf.less, 13.tf.boolean_mask
1. tf.split(3, group, input) # 拆分函数 3 表示的是在第三个维度上, group表示拆分的次数, input 表示输入的值 import tensorflow ...
- Tensorflow 学习笔记 -----tf.where
TensorFlow函数:tf.where 在之前版本对应函数tf.select 官方解释: tf.where(input, name=None)` Returns locations of true ...
- 【TensorFlow基础】tf.add 和 tf.nn.bias_add 的区别
1. tf.add(x, y, name) Args: x: A `Tensor`. Must be one of the following types: `bfloat16`, `half`, ...
- Tensorflow中的tf.argmax()函数
转载请注明出处:http://www.cnblogs.com/willnote/p/6758953.html 官方API定义 tf.argmax(input, axis=None, name=None ...
- tensorflow中使用tf.variable_scope和tf.get_variable的ValueError
ValueError: Variable conv1/weights1 already exists, disallowed. Did you mean to set reuse=True in Va ...
- tf.trainable_variables和tf.all_variables的对比
tf.trainable_variables返回的是可以用来训练的变量列表 tf.all_variables返回的是所有变量的列表
- TensorFlow高级API(tf.contrib.learn)及可视化工具TensorBoard的使用
一.TensorFlow高层次机器学习API (tf.contrib.learn) 1.tf.contrib.learn.datasets.base.load_csv_with_header 加载cs ...
- tensorflow笔记:使用tf来实现word2vec
(一) tensorflow笔记:流程,概念和简单代码注释 (二) tensorflow笔记:多层CNN代码分析 (三) tensorflow笔记:多层LSTM代码分析 (四) tensorflow笔 ...
随机推荐
- Web 3.0 会是互联网的下一个时代吗?
2000 年初,只读互联网 Web 1.0 被 Web 2.0 所取代.在 Web 2.0 时代,用户摆脱了只读的困扰,可以在平台上进行互动并创作内容.而 Web 3.0 的到来,除了加密货币和区块链 ...
- 【库函数】QT 中QString字符串的操作
QString是QT提供的字符串类,相应的也就提供了很多很方便对字符串的处理方法.这里把这些对字符串的操作做一个整理和总结. 1. 将一个字符串追加到另一个字符串的末尾 QString str1 = ...
- 【CJsonObject】C++ JSON 解析器使用教程
能选封装的尽量不使用底层的 一.CJsonObject 简介 CJsonObject 是 Bwar 基于 cJSON 全新开发一个 C++ 版的 JSON 库. CJsonObject 的最大优势是轻 ...
- #2089: 不要62 (数位dp模板题,附带详细解释)
题目链接 题意:问区间[n,m]中,不含数字4,也不含数字串"62"的所有数的个数. 思路:可以转化成求区间[0,x] 第一次接触数位dp,参考了这几篇博客. 不要62(数位dp) ...
- AtCoder Regular Contest 116 (A~F补题记录)
补题链接:Here 第一次打 ARC,被数学题虐惨了 赛后部分数学证明学习自 ACwisher A - Odd vs Even \(T(1≤T≤2×10^5)\)组测试数据,每次询问一个正整数 \(N ...
- HTML+CSS小实战案例 (照片墙特效、代码展示)
预览图: HMTL代码部分 <!DOCTYPE html> <html lang="en"> <head> <meta charset=& ...
- Kubernetes: client-go 源码剖析(二)
kubernetes:client-go 系列文章: Kubernetes: client-go 源码剖析(一) Kubernetes: client-go 源码剖析(二) 2.3 运行 inform ...
- docker 原理之 user namespace(下)
1. user namespace user namespace 主要隔离了安全相关的标识符和属性,包括用户 ID,用户组 ID,key 和 capabilities 等.同样一个用户 id 在不同 ...
- C#对象二进制序列化优化:位域技术实现极限压缩
目录 1. 引言 2. 优化过程 2.1. 进程对象定义与初步分析 2.2. 排除Json序列化 2.3. 使用BinaryWriter进行二进制序列化 2.4. 数据类型调整 2.5. 再次数据类型 ...
- 应用程序使用统计信息 – .NET CORE(C#) WPF界面设计
应用程序使用统计信息 - .NET CORE(C#) WPF界面设计 首发文章地址:https://dotnet9.com/10546.html 关键功能点 抽屉式菜单 圆形进度条 Demo演示: 1 ...