BiLSTM-CRF学习笔记(原理和理解) 维特比
BiLSTM-CRF 被提出用于NER或者词性标注,效果比单纯的CRF或者lstm或者bilstm效果都要好。
根据pytorch官方指南(https://pytorch.org/tutorials/beginner/nlp/advanced_tutorial.html#bi-lstm-conditional-random-field-discussion),实现了BiLSTM-CRF一个toy级别的源码。下面是我个人的学习理解过程。
1. LSTM
LSTM的原理前人已经解释的非常清楚了:https://zhuanlan.zhihu.com/p/32085405
BiLSTM-CRF中,BiLSTM部分主要用于,根据一个单词的上下文,给出当前单词对应标签的概率分布,可以把BiLSTM看成一个编码层。
比如,对于标签集{N, V, O}和单词China,BiLSTM可能输出形如(0.88,-1.23,0.03)的非归一化概率分布。
这个分布我们看作是crf的特征分布输入,那么在CRF中我们需要学习的就是特征转移概率。
2. CRF
主要讲一下代码中要用到的CRF的预测(维特比解码)
维特比算法流程:
1.求出位置1的各个标记的非规范化概率δ1(j)δ1(j)
2.由递推公式(前后向概率计算)
每一步都保留当前所有可能的状态ll 对应的最大的非规范化概率,
并将最大非规范化概率状态对应的路径(当前状态得到最大概率时上一步的状态yiyi)记录
Ψi(l)=argmax(1≤j≤m){δi−1(j)+w∗Fi(yi−1=j,yi=l,x)}=argmaxδi(l),l=1,2,…,mΨi(l)=argmax(1≤j≤m){δi−1(j)+w∗Fi(yi−1=j,yi=l,x)}=argmaxδi(l),l=1,2,…,m
就是PijPij的取值有m*m个,对每一个yjyj,都确定一个(而不是可能的m个)能最大化概率的yiyi状态
3.递推到i=ni=n时终止
这时候求得非规范化概率的最大值为
最优路径终点
4.递归路径
由最优路径终点递归得到的最优路径(由当前最大概率状态状态对应的上一步状态,然后递归)
求得最优路径:
3. 损失函数
最后由CRF输出,损失函数的形式主要由CRF给出
在BiLSTM-CRF中,给定输入序列X,网络输出对应的标注序列y,得分为
(转移概率和状态概率之和)
利用softmax函数,我们为每一个正确的tag序列y定义一个概率值
在训练中,我们的目标就是最大化概率p(y│X) ,怎么最大化呢,用对数似然(因为p(y│X)中存在指数和除法,对数似然可以化简这些运算)
对数似然形式如下:
最大化这个对数似然,就是最小化他的相反数:
¥−log(p(y│X))=log(∑y′∈YXes(X,y′))−S(X,y)−log(p(y│X))=log(∑y′∈YXes(X,y′))−S(X,y)$
(loss function/object function)
最小化可以借助梯度下降实现
在对损失函数进行计算的时候,前一项S(X,y)S(X,y)很容易计算,
后一项log(∑y′∈YXes(X,y′))log(∑y′∈YXes(X,y′))比较复杂,计算过程中由于指数较大常常会出现上溢或者下溢,
由公式 log∑e(xi)=a+log∑e(xi−a)log∑e(xi)=a+log∑e(xi−a),可以借助a对指数进行放缩,通常a取xixi的最大值(即a=max[Xi]a=max[Xi]),这可以保证指数最大不会超过0,于是你就不会上溢出。即便剩余的部分下溢出了,你也能得到一个合理的值。
又因为log(∑yelog(∑xex)+y)log(∑yelog(∑xex)+y),在loglog取ee作为底数的情况下,可以化简为
log(∑yey∗elog(∑xex))=log(∑yey∗∑xex)=log(∑y∑xex+y)log(∑yey∗elog(∑xex))=log(∑yey∗∑xex)=log(∑y∑xex+y)。
log_sum_exp因为需要计算所有路径,那么在计算过程中,计算每一步路径得分之和和直接计算全局得分是等价的,就可以大大减少计算时间。
当前的分数可以由上一步的总得分+转移得分+状态得分得到,这也是pytorch范例中
next_tag_var = forward_var + trans_score + emit_score
的由来
注意,由于程序中比较好选一整行而不是一整列,所以调换i,j的含义,t[i][j]表示从j状态转移到i状态的转移概率
直接分析源码的前向传播部分,其中_get_lstm_features函数调用了pytorch的BiLSTM
def forward(self, sentence):
"""
重写前向传播
:param sentence: 输入的句子序列
:return:返回分数和标记序列
"""
lstm_feats = self._get_lstm_features(sentence)
score, tag_seq = self._viterbi_decode(lstm_feats)
return score, tag_seq
源码的维特比算法实现,在训练结束,还要使用该算法进行预测
def _viterbi_decode(self, feats):
"""
使用维特比算法预测
:param feats:lstm的所有输出
:return:返回最大概率和最优路径
"""
backpointers = []
# step1. 初始化
init_vvars = torch.full((1, self.tagset_size), -1000.)
# 初始化第一步的转移概率
init_vvars[0][self.tag_to_idx[START_TAG]] = 0
# 初始化每一步的非规范化概率
forward_var = init_vvars
# step2. 递推
# 遍历每一个单词通过bilstm输出的概率分布
for feat in feats:
# 每次循环重新统计
bptrs_t = []
viterbivars_t = []
for next_tag in range(self.tagset_size):
# 根据维特比算法
# 下一个tag_i+1的非归一化概率是上一步概率加转移概率(势函数和势函数的权重都统一看成转移概率的一部分)
next_tag_var = forward_var + self.transitions[next_tag]
# next_tag_var = tensor([[-3.8879e-01, 1.5657e+00, 1.7734e+00, -9.9964e+03, -9.9990e+03]])
# 计算所有前向概率(?)
# CRF是单步线性链马尔可夫,所以每个状态只和他上1个状态有关,可以用二维的概率转移矩阵表示
# 保存当前最大状态
best_tag_id = argmax(next_tag_var)
# best_tag_id = torch.argmax(next_tag_var).item()
bptrs_t.append(best_tag_id)
# 从一个1*N向量中取出一个值(标量),将这个标量再转换成一维向量
viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
# viterbivars 长度为self.tagset_size,对应feat的维度
forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
# 记录每一个时间i,每个状态取值l取最大非规范化概率对应的上一步状态
backpointers.append(bptrs_t)
# step3. 终止
terminal_var = forward_var + self.transitions[self.tag_to_idx[STOP_TAG]]
best_tag_id = argmax(terminal_var)
path_score = terminal_var[0][best_tag_id]
# step4. 返回路径
best_path = [best_tag_id]
for bptrs_t in reversed(backpointers):
best_tag_id = bptrs_t[best_tag_id]
best_path.append(best_tag_id)
# Pop off the start tag (we dont want to return that to the caller)
start = best_path.pop()
assert start == self.tag_to_idx[START_TAG] # Sanity check
best_path.reverse()
return path_score, best_path
源码的损失函数计算
def neg_log_likelihood(self, sentence, tags):
"""
实现负对数似然函数
:param sentence:
:param tags:
:return:
"""
# 返回句子中每个单词对应的标签概率分布
feats = self._get_lstm_features(sentence)
forward_score = self._forward_alg(feats)
gold_score = self._score_sentence(feats, tags) # 输出路径的得分(S(X,y))
# 返回负对数似然函数的结果
return forward_score - gold_score
def _forward_alg(self, feats):
"""
使用前向算法计算损失函数的第一项log(\sum(exp(S(X,y’))))
:param feats: 从BiLSTM输出的特征
:return: 返回
"""
init_alphas = torch.full((1, self.tagset_size), -10000.)
init_alphas[0][self.tag_to_idx[START_TAG]] = 0.
forward_var = init_alphas
for feat in feats:
# 存放t时刻的 概率状态
alphas_t = []
for current_tag in range(self.tagset_size):
# lstm输出的是非归一化分布概率
emit_score = feat[current_tag].view(1, -1).expand(1, self.tagset_size)
# self.transitions[current_tag] 就是从上一时刻所有状态转移到当前某状态的非归一化转移概率
# 取出的转移矩阵的行是一维的,这里调用view函数转换成二维矩阵
trans_score = self.transitions[current_tag].view(1, -1)
# trans_score + emit_score 等于所有特征函数之和
# forward 是截至上一步的得分
current_tag_var = forward_var + trans_score + emit_score
alphas_t.append(log_sum_exp(current_tag_var).view(1))
forward_var = torch.cat(alphas_t).view(1, -1) # 调用view函数转换成1*N向量
terminal_var = forward_var + self.transitions[self.tag_to_idx[STOP_TAG]]
alpha = log_sum_exp(terminal_var)
return alpha
def _score_sentence(self, feats, tags):
"""
返回S(X,y)
:param feats: 从BiLSTM输出的特征
:param tags: CRF输出的标记路径
:return:
"""
score = torch.zeros(1)
tags = torch.cat([torch.tensor([self.tag_to_idx[START_TAG]], dtype=torch.long),tags])
for i, feat in enumerate(feats):
score = score + self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]
score = score + self.transitions[self.tag_to_idx[STOP_TAG],tags[-1]]
return score
BiLSTM-CRF学习笔记(原理和理解) 维特比的更多相关文章
- Java四种引用--《深入理解Java虚拟机》学习笔记及个人理解(四)
Java四种引用--<深入理解Java虚拟机>学习笔记及个人理解(四) 书上P65. StrongReference(强引用) 类似Object obj = new Object() 这类 ...
- Java虚拟机内存溢出异常--《深入理解Java虚拟机》学习笔记及个人理解(三)
Java虚拟机内存溢出异常--<深入理解Java虚拟机>学习笔记及个人理解(三) 书上P39 1. 堆内存溢出 不断地创建对象, 而且保证创建的这些对象不会被回收即可(让GC Root可达 ...
- 【学习笔记】深入理解js原型和闭包(0)——目录
文章转载:https://www.cnblogs.com/wangfupeng1988/p/4001284.html 说明: 本篇文章一共16篇章,外加两篇后补的和一篇自己后来添加的学习笔记,一共19 ...
- 【疯狂Java学习笔记】【理解面向对象】
[学习笔记]1.Java语言是纯粹的面向对象语言,这体现在Java完全支持面向对象的三大基本特征:封装.继承.多态.抽象也是面向对象的重要组成部分,不过它不是面向对象的特征之一,因为所有的编程语言都需 ...
- CSS学习笔记09 简单理解BFC
引子 在讲BFC之前,先来看看一个例子 <!DOCTYPE html> <html lang="en"> <head> <meta cha ...
- elasticsearch学习笔记--原理介绍
前言:上一篇中我们对ES有了一个比较大概的概念,知道它是什么,干什么用的,今天给大家主要讲一下他的工作原理 介绍:ElasticSearch是一个基于Lucene的搜索服务器.它提供了一个分布式多用户 ...
- Java虚拟机运行时栈帧结构--《深入理解Java虚拟机》学习笔记及个人理解(二)
Java虚拟机运行时栈帧结构(周志明书上P237页) 栈帧是什么? 栈帧是一种数据结构,用于虚拟机进行方法的调用和执行. 栈帧是虚拟机栈的栈元素,也就是入栈和出栈的一个单元. 2018.1.2更新(在 ...
- 深入剖析Kubernetes学习笔记:深入理解镜像(08)
一.Python 应用案例环境 [root@k8s-node1 Flask]# pwd /opt/Dockerfile/Flask [root@k8s-node1 Flask]# ll total 1 ...
- Vxlan学习笔记——原理
1. 为什么需要Vxlan 普通的VLAN数量只有4096个,无法满足大规模云计算IDC的需求,而IDC为何需求那么多VLAN呢,因为目前大部分IDC内部结构主要分为两种L2,L3.L2结构里面,所有 ...
- 【学习笔记】深入理解HTTP协议
参考:关于HTTP协议,一篇就够了,感谢作者认真细致的总结,本文在理解的基础上修改了内容,加深印象的同时也希望对大家有所帮助 HTTP是什么? HTTP协议是Hyper Text Transfer P ...
随机推荐
- Win7系统中wmiprvse.exe占用CPU高如何解决
该进程的详细路径是在:C:\WINDOWS\System32\Wbem 我们可以在任务管理器中“wmiprvse.exe”进程上单击右键,选择“打开文件位置”即可看到,如果该文件不在该文件夹中,那么 ...
- python中bisect模块的使用
一般用于二分查找, 当然列表应该是有序表 参考于: http://blog.csdn.net/xiaocaiju/article/details/6975714
- 【JZOJ3297】【SDOI2013】逃考(escape)
Mission 高考又来了,对于不认真读书的来讲真不是个好消息.为了小杨能在家里认真读书,他的亲戚决定驻扎在他的家里监督他学习,有爷爷奶奶.外公外婆.大舅.大嫂.阿姨-- 小杨实在是忍无可忍了,这种生 ...
- zt 比较各JAX-RS实现:CXF,Jersey,RESTEasy,Restlet
http://news.misuland.com/20080926/1222396399411.html JavaSE/EE执行委员批准了JSR 311 JAX-RS作为支持RESTful web服务 ...
- phpcms推送文章同时推送自定义字段
首先进入phpcms后台,模型管理-字段管理里,新建字段,新建字段必须是主表字段,如图所示 2 来到网站根目录,寻找phpcms\modules\content\classes\push_api.cl ...
- 技巧专题3(cdq分治、整体二分等)
cdq分治与整体二分 cdq来源于2008年国家集训队作业陈丹琦(雅礼巨佬),用一个log的代价完成从静态到动态(很多时候是减少时间那一维的). 对于一个时间段[L, R],我们取mid = (L + ...
- ThinkPHP中_after_update、_before_update等的用法
https://blog.csdn.net/aslackers/article/details/50339163 TP系统\Think\Model类里隐藏了几个有用的方法: _before_inser ...
- python中字母的大小写转换
1. capitalize(): 首字母大写,其余全部小写 2. upper() :全转换成大写 3. lower(): 全转换成小写 4. title() :标题首字大写,如 &q ...
- IDEA(JAVA)使用json
首先介绍一下json SON是一种取代XML的数据结构,和xml相比,它更小巧但描述能力却不差,由于它的小巧所以网络传输数据将减少更多流量从而加快速度. JSON就是一串字符串 只不过元素会使用特定的 ...
- SPSS和Mplus如何做非线性中介调节效应分析?如倒U形曲线
SPSS和Mplus如何做非线性中介调节效应分析?如倒U形曲线 传统的线性回归模型用的比较多,但有时候变量之间的关系更符合非线性关系,此时使用非线性模型其拟合度会更好,模型预测效果更佳.在非线性关系中 ...