一箭N雕:多任务深度学习实战
1、多任务学习导引
多任务学习是机器学习中的一个分支,按1997年综述论文Multi-task Learning一文的定义:Multitask Learning (MTL) is an inductive transfer mechanism whose principle goal is to improve generalization performance. MTL improves generalization by leveraging the domain-specific information contained in the training signals of related tasks. It does this by training tasks in parallel while using a shared representation。翻译成中文:多任务学习是一种归纳迁移机制,基本目标是提高泛化性能。多任务学习通过相关任务训练信号中的领域特定信息来提高泛化能力,利用共享表示采用并行训练的方法学习多个任务。
顾名思义,多任务学习是一种同时学习多个任务的机器学习方法,如图1所示,多任务学习同时学习了人类和狗的分类器以及男性和女性的性别分类器。
 进一步的,图2所示为单任务学习和多任务学习的对比。在单任务学习中,每个任务采用单独的数据源,分别学习每个任务单独的模型。而多任务学习中,多个数据源采用共享表示同时学习多个子任务模型。
进一步的,图2所示为单任务学习和多任务学习的对比。在单任务学习中,每个任务采用单独的数据源,分别学习每个任务单独的模型。而多任务学习中,多个数据源采用共享表示同时学习多个子任务模型。
 多任务学习的基本假设是多个任务之间具有相关性,因此能够利用任务之间的相关性互相促进。例如,属性分类中,抹口红和戴耳环有一定的相关性,单独训练的时候是无法利用这些信息,多任务学习则可以利用任务相关性联合提高多个属性分类的精度,详情可参考文章Maryland大学Hand等人的论文Attributes for Improved Attributes: A Multi-Task Network for Attribute Classification。
多任务学习的基本假设是多个任务之间具有相关性,因此能够利用任务之间的相关性互相促进。例如,属性分类中,抹口红和戴耳环有一定的相关性,单独训练的时候是无法利用这些信息,多任务学习则可以利用任务相关性联合提高多个属性分类的精度,详情可参考文章Maryland大学Hand等人的论文Attributes for Improved Attributes: A Multi-Task Network for Attribute Classification。
2、多任务深度学习
近年来,在深度学习技术的推动下计算机视觉领域取得了突飞猛进的进展。本质上说,深度学习是多层的神经网络,对输入进行了层级的非线性表示,来自网络可视化的证据表明,深度网络的层级表示从语义上从底层到高层不断递进。深度网络强大的表示能力,使得多任务深度学习有了施展的空间。图3所示为多任务深度网络结构示意图。Input x表示不同任务的输入数据,绿色部分表示不同任务之间共享的层,紫色表示每个任务特定的层,Task x表示不同任务对应的损失函数层。在多任务深度网络中,低层次语义信息的共享有助于减少计算量,同时共享表示层可以使得几个有共性的任务更好的结合相关性信息,任务特定层则可以单独建模任务特定的信息,实现共享信息和任务特定信息的统一。
 在深度网络中,多任务的语义信息还可以从不同的层次输出,例如GoogLeNet中的两个辅助损失层。另外一个例子比如衣服图像检索系统,颜色这类的信息可以从较浅层的时候就进行输出判断,而衣服的样式风格这类的信息,更接近高层语义,需要从更高的层次进行输出,这里的输出指的是每个任务对应的损失层的前一层。
在深度网络中,多任务的语义信息还可以从不同的层次输出,例如GoogLeNet中的两个辅助损失层。另外一个例子比如衣服图像检索系统,颜色这类的信息可以从较浅层的时候就进行输出判断,而衣服的样式风格这类的信息,更接近高层语义,需要从更高的层次进行输出,这里的输出指的是每个任务对应的损失层的前一层。
3、多任务深度学习应用案例
目前,多任务深度学习已经广泛应用于人脸识别、细粒度车辆分类、面部关键点定位与属性分类等多个领域,以下讲介绍其中的代表性论文。
3.1人脸识别网络 DeepID2
香港中文大学汤晓鸥组发表在NIPS14的论文Deep Learning Face Representation by Joint Identification-Verification,提出了一种联合训练人脸确认损失和人脸分类损失的多任务人脸识别网络DeepID2,网络结构如下图所示:
 DeepID2中共有两个损失函数,分别为人脸分类损失函数,对应于Caffe中的SoftmaxLoss:
DeepID2中共有两个损失函数,分别为人脸分类损失函数,对应于Caffe中的SoftmaxLoss:
 另外一个是人脸确认损失函数,对应于Caffe中的Contrastive Loss:
另外一个是人脸确认损失函数,对应于Caffe中的Contrastive Loss:
 3.2细粒度车辆分类网络
3.2细粒度车辆分类网络
这里介绍一个比较有趣的将SoftmaxLoss和TripletLoss结合在一个网络中进行多任务训练的方法Embedding Label Structures for Fine-Grained Feature Representation,目前文章发表于arXiv。作者将这个网络用于细粒度车辆分类上,提醒注意的是为了计算Tiplet Loss,特征进行了L2范数归一操作,网络结构如下图所示:
 3.3物体检测网络Faster R-CNN
3.3物体检测网络Faster R-CNN
在物体检测网络Faster R-CNN中也有多任务学习的应用。Faster R-CNN的网络结构如下图6所示,包含两个任务,分别为窗口回归和窗口分类,其中RPN模块的卷积层在两个任务之间共享。Faster R-CNN的最新版本支持整体端到端训练,可以同时检测多类物体,是目前最具代表性的目标检测框架,同时也是多任务深度学习的一个典型应用。
 3.4面部关键点定位与属性分类网络TCDCN
3.4面部关键点定位与属性分类网络TCDCN
面部关键点估计和头部姿态以及人脸属性(是否戴眼镜、是否微笑和性别)之间有着紧密的联系,香港中文大学汤晓鸥组发表于ECCV14的工作Facial Landmark Detection by Deep Multi-task Learning利用多任务学习方法联合进行人脸面部关键点定位和属性预测,网络结构如下图7所示。
 4、基于Caffe实现多任务学习的小样例
4、基于Caffe实现多任务学习的小样例
本节在目前广泛使用的深度学习开源框架Caffe的基础上实现多任务深度学习算法所需的多维标签输入。默认的,Caffe中的Data层只支持单维标签,为了支持多维标签,首先修改Caffe中的convert_imageset.cpp以支持多标签:
 这样我们就有了多任务的深度学习的基础部分数据输入。为了向上兼容Caffe框架,本文摒弃了部分开源实现增加Data层标签维度选项并修改Data层代码的做法,直接使用两个Data层将数据读入,即分别读入数据和多维标签,接下来介绍对应的网络结构文件prototxt的修改,注意红色的注释部分。
这样我们就有了多任务的深度学习的基础部分数据输入。为了向上兼容Caffe框架,本文摒弃了部分开源实现增加Data层标签维度选项并修改Data层代码的做法,直接使用两个Data层将数据读入,即分别读入数据和多维标签,接下来介绍对应的网络结构文件prototxt的修改,注意红色的注释部分。

 特别的,slice层对多维的标签进行了切分,为每个任务输出了单独的标签。
特别的,slice层对多维的标签进行了切分,为每个任务输出了单独的标签。
 另外一个值得讨论的是每个任务的权重设置,在本文实践中五个任务设置为等权重loss_weight:0.2。一般的,建议所有任务的权重值相加为1,如果这个数值不设置,可能会导致网络收敛不稳定,这是因为多任务学习中对不同任务的梯度进行累加,导致梯度过大,甚至可能引发参数溢出错误导致网络训练失败。
另外一个值得讨论的是每个任务的权重设置,在本文实践中五个任务设置为等权重loss_weight:0.2。一般的,建议所有任务的权重值相加为1,如果这个数值不设置,可能会导致网络收敛不稳定,这是因为多任务学习中对不同任务的梯度进行累加,导致梯度过大,甚至可能引发参数溢出错误导致网络训练失败。
 本文的完整代码可在作者个人的github主页下载:
本文的完整代码可在作者个人的github主页下载:
CodeSnap/convert_multilabel.cpp at master · HolidayXue/CodeSnap · GitHub
多任务损失函数层的网络结构示意图如下图所示:
 5. 总结
5. 总结
本文回顾了多任务学习的基本概念,并讨论了多任务深度学习的基本思想和应用案例。最后以开源深度学习平台Caffe为例讨论了多任务深度学习的实现,并给出了开源代码。
致谢
本文在投稿之后经历了三轮修改,其中一轮公众号编辑部初审,一轮双盲评审大改和一轮单盲评审小修,两名审稿专家对原文进行了全面仔细的阅读,帮助作者修正了文章的若干理论表述,给出了建设性的提高可读性的修改意见。在此本文作者对全体审稿人表示感谢,并对深度学习大讲堂公众号编辑部耐心细致的审稿服务表示感谢。
作者:薛云峰,(https://github.com/HolidayXue),主要从事视频图像算法的研究,就职于浙江捷尚视觉科技股份有限公司担任深度学习算法研究员。捷尚致力于视频大数据和视频监控智能化,现诚招业内算法和工程技术人才,招聘主页http://www.icarevision.cn/job.php,联系邮箱:hr@icarevision.cn
一箭N雕:多任务深度学习实战的更多相关文章
- 深度学习实战篇-基于RNN的中文分词探索
		深度学习实战篇-基于RNN的中文分词探索 近年来,深度学习在人工智能的多个领域取得了显著成绩.微软使用的152层深度神经网络在ImageNet的比赛上斩获多项第一,同时在图像识别中超过了人类的识别水平 ... 
- 学习Keras:《Keras快速上手基于Python的深度学习实战》PDF代码+mobi
		有一定Python和TensorFlow基础的人看应该很容易,各领域的应用,但比较广泛,不深刻,讲硬件的部分可以作为入门人的参考. <Keras快速上手基于Python的深度学习实战>系统 ... 
- 对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码
		PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉.自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的Allen N ... 
- 『深度应用』NLP机器翻译深度学习实战课程·零(基础概念)
		0.前言 深度学习用的有一年多了,最近开始NLP自然处理方面的研发.刚好趁着这个机会写一系列NLP机器翻译深度学习实战课程. 本系列课程将从原理讲解与数据处理深入到如何动手实践与应用部署,将包括以下内 ... 
- 『深度应用』NLP机器翻译深度学习实战课程·壹(RNN base)
		深度学习用的有一年多了,最近开始NLP自然处理方面的研发.刚好趁着这个机会写一系列NLP机器翻译深度学习实战课程. 本系列课程将从原理讲解与数据处理深入到如何动手实践与应用部署,将包括以下内容:(更新 ... 
- TensorFlow 2.0 深度学习实战 —— 浅谈卷积神经网络 CNN
		前言 上一章为大家介绍过深度学习的基础和多层感知机 MLP 的应用,本章开始将深入讲解卷积神经网络的实用场景.卷积神经网络 CNN(Convolutional Neural Networks,Conv ... 
- 【神经网络与深度学习】深度学习实战——caffe windows 下训练自己的网络模型
		1.相关准备 1.1 手写数字数据集 这篇博客上有.jpg格式的图片下载,附带标签信息,有需要的自行下载,博客附带百度云盘下载地址(手写数字.jpg 格式):http://blog.csdn.net/ ... 
- Tensorflow 2.0 深度学习实战 —— 详细介绍损失函数、优化器、激活函数、多层感知机的实现原理
		前言 AI 人工智能包含了机器学习与深度学习,在前几篇文章曾经介绍过机器学习的基础知识,包括了监督学习和无监督学习,有兴趣的朋友可以阅读< Python 机器学习实战 >.而深度学习开始只 ... 
- TensorFlow深度学习实战---图像识别与卷积神经网络
		全连接层网络结构:神经网络每两层之间的所有结点都是有边相连的. 卷积神经网络:1.输入层 2.卷积层:将神经网络中的每一个小块进行更加深入地分析从而得到抽象程度更高的特征. 3 池化层:可以认为将一张 ... 
随机推荐
- WampServer Mysql配置
			WAMP:Windows下的Apache+Mysql+Perl/PHP/Python,一组常用来搭建动态网站或者服务器的开源软件.可点击此处下载WampServer,然后,按照提示安装WAMP.需要说 ... 
- 依赖注入(IOC)二
			依赖注入(IOC)二 上一章我们讲了构造注入与设值注入,这一篇我们主要讲接口注入与特性注入. 接口注入 接口注入是将抽象类型的入口以方法定义在一个接口中,如果客户类型需要获得这个方法,就需要以实现这个 ... 
- UVA 216 - Getting in Line
			216 - Getting in Line Computer networking requires that the computers in the network be linked. This ... 
- app wap开发mobile隐藏地址栏的js
			function scrolltol (){ setTimeout ( function () { , ) }, ); } window . onload = function () { if ( d ... 
- [置顶] Nosql笔记(一)——关系型数据库回顾
			Nosql笔记(一)——关系型数据库回顾 在平常的商业应用中,我们所使用的大多都是关系型数据库,诸如SQL Server. MY SQL. Oracle等. 关于关系型数据库中的关键技术: 存储引擎 ... 
- 最近修bug的一点感悟
			写在前面话 项目从13年1月份,现场开发,4月中旬,项目开发接近尾声,三个开发,留两个在现场,我被调回公司,5月份现场一同事离职,只有一个同事在开发,结果PM想让这一个同事承担余下的开发和bug工作, ... 
- Linux系统下搭建DNS服务器——DNS原理总结
			2017-01-07 整理 DNS原理 域名到IP地址的解析过程 IP地址到域名的反向域名解析过程 抓包分析DNS报文和具体解析过程 DNS服务器搭建和配置 这个东东也是今年博主参见校招的时候被很多公 ... 
- C语言之逻辑运算符
			一 逻辑运算符: &&:逻辑与,读作并且 表达式左右两边都为真,那么结果才为真 口诀:一假则假 ||:逻辑或,读作或者 表达式左右两边,有一个为真,那么结果就为真 口诀:一真则真 !: ... 
- XAF-BI.Dashboard模块概述 web/win
			Dashboard模块介绍了在ASP.NET XAF 和 WinForms 应用程序中简单的集成 DevExpress Dashboard控件的方法. 其实不仅仅是控件,利用了现有的XAF数据模型,这 ... 
- (转载)python日期函数
			转载于http://www.cnblogs.com/emanlee/p/4399147.html 所有日期.时间的api都在datetime模块内. 1. 日期输出格式化 datetime => ... 
