多任务学习模型之ESMM介绍与实现
简介:本文介绍的是阿里巴巴团队发表在 SIGIR’2018 的论文《Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate》。文章基于 Multi-Task Learning (MTL) 的思路,提出一种名为ESMM的CVR预估模型,有效解决了真实场景中CVR预估面临的数据稀疏以及样本选择偏差这两个关键问题。后续还会陆续介绍MMoE,PLE,DBMTL等多任务学习模型。
多任务学习背景
目前工业中使用的推荐算法已不只局限在单目标(ctr)任务上,还需要关注后续的转换链路,如是否评论、收藏、加购、购买、观看时长等目标。
本文介绍的是阿里巴巴团队发表在 SIGIR’2018 的论文《Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate》。文章基于 Multi-Task Learning (MTL) 的思路,提出一种名为ESMM的CVR预估模型,有效解决了真实场景中CVR预估面临的数据稀疏以及样本选择偏差这两个关键问题。后续还会陆续介绍MMoE,PLE,DBMTL等多任务学习模型。
论文介绍
CVR预估面临两个关键问题:
1. Sample Selection Bias (SSB)
转化是在点击之后才“有可能”发生的动作,传统CVR模型通常以点击数据为训练集,其中点击未转化为负例,点击并转化为正例。但是训练好的模型实际使用时,则是对整个空间的样本进行预估,而非只对点击样本进行预估。即训练数据与实际要预测的数据来自不同分布,这个偏差对模型的泛化能力构成了很大挑战,导致模型上线后,线上业务效果往往一般。
2. Data Sparsity (DS)
CVR预估任务的使用的训练数据(即点击样本)远小于CTR预估训练使用的曝光样本。仅使用数量较小的样本进行训练,会导致深度模型拟合困难。
一些策略可以缓解这两个问题,例如从曝光集中对unclicked样本抽样做负例缓解SSB,对转化样本过采样缓解DS等。但无论哪种方法,都没有从实质上解决上面任一个问题。
由于点击=>转化,本身是两个强相关的连续行为,作者希望在模型结构中显示考虑这种“行为链关系”,从而可以在整个空间上进行训练及预测。这涉及到CTR与CVR两个任务,因此使用多任务学习(MTL)是一个自然的选择,论文的关键亮点正在于“如何搭建”这个MTL。
首先需要重点区分下,CVR预估任务与CTCVR预估任务。
- CVR = 转化数/点击数。是预测“假设item被点击,那么它被转化”的概率。CVR预估任务,与CTR没有绝对的关系。一个item的ctr高,cvr不一定同样会高,如标题党文章的浏览时长往往较低。这也是不能直接使用全部样本训练CVR模型的原因,因为无法确定那些曝光未点击的样本,假设他们被点击了,是否会被转化。如果直接使用0作为它们的label,会很大程度上误导CVR模型的学习。
- CTCVR = 转换数/曝光数。是预测“item被点击,然后被转化”的概率。
其中x,y,z分别表示曝光,点击,转换。注意到,在全部样本空间中,CTR对应的label为click,而CTCVR对应的label为click & conversion,这两个任务是可以使用全部样本的。因此,ESMM通过学习CTR,CTCVR两个任务,再根据上式隐式地学习CVR任务。具体结构如下:
网络结构上有两点值得强调:
- 共享Embedding。 CVR-task和CTR-task使用相同的特征和特征embedding,即两者从Concatenate之后才学习各自独享的参数;
- 隐式学习pCVR。这里pCVR 仅是网络中的一个variable,没有显示的监督信号。
具体地,反映在目标函数中:

代码实现
基于EasyRec推荐算法框架,我们实现了ESMM算法,具体实现可移步至github:EasyRec-ESMM。
EasyRec介绍:EasyRec是阿里云计算平台机器学习PAI团队开源的大规模分布式推荐算法框架,EasyRec 正如其名字一样,简单易用,集成了诸多优秀前沿的推荐系统论文思想,并且有在实际工业落地中取得优良效果的特征工程方法,集成训练、评估、部署,与阿里云产品无缝衔接,可以借助 EasyRec 在短时间内搭建起一套前沿的推荐系统。作为阿里云的拳头产品,现已稳定服务于数百个企业客户。
模型前馈网络:
def build_predict_graph(self):
"""Forward function. Returns:
self._prediction_dict: Prediction result of two tasks.
"""
# 此处从Concatenate后的tensor(all_fea)开始,省略其生成逻辑 cvr_tower_name = self._cvr_tower_cfg.tower_name
dnn_model = dnn.DNN(
self._cvr_tower_cfg.dnn,
self._l2_reg,
name=cvr_tower_name,
is_training=self._is_training)
cvr_tower_output = dnn_model(all_fea)
cvr_tower_output = tf.layers.dense(
inputs=cvr_tower_output,
units=1,
kernel_regularizer=self._l2_reg,
name='%s/dnn_output' % cvr_tower_name) ctr_tower_name = self._ctr_tower_cfg.tower_name
dnn_model = dnn.DNN(
self._ctr_tower_cfg.dnn,
self._l2_reg,
name=ctr_tower_name,
is_training=self._is_training)
ctr_tower_output = dnn_model(all_fea)
ctr_tower_output = tf.layers.dense(
inputs=ctr_tower_output,
units=1,
kernel_regularizer=self._l2_reg,
name='%s/dnn_output' % ctr_tower_name) tower_outputs = {
cvr_tower_name: cvr_tower_output,
ctr_tower_name: ctr_tower_output
}
self._add_to_prediction_dict(tower_outputs)
return self._prediction_dict
loss计算:
注意:计算CVR的指标时需要mask掉曝光数据。
def build_loss_graph(self):
"""Build loss graph. Returns:
self._loss_dict: Weighted loss of ctr and cvr.
"""
cvr_tower_name = self._cvr_tower_cfg.tower_name
ctr_tower_name = self._ctr_tower_cfg.tower_name
cvr_label_name = self._label_name_dict[cvr_tower_name]
ctr_label_name = self._label_name_dict[ctr_tower_name] ctcvr_label = tf.cast(
self._labels[cvr_label_name] * self._labels[ctr_label_name],
tf.float32)
cvr_loss = tf.keras.backend.binary_crossentropy(
ctcvr_label, self._prediction_dict['probs_ctcvr'])
cvr_loss = tf.reduce_sum(cvr_losses, name="ctcvr_loss") # The weight defaults to 1.
self._loss_dict['weighted_cross_entropy_loss_%s' %
cvr_tower_name] = self._cvr_tower_cfg.weight * cvr_loss ctr_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.cast(self._labels[ctr_label_name], tf.float32),
logits=self._prediction_dict['logits_%s' % ctr_tower_name]
), name="ctr_loss") self._loss_dict['weighted_cross_entropy_loss_%s' %
ctr_tower_name] = self._ctr_tower_cfg.weight * ctr_loss
return self._loss_dict
note: 这里loss是 weighted_cross_entropy_loss_ctr + weighted_cross_entropy_loss_cvr, EasyRec框架会自动对self._loss_dict中的内容进行加和。
metric计算:
注意:计算CVR的指标时需要mask掉曝光数据。
def build_metric_graph(self, eval_config):
"""Build metric graph. Args:
eval_config: Evaluation configuration. Returns:
metric_dict: Calculate AUC of ctr, cvr and ctrvr.
"""
metric_dict = {} cvr_tower_name = self._cvr_tower_cfg.tower_name
ctr_tower_name = self._ctr_tower_cfg.tower_name
cvr_label_name = self._label_name_dict[cvr_tower_name]
ctr_label_name = self._label_name_dict[ctr_tower_name]
for metric in self._cvr_tower_cfg.metrics_set:
# CTCVR metric
ctcvr_label_name = cvr_label_name + '_ctcvr'
cvr_dtype = self._labels[cvr_label_name].dtype
self._labels[ctcvr_label_name] = self._labels[cvr_label_name] * tf.cast(
self._labels[ctr_label_name], cvr_dtype)
metric_dict.update(
self._build_metric_impl(
metric,
loss_type=self._cvr_tower_cfg.loss_type,
label_name=ctcvr_label_name,
num_class=self._cvr_tower_cfg.num_class,
suffix='_ctcvr')) # CVR metric
cvr_label_masked_name = cvr_label_name + '_masked'
ctr_mask = self._labels[ctr_label_name] > 0
self._labels[cvr_label_masked_name] = tf.boolean_mask(
self._labels[cvr_label_name], ctr_mask)
pred_prefix = 'probs' if self._cvr_tower_cfg.loss_type == LossType.CLASSIFICATION else 'y'
pred_name = '%s_%s' % (pred_prefix, cvr_tower_name)
self._prediction_dict[pred_name + '_masked'] = tf.boolean_mask(
self._prediction_dict[pred_name], ctr_mask)
metric_dict.update(
self._build_metric_impl(
metric,
loss_type=self._cvr_tower_cfg.loss_type,
label_name=cvr_label_masked_name,
num_class=self._cvr_tower_cfg.num_class,
suffix='_%s_masked' % cvr_tower_name)) for metric in self._ctr_tower_cfg.metrics_set:
# CTR metric
metric_dict.update(
self._build_metric_impl(
metric,
loss_type=self._ctr_tower_cfg.loss_type,
label_name=ctr_label_name,
num_class=self._ctr_tower_cfg.num_class,
suffix='_%s' % ctr_tower_name))
return metric_dict
实验及不足
我们基于开源AliCCP数据,进行了大量实验,实验部分请期待下一篇文章。实验发现,ESMM的跷跷板现象较为明显,CTR与CVR任务的效果较难同时提升。
参考文献
- Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate
- 阿里CVR预估模型之ESMM
- EasyRec-ESMM使用介绍多任务学习模型之ESMM介绍与实现
本文为阿里云原创内容,未经允许不得转载。
多任务学习模型之ESMM介绍与实现的更多相关文章
- 【论文笔记】多任务学习(Multi-Task Learning)
1. 前言 多任务学习(Multi-task learning)是和单任务学习(single-task learning)相对的一种机器学习方法.在机器学习领域,标准的算法理论是一次学习一个任务,也就 ...
- [译]深度神经网络的多任务学习概览(An Overview of Multi-task Learning in Deep Neural Networks)
译自:http://sebastianruder.com/multi-task/ 1. 前言 在机器学习中,我们通常关心优化某一特定指标,不管这个指标是一个标准值,还是企业KPI.为了达到这个目标,我 ...
- 使用深度学习的超分辨率介绍 An Introduction to Super Resolution using Deep Learning
使用深度学习的超分辨率介绍 关于使用深度学习进行超分辨率的各种组件,损失函数和度量的详细讨论. 介绍 超分辨率是从给定的低分辨率(LR)图像恢复高分辨率(HR)图像的过程.由于较小的空间分辨率(即尺寸 ...
- 推荐中的多任务学习-ESMM
本文将介绍阿里发表在 SIGIR'18 的论文ESMM<Entire Space Multi-Task Model: An Effective Approach for Estimating Po ...
- 牛亚男:基于多Domain多任务学习框架和Transformer,搭建快精排模型
导读: 本文主要介绍了快手的精排模型实践,包括快手的推荐系统,以及结合快手业务展开的各种模型实战和探索,全文围绕以下几大方面展开: 快手推荐系统 CTR模型--PPNet 多domain多任务学习框架 ...
- 多任务学习(MTL)在转化率预估上的应用
今天主要和大家聊聊多任务学习在转化率预估上的应用. 多任务学习(Multi-task learning,MTL)是机器学习中的一个重要领域,其目标是利用多个学习任务中所包含的有用信息来帮助每个任务学习 ...
- 推荐中的多任务学习-YouTube视频推荐
本文将介绍Google发表在RecSys'19 的论文<Recommending What Video to Watch Next: A Multitask Ranking System> ...
- 分布式多任务学习论文阅读(四):去偏lasso实现高效通信
1.难点-如何实现高效的通信 我们考虑下列的多任务优化问题: \[ \underset{\textbf{W}}{\min} \sum_{t=1}^{T} [\frac{1}{m_t}\sum_{i=1 ...
- 【NLP】蓦然回首:谈谈学习模型的评估系列文章(一)
统计角度窥视模型概念 作者:白宁超 2016年7月18日17:18:43 摘要:写本文的初衷源于基于HMM模型序列标注的一个实验,实验完成之后,迫切想知道采用的序列标注模型的好坏,有哪些指标可以度量. ...
- Stanford机器学习笔记-6. 学习模型的评估和选择
6. 学习模型的评估与选择 Content 6. 学习模型的评估与选择 6.1 如何调试学习算法 6.2 评估假设函数(Evaluating a hypothesis) 6.3 模型选择与训练/验证/ ...
随机推荐
- Java递归实现全排列改进(一)---利用HashSet实现去重
import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; import java.util.Iter ...
- thttpd 2.27(最新)移植指南(官方安装脚本好多坑,我只想说)
PS:要转载请注明出处,本人版权所有. PS: 这个只是基于<我自己>的理解, 如果和你的原则及想法相冲突,请谅解,勿喷. 前置说明 本文作为本人csdn blog的主站的备份.(Bl ...
- 简单使用vim编辑器的用法
vim的使用笔记可以涵盖很多内容,以下是一些基本操作和常见命令的简要总结: 启动Vim 打开或创建文件:vim filename 基本模式切换 正常模式(Normal Mode):启动时默认进入此模式 ...
- 【UE虚幻引擎】干货!UE修改分辨率的3种方法
虚幻引擎作为一款实时3D创作工具,在游戏.建筑.影视动画.虚拟仿真等领域受到全球各行各业创作者广泛欢迎,在UE中获取和设置分辨率也是3D创作开发工作中的常用功能.本文介绍了在虚幻引擎中修改分辨率的3种 ...
- App磁盘沙盒工具实践
目录介绍 01.磁盘沙盒的概述 1.1 项目背景说明 1.2 沙盒作用 1.3 设计目标 02.Android存储概念 2.1 存储划分介绍 2.2 机身内部存储 2.3 机身外部存储 2.4 SD卡 ...
- PHP 数据库表单创建方法记录(储存三方接口数据必用)
最近项目在对接第三方接口数据,这里分享下我用来偷懒的一个PHP方法: /** * 数据库表单创建方法 * @return string * @throws \Exception */ public f ...
- 在 PostgreSQL 中,解决图片二进制数据,由于bytea_output参数问题导致显示不正常的问题。
在 PostgreSQL 中,bytea_output 参数控制在查询结果中 bytea 类型的显示格式.默认情况下,bytea_output 的值为 hex,这意味着在查询结果中,bytea 类型的 ...
- KingbaseES数据库使用kdb_database_link扩展常见问题
KingbaseES数据库使用kdb_database_link扩展常见问题 kdb_database_link主要功能是为了满足@link语法的适配,让用户应用的代码能够适用于更宽泛的产品而无需在移 ...
- java实战字符串+栈5:解码字符
题目: 有形如 (重复字符串)<重复次数n> 的片段,解码后相当于n个重复字符串连续拼接在一起,求展开后的字符串. 求解: public static String zipString( ...
- 鸿蒙HarmonyOS实战-ArkUI组件(Swiper)
一.Swiper 1.概述 Swiper可以实现手机.平板等移动端设备上的图片轮播效果,支持无缝轮播.自动播放.响应式布局等功能.Swiper轮播图具有使用简单.样式可定制.功能丰富.兼容性好等优点, ...