【47】迁移学习(Transfer Learning)
迁移学习(Transfer Learning)
如果你要做一个计算机视觉的应用,相比于从头训练权重,或者说从随机初始化权重开始,如果你下载别人已经训练好网络结构的权重,你通常能够进展的相当快,用这个作为预训练,然后转换到你感兴趣的任务上。
计算机视觉的研究社区非常喜欢把许多数据集上传到网上,如果你听说过,比如ImageNet,或者MS_COCO,或者Pascal类型的数据集,这些都是不同数据集的名字,它们都是由大家上传到网络的,并且有大量的计算机视觉研究者已经用这些数据集训练过他们的算法了。
有时候这些训练过程需要花费好几周,并且需要很多的GPU,其它人已经做过了,并且经历了非常痛苦的寻最优过程,这就意味着你可以下载花费了别人好几周甚至几个月而做出来的开源的权重参数,把它当作一个很好的初始化用在你自己的神经网络上。用迁移学习把公共的数据集的知识迁移到你自己的问题上,让我们看一下怎么做。
来个栗子

举个例子,假如说你要建立一个猫咪检测器,用来检测你自己的宠物猫。
比如网络上的Tigger,是一个常见的猫的名字,Misty也是比较常见的猫名字。
假如你的两只猫叫Tigger和Misty,还有一种情况是,两者都不是。所以你现在有一个三分类问题,图片里是Tigger还是Misty,或者都不是,我们忽略两只猫同时出现在一张图片里的情况。现在你可能没有Tigger或者Misty的大量的图片,所以你的训练集会很小,你该怎么办呢?

我建议你从网上下载一些神经网络开源的实现,不仅把代码下载下来,也把权重下载下来。
有许多训练好的网络,你都可以下载。举个例子,ImageNet数据集,它有1000个不同的类别,因此这个网络会有一个Softmax单元,它可以输出1000个可能类别之一。

你可以去掉上图中的这个Softmax层,创建你自己的Softmax单元,用来输出Tigger、Misty和neither三个类别。
就网络而言,我建议你把所有的层看作是冻结的,你冻结网络中所有层的参数,你只需要训练和你的Softmax层有关的参数。这个Softmax层有三种可能的输出,Tigger、Misty或者都不是。
通过使用其他人预训练的权重,你很可能得到很好的性能,即使只有一个小的数据集。幸运的是,大多数深度学习框架都支持这种操作,事实上,取决于用的框架,它也许会有trainableParameter=0这样的参数,对于这些前面的层,你可能会设置这个参数。
为了不训练这些权重,有时也会有freeze=1这样的参数。不同的深度学习编程框架有不同的方式,允许你指定是否训练特定层的权重。在这个例子中,你只需要训练softmax层的权重,把前面这些层的权重都冻结。

另一个技巧,也许对一些情况有用,由于前面的层都冻结了,相当于一个固定的函数,不需要改变。因为你不需要改变它,也不训练它,取输入图像X,然后把它映射到这层(softmax的前一层)的激活函数。
所以这个能加速训练的技巧就是,如果我们先计算这一层(紫色箭头标记),计算特征或者激活值,然后把它们存到硬盘里。你所做的就是用这个固定的函数,在这个神经网络的前半部分(softmax层之前的所有层视为一个固定映射),取任意输入图像X,然后计算它的某个特征向量,这样你训练的就是一个很浅的softmax模型,用这个特征向量来做预测。
对你的计算有用的一步就是对你的训练集中所有样本的这一层的激活值进行预计算,然后存储到硬盘里,然后在此之上训练softmax分类器。所以,存储到硬盘或者说预计算方法的优点就是,你不需要每次遍历训练集再重新计算这个激活值了。

因此如果你的任务只有一个很小的数据集,你可以这样做。
要有一个更大的训练集怎么办呢?
根据经验,如果你有一个更大的标定的数据集,也许你有大量的Tigger和Misty的照片,还有两者都不是的,这种情况,你应该冻结更少的层,比如只把上图中前面括起来的这些层冻结,然后训练后面的层。如果你的输出层的类别不同,那么你需要构建自己的输出单元,Tigger、Misty或者两者都不是三个类别。有很多方式可以实现,你可以取后面几层的权重,用作初始化,然后从这里开始梯度下降。

或者你可以直接去掉这几层,换成你自己的隐藏单元和你自己的softmax输出层,这些方法值得一试。但是有一个规律,如果你有越来越多的数据,你需要冻结的层数越少,你能够训练的层数就越多。这个理念就是,如果你有一个更大的数据集,也许有足够多的数据,那么不要单单训练一个softmax单元,而是考虑训练中等大小的网络,包含你最终要用的网络的后面几层。

最后,如果你有大量数据,你应该做的就是用开源的网络和它的权重,把这、所有的权重当作初始化,然后训练整个网络。再次注意,如果这是一个1000节点的softmax,而你只有三个输出,你需要你自己的softmax输出层来输出你要的标签。
如果你有越多的标定的数据,或者越多的Tigger、Misty或者两者都不是的图片,你可以训练越多的层。极端情况下,你可以用下载的权重只作为初始化,用它们来代替随机初始化,接着你可以用梯度下降训练,更新网络所有层的所有权重。
这就是卷积网络训练中的迁移学习,事实上,网上的公开数据集非常庞大,并且你下载的其他人已经训练好几周的权重,已经从数据中学习了很多了,你会发现,对于很多计算机视觉的应用,如果你下载其他人的开源的权重,并用作你问题的初始化,你会做的更好。
在所有不同学科中,在所有深度学习不同的应用中,我认为计算机视觉是一个你经常用到迁移学习的领域,除非你有非常非常大的数据集,你可以从头开始训练所有的东西。总之,迁移学习是非常值得你考虑的,除非你有一个极其大的数据集和非常大的计算量预算来从头训练你的网络。
【47】迁移学习(Transfer Learning)的更多相关文章
- 【转载】 迁移学习(Transfer learning),多任务学习(Multitask learning)和端到端学习(End-to-end deep learning)
--------------------- 作者:bestrivern 来源:CSDN 原文:https://blog.csdn.net/bestrivern/article/details/8700 ...
- 迁移学习-Transfer Learning
迁移学习两种类型: ConvNet as fixed feature extractor:利用在大数据集(如ImageNet)上预训练过的ConvNet(如AlexNet,VGGNet),移除最后几层 ...
- 【深度学习系列】迁移学习Transfer Learning
在前面的文章中,我们通常是拿到一个任务,譬如图像分类.识别等,搜集好数据后就开始直接用模型进行训练,但是现实情况中,由于设备的局限性.时间的紧迫性等导致我们无法从头开始训练,迭代一两百万次来收敛模型, ...
- pytorch例子学习——TRANSFER LEARNING TUTORIAL
参考:https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html 以下是两种主要的迁移学习场景 微调convnet : ...
- 图像识别 | AI在医学上的应用 | 深度学习 | 迁移学习
参考:登上<Cell>封面的AI医疗影像诊断系统:机器之心专访UCSD张康教授 Identifying Medical Diagnoses and Treatable Diseases b ...
- AI小白必读:深度学习、迁移学习、强化学习别再傻傻分不清
摘要:诸多关于人工智能的流行词汇萦绕在我们耳边,比如深度学习 (Deep Learning).强化学习 (Reinforcement Learning).迁移学习 (Transfer Learning ...
- 【迁移学习】2010-A Survey on Transfer Learning
资源:http://www.cse.ust.hk/TL/ 简介: 一个例子: 关于照片的情感分析. 源:比如你之前已经搜集了大量N种类型物品的图片进行了大量的人工标记(label),耗费了巨大的人力物 ...
- [DeeplearningAI笔记]ML strategy_2_3迁移学习/多任务学习
机器学习策略-多任务学习 Learninig from multiple tasks 觉得有用的话,欢迎一起讨论相互学习~Follow Me 2.7 迁移学习 Transfer Learninig 神 ...
- 迁移学习(Transformer),面试看这些就够了!(附代码)
1. 什么是迁移学习 迁移学习(Transformer Learning)是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中.迁移学习是通过从已学习的相 ...
随机推荐
- archlinux+UEFI模式在linux主机下基于KVM-QEMU命令行虚拟机安装笔记
ArchLinux十分精简,并且具有强大的滚动更新.最近在基于ubuntu的宿主机下通过KVM-QEMU虚拟机安装了archlinux,将过程记录下来以供参考. 1.下载启动盘 1.1.下载archl ...
- ubuntu 如何添加alias
公司的nx 上面一般使用gvim 编辑文件.并且为gvim 增加了alias,只要敲 g 就是gvim 的意思,这样编辑一个文件只需要 g xxx.v 就可以了.非常方便. 在自己电脑上安装了ubun ...
- HDU_4570_区间dp
http://acm.hdu.edu.cn/showproblem.php?pid=4570 连题目都看不懂,直接找了题解,copy了过来= =. 一个长度为n的数列,将其分成若干段(每一段的长度要& ...
- Codeforces_540_C
http://codeforces.com/problemset/problem/540/C 简单bfs,注意结束条件. #include<iostream> #include<cs ...
- 【原创】为什么Mongodb索引用B树,而Mysql用B+树?
引言 好久没写文章了,今天回来重操旧业.毕竟现在对后端开发的要求越来越高,大家要做好各种准备. 因此,大家有可能遇到如下问题 为什么Mysql中Innodb的索引结构采取B+树? 回答这个问题时,给自 ...
- HDU 6599 I Love Palindrome String (回文树+hash)
题意 找如下子串的个数: (l,r)是回文串,并且(l,(l+r)/2)也是回文串 思路 本来写了个回文树+dfs+hash,由于用了map所以T了 后来发现既然该子串和该子串的前半部分都是回文串,所 ...
- Netty学习(2):IO模型之NIO初探
NIO 概述 前面说到 BIO 有着创建线程多,阻塞 CPU 等问题,因此为解决 BIO 的问题,NIO 作为同步非阻塞 IO模型,随 JDK1.4 而出生了. 在前面我们反复说过4个概念:同步.异步 ...
- 研发协同平台持续集成之Jenkins实践
导读 研发协同平台有两个核心目标,一是提高研发效率 ,二是提高研发质量,要实现这两个核心目标,实现持续集成是关键之一. 什么是持续集成 在<持续集成>一书中,对持续集成的定义如下:持续集成 ...
- Linux 发行版本简述
在撰写这篇文章前,先向linux创始人 Linus Torvalds 先生致敬,感谢您二十多年前的无私开源! 其次向二十多年来维护更新的开发者们致敬! Lin ...
- [dubbo 源码之 ]1. 服务提供方如何发布服务
服务发布 启动流程 1.ServiceConfig#export 服务提供方在启动部署时,dubbo会调用ServiceConfig#export来激活服务发布流程,如下所示: Java API: ` ...