转载请注明处处:

http://www.cnblogs.com/darkknightzh/p/7608709.html

参考网址:

https://kratzert.github.io/2017/02/24/finetuning-alexnet-with-tensorflow.html

https://github.com/kratzert/finetune_alexnet_with_tensorflow/blob/master/finetune.py#L109

https://github.com/davidsandberg/facenet

得到正常训练时的train_op时,使用tf.trainable_variables(),想要finetune,使用下面这句话(参考网址1,其实网址1和2都是同一个人的。。。):

fine_tune_var_list = [v for v in tf.trainable_variables() if v.name.split('/')[0] in train_layers]

或者

fine_tune_var_list = [i for i in tf.trainable_variables() if 'fc1/weights' in i.name]

可以得到需要finetune的参数。如果不知道参数名字,可以先print出来:

for var in tf.trainable_variables():
print(var)

然后找到需要finetune的层。也可以通过其他方法,只要能找到就行。。。

将fine_tune_var_list传给train_op,之后正常训练,便可以对网络进行finetune了,如下(见参考网址2):

with tf.name_scope("train"):
# Get gradients of all trainable variables
gradients = tf.gradients(loss, var_list)
gradients = list(zip(gradients, var_list)) # Create optimizer and apply gradient descent to the trainable variables
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
train_op = optimizer.apply_gradients(grads_and_vars=gradients)

当然,也可以使用自己其他的代码。但是,我这边直接使用上面代码,保存模型时,后缀都是0,把train_op 那句改成train_op = optimizer.apply_gradients(grads_and_vars=gradients , global_step=global_step)就可以了。当然,我出现的问题,和自己的代码有关系。

也可以用下面的代码(从参考网址3中提取):

opt = tf.train.GradientDescentOptimizer(learning_rate)
train_op = get_train_op(grad, opt, global_step, args.moving_average_decay, finetune_params)

其中,

def get_train_op(grads, opt, global_step, moving_average_decay, train_var):

    # Apply gradients.
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) # Track the moving averages of trainable variables.
variable_averages = tf.train.ExponentialMovingAverage(moving_average_decay, global_step)
variables_averages_op = variable_averages.apply(train_var) with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
train_op = tf.no_op(name='train') return train_op

注意的是,参考网址2中的代码,再对参数更新的时候,没有使用滑动平均。上面的代码,使用了滑动平均。

(原)tensorflow中finetune某些层的更多相关文章

  1. TensorFlow中max pooling层各参数的意义

    官方教程中没有解释pooling层各参数的意义,找了很久终于找到,在tensorflow/python/ops/gen_nn_ops.py中有写: def _max_pool(input, ksize ...

  2. (原)torch中微调某层参数

    转载请注明出处: http://www.cnblogs.com/darkknightzh/p/6221664.html 参考网址: https://github.com/torch/nn/issues ...

  3. tensorflow CNN 卷积神经网络中的卷积层和池化层的代码和效果图

    tensorflow CNN 卷积神经网络中的卷积层和池化层的代码和效果图 因为很多 demo 都比较复杂,专门抽出这两个函数,写的 demo. 更多教程:http://www.tensorflown ...

  4. (原)torch和caffe中的BatchNorm层

    转载请注明出处: http://www.cnblogs.com/darkknightzh/p/6015990.html BatchNorm具体网上搜索. caffe中batchNorm层是通过Batc ...

  5. (原)tensorflow中函数执行完毕,显存不自动释放

    转载请注明出处: http://www.cnblogs.com/darkknightzh/p/7608916.html 参考网址: https://stackoverflow.com/question ...

  6. 第十四节,TensorFlow中的反卷积,反池化操作以及gradients的使用

    反卷积是指,通过测量输出和已知输入重构未知输入的过程.在神经网络中,反卷积过程并不具备学习的能力,仅仅是用于可视化一个已经训练好的卷积神经网络,没有学习训练的过程.反卷积有着许多特别的应用,一般可以用 ...

  7. tensorflow中slim模块api介绍

    tensorflow中slim模块api介绍 翻译 2017年08月29日 20:13:35   http://blog.csdn.net/guvcolie/article/details/77686 ...

  8. CNN中的卷积核及TensorFlow中卷积的各种实现

    声明: 1. 我和每一个应该看这篇博文的人一样,都是初学者,都是小菜鸟,我发布博文只是希望加深学习印象并与大家讨论. 2. 我不确定的地方用了"应该"二字 首先,通俗说一下,CNN ...

  9. TensorFlow中的通信机制——Rendezvous(二)gRPC传输

    背景 [作者:DeepLearningStack,阿里巴巴算法工程师,开源TensorFlow Contributor] 本篇是TensorFlow通信机制系列的第二篇文章,主要梳理使用gRPC网络传 ...

随机推荐

  1. 移动前端调试工具-Weinre真机调试

    之前做移动前端调试页面的时候就是简单的使用Chrome模拟器调试,能满足基本基本的需求,后来发现了基于Web Inspector(Webkit)的远程调试工具Weinre,可以在PC端直接调试运行在移 ...

  2. Intent 简介 结构 传递数据 常见Action 常量 MD

    Markdown版本笔记 我的GitHub首页 我的博客 我的微信 我的邮箱 MyAndroidBlogs baiqiantao baiqiantao bqt20094 baiqiantao@sina ...

  3. MyBatis动态SQL foreach标签实现批量插入

    需求:查出给定id的记录: <select id="getEmpsByConditionForeach" resultType="com.test.beans.Em ...

  4. jquery.cookie 使用方法

    一个轻量级的cookie 插件,可以读取.写入.删除 cookie. jquery.cookie.js 的配置 首先包含jQuery的库文件,在后面包含 jquery.cookie.js 的库文件. ...

  5. 论述Android通过HttpURLConnection与HttpClient联网代理网关设置

    Android联网主要使用HttpURLConneciton和HttpClient进行联网,在手机联网的时候,我们优先选择wifi网络,其次在选择移动网络,这里所述移动网络主要指cmwap. 大家都知 ...

  6. 大数据开发实战:Hive表DDL和DML

    1.Hive 表 DDL 1.1.创建表 Hive中创建表的完整语法如下: CREATE [EXTERNAL] TABLE [IF NOT EXISTS]  table_name [ (col_nam ...

  7. Spring Boot集成持久化Quartz定时任务管理和界面展示

    本文是对之前的一篇文章Spring+SpringMVC+mybatis+Quartz整合代码部分做的一个修改和补充, 其中最大的变化就是后台框架变成了Spring Boot. 本工程所用到的技术或工具 ...

  8. JS 判断浏览器是否安装Flash 兼容IE、firefox

    /** * @Author: HTL * @Email: Huangyuan413026@163.com * @DateTime: 2016-06-02 11:37:05 * @Description ...

  9. redis清除数据/xargs使用

    redis清除数据/xargs使用 redis比memcache好的地方之一,如果memcache,恐怕就得关掉重启了. 1 使用cli FLUSHDB 清除一个数据库,FLUSHALL清除整个red ...

  10. Docker: 如何将node.js的项目部署到docker的swarm上面去

    前提条件: Docker创建虚机和swarm 如何用Docker建立一个Node.js的开发环境 正文: 将如何用Docker建立一个Node.js的开发环境文中创建的nodehello image发 ...