自然语言处理任务,比如caption generation(图片描述文本生成)、机器翻译中,都需要进行词或者字符序列的生成。常见于seq2seq模型或者RNNLM模型中。

这篇博文主要介绍文本生成解码过程中用的greedy search 和beam search算法实现。其中,greedy search 比较简单,着重介绍beam search算法的实现。

我们在文本生成解码时,实际上是想找对最有的文本序列,或者说是概率,可能性最大的文本序列。而要在全局搜索这个最有解空间,往往是不可能的(因为词典太大),建设生成序列长度为N,词典大小为V, 则复杂度为 V^N次方。这实际上是一个NP难题。退而求其次,我们使用启发式算法,来找到可能的最优解,或者说足够好的解。

假设序列数据(假设每个位置词的概率都已经给出):

data = [[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)

1、greedy search decoder

非常简单,我们用argmax就可以实现

# greedy decoder
def greedy_decoder(data):
# 每一行最大概率词的索引
return [argmax(s) for s in data]

完整代码

from numpy import array
from numpy import argmax # greedy decoder
def greedy_decoder(data):
# 每一行最大概率词的索引
return [argmax(s) for s in data] # 定义一个句子,长度为10,词典大小为5
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)
# 使用greedy search解码
result = greedy_decoder(data)
print(result)

2. beam search

与greedy search不同,beam search返回多个最有可能的解码结果(具体多少个,由参数k执行)。

greedy search每一步都都采用最大概率的词,而beam search每一步都保留k个最有可能的结果,在每一步,基于之前的k个可能最优结果,继续搜索下一步。(参考下面示意图理解)

示例图(设置返回解码结果为2个):

from math import log
from numpy import array
from numpy import argmax # beam search
def beam_search_decoder(data, k):
sequences = [[list(), 1.0]]
for row in data:
all_candidates = list()
for i in range(len(sequences)):
seq, score = sequences[i]
for j in range(len(row)):
candidate = [seq + [j], score * -log(row[j])]
all_candidates.append(candidate)
# 所有候选根据分值排序
ordered = sorted(all_candidates, key=lambda tup:tup[1])
# 选择前k个
sequences = ordered[:k]
return sequences # 定义一个句子,长度为10,词典大小为5
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)
# 解码
result = beam_search_decoder(data, 3)
# print result
for seq in result:
print(seq)

相关资料:

实现nlp文本生成中的beam search解码器的更多相关文章

  1. 斯坦福NLP课程 | 第15讲 - NLP文本生成任务

    作者:韩信子@ShowMeAI,路遥@ShowMeAI,奇异果@ShowMeAI 教程地址:http://www.showmeai.tech/tutorials/36 本文地址:http://www. ...

  2. Beam Search

    Q: 什么是Beam Search? 它在NLP中的什么场景里会⽤到? 传统的广度优先策略能够找到最优的路径,但是在搜索空间非常大的情况下,内存占用是指数级增长,很容易造成内存溢出,因此提出了beam ...

  3. 【NLP】选择目标序列:贪心搜索和Beam search

    构建seq2seq模型,并训练完成后,我们只要将源句子输入进训练好的模型,执行一次前向传播就能得到目标句子,但是值得注意的是: seq2seq模型的decoder部分实际上相当于一个语言模型,相比于R ...

  4. NLP相关问题中文本数据特征表达初探

    1. NLP问题简介 0x1:NLP问题都包括哪些内涵 人们对真实世界的感知被成为感知世界,而人们用语言表达出自己的感知视为文本数据.那么反过来,NLP,或者更精确地表达为文本挖掘,则是从文本数据出发 ...

  5. 关于 Image Caption 中测试时用到的 beam search算法

    关于beam search 之前组会中没讲清楚的 beam search,这里给一个案例来说明这种搜索算法. 在 Image Caption的测试阶段,为了得到输出的语句,一般会选用两种搜索方式,一种 ...

  6. 浅谈NLP 文本分类/情感分析 任务中的文本预处理工作

    目录 浅谈NLP 文本分类/情感分析 任务中的文本预处理工作 前言 NLP相关的文本预处理 浅谈NLP 文本分类/情感分析 任务中的文本预处理工作 前言 之所以心血来潮想写这篇博客,是因为最近在关注N ...

  7. Beam Search快速理解及代码解析(上)

    Beam Search 简单介绍一下在文本生成任务中常用的解码策略Beam Search(集束搜索). 生成式任务相比普通的分类.tagging等NLP任务会复杂不少.在生成的时候,模型的输出是一个时 ...

  8. Beam Search快速理解及代码解析

    目录 Beam Search快速理解及代码解析(上) Beam Search 贪心搜索 Beam Search Beam Search代码解析 准备初始输入 序列扩展 准备输出 总结 Beam Sea ...

  9. 使用 paddle来进行文本生成

    paddle 简单介绍 paddle 是百度在2016年9月份开源的深度学习框架. 就我最近体验的感受来说的它具有几大优点: 1. 本身内嵌了许多和实际业务非常贴近的模型比如个性化推荐,情感分析,词向 ...

随机推荐

  1. Java基础-IO流对象之打印流(PrintStream与PrintWriter)

    Java基础-IO流对象之打印流(PrintStream与PrintWriter) 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 一.打印流的特性 打印对象有两个,即字节打印流(P ...

  2. jquery中的$.each跳出循环并获取返回值

    jquery中使用each方法,类似于while或者for循环 一种退出循环的方式是:当数据全部被遍历完成,自然退出, 另一种方法是:当我们在循环中返回一个false时,会跳出循环 这里来使用第二种方 ...

  3. bzoj千题计划169:bzoj2463: [中山市选2009]谁能赢呢?

    http://www.lydsy.com/JudgeOnline/problem.php?id=2463 n为偶数时,一定可以被若干个1*2 矩形覆盖 先手每次从矩形的一端走向另一端,后手每次走向一个 ...

  4. bzoj千题计划119:bzoj1029: [JSOI2007]建筑抢修

    http://www.lydsy.com/JudgeOnline/problem.php?id=1029 把任务按截止时间从小到大排序 如果当前时间+当前任务耗时<=当前任务截止时间,把这个任务 ...

  5. JavaScript 数组元素排序

    var sortArray = new Array(3,6,8888,66); // 元素必须是数字 sortArray.sort(function(a,b){return a-b}); // a-b ...

  6. POJ 2438 Children’s Dining (哈密顿图模板题之巧妙建反图 )

    题目链接 Description Usually children in kindergarten like to quarrel with each other. This situation an ...

  7. bootstrap-table 应用

    更多内容推荐微信公众号,欢迎关注: 前端代码:js初始化表格,使用服务器端分页:<!DOCTYPE html> <html> <head> <meta cha ...

  8. [转]程序进行性能分析工具gprof使用入门

    性能分析工具 软件的性能是软件质量的重要考察点,不论是在线服务程序还是离线程序,甚至是终端应用,性能都是用户体验的关键.这里说的性能重大的范畴来讲包括了性能和稳定性两个方面,我们在做软件测试的时候也是 ...

  9. python模块-platform

    #author:Blood_Zero #coding:utf- import platform print dir(platform) #获取platform函数功能 platform.archite ...

  10. 利用SSLStrip截获https协议--抓取邮箱等密码

    1.SSL解析 SSL 是 Secure Socket Layer 的简称, 中文意思是安全套接字层,由 NetScape公司所开发,用以保障在 Internet 上数据传输的安全,确保数据在网络的传 ...