Wide & Deep的OneFlow网络训练
Wide & Deep的OneFlow网络训练
HugeCTR是英伟达提供的一种高效的GPU框架,专为点击率(CTR)估计训练而设计。
OneFlow对标HugeCTR搭建了Wide & Deep 学习网络(WDL)。OneFlow-WDL网络实现了模型并行与稀疏更新,在8卡12G TitanV的服务器上实现支持超过4亿的词表大小,而且性能没有损失与小词表性能相当。
本文介绍如何使用OneFlow-WDL网络进行训练,以及一些训练结果及分析。
环境和准备
运行OneFlow-WDL需要有安装好OneFlow的python环境,并安装了scikit-learn。
软件要求
- python 3.x(推荐)
- OneFlow 0.x
- scikit-learn
数据准备
准备了一个小的样本数据集,可以下载进行简单测试。
或者参考《使用Spark创建WDL数据集》中的步骤,从CriteoLabs官网下载原始数据集并制作成OneFlow所需要的OFRecord格式的数据集。
OneFlow-WDL脚本
OneFlow-WDL脚本只有一个文件wdl_train_eval.py,请从这里下载。
运行OneFlow-WDL脚本
EMBD_SIZE=1603616
DATA_ROOT=/path/to/wdl/ofrecord
python3 wdl_train_eval.py \
--train_data_dir $DATA_ROOT/train \
--train_data_part_num 256 \
--train_part_name_suffix_length=5 \
--eval_data_dir $DATA_ROOT/val \
--eval_data_part_num 256 \
--max_iter=300000 \
--loss_print_every_n_iter=1000 \
--eval_interval=1000 \
--batch_size=16384 \
--wide_vocab_size=$EMBD_SIZE \
--deep_vocab_size=$EMBD_SIZE \
--gpu_num 1
通常配置好数据集的位置DATA_ROOT后,上面的shell脚本就可以被执行了,如果屏幕上能够输出下面类似的结果,就表示已经正确运行。
1000 time 2020-07-08 00:28:08.066281 loss 0.503295350909233
1000 eval_loss 0.4846755236387253 eval_auc 0.7616240146992771
2000 time 2020-07-08 00:28:11.613961 loss 0.48661992555856703
2000 eval_loss 0.4816856697201729 eval_auc 0.765256583562705
3000 time 2020-07-08 00:28:15.149135 loss 0.48245503094792364
3000 eval_loss 0.47835959643125536 eval_auc 0.7715609382514008
4000 time 2020-07-08 00:28:18.686327 loss 0.47975033831596375
4000 eval_loss 0.47925308644771575 eval_auc 0.7781267916810946
测试结果及说明
在一台有8块12G显存的TitanV的服务器上对OneFlow-WDL进行了一组测试,并使用HugeCTR提供的docker容器做了同样参数的测试。
多GPU性能测试
主要测试目的是在batch size = 16384的情况下,测量不同GPU数量处理每个批次的平均时延(latency)。 测试配置了7个1024神经单元的隐藏层。
结果如下图:
同时记录了,测试时实际最大占用显存的大小,结果如下图:
综合上面结果表明,1卡到8卡,OneFlow-WDL在占用较少的显存的情况下,速度要比HugeCTR快。
batch size=16384每卡,多卡性能测试
主要测试目的是在保证每GPU卡处理16384batch size情况下,使用1至8GPU卡进行训练每个批次的平均时延(latency)。 测试配置了7个1024神经单元的隐藏层。
结果如下图:
同时记录了,测试时实际最大占用显存的大小,结果如下图:
综合上面结果表明,随着卡数的增加,时延增加,OneFlow-WDL在占用较少的显存的情况下,速度要比HugeCTR快;因为每卡保证16384 batch size,OneFlow每卡占用的内存并无显著变化。
单GPU卡不同batch size性能测试
主要测试目的是在一个GPU卡情况下,测量不同batch size每个批次的平均时延(latency)。 测试配置了2个1024神经单元的隐藏层。
结果如下图:
超大词表测试
OneFlow-WDL中配置了两个Embedding Table: - wide_embedding 大小是vocab_size x 1 - deep_embedding 大小是vocab_size x 16
HugeCTR中词表大小(vocab_size)是1603616。从3200000开始测起,一直到支持4亿的词表大小,结果如下图:
上面的图中,蓝色柱子是批次训练的平均时延(latency),红色曲线代表GPU显存的占用。
结论:随着词表大小的增大,内存随之增大,但latency没有明显的变化。
收敛性测试1
选取了batch size=512进行了收敛性能的测试。
下面这张图是,前500步的结果,每一步训练都在验证集中选取20条记录进行验证,图中的曲线分别是loss和AUC:
结论:AUC迅速就增长到超过了0.75。
收敛性测试2
和收敛性测试1同样的情况,这一次是每训练1000步打印训练loss的平均值,然后选取20条验证集数据进行验证,一共训练30万步,结果如下:
结论与分析: 1. 蓝色的train loss曲线有明显向下的台阶,因为整个训练集有36674623条数据,batch_size=512的情况下,大概71630步就过了整个数据集(一个epoch),30万步就把训练数据集用了4次多,蓝色曲线的台阶印证了这些。OneFlow在训练过程中支持数据的打乱,每当数据集被完整的用完一遍之后,数据会被重新打乱,减少过拟合。 2. 橙色的曲线是验证集loss,在前两个epoch的时候基本保持下降的趋势,从第三个epoch开始,loss开始有上升的趋势,表明已经过拟合了。 3. 灰色是验证集的AUC,AUC也是在第二个epoch的时候达到了峰值,超过了0.8,后面几个epoch就开始下降。
Wide & Deep的OneFlow网络训练的更多相关文章
- 【RS】Wide & Deep Learning for Recommender Systems - 广泛和深度学习的推荐系统
[论文标题]Wide & Deep Learning for Recommender Systems (DLRS'16) [论文作者] Heng-Tze Cheng, Levent Koc, ...
- 深度排序模型概述(一)Wide&Deep/xDeepFM
本文记录几个在广告和推荐里面rank阶段常用的模型.广告领域机器学习问题的输入其实很大程度了影响了模型的选择,因为输入一般维度非常高,稀疏,同时包含连续性特征和离散型特征.模型即使到现在DeepFM类 ...
- 深度学习在美团点评推荐平台排序中的应用&& wide&&deep推荐系统模型--学习笔记
写在前面:据说下周就要xxxxxxxx, 吓得本宝宝赶紧找些广告的东西看看 gbdt+lr的模型之前是知道怎么搞的,dnn+lr的模型也是知道的,但是都没有试验过 深度学习在美团点评推荐平台排序中的运 ...
- Wide & Deep Learning Model
Generalized linear models with nonlinear feature transformations (特征工程 + 线性模型) are widely used for l ...
- wide&deep模型演化
推荐系统模型演化 LR-->GBDT+LR FM-->FFM-->GBDT+FM|FFM FTRL-->GBDT+FTRL Wide&DeepModel (Deep l ...
- 巨经典论文!推荐系统经典模型Wide & Deep
今天我们剖析的也是推荐领域的经典论文,叫做Wide & Deep Learning for Recommender Systems.它发表于2016年,作者是Google App Store的 ...
- Deep Learning 学习随记(五)Deep network 深度网络
这一个多周忙别的事去了,忙完了,接着看讲义~ 这章讲的是深度网络(Deep Network).前面讲了自学习网络,通过稀疏自编码和一个logistic回归或者softmax回归连接,显然是3层的.而这 ...
- Caffe-python interface 学习|网络训练、部署、測试
继续python接口的学习.剩下还有solver.deploy文件的生成和模型的測试. 网络训练 solver文件生成 事实上我认为用python生成solver并不如直接写个配置文件,它不像net配 ...
- 推荐系统系列(六):Wide&Deep理论与实践
背景 在CTR预估任务中,线性模型仍占有半壁江山.利用手工构造的交叉组合特征来使线性模型具有"记忆性",使模型记住共现频率较高的特征组合,往往也能达到一个不错的baseline,且 ...
随机推荐
- 硬件篇-03-SLAM移动底盘电气设计
最近因为在忙毕设,专栏已经1个多月没更,对于托更我很抱歉.不过这几周真的没什么时间,Rick&Morty的最新集我到现在都还没看哈哈. 现在毕设已经搞得差不多了,水专栏文章的快乐生 ...
- hdu1671 字典树记录前缀出现次数
题意: 给你一堆电话号,问你这些电话号后面有没有相互冲突的,冲突的条件是当前这个电话号是另一个电话号的前缀,比如有 123456789 123,那么这两个电话号就冲突了,直接输出NO. 思 ...
- poj2175费用流消圈算法
题意: 有n个建筑,每个建筑有ai个人,有m个避难所,每个避难所的容量是bi,ai到bi的费用是|x1-x2|+|y1-y2|+1,然后给你一个n*m的矩阵,表示当前方案,问当前避难方案是否 ...
- 利用 Windows 线程池定制的 4 种方式完成任务(Windows 核心编程)
Windows 线程池 说起底层的线程操作一般都不会陌生,Windows 提供了 CreateThread 函数来创建线程,为了同步线程的操作,Windows 提供了事件内核对象.互斥量内核对象.关键 ...
- CVE-2018-0802:Microsoft office 公式编辑器 font name 字段二次溢出漏洞调试分析
\x01 前言 CVE-2018-0802 是继 CVE-2017-11882 发现的又一个关于 font name 字段的溢出漏洞,又称之为 "第二代噩梦公式",巧合的是两个漏洞 ...
- PWD 好网站
http://angelboy.logdown.com/ https://wizardforcel.gitbooks.io/sploitfun-linux-x86-exp-tut/content/ h ...
- Bugku-文件包含2
文件包含2 目录 文件包含2 题目描述 解题过程 参考 题目描述 没有描述 解题过程 文件包含题目大多都是php环境的, 所以先试试伪协议 发现php://被ban了 继续尝试,发现file://协议 ...
- scrapy爬虫案例--爬取阳关热线问政平台
阳光热线问政平台:http://wz.sun0769.com/political/index/politicsNewest?id=1&page=1 爬取最新问政帖子的编号.投诉标题.投诉内容以 ...
- 【Unity】实验二 游戏场景搭建
实验要求 实验二 游戏场景搭建 实验目的:掌握游戏场景搭建. 实验要求:能够使用Unity的地形引擎创建地形,熟悉场景中的光照与阴影,掌握天空盒和雾化效果等. 实验内容: 地形的绘制:使用高度图绘制: ...
- java随堂笔记
JAVA 1只要是字符串,必然就是对象. 2API文档的基本使用 3如何创建字符串: a直接赋值双引号,也是一个字符串对象. b可以通过new关键字来调用String的构造方法 public Stri ...