tensorflow动态设置trainable
tensorflow中定义的tf.Variable时,可以通过trainable属性控制这个变量是否可以被优化器更新。但是,tf.Variable的trainable属性是只读的,我们无法动态更改这个只读属性。在定义tf.Variable时,如果指定trainable=True,那么会把这个Variable添加到“可被训练的变量”集合中。
把trainable指定为布尔变量是不管用的,trainable只在定义变量的那一瞬间有用。
# trainable只能是bool值,不能是张量
trainable = tf.Variable(False, dtype=tf.bool)
loss = tf.Variable(3.0, dtype=tf.float32, trainable=trainable)
train_op = tf.train.AdamOptimizer(0.01).minimize(loss)
with tf.Session()as sess:
sess.run(tf.global_variables_initializer())
for i in range(100):
_, lo = sess.run([train_op, loss], feed_dict={
trainable: i % 10 < 5
})
print('epoch', i, 'loss', lo)
在定义Variable变量的那一瞬间,如果trainable=true,这个变量就会被添加到可被训练的变量集合中去。当定义optimizer的minimize张量时,minimize张量就会读取可被训练的变量集合并构建张量。此后,即便可被训练的变量集合发生改变,minimize张量也不会再去管哪些变量不能被训练了。
"""
如果optimizer的全部变量都是不可训练的,tensorflow会抛出异常
所以在这里使用两个变量,两个变量轮流变得可调节
:return:
"""
x = tf.Variable(3.0, dtype=tf.float32)
y = tf.Variable(13.0, dtype=tf.float32)
train_op = tf.train.AdamOptimizer(0.01).minimize(tf.abs(y - x))
with tf.Session()as sess:
sess.run(tf.global_variables_initializer())
print("trainable_variables is a function")
print(tf.trainable_variables, type(tf.trainable_variables()))
print(tf.trainable_variables())
print("tf.GraphKeys has several string key")
print(tf.GraphKeys.TRAINABLE_VARIABLES, type(tf.GraphKeys.TRAINABLE_VARIABLES))
print("tf.get_collection can get something by tf.GraphKeys")
col = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
print(col, type(col))
print("try remove x from trainable variables")
del col[col.index(x)] # 此处虽然可被训练的变量集合变化了,但是train_op已经定义完了
print(tf.trainable_variables())
print('=======')
for i in range(100):
_, xx, yy = sess.run([train_op, x, y])
print('epoch', i, xx, yy) # 此处x和y都会变化
tf.GraphKeys
tf.GraphKeys中包含了所有默认集合的名称,可以通过查看__dict__发现具体集合。
tf.GraphKeys.GLOBAL_VARIABLES:global_variables被收集在名为tf.GraphKeys.GLOBAL_VARIABLES的colletion中,包含了模型中的通用参数
tf.GraphKeys.TRAINABLE_VARIABLES:tf.Optimizer默认只优化tf.GraphKeys.TRAINABLE_VARIABLES中的变量。
- tf.global_variables() GLOBAL_VARIABLES
存储和读取checkpoints时,使用其中所有变量
跨设备全局变量集合 - tf.trainable_variables() TRAINABLE_VARIABLES
训练时,更新其中所有变量
存储需要训练的模型参数的变量集合 - tf.moving_average_variables() MOVING_AVERAGE_VARIABLES
ExponentialMovingAverage对象会生成此类变量
实用指数移动平均的变量集合 - tf.local_variables() LOCAL_VARIABLES
在global_variables()之外,需要用tf.init_local_variables()初始化
进程内本地变量集合 - tf.model_variables() MODEL_VARIABLES
Key to collect model variables defined by layers.
进程内存储的模型参数的变量集合 - QUEUE_RUNNERS 并非存储variables,存储处理输入的QueueRunner
- SUMMARIES 并非存储variables,存储日志生成相关张量
除了以上的函数外(上表中最后两个集合并非变量集合,为了方便一并放在这里),还可以使用tf.get_collection(集合名)获取集合中的变量,不过这个函数更多与tf.get_collection(集合名)搭配使用,操作自建集合。
Summary被收集在名为tf.GraphKeys.UMMARIES的colletion中,Summary是对网络中Tensor取值进行监测的一种Operation,这些操作在图中是“外围”操作,不影响数据流本身,调用tf.scalar_summary系列函数时,就会向默认的collection中添加一个Operation。
我们也可以自定义变量集合、操作集合,这在正则化参数时非常有用。
x1 = tf.constant(1.0)
l1 = tf.nn.l2_loss(x1)
x2 = tf.constant([2.5, -0.3])
l2 = tf.nn.l2_loss(x2)
tf.add_to_collection("losses", l1)
tf.add_to_collection("losses", l2)
losses = tf.get_collection('losses')
loss_total = tf.add_n(losses)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
losses_val = sess.run(losses)
loss_total_val = sess.run(loss_total)
我说
tensorflow臃肿庞杂,设计者的设计水平远远比不上keras。
tensorflow臃肿庞杂,做了许多外围操作。比如为变量起名字,把变量添加到集合中,使用summary来监控训练中产生的数据。这些操作都不是核心操作,分清核心操作和扩展操作非常重要。
- 基本操作:如加减乘除、矩阵乘法等运算
- python语言操作:基本上是一些外围操作如collection,summary,dataset等。tf.gfile中定义了一堆文件操作,比python自带的文件操作要高效易用。
- 函数级封装:把经常使用的基本操作定义成一个函数,如softmax、wx_b、cross_entropy等。
- 层级封装:定义一些常见层,如全连接层、卷积层等。
- 模型封装:keras中有Model,Tensorflow不好意思直接拿来用,起了个名叫“Estimator”。
optimizer其实也是一种封装,optimizer其实就是对变量执行assign操作。除了使用反向传播,我们也可以自己定义基于遗传算法的optimizer。
拦截optimizer的梯度更新过程实现动态trainable
optimizer计算梯度的过程是应用梯度的过程是两个步骤。计算梯度张量返回一个grad_and_vars列表,应用梯度需要grad_and_vars列表作为参数。
我们可以建立(loss,exemp)到minize张量的映射。
# 拦截梯度更新过程
class MyOptimizer:
def __init__(self, optimizer: tf.train.Optimizer):
self.optimizer = optimizer
self.operations = dict()
def minimize(self, loss, exemp):
"""
注意:因为minimize操作是在sess运行时运行的,如果总是创建新操作,GPU内存会溢出
"""
k = ' '.join(sorted([i.name for i in exemp])) + loss.name
if k not in self.operations:
a = [i for i in tf.trainable_variables() if i not in exemp]
grad_vars = self.optimizer.compute_gradients(loss, a)
op = self.optimizer.apply_gradients(grad_vars)
self.operations[k] = op
return self.operations[k]
x = tf.Variable(3.0, dtype=tf.float32)
y = tf.Variable(31.0, dtype=tf.float32)
loss = tf.abs(x - y)
"""
为了初始化optimizer中的一些信息,所以需要来一个加的operation形成一个张量
"""
optimizer = MyOptimizer(tf.train.AdamOptimizer(0.01))
train_op = optimizer.minimize(loss, [])
with tf.Session()as sess:
sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))
for i in range(100):
exemp = [x if i % 10 < 5 else y]
_, xx, yy, lo = sess.run([optimizer.minimize(loss, exemp=exemp), x, y, loss])
print('epoch', i, 'x', xx, 'y', yy, 'loss', lo)
这种方法的缺点在于使用loss和exemp作为key,如果key太多,定义的张量就会变多,这样会产生很多变量。
尝试优化一下,使用loss作为key。
def __init__(self, optimizer: tf.train.Optimizer):
self.optimizer = optimizer
self.operations = dict()
def minimize(self, loss, exemp):
"""
注意:因为minimize操作是在sess运行时运行的,如果总是创建新操作,GPU内存会溢出
"""
if loss.name not in self.operations:
grad_vars = self.optimizer.compute_gradients(loss)
self.operations[loss.name] = grad_vars
grad_vars = self.operations[loss.name]
exemp = set(exemp)
grad_vars = list(filter(lambda x: x[1] not in exemp, grad_vars))
op = self.optimizer.apply_gradients(grad_vars)
return op
x = tf.Variable(3.0, dtype=tf.float32)
y = tf.Variable(31.0, dtype=tf.float32)
loss = tf.abs(x - y)
"""
为了初始化optimizer中的一些信息,所以需要来一个加的operation形成一个张量
"""
optimizer = MyOptimizer(tf.train.AdamOptimizer(0.01))
train_op = optimizer.minimize(loss, [])
with tf.Session()as sess:
sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))
for i in range(100):
exemp = [x if i % 10 < 5 else y]
_, xx, yy, lo = sess.run([optimizer.minimize(loss, exemp=exemp), x, y, loss])
print('epoch', i, 'x', xx, 'y', yy, 'loss', lo)
这种方法其实更差劲,因为apply_gradients依旧会创建许多张量(许多tf.assign_sub张量),而第一种方法反倒没有那么多的张量。
梯度更新的过程其实就是一堆assign操作。
# 拦截梯度更新过程
class MyOptimizer:
def __init__(self, optimizer: tf.train.Optimizer):
self.optimizer = optimizer
self.operations = dict()
def minimize(self, loss, exemp):
"""
注意:因为minimize操作是在sess运行时运行的,如果总是创建新操作,GPU内存会溢出
"""
if loss.name not in self.operations:
grad_vars = self.optimizer.compute_gradients(loss)
op = [(variable, tf.assign_sub(variable, self.optimizer._lr * grad)) for grad, variable in grad_vars]
self.operations[loss.name] = op
grad_vars = self.operations[loss.name]
op = [x[1] for x in grad_vars if x[0] not in exemp]
return op
x = tf.Variable(3.0, dtype=tf.float32)
y = tf.Variable(31.0, dtype=tf.float32)
loss = tf.abs(x - y)
"""
为了初始化optimizer中的一些信息,所以需要来一个加的operation形成一个张量
"""
optimizer = MyOptimizer(tf.train.AdamOptimizer(0.01))
train_op = optimizer.minimize(loss, [])
with tf.Session()as sess:
sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))
for i in range(100):
exemp = [x if i % 10 < 5 else y]
_, xx, yy, lo = sess.run([optimizer.minimize(loss, exemp=exemp), x, y, loss])
print('epoch', i, 'x', xx, 'y', yy, 'loss', lo)
参考资料
https://www.cnblogs.com/hellcat/p/9006904.html
tensorflow动态设置trainable的更多相关文章
- android ImageView网络图片加载、动态设置尺寸、圆角..
封装了一个关于ImageView的辅助类,该类可以方便实现网络图片下载的同时,动态设置图片尺寸.圆角.....一系列连贯的操作,无样式表,java代码实现所有功能,使用很方便. package com ...
- 根据屏幕大小动态设置字体rem
1.根据屏幕大小动态设置字体rem var docEl = document.documentElement, //当设备的方向变化(设备横向持或纵向持)此事件被触发.绑定此事件时, //注意现在当浏 ...
- 动态设置和访问cxgrid列的Properties(转)
原文:http://www.cnblogs.com/hnxxcxg/archive/2010/05/24/2940711.html 动态设置和访问cxgrid列的Properties 设置: cxGr ...
- SSRS动态设置文本框属性
SSRS可以通过表达式动态设置文本框所有的属性,比如字体,字号,是否加粗,如下图所示: 汉字和数字英文字母占用的空间不一样,一个汉字占用两个数字和英文字母的空间,VB里有LENB取得字节数,这SSRS ...
- 【Android疑难杂症】GridView动态设置Item的宽高导致第一个Item不响应或显示不正常的问题
前言 这个问题在之前做一个盒子项目时遇到过,最近又遇到了,使用GridView遇到的非常奇葩的问题,这里记录分享一下. 声明 欢迎转载,但请保留文章原始出处:) 博客园:http://www.cnb ...
- 如果动态设置json对象的key
项目中要求动态设置json的key属性,如果按照一般的json设置方法是不行的.假如你把一个key设置为一个变量的话,那么最后js解析出来的就是key为这个变量名而不是这个变量的值. 解决:通过使用 ...
- Android 通过Java代码生成创建界面。动态生成View,动态设置View属性。addRules详解
废话不多说,本文将会层层深入给大家讲解如何动态的生成一个完整的界面. 本文内容: Java代码中动态生成View Java代码中动态设置View的位置,以及其他的属性 LayoutParams详解 一 ...
- easyui表单多重验证,动态设置easyui控件
要实现的功能:在做添加学生信息的时候,利用easyui的验证功能判断 学号是否重复和学号只能为数字 最终效果如下图: 但在做这个的过程中,遇到了一系列的问题: 扩展validatebox的验证方法,最 ...
- Quartz在Spring中动态设置cronExpression (spring设置动态定时任务)
什么是动态定时任务:是由客户制定生成的,服务端只知道该去执行什么任务,但任务的定时是不确定的(是由客户制定). 这样总不能修改配置文件每定制个定时任务就增加一个trigger吧,即便允许客户 ...
随机推荐
- C#线程同步方法汇总
我们在编程的时候,有时会使用多线程来解决问题,比如你的程序需要在 后台处理一大堆数据,但还要使用户界面处于可操作状态:或者你的程序需要访问一些外部资源如数据库或网络文件等.这些情况你都可以创建一个子线 ...
- 8天学通MongoDB——第一天 基础入门(转)
关于mongodb的好处,优点之类的这里就不说了,唯一要讲的一点就是mongodb中有三元素:数据库,集合,文档,其中“集合” 就是对应关系数据库中的“表”,“文档”对应“行”. 一: 下载 上Mon ...
- CSS渐变字体、镂空字体、input框提示信息颜色、给图片加上内阴影、3/4圆
1.渐变字体 主要是看:-webkit-background-clip: text; 该属性 <style> .b1{ width: 500px; height: 200px; font- ...
- Python3 比较两个图片是否类似或相同
Python代码 #coding:utf8 import os from PIL import Image,ImageDraw,ImageFile import numpy import pytess ...
- Openstack中为虚拟机使用CDROM光驱设备
在Libvirt里处理 在nova里处理 实际效果 怎么卸载 在Libvirt里处理 尝试了下面有几种方法,为虚拟机载入光盘文件: 1.使用ide方式挂载: virsh attach-disk {in ...
- PHP优化---opcache的配置说明
[opcache] zend_extension = "G:/PHP/php-5.5.6-Win32-VC11-x64/ext/php_opcache.dll" ; Zend Op ...
- Eclipse Maven项目报错3之找不到配置文件spring-servlet-context.xml
一.具体错误如下图所示 根据文字提示可以看出是这个文件找不到,但是我去项目的这个目录找了,这个文件确实存在,那么是什么问题呢 二.解决问题 原因分析(来自网上) 代码编译的过程,是一个自动生成相应编译 ...
- vc2008中mfc菜单、控件等汉字显示为问号或乱码的解决方法
在vc2008中建立基于mfc的project.在向导的Application type页面中如果在resource language选项中选择"英语(美国)"(图一),那么在pr ...
- MySQL存储引擎与数据类型
1 数据存储引擎 存储引擎的概念是MySQL的一个特性,它指定了表的类型(诸如表怎样存储与索引数据.是否支持事务.外键等),表在计算机中的存储方式. 1.1 MySql支持的数据存储引擎 查看引擎信息 ...
- .geodatabase与gdb的相互转换
.geodatabase长得是gdb的全称,确实它们有一定的关系,但也有区别. 简单认识一下 有人也问过我,gdb外表像个文件夹,是怎么实现的.gdb数据库是ESRI特有的数据库,它是一些数据集定义. ...