(原)tensorflow中finetune某些层
转载请注明处处:
http://www.cnblogs.com/darkknightzh/p/7608709.html
参考网址:
https://kratzert.github.io/2017/02/24/finetuning-alexnet-with-tensorflow.html
https://github.com/kratzert/finetune_alexnet_with_tensorflow/blob/master/finetune.py#L109
https://github.com/davidsandberg/facenet
得到正常训练时的train_op时,使用tf.trainable_variables(),想要finetune,使用下面这句话(参考网址1,其实网址1和2都是同一个人的。。。):
fine_tune_var_list = [v for v in tf.trainable_variables() if v.name.split('/')[0] in train_layers]
或者
fine_tune_var_list = [i for i in tf.trainable_variables() if 'fc1/weights' in i.name]
可以得到需要finetune的参数。如果不知道参数名字,可以先print出来:
for var in tf.trainable_variables():
print(var)
然后找到需要finetune的层。也可以通过其他方法,只要能找到就行。。。
将fine_tune_var_list传给train_op,之后正常训练,便可以对网络进行finetune了,如下(见参考网址2):
with tf.name_scope("train"):
# Get gradients of all trainable variables
gradients = tf.gradients(loss, var_list)
gradients = list(zip(gradients, var_list))
# Create optimizer and apply gradient descent to the trainable variables
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
train_op = optimizer.apply_gradients(grads_and_vars=gradients)
当然,也可以使用自己其他的代码。但是,我这边直接使用上面代码,保存模型时,后缀都是0,把train_op 那句改成train_op = optimizer.apply_gradients(grads_and_vars=gradients , global_step=global_step)就可以了。当然,我出现的问题,和自己的代码有关系。
也可以用下面的代码(从参考网址3中提取):
opt = tf.train.GradientDescentOptimizer(learning_rate)
train_op = get_train_op(grad, opt, global_step, args.moving_average_decay, finetune_params)
其中,
def get_train_op(grads, opt, global_step, moving_average_decay, train_var):
# Apply gradients.
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
# Track the moving averages of trainable variables.
variable_averages = tf.train.ExponentialMovingAverage(moving_average_decay, global_step)
variables_averages_op = variable_averages.apply(train_var)
with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
train_op = tf.no_op(name='train')
return train_op
注意的是,参考网址2中的代码,再对参数更新的时候,没有使用滑动平均。上面的代码,使用了滑动平均。
(原)tensorflow中finetune某些层的更多相关文章
- TensorFlow中max pooling层各参数的意义
官方教程中没有解释pooling层各参数的意义,找了很久终于找到,在tensorflow/python/ops/gen_nn_ops.py中有写: def _max_pool(input, ksize ...
- (原)torch中微调某层参数
转载请注明出处: http://www.cnblogs.com/darkknightzh/p/6221664.html 参考网址: https://github.com/torch/nn/issues ...
- tensorflow CNN 卷积神经网络中的卷积层和池化层的代码和效果图
tensorflow CNN 卷积神经网络中的卷积层和池化层的代码和效果图 因为很多 demo 都比较复杂,专门抽出这两个函数,写的 demo. 更多教程:http://www.tensorflown ...
- (原)torch和caffe中的BatchNorm层
转载请注明出处: http://www.cnblogs.com/darkknightzh/p/6015990.html BatchNorm具体网上搜索. caffe中batchNorm层是通过Batc ...
- (原)tensorflow中函数执行完毕,显存不自动释放
转载请注明出处: http://www.cnblogs.com/darkknightzh/p/7608916.html 参考网址: https://stackoverflow.com/question ...
- 第十四节,TensorFlow中的反卷积,反池化操作以及gradients的使用
反卷积是指,通过测量输出和已知输入重构未知输入的过程.在神经网络中,反卷积过程并不具备学习的能力,仅仅是用于可视化一个已经训练好的卷积神经网络,没有学习训练的过程.反卷积有着许多特别的应用,一般可以用 ...
- tensorflow中slim模块api介绍
tensorflow中slim模块api介绍 翻译 2017年08月29日 20:13:35 http://blog.csdn.net/guvcolie/article/details/77686 ...
- CNN中的卷积核及TensorFlow中卷积的各种实现
声明: 1. 我和每一个应该看这篇博文的人一样,都是初学者,都是小菜鸟,我发布博文只是希望加深学习印象并与大家讨论. 2. 我不确定的地方用了"应该"二字 首先,通俗说一下,CNN ...
- TensorFlow中的通信机制——Rendezvous(二)gRPC传输
背景 [作者:DeepLearningStack,阿里巴巴算法工程师,开源TensorFlow Contributor] 本篇是TensorFlow通信机制系列的第二篇文章,主要梳理使用gRPC网络传 ...
随机推荐
- [leetcode]ZigZag Conversion @ Python
原题地址:https://oj.leetcode.com/problems/zigzag-conversion/ 题意: The string "PAYPALISHIRING" i ...
- Windows 安装配置 JIRA
MySQL-5.5.28 JDK1.6.0_21 JIRA功能全面,界面友好,安装简单,配置灵活,权限管理以及可扩展性方面都十分出色. 一.MySQL建库和建账号 1. mysql中创建数据库jira ...
- 【Kafka】Kafka-配置参数详解-参数调优
Kafka-配置参数详解-参数调优 kafka 目录_百度搜索 为什么kafka使用磁盘而不是内存 - CSDN博客 Kafka 配置说明 - 風吹云动 - 博客园 kafka生产服务器配置 - Or ...
- (转)C/C++ 程序设计员应聘常见 面试笔试 试题深入剖析
C/C++ 程序设计员应聘常见 面试笔试 试题深入剖析 http://www.nowcoder.com/discuss/1826?type=2&order=0&pos=23&p ...
- Android 之 WebView开发问题及优化
WebView 在现在的项目中使用的频率应该还是非常高的,WebView 主要用来加载一些容易改变的频繁交互的应用App.目前 HTML5 是一种趋势.在开发中会遇到一些开发问题及优化问题,如下所记. ...
- 【树莓派】使用VNC进行远程控制
之前有进行过VNC以及xrdp连接树莓派,并成功了. 这里看到一篇比较新的,基于mac的连接,文章转载收藏,实践可参考. 这一课里我们将学习如何在树莓派上安装和使用VNC.它可以使你通过图形界面的方式 ...
- 【树莓派】在树莓派的Android系统中安装APK应用
树莓派3 Android TV安装APK应用教程 本文摘自:http://www.mz6.net/news/android/6867.html 树莓派3 Android TV怎样安装软件?对于熟悉AD ...
- hadoop 文件合并
来自:http://blog.csdn.net/dandingyy/article/details/7490046 众所周知,Hadoop对处理单个大文件比处理多个小文件更有效率,另外单个文件也非常占 ...
- Python实践摘录
1:中文编码问题 Python语言默认不识别UTF-8的编码字符串,所以当文件中有中文并且是以UTF-8编码时,需要在python文件头部加一行注释,指明识别utf-8编码. # coding=utf ...
- Java中的List
转载请注明原文地址:http://www.cnblogs.com/ygj0930/p/6538256.html Java中常用的List子类主要有:ArrayList.LinkedList.Vecto ...