最近一直在做多标签分类任务,学习了一种层次注意力模型,基本结构如下:

简单说,就是两层attention机制,一层基于词,一层基于句。

首先是词层面:

输入采用word2vec形成基本语料向量后,采用双向GRU抽特征:

一句话中的词对于当前分类的重要性不同,采用attention机制实现如下:

tensorflow代码实现如下:

···

def attention_word_level(self, hidden_state):

    """
input1:self.hidden_state: hidden_state:list,len:sentence_length,element:[batch_size*num_sentences,hidden_size*2]
input2:sentence level context vector:[batch_size*num_sentences,hidden_size*2]
:return:representation.shape:[batch_size*num_sentences,hidden_size*2]
"""
hidden_state_ = tf.stack(hidden_state, axis=1) # shape:[batch_size*num_sentences,sequence_length,hidden_size*2]
# 0) one layer of feed forward network
hidden_state_2 = tf.reshape(hidden_state_, shape=[-1,
self.hidden_size * 2]) # shape:[batch_size*num_sentences*sequence_length,hidden_size*2]
# hidden_state_:[batch_size*num_sentences*sequence_length,hidden_size*2];W_w_attention_sentence:[,hidden_size*2,,hidden_size*2]
hidden_representation = tf.nn.tanh(tf.matmul(hidden_state_2,
self.W_w_attention_word) + self.W_b_attention_word) # shape:[batch_size*num_sentences*sequence_length,hidden_size*2]
hidden_representation = tf.reshape(hidden_representation, shape=[-1, self.sequence_length,
self.hidden_size * 2]) # shape:[batch_size*num_sentences,sequence_length,hidden_size*2]
# attention process:1.get logits for each word in the sentence. 2.get possibility distribution for each word in the sentence. 3.get weighted sum for the sentence as sentence representation.
# 1) get logits for each word in the sentence.
hidden_state_context_similiarity = tf.multiply(hidden_representation,
self.context_vecotor_word) # shape:[batch_size*num_sentences,sequence_length,hidden_size*2]
attention_logits = tf.reduce_sum(hidden_state_context_similiarity,
axis=2) # shape:[batch_size*num_sentences,sequence_length]
# subtract max for numerical stability (softmax is shift invariant). tf.reduce_max:Computes the maximum of elements across dimensions of a tensor.
attention_logits_max = tf.reduce_max(attention_logits, axis=1,
keep_dims=True) # shape:[batch_size*num_sentences,1]
# 2) get possibility distribution for each word in the sentence.
p_attention = tf.nn.softmax(
attention_logits - attention_logits_max) # shape:[batch_size*num_sentences,sequence_length]
# 3) get weighted hidden state by attention vector
p_attention_expanded = tf.expand_dims(p_attention, axis=2) # shape:[batch_size*num_sentences,sequence_length,1]
# below sentence_representation'shape:[batch_size*num_sentences,sequence_length,hidden_size*2]<----p_attention_expanded:[batch_size*num_sentences,sequence_length,1];hidden_state_:[batch_size*num_sentences,sequence_length,hidden_size*2]
sentence_representation = tf.multiply(p_attention_expanded,
hidden_state_) # shape:[batch_size*num_sentences,sequence_length,hidden_size*2]
sentence_representation = tf.reduce_sum(sentence_representation,
axis=1) # shape:[batch_size*num_sentences,hidden_size*2]
return sentence_representation # shape:[batch_size*num_sentences,hidden_size*2]

···

句子层面和词层面基本相同

双向GRU输入,softmax计算attention

最后基于句子层面的输出,计算分类

指数损失

github源代码:https://github.com/zhaowei555/multi_label_classify/tree/master/han

NLP文本多标签分类---HierarchicalAttentionNetwork的更多相关文章

  1. fastText、TextCNN、TextRNN……这里有一套NLP文本分类深度学习方法库供你选择

    https://mp.weixin.qq.com/s/_xILvfEMx3URcB-5C8vfTw 这个库的目的是探索用深度学习进行NLP文本分类的方法. 它具有文本分类的各种基准模型,还支持多标签分 ...

  2. NLP文本分类方法汇总

    模型: FastText TextCNN TextRNN RCNN 分层注意网络(Hierarchical Attention Network) 具有注意的seq2seq模型(seq2seq with ...

  3. NLP文本分类

    引言 其实最近挺纠结的,有一点点焦虑,因为自己一直都期望往自然语言处理的方向发展,梦想成为一名NLP算法工程师,也正是我喜欢的事,而不是为了生存而工作.我觉得这也是我这辈子为数不多的剩下的可以自己去追 ...

  4. 浅谈NLP 文本分类/情感分析 任务中的文本预处理工作

    目录 浅谈NLP 文本分类/情感分析 任务中的文本预处理工作 前言 NLP相关的文本预处理 浅谈NLP 文本分类/情感分析 任务中的文本预处理工作 前言 之所以心血来潮想写这篇博客,是因为最近在关注N ...

  5. LM-MLC 一种基于完型填空的多标签分类算法

    LM-MLC 一种基于完型填空的多标签分类算法 1 前言 本文主要介绍本人在全球人工智能技术创新大赛[赛道一]设计的一种基于完型填空(模板)的多标签分类算法:LM-MLC,该算法拟合能力很强能感知标签 ...

  6. CSS.02 -- 样式表 及标签分类(块、行、行内块元素)、CSS三大特性、背景属性

    样式表书写位置  内嵌式写法 <head> <style type="text/css"> 样式表写法 </style> </head&g ...

  7. html(常用标签,标签分类),页面模板, CSS(css的三种引入方式),三种引入方式优先级

    HTML 标记语言为非编程语言负责完成页面的结构 组成: 标签:被<>包裹的由字母开头,可以结合合法字符( -|数字 ),能被浏览器解析的特殊符号,标签有头有尾 指令:被<>包 ...

  8. 从零开始学 Web 之 CSS(二)文本、标签、特性

    大家好,这里是「 Daotin的梦呓 」从零开始学 Web 系列教程.此文首发于「 Daotin的梦呓 」公众号,欢迎大家订阅关注.在这里我会从 Web 前端零基础开始,一步步学习 Web 相关的知识 ...

  9. Python-HTML 最强标签分类

    编程: 使用(展示)数据 存储数据 处理数据 前端 1. 前端是做什么的? 2. 我们为什么要学前端? 3. 前端都有哪些内容? 1. HTML 2. CSS 3. JavaScript 4.jQue ...

随机推荐

  1. 【译】使用 WebView2 将最好的 Web 带到 .NET 桌面应用程序中

    在去年的 Build 大会上,我们引入了 WebView2,这是一个浏览器控件,可以用新的基于 Chrome 的 Microsoft Edge 来呈现 Web 内容(HTML / CSS / Java ...

  2. mysql-2-where

    #进阶2:条件查询 /* 语法: SELECT 查询列表 FROM 表名 WHERE 筛选条件 分类: 1.按条件表达式筛选:> < = != <> >= <= 2 ...

  3. 【题解】CF1368C Even Picture

    \(\color{purple}{Link}\) \(\text{Solution:}\) 这是一道构造题. 题目要求恰好有\(n\)个点的四周全都是灰色点,所以直接输正方形是不行了. 考虑\(k=1 ...

  4. [C#.NET 拾遗补漏]09:数据标注与数据校验

    数据标注(Data Annotation)是类或类成员添加上下文信息的一种方式,在 C# 通常用特性(Attribute)类来描述.它的用途主要可以分为下面这三类: 验证 Validation:向数据 ...

  5. 在java中使用SFTP协议安全的传输文件

    本文介绍在Java中如何使用基于SSH的文件传输协议(SFTP)将文件从本地上传到远程服务器,或者将文件在两个服务器之间安全的传输.我们先来了解一下这几个协议 SSH 是较可靠,专为远程登录会话和其他 ...

  6. 【LGR-070】洛谷 3 月月赛-官方题解

    本次免费为大家提供[LGR-070]洛谷 3 月月赛的官方题解,点个赞再走呗! 代码就不上了,大家可以到别的博客上去找找!希望这篇博客能对你有所帮助!

  7. mysql物理优化器代价模型分析【原创】

    1. 引言 mysql的sql server在根据where condition检索数据的时候,一般会有多种数据检索的方法,其会根据各种数据检索方法代价的大小,选择代价最小的那个数据检索方法. 比如说 ...

  8. oracle统计同一字段0和1

    SELECT 班级表.班级编号,班级表.班级名称,SUM(DECODE(性别, '1', 1)) 女生人数,SUM(DECODE(性别, '0', 1)) 男生人数FROM 学生表, 班级表WHERE ...

  9. jquery $.ajax 获取josn数据

    <script type="text/javascript" src="jquery-1.9.1.js"></script> <s ...

  10. git -- Authentication failed for 报错如何解决?

    昨天拉代码拉不下来,报这个错误:fatal: Authentication failed for .... 有很多网上的解释是 $  git config --global --replace-all ...