tensorflow冻结变量方法(tensorflow freeze variable)
最近由于项目需要,要对tensorflow构造的模型中部分变量冻结,然后继续训练,因此研究了一下tf中冻结变量的方法,目前找到三种,各有优缺点,记录如下:
1.名词解释
冻结变量,指的是在训练模型时,对某些可训练变量不更新,即仅参与前向loss计算,不参与后向传播,一般用于模型的finetuning等场景。例如:我们在其他数据上训练了一个resnet152模型,然后希望在目前数据上做finetuning,一般来讲,网络的前几层卷积是用来提取底层图像特征的,因此可以对前3个卷积层进行冻结,不改变其weight和bias的数值。
2.方法介绍
目前我找到了三种tf冻结变量的方法,各有优缺点,具体如下:
2.1 trainable=False
一切tf.Variable或tf.Variable的子类,在创建时,都有一个trainable参数,在tf官方文档(https://www.tensorflow.org/api_docs/python/tf/Variable)中有对这个参数的定义,

意思是,如果trainable设置为True,就会把变量添加到GraphKeys.TRAINABLE_VARIABLES集合中,如果是False,则不添加。而在计算梯度进行后向传播时,我们一般会使用一个optimizer,然后调用该optimizer的compute_gradients方法。在compute_gradients中,第二个参数var_list如果不传入,则默认为GraphKeys.TRAINABLE_VARIABLES。

总结下,trainable=False冻结变量的逻辑:trainable=False → 该变量不会放入GraphKeys.TRAINABLE_VARIABLES → 调用optimizer.compute_gradients方法时默认变量列表为GraphKeys.TRAINABLE_VARIABLES,该变量不在其中,因此不参与后向传播,值不进行更新,达到冻结变量效果。
优点:操作简单,只要在你创建变量时设置trainable=False即可
缺点:不知道大家发现没有,我上面的总结中,optimizer.compute_gradients方法默认变量列表是GraphKeys.TRAINABLE_VARIABLES,这句话还意味着,如果我不想用默认变量列表,而使用自定义变量列表,那么即使设置了trainable=False,只要把该变量加入到自定义变量列表中,变量还是会参与后向传播的,值也会更新。另外,tf.layers、tf.contrib.rnn等一些高度封装的API是不支持这个参数的,没法用该方法冻结变量。最后,如果我们在使用Saver保存ckpt时,一般调动tf.trainable_variables()方法只保存可训练参数,这时返回的变量列表,也有上面的问题,即设置了trainable=False的变量不会在里面。
2.2 tf.stop_gradient()
我们还可以通过在某个变量外面包裹一层tf.stop_gradient()函数来达到冻结变量的目的。例如我们想冻结w1,可以写成这样:
w1 = tf.stop_gradient(w1)
在后向传播时,w1的值就不会更新。下面说下优缺点。
优点:操作简单,针对想冻结的变量,添加上面这一行即可,而且相比于上一个方法,设置了tf.stop_gradient()的变量,不会从GraphKeys.TRAINABLE_VARIABLES集合中去除,因此不会影响梯度计算和保存模型
缺点:和上一个方法类似,tf.stop_gradient()的输入是Tensor,tf.layers、tf.contrib.rnn等一些高度封装的API的返回值没法作为参数传入,即不能用该方法冻结
2.3 optimizer.compute_gradients(loss,var_list=no_freeze_vars)
optimizer.compute_gradients在2.1中提到过,其实我们只需要在计算梯度时,指定变量列表,把希望冻结的变量去除,即可完成冻结变量。但这么做有一个前提,我们必须知道所有可训练变量的名字,并根据一些规则去除变量。获取所有可训练变量名字调用tf.trainable_variables()方法即可,但去除变量则需要我们在构建网络的时候,合理利用tf.variable_scope,对不同变量做区分。例如,我们如果想把可训练变量中所有卷积层变量冻结,可以这么写:
trainable_vars = tf.trainable_variables()
freeze_conv_var_list = [t for t in trainable_vars if not t.name.startswith(u'conv')]
grads = opt.compute_gradients(loss, var_list=freeze_conv_var_list)
下面总结下优缺点,
优点:没有2.1和2.2的缺点,是一种适用范围更加广泛的方法
缺点:相对2.1,2.2使用起来比较复杂,需要自己去除冻结变量,并且variable_scope不能随意改动,因为可能使去除变量的过滤操作无效化。例如:如果把原来'cnn' scope改为'vgg',那么上面的代码就无效了
3.总结
tf对于一些常用操作,往往会提供多种方法,但每种方法一般都是有区别的,并且操作原理和后面的逻辑也会有不同,要谨慎使用
tensorflow冻结变量方法(tensorflow freeze variable)的更多相关文章
- Tensorflow模型变量保存
Tensorflow:模型变量保存 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tensorflow1.4.0 pyt ...
- 111、TensorFlow 初始化变量
# 显式的初始化时非常有用的 # 因为它可以让你不用重复进行繁重的初始化工作 # 当你重新从checkpoint文件中加载一个模型的时候 # 当随机初始化变量被配置在分布式的配置文件中 # 为了在开始 ...
- 05 Tensorflow中变量的初始化
打开Python Shell,输入import tensorflow as tf,然后可以执行以下代码. 1.创建一个2*3的矩阵,并让所有元素的值为0.(类型为tf.float) a = tf.ze ...
- TF:Tensorflow定义变量+常量,实现输出计数功能—Jason niu
#TF:Tensorflow定义变量+常量,实现输出计数功能 import tensorflow as tf state = tf.Variable(0, name='Parameter_name_c ...
- Java获取系统环境变量(System Environment Variable)和系统属性(System Properties)以及启动参数的方法
系统环境变量(System Environment Variable): 在Linux下使用export $ENV=123指定的值.获取的方式如下: Map<String,String> ...
- [tensorflow in a nutshell] tensorflow简明教程 (第一部分)
原文链接: https://medium.com/@camrongodbout/tensorflow-in-a-nutshell-part-one-basics-3f4403709c9d#.31jv5 ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)
tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)
续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...
- Tensorflow r1.12及tensorflow serving r1.12 GPU版本编译遇到的问题
1.git clone tensorflow serving 及tensorflow代码 2. ERROR: /root/.cache/bazel/_bazel_root/f71d782da17fd8 ...
随机推荐
- 从零开始搭建一个从Win7环境备份至CentOS7的SVN双机备份环境
★几个关键的事项★ 需要关闭防火墙filewalld跟selinux ,在root用户下操作: [root@localhost svnRepos]# systemctl stop firewalld ...
- BZOJ4964 : 加长的咒语
把$($看作$-1$,$)$看作$1$,设$a$为前缀和,则相当于找两个位置$x,y$使得$a[x]=a[y]$,且$a[x]$是$[x,y]$的区间最大值. 求出询问区间的最大值$o$,然后找到$o ...
- Android 蓝牙4.0 BLE (onServicesDiscovered 返回 status 是 129,133时)
Android ble (Bluetooth Low Energy) 蓝牙4.0,也就是说android 4.3+, API level >= 18,且支持蓝牙4.0的手机才可以使用. BLE是 ...
- 如何绘制UML图?
首先推荐在线绘制UML的网址:https://www.processon.com/,很好用. 在软件开发过程中,开发人员往往需要通过绘制类图来理清业务的实现思路,从而方便代码实现,也便于后期的代码维护 ...
- c++语言的设计和演化---在线函数
开始的c++语言中引入inline函数的目的是处理一些实时的情况,而普通的函数调用的开销无法被接受. 起初是在类的声明中定义inline函数,也只支持成员函数,后来才支持非成员函数:
- Yii2 DetailView小部件
DetailView小部件 Yii 提供了一套数据库小部件 widgets,这些小部件可以用于显示数据 DetailView 小部件用于显示一条记录数据 ListView 和 GridView 可以用 ...
- windows下安装和配置redis
1.windows下安装和配置redis 1.1 下载: 官网(linux下载地址):https://redis.io/ Windows系统下载地址:https://github.com/MSOpen ...
- GPS信号不足情况下,如何用GPRS模块根据基站进行定位
AT+CREG=2 //设置参数,2为返回详细信息,包含基站的地区区域码和基站码 注意:GPRS命令后面都要有回车 AT+CREG? 下面为返回值 ...
- /etc/security/limits.conf 文件说明
/etc/security/limits.conf 是 Linux 资源使用配置文件,用来限制用户对系统资源的使用 语法:<domain> <type> <item& ...
- js顺序播放列表中的音乐
今天一个朋友问我js顺序播放音乐列表中的音乐的问题,我仔细一想,我也没有做过啊,无从下手啊,怎么办?然后我就上网查了一下audio标签,又百度了js如何顺序播放音乐,结果就找到了解决的办法. audi ...