迁移学习、fine-tune和局部参数恢复
一、迁移学习
就是把已训练好的模型参数迁移到新的模型来帮助新模型训练。
模型的训练与预测:
深度学习的模型可以划分为 训练 和 预测 两个阶段。
训练 分为两种策略:一种是白手起家从头搭建模型进行训练,一种是通过预训练模型进行训练。
预测 相对简单,直接用已经训练好的模型对数据集进行预测即可。
优点:
1)站在巨人的肩膀上:前人花很大精力训练出来的模型在大概率上会比你自己从零开始搭的模型要强悍,没有必要重复造轮子。
2)训练成本可以很低:如果采用导出特征向量的方法进行迁移学习,后期的训练成本非常低,用CPU都完全无压力,没有深度学习机器也可以做。
3)适用于小数据集:对于数据集本身很小(几千张图片)的情况,从头开始训练具有几千万参数的大型神经网络是不现实的,因为越大的模型对数据量的要求越大,过拟合无法避免。这时候如果还想用上大型神经网络的超强特征提取能力,只能靠迁移学习。
迁移学习的几种方式:
1、Transfer Learning:冻结预训练模型的全部卷积层,只训练自己定制的全连接层。
2、Extract Feature Vector:先计算出预训练模型的卷积层对所有训练和测试数据的特征向量,然后抛开预训练模型,只训练自己定制的简配版全连接网络。
3、Fine-tune:冻结预训练模型的部分卷积层(通常是靠近输入的多数卷积层),训练剩下的卷积层(通常是靠近输出的部分卷积层)和全连接层。
* 注:Transfer Learning关心的问题是:什么是“知识”以及如何更好地运用之前得到的“知识”,这可以有很多方法和手段,eg:SVM,贝叶斯,CNN等。
而fine-tune只是其中的一种手段,更常用于形容迁移学习的后期微调中。
三种迁移学习方式对比
1、第一种和第二种训练得到的模型本质上并没有什么区别,但是第二种的计算复杂度要远远优于第一种。
2、第三种是对前两种方法的补充,以进一步提升模型性能。要注意的是,这种方法并不一定能真的对模型有所提升。
本质上来讲:这三种迁移学习的方式都是为了让预训练模型能够胜任新数据集的识别工作,能够让预训练模型原本的特征提取能力得到充分的释放和利用。但是,在此基础上如果想让模型能够达到更低的Loss,那么光靠迁移学习是不够的,靠的更多的还是模型的结构以及新数据集的丰富程度。
二、实验:尝试对模型进行微调,以进一步提升模型性能
1、fine-tune的作用:
拿到新数据集,先用预训练模型处理,通常用上面的方法一或方法二测试预训练模型在新数据上的表现,如果表现不错,可以尝试fine-tune,进一步解锁卷积层以继续训练。
但是不要期待质的飞跃,另外,如果由于新数据集与原数据集差别太大导致表现很差,一方面可以考虑从头训练,另一方面也可以考虑解锁比较多层的训练。
2、不同数据集下使用微调
数据集1:数据量少,但数据相似度非常高
在这种情况下,我们所做的只是修改最后几层或最终的softmax图层的输出类别,方法一
数据集2:数据量少,数据相似度低
在这种情况下,我们可以冻结预训练模型的初始层(比如k层),并再次训练剩余的(n-k)层。由于新数据集的相似度较低,因此根据新数据集对较高层进行重新训练具有重要意义。方法三
数据集3:数据量大,数据相似度低
在这种情况下,由于我们有一个大的数据集,我们的神经网络训练将会很有效。但是,由于我们的数据与用于训练我们的预训练模型的数据相比有很大不同。使用预训练模型进行的预测不会有效。因此,最好根据你的数据从头开始训练神经网络(Training from scatch)。
数据集4:数据量大,相似度高
这是理想情况。在这种情况下,预训练模型应该是最有效的。使用模型的最好方法是保留模型的体系结构和模型的初始权重。然后,我们可以使用在预先训练的模型中的权重来重新训练该模型。
3.微调的注意事项
1)通常的做法是截断预先训练好的网络的最后一层(softmax层),并用与我们自己的问题相关的新的softmax层替换它。
2)使用较小的学习率来训练网络。
3)如果数据集数量过少,我们进来只训练最后一层,如果数据集数量中等,冻结预训练网络的前几层的权重也是一种常见做法。
注:卷积神经网络的核心是:
(1)浅层卷积层提取基础特征,比如边缘,轮廓等基础特征。
(2)深层卷积层提取抽象特征,比如整个脸型。
(3)全连接层根据特征组合进行评分分类。
4、实验操作具体步骤
1、下载预训练模型
2、预处理:按照预训练模型原本的预处理方式对数据进行预处理,使用预训练模型一定要确保让待训练的数据尽可能向原数据集靠拢,这样才能最大程度发挥模型的识图本领。
3、基模型和定制模型:构建和预训练里面完全相同的模型。
4、查看固定和恢复节点名
5、训练过程设置恢复,固定张量的列表
三、代码详情
基模型和定制模型
import slim.nets.resnet_v1 as resnet_v1 # 定义模型,因为给出的只有参数,并没有模型,这里需要指定模型的具体结构
with slim.arg_scope(resnet_v1.resnet_arg_scope()):
# logits就是最后预测值,images就是输入数据,指定num_classes=None是为了使resnet模型最后的输出层禁用
logits, end_points = resnet_v1.resnet_v1_50(inputs=input_images, num_classes=None) # 自定义的输出层
with tf.variable_scope("Logits"):
# 将原始模型的输出数据去掉维度为2和3的维度,最后只剩维度1的batch数和维度4的300*300*3
# 也就是将原来的二三四维度全部压缩到第四维度
net = tf.squeeze(logits, axis=[1, 2])
# 加入一层dropout层
net = slim.dropout(net, keep_prob=0.5, scope='dropout_scope')
# 加入一层全连接层,指定最后输出大小
logits = slim.fully_connected(net, num_outputs=labels_nums, scope='fc')
查看固定和恢复节点名
look_checkpoint.py
import os
from tensorflow.python import pywrap_tensorflow model_dir = os.getcwd() # 获取当前文件工作路径
print(model_dir)#输出当前工作路径
checkpoint_path = r'G:\1-modelused\Siamese_Densenet_Single_Net\output\640model\model3/model_epoch_20.ckpt'#model_dir + "\\ckpt_dir\\model-ckpt-100" print(checkpoint_path)#输出读取的文件路径
# 从checkpoint文件中读取参数
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# 输出变量名称及变量值
for key in var_to_shape_map:
# if key.startswith('DenseNet_121/AuxLogits'): # print(1)
# print(key)
print("tensor_name: ", key)
训练过程设置恢复,固定张量的列表
CKPT_FILE = r'.\pretrain\resnet_v1_50.ckpt'
#不需要从谷歌训练好的模型中加载的参数。这里就是最后的全连接层,因为在新的问题中要重新训练这一层中的参数。
#这里给出的是参数的前缀
CHECKPOINT_EXCLUDE_SCOPES = 'Logits'
## 指定最后的全连接层为可训练的参数,需要训练的网络层参数名称,在fine-tuning的过程中就是最后的全连接层
TRAINABLE_SCOPES = 'Logits' #获取所有需要从谷歌训练好的模型中加载的参数
def get_tuned_variables():
exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(',')]
variables_to_restore = []
#枚举inception-v3模型中所有的参数,然后判断是否需要从加载列表中移除
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
return variables_to_restore #获取所有需要训练的变量列表。
def get_trainable_variables():
scopes = [scope.strip() for scope in TRAINABLE_SCOPES.split(',')]
variables_to_train = []
#枚举所有需要训练的参数前缀,并通过这些前缀找到所有的参数。
for scope in scopes:
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope)
variables_to_train.extend(variables)
return variables_to_train #定义加载Google训练好的Inception-v3模型的Saver
load_fn = slim.assign_from_checkpoint_fn(
CKPT_FILE,
get_tuned_variables(),
ignore_missing_vars=True
) saver = tf.train.Saver(max_to_keep=100)
max_acc = 0.0
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state('models/resnet_v1/')
if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
saver.restore(sess, ckpt.model_checkpoint_path)
else:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
# 加载谷歌已经训练好的模型
print('Loading tuned variables from %s' % CKPT_FILE)
load_fn(sess)
迁移学习、fine-tune和局部参数恢复的更多相关文章
- 1 如何使用pb文件保存和恢复模型进行迁移学习(学习Tensorflow 实战google深度学习框架)
学习过程是Tensorflow 实战google深度学习框架一书的第六章的迁移学习环节. 具体见我提出的问题:https://www.tensorflowers.cn/t/5314 参考https:/ ...
- [机器学习]Fine Tune
Fine Tune顾名思义,就是微调.在机器学习中,一般用在迁移学习中,通过控制一些layer调节一些layer来达到迁移学习的目的.这样可以利用已有的参数,稍微变化一些,以适应新的学习任务.所以说, ...
- 基于TensorFlow Object Detection API进行迁移学习训练自己的人脸检测模型(二)
前言 已完成数据预处理工作,具体参照: 基于TensorFlow Object Detection API进行迁移学习训练自己的人脸检测模型(一) 设置配置文件 新建目录face_faster_rcn ...
- caffe简易上手指南(三)—— 使用模型进行fine tune
之前的教程我们说了如何使用caffe训练自己的模型,下面我们来说一下如何fine tune. 所谓fine tune就是用别人训练好的模型,加上我们自己的数据,来训练新的模型.fine tune相当于 ...
- [DeeplearningAI笔记]ML strategy_2_3迁移学习/多任务学习
机器学习策略-多任务学习 Learninig from multiple tasks 觉得有用的话,欢迎一起讨论相互学习~Follow Me 2.7 迁移学习 Transfer Learninig 神 ...
- 用tensorlayer导入Slim模型迁移学习
上一篇博客[用tensorflow迁移学习猫狗分类]笔者讲到用tensorlayer的[VGG16模型]迁移学习图像分类,那麽问题来了,tensorlayer没提供的模型怎么办呢?别担心,tensor ...
- TensorFlow从1到2(九)迁移学习
迁移学习基本概念 迁移学习是这两年比较火的一个话题,主要原因是在当前的机器学习中,样本数据的获取是成本最高的一块.而迁移学习可以有效的把原有的学习经验(对于模型就是模型本身及其训练好的权重值)带入到新 ...
- TensorFlow迁移学习的识别花试验
最近学习了TensorFlow,发现一个模型叫vgg16,然后搭建环境跑了一下,觉得十分神奇,而且准确率十分的高.又上了一节选修课,关于人工智能,老师让做一个关于人工智能的试验,于是觉得vgg16很不 ...
- 第二十四节,TensorFlow下slim库函数的使用以及使用VGG网络进行预训练、迁移学习(附代码)
在介绍这一节之前,需要你对slim模型库有一些基本了解,具体可以参考第二十二节,TensorFlow中的图片分类模型库slim的使用.数据集处理,这一节我们会详细介绍slim模型库下面的一些函数的使用 ...
随机推荐
- 【转载】 使用宝塔Linux面板功能查看服务器内存使用情况
运维过阿里云服务器或者腾讯云服务器的运维人员都知道,针对占用内存比较高的应用或者服务等,我们需要时刻关注服务器的内存使用率,是否存在内存瓶颈等情况的出现.阿里云和腾讯云官方后台界面的监控数据页面也有相 ...
- JavaWeb第三天--JavaScript
JavaScript 1. JavaScript概述 1.1 JavaScript是什么?有什么作用? HTML:就是用来写网页的.(人的身体) CSS:就是用来美化页面的.(人的衣服) JavaSc ...
- 地产propretie单词propretie财产
中文名:房产财产地产 外文名:property.propretie 释义:财产.所有物等 用法:作名词. 词汇搭配动词+-等 目录 1 英文释义 2 释义例句 3 词汇搭配 4 衍生 英文释义 1. ...
- APS系统如何让企业实现“多赢”?看高博通信是怎么做的
高博通信(上海)有限公司凭籍在超精密产业中的技术积累, 强大的资金优势以及与一流大学的联合,使得其正成为国内超精密电子制造行业的领导者. 雄厚的技术实力和专业的团队赢得了波音,空客公司等国际航空器制造 ...
- Spark-Bench 测试教程
Spark-Bench 教程 本文原始地址:https://sitoi.cn/posts/19752.html 系统环境配置 操作系统:centos7 环境要求:安装 JDK, Hadoop, Spa ...
- java-spring基于redis单机版(redisTemplate)实现的分布式锁+redis消息队列,可用于秒杀,定时器,高并发,抢购
此教程不涉及整合spring整合redis,可另行查阅资料教程. 代码: RedisLock package com.cashloan.analytics.utils; import org.slf4 ...
- linux卸载gitlab
完全卸载gitlab 1.停止gitlab # gitlab-ctl stop 2.卸载gitlab(看是gitlab-ce版本还是gitlab-ee版本) # rpm -e gitl ...
- Python系统运维常用库
1.psutil是一个跨平台库(http://code.google.com/p/psutil/) 能够实现获取系统运行的进程和系统利用率(内存,CPU,磁盘,网络等),主要用于系统监控,分析和系统资 ...
- 用Java的大整数类BigInteger来实现大整数的一些运算
关于BigInteger的构造函数,一般会用到两个: BigInteger(String val); //将指定字符串转换为十进制表示形式: BigInteger(String val,int rad ...
- MySQL 中的默认数据库介绍
MySQL 中的默认数据库介绍:https://dataedo.com/kb/databases/mysql/default-databases-schemas 默认数据库 官方文档 informat ...