Bert实战---情感分类
1.情感分析语料预处理
使用酒店评论语料,正面评论和负面评论各5000条,用BERT参数这么大的模型, 训练会产生严重过拟合,,泛化能力差的情况, 这也是我们下面需要解决的问题;
2.sigmoid二分类
回顾在BERT的训练中Next Sentence Prediction中, 我们取出$[cls]$对应的那一条向量, 然后把他映射成1个数值并用$sigmoid$函数激活:
$$\hat{y} = sigmoid(Linear(cls\_vector)) \quad \hat{y} \in (0, \ 1)$$
3.动态学习率和提前终止$(early \ stop)$
训练方式是,每个$epoch$,用训练集训练。对模型性能的衡量标准是$AUC$, $AUC$的衡量标准对二分类非常易用。当前$epoch$训练完毕之后, 用测试集衡量当前训练结果,并记下当前$epoch$的$AUC$, 如果当前的$AUC$较上一个$epoch$没有提升,那就降低学习率,实际操作是让当前的学习率降低$1/5$, 直到$10$个$epoch$测试集的$AUC$都没有提升, 就终止训练。
初始学习率是$1e-6$, 因为我们是在维基百科预训练语料的基础上进行训练的, 属于下游任务,只需要微调预训练模型就好。
4.解决过拟合问题
但在实际操作中, 使用$\hat{y} = sigmoid(Linear(cls\_vector)) \quad \hat{y} \in (0, \ 1)$的方式, 发现虽然在训练集和测试集上$AUC$都很高, 但实际随便输入一些从各种网上随便找的一些酒店评论后, 发现泛化能力不好. 这是因为训练数据集非常小,即使区分训练集和测试集,但因为整体数据形态比较单一,模型遇到自己没见过的情况就很容易无法做出正确判断,为了提高模型的泛化性能,尝试了另一种模型结构:

(1)mean-max-pool
一种把隐藏层的序列转换为一条向量的方式,其实就是沿着sequence length 的维度分别求均值和max,之后拼起来成为一条向量,之后同样映射成一个值再激活。
$X_{hidden}: [batch\_size, \ seq\_len, \ embedding\_dim]$
$mean\_pooled = mean(X_{hidden}, \ dimension=seq\_len) \quad [batch\_size, \ embedding\_dim]$
$max\_pooled = max(X_{hidden}, \ dimension=seq\_len) \quad [batch\_size, \ embedding\_dim]$
$mean\_max\_pooled = concatenate(mean\_pooled, \ max\_pooled, \ dimension=embedding\_dim ) \quad [batch\_size, \ embedding\_dim * 2]$
上式中mean_max_pooled 也就是我们得到的一句话的数学表达,含有这句话的信息, 其实这也是一种DOC2VEC的方法, 也就是把一句话转换成一条向量,而且无论这句话有多长,转换出来向量的维度都是一样的,之后可以用这些向量做一些分类聚类等任务。
下一步我们同样做映射, 之后用$sigmoid$激活:
$\hat{y} = sigmoid(Linear(mean\_max\_pooled)) \quad \hat{y} \in (0, \ 1)$
怎样理解这样的操作呢, 隐藏层就是一句话的数学表达, 我们求均值和最大值正数学表达对这句话的平均响应, 和最大响应, 之后我们用线性映射来识别这些响应, 从而得到模型的推断结果。
(2)weight decay权重衰减
其实就是L2 normalization,在PyTorch里有接口可以直接调用, 其实$L2$正则的作用就是防止参数的值变得过大或过小,我们可以设想一下,由于我们的训练数据很少,所以实际使用模型进行推断的时候有些字和词或者句子结构的组合模型都是没见过的, 模型里面参数的值很大的话会造成遇到某一些特别的句子或者词语的时候, 模型对句子的响应过大, 导致最终输出的值偏离实际, 其实我们希望模型更从容淡定一些, 所以我们加入$L2 \ normalization$.
除此之外, 我们预训练的BERT有6个transformer block, 我们在情感分析的时候,只用了3个,因为后面实在是参数太多,容易导致过拟合,所以在第三个transformer block之后,就截出隐藏层进行$pooling$了,后面的transformer block都没有用到。
(3)dropout
$dropout$设为了$0.4$,因为模型参数是在是太多,所以在训练的时候直接让$40\%$的参数失能,防止过拟合。
经过以上方法, 模型训练集和测试机的$AUC$都达到了$0.95$以上, 而且经过实际的测试, 模型也可以基本比较正确的分辨出语句的情感极性.
5.阈值微调
经过模型的推断, 输出的值介于0到1之间, 我们可以认为只要这个值在0.5以上, 就是正样本, 如果在0.5以下, 就是副样本, 其实这是不一定的, 0.5通常不是最佳的分类边界, 所以我写了一个用来寻找最佳阈值的脚本, 在./metrics/\_\_init\_\_.py里面.
这个脚本的方法是从0.01到0.99定义99个阈值, 高于阈值算正样本, 低于算副样本, 然后与测试集计算$f1 \ score$, 之后选出可以使$f1 \ score$最高的阈值, 在训练中, 每一个$epoch$都会运行一次寻找阈值的脚本.
参考文献:
【2】ymcui / Chinese-PreTrained-XLNet:预培训的中文XLNet(中文XLNet预训练模型)
【4】Self-Attention & Transformer - w55100的博客 - CSDN博客
【5】使用google的Bert获得中文的词向量 - u014553172的博客 - CSDN博客
【6】aespresso/a_journey_into_math_of_ml: 汉语自然语言处理视频教程-开源学习资料
Bert实战---情感分类的更多相关文章
- 使用BERT进行情感分类预测及代码实例
文章目录 0. BERT介绍 1. BERT配置 1.1. clone BERT 代码 1.2. 数据处理 1.2.1预训练模型 1.2.2数据集 训练集 测试集 开发集 2. 修改代码 2.1 加入 ...
- 使用bert进行情感分类
2018年google推出了bert模型,这个模型的性能要远超于以前所使用的模型,总的来说就是很牛.但是训练bert模型是异常昂贵的,对于一般人来说并不需要自己单独训练bert,只需要加载预训练模型, ...
- 基于Bert的文本情感分类
详细代码已上传到github: click me Abstract: Sentiment classification is the process of analyzing and reaso ...
- 在Keras中用Bert进行情感分析
之前在BERT实战——基于Keras一文中介绍了两个库 keras_bert 和 bert4keras 但是由于 bert4keras 处于开发阶段,有些函数名称和位置等等发生了变化,那篇文章只用了 ...
- 关于情感分类(Sentiment Classification)的文献整理
最近对NLP中情感分类子方向的研究有些兴趣,在此整理下个人阅读的笔记(持续更新中): 1. Thumbs up? Sentiment classification using machine lear ...
- kaggle——Bag of Words Meets Bags of Popcorn(IMDB电影评论情感分类实践)
kaggle链接:https://www.kaggle.com/c/word2vec-nlp-tutorial/overview 简介:给出 50,000 IMDB movie reviews,进行0 ...
- NLP文本情感分类传统模型+深度学习(demo)
文本情感分类: 文本情感分类(一):传统模型 摘自:http://spaces.ac.cn/index.php/archives/3360/ 测试句子:工信处女干事每月经过下属科室都要亲口交代24口交 ...
- kaggle之电影评论文本情感分类
电影文本情感分类 Github地址 Kaggle地址 这个任务主要是对电影评论文本进行情感分类,主要分为正面评论和负面评论,所以是一个二分类问题,二分类模型我们可以选取一些常见的模型比如贝叶斯.逻辑回 ...
- PaddlePaddle︱开发文档中学习情感分类(CNN、LSTM、双向LSTM)、语义角色标注
PaddlePaddle出教程啦,教程一部分写的很详细,值得学习. 一期涉及新手入门.识别数字.图像分类.词向量.情感分析.语义角色标注.机器翻译.个性化推荐. 二期会有更多的图像内容. 随便,帮国产 ...
随机推荐
- input 控件常用属性
- DQN的三大改进:
Double DQN:https://www.jianshu.com/p/fae51b5fe000 Prioritised Replay:https://www.jianshu.com/p/db14f ...
- 【2019.7.20 NOIP模拟赛 T1】A(A)(暴搜)
打表+暴搜 这道题目,显然是需要打表的,不过打表的方式可以有很多. 我是打了两个表,分别表示每个数字所需的火柴棒根数以及从一个数字到另一个数字,除了需要去除或加入的火柴棒外,至少需要几根火柴棒. 然后 ...
- 【转】ServletContext介绍及用法
1.1. 介绍 ServletContext官方叫servlet上下文.服务器会为每一个工程创建一个对象,这个对象就是ServletContext对象.这个对象全局唯一,而且工程内部的所有servl ...
- Paper | U-Net: Convolutional Networks for Biomedical Image Segmentation
目录 故事背景 U-Net 具体结构 损失 数据扩充 发表在2015 MICCAI.原本是一篇医学图像分割的论文,但由于U-Net杰出的网络设计,得到了8k+的引用. 摘要 There is larg ...
- Nginx+Tomcat+Memcache 实现session共享
Nginx + Tomcat + Memcache 实现session共享 1. Nginx 部署 1.上传源码包到服务器,解压安装 下载地址:http://nginx.org/en/download ...
- Ubuntu 16.04上anaconda安装和使用教程,安装jupyter扩展等 | anaconda tutorial on ubuntu 16.04
本文首发于个人博客https://kezunlin.me/post/23014ca5/,欢迎阅读最新内容! anaconda tutorial on ubuntu 16.04 Guide versio ...
- TensorFlow函数: tf.stop_gradient
停止梯度计算. 在图形中执行时,此操作按原样输出其输入张量. 在构建计算梯度的操作时,这个操作会阻止将其输入的共享考虑在内.通常情况下,梯度生成器将操作添加到图形中,通过递归查找有助于其计算的输入来计 ...
- jvm的组成入门
JVM的组成分为整体组成部分和运行时数据区组成部分. JVM的整体组成 JVM的整体组成可以分为4个部分:类加载器(Classloader).运行时数据区(Runtime Data Area).执行引 ...
- Kubernetes Job与CronJob(离线业务)
Kubernetes Job与CronJob(离线业务) Job Job分为普通任务(Job) 一次性执行 应用场景:离线数据处理,视频解码等业务 官方文档:https://kubernetes.i ...