多任务学习模型之ESMM介绍与实现
简介:本文介绍的是阿里巴巴团队发表在 SIGIR’2018 的论文《Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate》。文章基于 Multi-Task Learning (MTL) 的思路,提出一种名为ESMM的CVR预估模型,有效解决了真实场景中CVR预估面临的数据稀疏以及样本选择偏差这两个关键问题。后续还会陆续介绍MMoE,PLE,DBMTL等多任务学习模型。
多任务学习背景
目前工业中使用的推荐算法已不只局限在单目标(ctr)任务上,还需要关注后续的转换链路,如是否评论、收藏、加购、购买、观看时长等目标。
本文介绍的是阿里巴巴团队发表在 SIGIR’2018 的论文《Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate》。文章基于 Multi-Task Learning (MTL) 的思路,提出一种名为ESMM的CVR预估模型,有效解决了真实场景中CVR预估面临的数据稀疏以及样本选择偏差这两个关键问题。后续还会陆续介绍MMoE,PLE,DBMTL等多任务学习模型。
论文介绍
CVR预估面临两个关键问题:
1. Sample Selection Bias (SSB)
转化是在点击之后才“有可能”发生的动作,传统CVR模型通常以点击数据为训练集,其中点击未转化为负例,点击并转化为正例。但是训练好的模型实际使用时,则是对整个空间的样本进行预估,而非只对点击样本进行预估。即训练数据与实际要预测的数据来自不同分布,这个偏差对模型的泛化能力构成了很大挑战,导致模型上线后,线上业务效果往往一般。
2. Data Sparsity (DS)
CVR预估任务的使用的训练数据(即点击样本)远小于CTR预估训练使用的曝光样本。仅使用数量较小的样本进行训练,会导致深度模型拟合困难。
一些策略可以缓解这两个问题,例如从曝光集中对unclicked样本抽样做负例缓解SSB,对转化样本过采样缓解DS等。但无论哪种方法,都没有从实质上解决上面任一个问题。
由于点击=>转化,本身是两个强相关的连续行为,作者希望在模型结构中显示考虑这种“行为链关系”,从而可以在整个空间上进行训练及预测。这涉及到CTR与CVR两个任务,因此使用多任务学习(MTL)是一个自然的选择,论文的关键亮点正在于“如何搭建”这个MTL。
首先需要重点区分下,CVR预估任务与CTCVR预估任务。
- CVR = 转化数/点击数。是预测“假设item被点击,那么它被转化”的概率。CVR预估任务,与CTR没有绝对的关系。一个item的ctr高,cvr不一定同样会高,如标题党文章的浏览时长往往较低。这也是不能直接使用全部样本训练CVR模型的原因,因为无法确定那些曝光未点击的样本,假设他们被点击了,是否会被转化。如果直接使用0作为它们的label,会很大程度上误导CVR模型的学习。
- CTCVR = 转换数/曝光数。是预测“item被点击,然后被转化”的概率。
其中x,y,z分别表示曝光,点击,转换。注意到,在全部样本空间中,CTR对应的label为click,而CTCVR对应的label为click & conversion,这两个任务是可以使用全部样本的。因此,ESMM通过学习CTR,CTCVR两个任务,再根据上式隐式地学习CVR任务。具体结构如下:
网络结构上有两点值得强调:
- 共享Embedding。 CVR-task和CTR-task使用相同的特征和特征embedding,即两者从Concatenate之后才学习各自独享的参数;
- 隐式学习pCVR。这里pCVR 仅是网络中的一个variable,没有显示的监督信号。
具体地,反映在目标函数中:
代码实现
基于EasyRec推荐算法框架,我们实现了ESMM算法,具体实现可移步至github:EasyRec-ESMM。
EasyRec介绍:EasyRec是阿里云计算平台机器学习PAI团队开源的大规模分布式推荐算法框架,EasyRec 正如其名字一样,简单易用,集成了诸多优秀前沿的推荐系统论文思想,并且有在实际工业落地中取得优良效果的特征工程方法,集成训练、评估、部署,与阿里云产品无缝衔接,可以借助 EasyRec 在短时间内搭建起一套前沿的推荐系统。作为阿里云的拳头产品,现已稳定服务于数百个企业客户。
模型前馈网络:
def build_predict_graph(self):
"""Forward function. Returns:
self._prediction_dict: Prediction result of two tasks.
"""
# 此处从Concatenate后的tensor(all_fea)开始,省略其生成逻辑 cvr_tower_name = self._cvr_tower_cfg.tower_name
dnn_model = dnn.DNN(
self._cvr_tower_cfg.dnn,
self._l2_reg,
name=cvr_tower_name,
is_training=self._is_training)
cvr_tower_output = dnn_model(all_fea)
cvr_tower_output = tf.layers.dense(
inputs=cvr_tower_output,
units=1,
kernel_regularizer=self._l2_reg,
name='%s/dnn_output' % cvr_tower_name) ctr_tower_name = self._ctr_tower_cfg.tower_name
dnn_model = dnn.DNN(
self._ctr_tower_cfg.dnn,
self._l2_reg,
name=ctr_tower_name,
is_training=self._is_training)
ctr_tower_output = dnn_model(all_fea)
ctr_tower_output = tf.layers.dense(
inputs=ctr_tower_output,
units=1,
kernel_regularizer=self._l2_reg,
name='%s/dnn_output' % ctr_tower_name) tower_outputs = {
cvr_tower_name: cvr_tower_output,
ctr_tower_name: ctr_tower_output
}
self._add_to_prediction_dict(tower_outputs)
return self._prediction_dict
loss计算:
注意:计算CVR的指标时需要mask掉曝光数据。
def build_loss_graph(self):
"""Build loss graph. Returns:
self._loss_dict: Weighted loss of ctr and cvr.
"""
cvr_tower_name = self._cvr_tower_cfg.tower_name
ctr_tower_name = self._ctr_tower_cfg.tower_name
cvr_label_name = self._label_name_dict[cvr_tower_name]
ctr_label_name = self._label_name_dict[ctr_tower_name] ctcvr_label = tf.cast(
self._labels[cvr_label_name] * self._labels[ctr_label_name],
tf.float32)
cvr_loss = tf.keras.backend.binary_crossentropy(
ctcvr_label, self._prediction_dict['probs_ctcvr'])
cvr_loss = tf.reduce_sum(cvr_losses, name="ctcvr_loss") # The weight defaults to 1.
self._loss_dict['weighted_cross_entropy_loss_%s' %
cvr_tower_name] = self._cvr_tower_cfg.weight * cvr_loss ctr_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.cast(self._labels[ctr_label_name], tf.float32),
logits=self._prediction_dict['logits_%s' % ctr_tower_name]
), name="ctr_loss") self._loss_dict['weighted_cross_entropy_loss_%s' %
ctr_tower_name] = self._ctr_tower_cfg.weight * ctr_loss
return self._loss_dict
note: 这里loss是 weighted_cross_entropy_loss_ctr + weighted_cross_entropy_loss_cvr, EasyRec框架会自动对self._loss_dict中的内容进行加和。
metric计算:
注意:计算CVR的指标时需要mask掉曝光数据。
def build_metric_graph(self, eval_config):
"""Build metric graph. Args:
eval_config: Evaluation configuration. Returns:
metric_dict: Calculate AUC of ctr, cvr and ctrvr.
"""
metric_dict = {} cvr_tower_name = self._cvr_tower_cfg.tower_name
ctr_tower_name = self._ctr_tower_cfg.tower_name
cvr_label_name = self._label_name_dict[cvr_tower_name]
ctr_label_name = self._label_name_dict[ctr_tower_name]
for metric in self._cvr_tower_cfg.metrics_set:
# CTCVR metric
ctcvr_label_name = cvr_label_name + '_ctcvr'
cvr_dtype = self._labels[cvr_label_name].dtype
self._labels[ctcvr_label_name] = self._labels[cvr_label_name] * tf.cast(
self._labels[ctr_label_name], cvr_dtype)
metric_dict.update(
self._build_metric_impl(
metric,
loss_type=self._cvr_tower_cfg.loss_type,
label_name=ctcvr_label_name,
num_class=self._cvr_tower_cfg.num_class,
suffix='_ctcvr')) # CVR metric
cvr_label_masked_name = cvr_label_name + '_masked'
ctr_mask = self._labels[ctr_label_name] > 0
self._labels[cvr_label_masked_name] = tf.boolean_mask(
self._labels[cvr_label_name], ctr_mask)
pred_prefix = 'probs' if self._cvr_tower_cfg.loss_type == LossType.CLASSIFICATION else 'y'
pred_name = '%s_%s' % (pred_prefix, cvr_tower_name)
self._prediction_dict[pred_name + '_masked'] = tf.boolean_mask(
self._prediction_dict[pred_name], ctr_mask)
metric_dict.update(
self._build_metric_impl(
metric,
loss_type=self._cvr_tower_cfg.loss_type,
label_name=cvr_label_masked_name,
num_class=self._cvr_tower_cfg.num_class,
suffix='_%s_masked' % cvr_tower_name)) for metric in self._ctr_tower_cfg.metrics_set:
# CTR metric
metric_dict.update(
self._build_metric_impl(
metric,
loss_type=self._ctr_tower_cfg.loss_type,
label_name=ctr_label_name,
num_class=self._ctr_tower_cfg.num_class,
suffix='_%s' % ctr_tower_name))
return metric_dict
实验及不足
我们基于开源AliCCP数据,进行了大量实验,实验部分请期待下一篇文章。实验发现,ESMM的跷跷板现象较为明显,CTR与CVR任务的效果较难同时提升。
参考文献
- Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate
- 阿里CVR预估模型之ESMM
- EasyRec-ESMM使用介绍多任务学习模型之ESMM介绍与实现
本文为阿里云原创内容,未经允许不得转载。
多任务学习模型之ESMM介绍与实现的更多相关文章
- 【论文笔记】多任务学习(Multi-Task Learning)
1. 前言 多任务学习(Multi-task learning)是和单任务学习(single-task learning)相对的一种机器学习方法.在机器学习领域,标准的算法理论是一次学习一个任务,也就 ...
- [译]深度神经网络的多任务学习概览(An Overview of Multi-task Learning in Deep Neural Networks)
译自:http://sebastianruder.com/multi-task/ 1. 前言 在机器学习中,我们通常关心优化某一特定指标,不管这个指标是一个标准值,还是企业KPI.为了达到这个目标,我 ...
- 使用深度学习的超分辨率介绍 An Introduction to Super Resolution using Deep Learning
使用深度学习的超分辨率介绍 关于使用深度学习进行超分辨率的各种组件,损失函数和度量的详细讨论. 介绍 超分辨率是从给定的低分辨率(LR)图像恢复高分辨率(HR)图像的过程.由于较小的空间分辨率(即尺寸 ...
- 推荐中的多任务学习-ESMM
本文将介绍阿里发表在 SIGIR'18 的论文ESMM<Entire Space Multi-Task Model: An Effective Approach for Estimating Po ...
- 牛亚男:基于多Domain多任务学习框架和Transformer,搭建快精排模型
导读: 本文主要介绍了快手的精排模型实践,包括快手的推荐系统,以及结合快手业务展开的各种模型实战和探索,全文围绕以下几大方面展开: 快手推荐系统 CTR模型--PPNet 多domain多任务学习框架 ...
- 多任务学习(MTL)在转化率预估上的应用
今天主要和大家聊聊多任务学习在转化率预估上的应用. 多任务学习(Multi-task learning,MTL)是机器学习中的一个重要领域,其目标是利用多个学习任务中所包含的有用信息来帮助每个任务学习 ...
- 推荐中的多任务学习-YouTube视频推荐
本文将介绍Google发表在RecSys'19 的论文<Recommending What Video to Watch Next: A Multitask Ranking System> ...
- 分布式多任务学习论文阅读(四):去偏lasso实现高效通信
1.难点-如何实现高效的通信 我们考虑下列的多任务优化问题: \[ \underset{\textbf{W}}{\min} \sum_{t=1}^{T} [\frac{1}{m_t}\sum_{i=1 ...
- 【NLP】蓦然回首:谈谈学习模型的评估系列文章(一)
统计角度窥视模型概念 作者:白宁超 2016年7月18日17:18:43 摘要:写本文的初衷源于基于HMM模型序列标注的一个实验,实验完成之后,迫切想知道采用的序列标注模型的好坏,有哪些指标可以度量. ...
- Stanford机器学习笔记-6. 学习模型的评估和选择
6. 学习模型的评估与选择 Content 6. 学习模型的评估与选择 6.1 如何调试学习算法 6.2 评估假设函数(Evaluating a hypothesis) 6.3 模型选择与训练/验证/ ...
随机推荐
- Miracast技术详解(一):Wi-Fi Display
目录 Miracast概述 Miracast Wi-Fi Direct Wi-Fi Display Sink & Source Android上Wi-Fi Direct的实现 Wi-Fi P2 ...
- 最简洁明了的Linux常用命令
1.ls 命令 查看当前目录下可见的文件.文件夹及其相关权限 常用参数:-l 列表式查看 -al 查看所有,包括隐藏的文件.文件夹 [root@qinshengfei bin]# ls --color ...
- 低代码平台前端的设计与实现(三)设计态画布DesignCanvas的设计与实现
上一篇文章,我们分析并设计了关于构建引擎BuildEngine的切面设计.本文我们将基于BuildEngine所提供的切面处理能力,在CustomCreateElementHandle中通过一些逻辑, ...
- 主nginx和子nginx-------域名-端口-解答
主nginx和子nginx-------域名-端口-解答 想象一下Nginx是一个接待员,每个端口就像接待员的一个电话线,而server_name就像是客户拨打的不同号码. 当你在Nginx配置文件里 ...
- 微信小程序获取手机号流程
小程序中获取手机号前提 小程序需企业认证,才可以获取用户的手机号,个人开发者是不能获取的 哔哔下 官方文档给出需先登录才可获取手机号 传送门 思路为:login登录获取code-->code传给 ...
- 服创杯 【A15】智能信号灯-交通流疏导控制系统【融创软通】数据流图
- 官宣!禅道与极狐(GitLab)达成深度合作,携手推进开源开放DevOps生态发展
近日,禅道与著名编程开源开发平台极狐(GitLab)公司签署战略合作,双方将重点探索适用于中国用户DevOps全生命周期解决方案,并将在开源培训和教育.云服务解决方案等多个领域深度合作,共同助力国内D ...
- 【FAQ】HarmonyOS SDK 闭源开放能力 —Push Kit
1.问题描述 升级到4.0.0.59版本后,通过pushService.getToken获取华为的token时报如下错误:Illegal application identity. 解决方案 Mate ...
- 动态规划(四)——区间dp
区间dp: 就是对于区间的一种动态规划,对于某个区间,它的合并方式可能有很多种,我们需要去枚举所有的方式,通常是去枚举区间的分割点,找到最优的方式(一般是找最少消耗). 通常都是先枚举区间长度,区间长 ...
- 什么是ip协议一
前言 两节结束,为网络底层系列做铺垫. 首先来看一张图: IOS有七层,但是我们可以简化层4层,ip属于传输层,可以说是非常重要,下面简单的做一个介绍. 正文 ip的介绍: 1.ip是tcp/ip 协 ...