1,概述

  剪枝可以分为两种:一种是无序的剪枝,比如将权重中一些值置为0,这种也称为稀疏化,在实际的应用上这种剪枝基本没有意义,因为它只能压缩模型的大小,但很多时候做不到模型推断加速,而在当今的移动设备上更多的关注的是系统的实时相应,也就是模型的推断速度。另一种是结构化的剪枝,比如卷积中对channel的剪枝,这种不仅可以降低模型的大小,还可以提升模型的推断速度。剪枝之前在卷积上应用较多,而随着bert之类的预训练模型的出现,这一类模型通常比较大,且推断速度较慢。例如bert在文本分类的任务上,128的序列长度,其推断速度都只有80ms左右,这还只是单个模型,而一个大的系统,往往是有多个模型组成的。因此bert要想在工业界,尤其是移动端落地,是极度需要模型压缩的。

2,具体方法

  看完这篇论文之后,更多的感觉是这篇论文并没有在剪枝上有太多的贡献,更像是对multi head中head的数量做了一个实验性的工作,探索了在multi head中并不是所有的head都需要,有很多head提取的信息对最终的结果并没有什么影响,是冗余存在的。

  本论文在探讨在test阶段,去掉一部分head是否会影响模型的性能,得到的结论是大多数都不会,而且部分还会提升性能,作者给出了三种实验方法来证明这一点:

  1,每次去掉一层中一个head,测试模型的性能

  2,每次去掉一层中剩余的层,只保存一个head,测试模型的性能

  3,通过梯度来判断每个head的重要性,然后去掉一部分不重要的head,测试模型的性能

  为了实现上述的实验,作者对multi head的计算做了一些修改,修改后的公式如下:

    

  在这里引入了一个系数$\zeta_h$,该值的取值为0或1,它的作用是用来mask不重要的head。在训练时保持为1,到test的时候对部分head mask掉。

  作者在基于transformer的机器翻译模型上和基于bert的NLI任务上做了实验,我们来看看上面三个实验的结果

  Ablating One Head

  去掉一个head,作者给出了实验结果如下:

    

  从上面的图中可以看到大多数head去掉之后的模型分数还基本分布在baseline附近,从作者给的表格数据看会更加的清晰:

    

  上面给出的是机器翻译的表格数据,蓝色的值表示性能增加,红色的值表示性能下降,大多数情况下性能是增加的,只有少部分性能会有所下降,只有极少部分性能会下降的比较多。

  Ablating All Heads but One

  当去掉一层中的其余head只保留一个head时,我们来看下模型的结果,这回作者给出了一个离散图:

    

  同样的,大多数情况下的性能都分布在baseline附近,同样看看表格会更清晰:

    

  从上面来看除了机器翻译中的encoder-decoder之间的attention的最后一层会出现性能明显的下降,其他大多数情况都还好,甚至有的情况下性能反而上升。

  上面两种实验都有一个共同的弊端,就是每次实验只能对一层做head的mask,但实际过程中所有层的head都有可能会被去除,且至于去除哪些还和层与层之间的依赖性有关,因此第三种方法可以来改善这个问题。

  Head Importance Score for Pruning

  在这里作者引入了梯度来衡量head的重要性,首先给出一个公式如下:

    

  上面公式是对mask系数的偏导,我们知道偏导的值的大小可以衡量这个维度上对损失的影响大小,在这里作者对偏导取了个绝对值,避免在求期望的时候正负抵消,因为无论是正值还是负值,只要绝对值比较大,就可以衡量偏导对损失的影响是比较大的,这里的期望是对所有样本X的,因为单个batch是存在误差的,因此对全量样本计算的偏导求均值。

   对上面的公式做一个链式转换,可以得到:

    

  这样我们就可以用这个对head的期望梯度值来衡量其重要性,然后按百分比去除head,得到的结果如下:

    

  上面图中的实验是通过梯度来进行剪枝的,虚线是通过第一种方法中的分数来衡量head的重要性进行剪枝的,可以看到基于梯度的效果还是很明显的,但是剪枝范围也是有限的,超过这个范围之后,性能会急剧下降。

  作者还测了下剪枝后模型的推断速度,个人感觉这个推断速度的减小真的是毫无意义:

    

  如上图所示,只有在batch达到16的时候才有比较明显的速度提升,但是大多数线上运行的时候都是batch为1的。不过也不能就此下定论说减少head的数量是起不到加速效果的,个人感觉作者在这里测推断速度的时候是存在一些问题的:作者是先训练,后剪枝,但剪枝之后没有再训练,这也就意味着这些head仍然存在,只是将不需要的head前面的mask系数置为0而已。为什么做出这样的认定呢?因为在实际的multi head设计中,我们是要保证每个head得到的词向量拼接在一起等于原始的词向量,因为后面要进入到前向层,必须保持维度一致,我猜这里作者可能是将mask掉的head得到的向量置为0,这样这些值在下一层计算self-attention就没有意义了,至于为什么还是有加速,原因不明。以上个人猜测。

  此外单纯得减少head的数量好像对加速意义不大,只有配合减小embedding size才有意义,否则计算复杂度基本一致,因为我们在做multi-attention时映射到不同子空间时,实际上是一个大的矩阵映射,这个大的矩阵的维度取决于embedding size,映射完之后再分割成多个而已。从计算上来看self-attention是耗时的,因为减少embedding size,减小序列长度都可以极大的提速(减小序列长度还会影响到前向层的速度)。

bert剪枝系列——Are Sixteen Heads Really Better than One?的更多相关文章

  1. Bert不完全手册6. Bert在中文领域的尝试 Bert-WWM & MacBert & ChineseBert

    一章我们来聊聊在中文领域都有哪些预训练模型的改良方案.Bert-WWM,MacBert,ChineseBert主要从3个方向在预训练中补充中文文本的信息:词粒度信息,中文笔画信息,拼音信息.与其说是推 ...

  2. Transformer模型---decoder

    一.结构 1.编码器 Transformer模型---encoder - nxf_rabbit75 - 博客园 2.解码器 (1)第一个子层也是一个多头自注意力multi-head self-atte ...

  3. Bert系列(二)——源码解读之模型主体

    本篇文章主要是解读模型主体代码modeling.py.在阅读这篇文章之前希望读者们对bert的相关理论有一定的了解,尤其是transformer的结构原理,网上的资料很多,本文内容对原理部分就不做过多 ...

  4. 就是要你明白机器学习系列--决策树算法之悲观剪枝算法(PEP)

    前言 在机器学习经典算法中,决策树算法的重要性想必大家都是知道的.不管是ID3算法还是比如C4.5算法等等,都面临一个问题,就是通过直接生成的完全决策树对于训练样本来说是“过度拟合”的,说白了是太精确 ...

  5. acdream 小晴天老师系列——晴天的后花园 (暴力+剪枝)

    小晴天老师系列——晴天的后花园 Time Limit: 10000/5000MS (Java/Others)    Memory Limit: 128000/64000KB (Java/Others) ...

  6. Bert系列 源码解读 四 篇章

    Bert系列(一)——demo运行 Bert系列(二)——模型主体源码解读 Bert系列(三)——源码解读之Pre-trainBert系列(四)——源码解读之Fine-tune 转载自: https: ...

  7. bert系列二:《BERT》论文解读

    论文<BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding> 以下陆续介绍ber ...

  8. Bert系列(三)——源码解读之Pre-train

    https://www.jianshu.com/p/22e462f01d8c pre-train是迁移学习的基础,虽然Google已经发布了各种预训练好的模型,而且因为资源消耗巨大,自己再预训练也不现 ...

  9. 广告行业中那些趣事系列6:BERT线上化ALBERT优化原理及项目实践(附github)

    摘要:BERT因为效果好和适用范围广两大优点,所以在NLP领域具有里程碑意义.实际项目中主要使用BERT来做文本分类任务,其实就是给文本打标签.因为原生态BERT预训练模型动辄几百兆甚至上千兆的大小, ...

随机推荐

  1. Jmeter-Question之“HTTPS请求”

    前面在Jmeter-Question中有提到若干问题,有时间呢,我也会进行继续编写随笔,梳理自己的知识,本篇呢,便来记Jmeter发送https请求的过程 内容大致与http://blog.csdn. ...

  2. (转)idea创建Maven项目时Maven插件内看不到mybatis-generator

    转载地址:https://blog.csdn.net/yytwiligt/article/details/81010360 创建Maven项目时插件配置添加了mybatis-generator但是右侧 ...

  3. 剑指Offer-15.反转链表(C++/Java)

    题目: 输入一个链表,反转链表后,输出新链表的表头. 分析: 可以利用栈将链表元素依次压入栈中,再从栈中弹出元素重新建立链表,返回头节点. 也可以在原有的链表上来翻转,先保存当前节点的下一个节点,然后 ...

  4. SQL Server 迁移数据库 (四)备份和还原

    1. 备份 2. 复制 3. 粘贴 4. 还原 截图软件出问题了,估计重启下就好,但是备份还原比较简单,懂的都懂,马上下班了就不贴图了.

  5. [LOJ 2134][UOJ 132][BZOJ 4200][NOI 2015]小园丁与老司机

    [LOJ 2134][UOJ 132][BZOJ 4200][NOI 2015]小园丁与老司机 题意 给定平面上的 \(n\) 个整点 \((x_i,y_i)\), 一共有两个问题. 第一个问题是从原 ...

  6. xLua 学习

    xLua https://github.com/Tencent/xLua 文档 https://tencent.github.io/xLua/public/v1/guide/index.html FA ...

  7. LeetCode 283:移动零 Move Zeroes

    给定一个数组 nums,编写一个函数将所有 0 移动到数组的末尾,同时保持非零元素的相对顺序. Given an array nums, write a function to move all 0' ...

  8. LeetCode28——实现strStr()

    6月中下旬辞职在家,7 月份无聊的度过了一个月.8 月份开始和朋友两个人写项目,一个后台和一个 APP ,APP 需要对接蓝牙打印机.APP 和蓝牙打印机都没有搞过,开始打算使用 MUI 开发 APP ...

  9. 阿里云CentOS7.x安装nodejs及pm2

    对之前文章的修订 您将了解 CentOS下如何安装nodejs CentOS下如何安装NVM CentOS下如何安装git CentOS下如何安装pm2 适用对象 本文档介绍如何在阿里云CentOS系 ...

  10. PHP高级进阶梳理

    基础篇 1.深入理解计算机系统 2.现代操作系统 3.C程序设计语言 4.C语言数据结构和算法 5.Unix环境高级编程 6.TCP/IP网络通信详解 7.Java面向对象编程 8.Java编程思想 ...