tensorflow中共享变量 tf.get_variable 和命名空间 tf.variable_scope
tensorflow中有很多需要变量共享的场合,比如在多个GPU上训练网络时网络参数和训练数据就需要共享。
tf通过 tf.get_variable() 可以建立或者获取一个共享的变量。 tf.get_variable函数的作用从tf的注释里就可以看出来-- ‘Gets an existing variable with this name or create a new one’。
与 tf.get_variable 函数相对的还有一个 tf.Variable 函数,两者的区别是:
- tf.Variable定义变量的时候会自动检测命名冲突并自行处理,例如已经定义了一个名称是 ‘wg_1’的变量,再使用tf.Variable定义名称是‘wg_1’的变量,会自动把后一个变量的名称更改为‘wg_1_0’,实际相当于创建了两个变量,tf.Variable不可以创建共享变量。
- tf.get_variable定义变量的时候不会自动处理命名冲突,如果遇到重名的变量并且创建该变量时没有设置为共享变量,tf会直接报错。
变量可以共享之后还有一个问题就是当模型很大很复杂的时候,变量和操作的数量也比较庞大,为了方便对这些变量进行管理,维护条理清晰的graph结构,tf建立了一套共享机制,通过 变量作用域(命名空间,variable_scope)实现对变量的共享和管理。例如,cnn的每一层中,均有weights和biases这两个变量,通过tf.variable_scope()为每一卷积层命名,就可以防止变量命名重复。
与 tf.variable_scope相对的还有一个 tf.name_scope 函数,两者的区别是:
- tf.name_scope 主要用于管理一个图(graph)里面的各种操作,返回的是一个以scope_name命名的context manager。一个graph会维护一个name_space的堆,每一个namespace下面可以定义各种op或者子namespace,实现一种层次化有条理的管理,避免各个op之间命名冲突。
- tf.variable_scope 一般与tf.name_scope()配合使用,用于管理一个图(graph)中变量的名字,避免变量之间的命名冲突,tf.variable_scope允许在一个variable_scope下面共享变量。
# coding: utf-8
import tensorflow as tf
# 定义的基本等价
v1 = tf.get_variable("v", shape=[1], initializer= tf.constant_initializer(1.0))
v2 = tf.Variable(tf.constant(1.0, shape=[1]), name="v")
with tf.variable_scope("abc"):
v3=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))
# 在变量作用域内定义变量,不同变量作用域内的变量命名可以相同
with tf.variable_scope("xyz"):
v4=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))
with tf.variable_scope("xyz", reuse=True):
v5 = tf.get_variable("v")
v6 = tf.get_variable("v",[1])
with tf.variable_scope("foo"):
v7 = tf.get_variable("v", [1])
# 通过 tf.get_variable_scope().reuse_variables() 设置以下的变量是共享变量;
# 如果不加,v8的定义会由于重名而报错
tf.get_variable_scope().reuse_variables()
v8 = tf.get_variable("v", [1])
assert v7 is v8
with tf.variable_scope("foo_1") as foo_scope:
v = tf.get_variable("v", [1])
with tf.variable_scope(foo_scope):
w = tf.get_variable("w", [1])
with tf.variable_scope(foo_scope, reuse=True):
v1 = tf.get_variable("v", [1])
w1 = tf.get_variable("w", [1])
assert v1 is v
assert w1 is w
with tf.variable_scope("foo1"):
with tf.name_scope("bar1"):
v_1 = tf.get_variable("v", [1])
x_1 = 1.0 + v_1
assert v_1.name == "foo1/v:0"
assert x_1.op.name == "foo1/bar1/add"
print v1==v2 # False
print v3==v4 # False 不同变量作用域中
print v3.name # abc/v:0
print v4==v5 # 输出为True
print v5==v6 # True
tensorflow中共享变量 tf.get_variable 和命名空间 tf.variable_scope的更多相关文章
- TensorFlow中的L2正则化函数:tf.nn.l2_loss()与tf.contrib.layers.l2_regularizerd()的用法与异同
tf.nn.l2_loss()与tf.contrib.layers.l2_regularizerd()都是TensorFlow中的L2正则化函数,tf.contrib.layers.l2_regula ...
- TensorFlow中的变量命名以及命名空间.
What: 在Tensorflow中, 为了区别不同的变量(例如TensorBoard显示中), 会需要命名空间对不同的变量进行命名. 其中常用的两个函数为: tf.variable_scope, t ...
- 【tf.keras】tf.keras使用tensorflow中定义的optimizer
Update:2019/09/21 使用 tf.keras 时,请使用 tf.keras.optimizers 里面的优化器,不要使用 tf.train 里面的优化器,不然学习率衰减会出现问题. 使用 ...
- Tensorflow中的name_scope和variable_scope
Tensorflow是一个编程模型,几乎成为了一种编程语言(里面有变量.有操作......). Tensorflow编程分为两个阶段:构图阶段+运行时. Tensorflow构图阶段其实就是在对图进行 ...
- 对tensorflow 中的attention encoder-decoder模型调试分析
#-*-coding:utf8-*- __author = "buyizhiyou" __date = "2017-11-21" import random, ...
- tensorflow中使用tf.variable_scope和tf.get_variable的ValueError
ValueError: Variable conv1/weights1 already exists, disallowed. Did you mean to set reuse=True in Va ...
- TensorFlow中get_variable共享变量调用
import tensorflow as tf with tf.variable_scope('v_scope',reuse=True) as scope1: Weights1 = tf.get_va ...
- TF之RNN:TF的RNN中的常用的两种定义scope的方式get_variable和Variable—Jason niu
# tensorflow中的两种定义scope(命名变量)的方式tf.get_variable和tf.Variable.Tensorflow当中有两种途径生成变量 variable import te ...
- tensorflow中 tf.add_to_collection、 tf.get_collection 和 tf.add_n函数
tf.add_to_collection(name, value) 用来把一个value放入名称是'name'的集合,组成一个列表; tf.get_collection(key, scope=Non ...
随机推荐
- MySQL多个相同结构的表查询并把结果合并放在一起的语句(union all)
union all select *,'1' as category from table1001 where price > 10 union all select *,'2' as cate ...
- Spark之Task原理分析
在Spark中,一个应用程序要想被执行,肯定要经过以下的步骤: 从这个路线得知,最终一个job是依赖于分布在集群不同节点中的task,通过并行或者并发的运行来完成真正的工作.由此可见 ...
- ScyllaDB - 基础部署
基础环境 操作系统: CentOS 7.2: 集群节点(虚拟机):172.16.134.15 ~ 17: 基础准备 安装依赖和卸载 abrt ( abrt 和 coredump 配置冲突 ): sud ...
- 模块讲解----反射 (基于web路由的反射)
一.反射的实际案例: def main(): menu = ''' 1.账户信息 2.还款 3.取款 4.转账 5.账单 ''' menu_dic = { ':account_info, ':repa ...
- mysql key index区别
看似有差不多的作用,加了Key的表与建立了Index的表,都可以进行快速的数据查询.他们之间的区别在于处于不同的层面上. Key即键值,是关系模型理论中的一部份,比如有主键(Primary Key), ...
- EditPlus 4.3.2477 中文版已经发布(11月3日更新)
新的版本修复了之前版本文本库和自动完成功能中的“^!”符号在填充项前面时不能正常工作的问题.
- web实现负载均衡的几种实现方式
摘要: 负载均衡(Load Balance)是集群技术(Cluster)的一种应用.负载均衡可以将工作任务分摊到多个处理单元,从而提高并发处理能力.目前最常见的负载均衡应用是Web负载均衡.根据实现的 ...
- crontab 定时执行脚本出错,但手动执行脚本正常
原因: crontab 没有去读环境变量,需要再脚本中手动引入环境变量,可以用source 也可以用export 写死环境变量. 为了定时监控Linux系统CPU.内存.负载的使用情况,写了个Shel ...
- nodejs v8引擎
Node.js 线程你理解的可能是错的 本文代码运行环境 系统:MacOS High Sierra Node.js:10.3.0 复制代码 Node.js是单线程的,那么Node.js启动后线程数是1 ...
- C++中的动态绑定
C++中基类和派生类遵循类型兼容原则:即可用派生类的对象去初始化基类的对象,可用派生类的对象去初始化基类的引用,可用派生类对象的地址去初始化基类对象指针. C++中动态绑定条件发生需要满足2个条件: ...