迁移学习(Transfer Learning)

如果你要做一个计算机视觉的应用,相比于从头训练权重,或者说从随机初始化权重开始,如果你下载别人已经训练好网络结构的权重,你通常能够进展的相当快,用这个作为预训练,然后转换到你感兴趣的任务上。

计算机视觉的研究社区非常喜欢把许多数据集上传到网上,如果你听说过,比如ImageNet,或者MS_COCO,或者Pascal类型的数据集,这些都是不同数据集的名字,它们都是由大家上传到网络的,并且有大量的计算机视觉研究者已经用这些数据集训练过他们的算法了。

有时候这些训练过程需要花费好几周,并且需要很多的GPU,其它人已经做过了,并且经历了非常痛苦的寻最优过程,这就意味着你可以下载花费了别人好几周甚至几个月而做出来的开源的权重参数,把它当作一个很好的初始化用在你自己的神经网络上。用迁移学习把公共的数据集的知识迁移到你自己的问题上,让我们看一下怎么做。

来个栗子

举个例子,假如说你要建立一个猫咪检测器,用来检测你自己的宠物猫。

比如网络上的Tigger,是一个常见的猫的名字,Misty也是比较常见的猫名字。

假如你的两只猫叫Tigger和Misty,还有一种情况是,两者都不是。所以你现在有一个三分类问题,图片里是Tigger还是Misty,或者都不是,我们忽略两只猫同时出现在一张图片里的情况。现在你可能没有Tigger或者Misty的大量的图片,所以你的训练集会很小,你该怎么办呢?

 

我建议你从网上下载一些神经网络开源的实现,不仅把代码下载下来,也把权重下载下来

有许多训练好的网络,你都可以下载。举个例子,ImageNet数据集,它有1000个不同的类别,因此这个网络会有一个Softmax单元,它可以输出1000个可能类别之一。

 
 

你可以去掉上图中的这个Softmax层,创建你自己的Softmax单元,用来输出Tigger、Misty和neither三个类别

就网络而言,我建议你把所有的层看作是冻结的,你冻结网络中所有层的参数,你只需要训练和你的Softmax层有关的参数。这个Softmax层有三种可能的输出,Tigger、Misty或者都不是。

通过使用其他人预训练的权重,你很可能得到很好的性能,即使只有一个小的数据集。幸运的是,大多数深度学习框架都支持这种操作,事实上,取决于用的框架,它也许会有trainableParameter=0这样的参数,对于这些前面的层,你可能会设置这个参数。

为了不训练这些权重,有时也会有freeze=1这样的参数。不同的深度学习编程框架有不同的方式,允许你指定是否训练特定层的权重。在这个例子中,你只需要训练softmax层的权重,把前面这些层的权重都冻结。

 

另一个技巧,也许对一些情况有用,由于前面的层都冻结了,相当于一个固定的函数,不需要改变。因为你不需要改变它,也不训练它,取输入图像X,然后把它映射到这层(softmax的前一层)的激活函数

所以这个能加速训练的技巧就是,如果我们先计算这一层(紫色箭头标记),计算特征或者激活值,然后把它们存到硬盘里。你所做的就是用这个固定的函数,在这个神经网络的前半部分(softmax层之前的所有层视为一个固定映射),取任意输入图像X,然后计算它的某个特征向量,这样你训练的就是一个很浅的softmax模型,用这个特征向量来做预测。

对你的计算有用的一步就是对你的训练集中所有样本的这一层的激活值进行预计算,然后存储到硬盘里,然后在此之上训练softmax分类器。所以,存储到硬盘或者说预计算方法的优点就是,你不需要每次遍历训练集再重新计算这个激活值了。

因此如果你的任务只有一个很小的数据集,你可以这样做。

要有一个更大的训练集怎么办呢?

根据经验,如果你有一个更大的标定的数据集,也许你有大量的Tigger和Misty的照片,还有两者都不是的,这种情况,你应该冻结更少的层,比如只把上图中前面括起来的这些层冻结,然后训练后面的层。如果你的输出层的类别不同,那么你需要构建自己的输出单元,Tigger、Misty或者两者都不是三个类别。有很多方式可以实现,你可以取后面几层的权重,用作初始化,然后从这里开始梯度下降。

或者你可以直接去掉这几层,换成你自己的隐藏单元和你自己的softmax输出层,这些方法值得一试。但是有一个规律,如果你有越来越多的数据,你需要冻结的层数越少,你能够训练的层数就越多。这个理念就是,如果你有一个更大的数据集,也许有足够多的数据,那么不要单单训练一个softmax单元,而是考虑训练中等大小的网络,包含你最终要用的网络的后面几层。

最后,如果你有大量数据,你应该做的就是用开源的网络和它的权重,把这、所有的权重当作初始化,然后训练整个网络。再次注意,如果这是一个1000节点的softmax,而你只有三个输出,你需要你自己的softmax输出层来输出你要的标签。

如果你有越多的标定的数据,或者越多的Tigger、Misty或者两者都不是的图片,你可以训练越多的层。极端情况下,你可以用下载的权重只作为初始化,用它们来代替随机初始化,接着你可以用梯度下降训练,更新网络所有层的所有权重。

这就是卷积网络训练中的迁移学习,事实上,网上的公开数据集非常庞大,并且你下载的其他人已经训练好几周的权重,已经从数据中学习了很多了,你会发现,对于很多计算机视觉的应用,如果你下载其他人的开源的权重,并用作你问题的初始化,你会做的更好。

在所有不同学科中,在所有深度学习不同的应用中,我认为计算机视觉是一个你经常用到迁移学习的领域,除非你有非常非常大的数据集,你可以从头开始训练所有的东西。总之,迁移学习是非常值得你考虑的,除非你有一个极其大的数据集和非常大的计算量预算来从头训练你的网络。

【47】迁移学习(Transfer Learning)的更多相关文章

  1. 【转载】 迁移学习(Transfer learning),多任务学习(Multitask learning)和端到端学习(End-to-end deep learning)

    --------------------- 作者:bestrivern 来源:CSDN 原文:https://blog.csdn.net/bestrivern/article/details/8700 ...

  2. 迁移学习-Transfer Learning

    迁移学习两种类型: ConvNet as fixed feature extractor:利用在大数据集(如ImageNet)上预训练过的ConvNet(如AlexNet,VGGNet),移除最后几层 ...

  3. 【深度学习系列】迁移学习Transfer Learning

    在前面的文章中,我们通常是拿到一个任务,譬如图像分类.识别等,搜集好数据后就开始直接用模型进行训练,但是现实情况中,由于设备的局限性.时间的紧迫性等导致我们无法从头开始训练,迭代一两百万次来收敛模型, ...

  4. pytorch例子学习——TRANSFER LEARNING TUTORIAL

    参考:https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html 以下是两种主要的迁移学习场景 微调convnet : ...

  5. 图像识别 | AI在医学上的应用 | 深度学习 | 迁移学习

    参考:登上<Cell>封面的AI医疗影像诊断系统:机器之心专访UCSD张康教授 Identifying Medical Diagnoses and Treatable Diseases b ...

  6. AI小白必读:深度学习、迁移学习、强化学习别再傻傻分不清

    摘要:诸多关于人工智能的流行词汇萦绕在我们耳边,比如深度学习 (Deep Learning).强化学习 (Reinforcement Learning).迁移学习 (Transfer Learning ...

  7. 【迁移学习】2010-A Survey on Transfer Learning

    资源:http://www.cse.ust.hk/TL/ 简介: 一个例子: 关于照片的情感分析. 源:比如你之前已经搜集了大量N种类型物品的图片进行了大量的人工标记(label),耗费了巨大的人力物 ...

  8. [DeeplearningAI笔记]ML strategy_2_3迁移学习/多任务学习

    机器学习策略-多任务学习 Learninig from multiple tasks 觉得有用的话,欢迎一起讨论相互学习~Follow Me 2.7 迁移学习 Transfer Learninig 神 ...

  9. 迁移学习(Transformer),面试看这些就够了!(附代码)

    1. 什么是迁移学习 迁移学习(Transformer Learning)是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中.迁移学习是通过从已学习的相 ...

随机推荐

  1. Spring初识、新建工程

    1.spring与三层架构的关系: spring负责管理项目中的所有对象,是一个一站式的框架,容器中的对象决定了spring的功能. 2.spring核心架构 Spring框架主要由六个模块组成,在开 ...

  2. BZOJ 3339 Rmq Problem(离线+线段树+mex函数)

    题意: q次询问,问[l,r]子区间的mex值 思路: 对子区间[l,r],当l固定的时候,[l,r]的mex值对r单调不减 对询问按照l离线,对当前的l,都有维护一个线段树,每个叶节点保存[l,r] ...

  3. Codeforces 1087C Connect Three (思维+模拟)

    题意: 网格图选中三个格,让你选中一些格子把这三个格子连起来,使得选中的格子总数最小.最后输出方案 网格范围为1000 思路: 首先两点间连起来最少需要的格子为他们的曼哈顿距离 然后连接方案一定是曼哈 ...

  4. HDU6395 Sequence(矩阵快速幂+数论分块)

    题意: F(1)=A,F(2)=B,F(n)=C*F(n-2)+D*F(n-1)+P/n 给定ABCDPn,求F(n) mod 1e9+7 思路: P/n在一段n里是不变的,可以数论分块,再在每一段里 ...

  5. Docker容器到底是什么?

    Docker是一个开源的应用容器引擎,是近些年最火的技术之一,Docker公司从Docker项目开源之后发家致富把公司商标改为了Docker,收购了fit项目,整合为了docker-compose,前 ...

  6. 【MySQL 线上 BUG 分析】之 多表同字段异常:Column ‘xxx’ in field list is ambiguous

    一.生产出错! 今天早上11点左右,我在工作休息之余,撸了一下猫.突然,工作群响了,老大在里面说:APP出错了! 妈啊,这太吓人了,因为只是说了出错,但是没说错误的信息.所以我赶紧到APP上看看. 这 ...

  7. 《C/C++实现Console下的加载进度条模拟[美观版]》

    前言   有时候我们会遇到在CMD或DOS控制台上出现的加载进度条,虽然不是如网页和软件写的美观.但确确实实也有着自己的特色.而且,一个好看的加载进度条也能增加用户使用控制台程序的体验!所以,拿来研究 ...

  8. Django (一) 基础

    创建项目 创建app     python manager.py startapp app01 修改.添加url from django.conf.urls import url,include fr ...

  9. H5异步加载多图

    异步加载多图(可能没啥用,加载慢)(图片预加载,提前给浏览器缓存图片) 1. 用一个计数变量记录需要加载的图片个数 2. 用new Image()去加载,加载完给此对象的src赋值要加载的url路径( ...

  10. windows下python3使用pip安装scrapy提示安装失败

    我的环境:     python3.6,     win10,      原因:不能成功安装twisted,因为twisted与高版本的python有兼容问题. 解决:1,先下载twisted二进制文 ...