decoder.py

"""
实现解码器
"""
import heapq import torch.nn as nn
import config
import torch
import torch.nn.functional as F
import numpy as np
import random
from chatbot.attention import Attention class Decoder(nn.Module):
def __init__(self):
super(Decoder,self).__init__() self.embedding = nn.Embedding(num_embeddings=len(config.target_ws),
embedding_dim=config.chatbot_decoder_embedding_dim,
padding_idx=config.target_ws.PAD) #需要的hidden_state形状:[1,batch_size,64]
self.gru = nn.GRU(input_size=config.chatbot_decoder_embedding_dim,
hidden_size=config.chatbot_decoder_hidden_size,
num_layers=config.chatbot_decoder_number_layer,
bidirectional=False,
batch_first=True,
dropout=config.chatbot_decoder_dropout) #假如encoder的hidden_size=64,num_layer=1 encoder_hidden :[2,batch_sizee,64] self.fc = nn.Linear(config.chatbot_decoder_hidden_size,len(config.target_ws))
self.attn = Attention(method="general")
self.fc_attn = nn.Linear(config.chatbot_decoder_hidden_size * 2, config.chatbot_decoder_hidden_size, bias=False) def forward(self, encoder_hidden,target,encoder_outputs):
# print("target size:",target.size())
#第一个时间步的输入的hidden_state
decoder_hidden = encoder_hidden #[1,batch_size,128*2]
#第一个时间步的输入的input
batch_size = encoder_hidden.size(1)
decoder_input = torch.LongTensor([[config.target_ws.SOS]]*batch_size).to(config.device) #[batch_size,1]
# print("decoder_input:",decoder_input.size()) #使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
decoder_outputs = torch.zeros([batch_size,config.chatbot_target_max_len,len(config.target_ws)]).to(config.device) if random.random()>0.5: #teacher_forcing机制 for t in range(config.chatbot_target_max_len):
decoder_output_t,decoder_hidden = self.forward_step(decoder_input,decoder_hidden,encoder_outputs)
decoder_outputs[:,t,:] = decoder_output_t #获取当前时间步的预测值
value,index = decoder_output_t.max(dim=-1)
decoder_input = index.unsqueeze(-1) #[batch_size,1]
# print("decoder_input:",decoder_input.size())
else:
for t in range(config.chatbot_target_max_len):
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden,encoder_outputs)
decoder_outputs[:, t, :] = decoder_output_t
#把真实值作为下一步的输入
decoder_input = target[:,t].unsqueeze(-1)
# print("decoder_input size:",decoder_input.size())
return decoder_outputs,decoder_hidden def forward_step(self,decoder_input,decoder_hidden,encoder_outputs):
'''
计算一个时间步的结果
:param decoder_input: [batch_size,1]
:param decoder_hidden: [1,batch_size,128*2]
:return:
''' decoder_input_embeded = self.embedding(decoder_input)
# print("decoder_input_embeded:",decoder_input_embeded.size()) #out:[batch_size,1,128*2]
#decoder_hidden :[1,bathc_size,128*2]
# print(decoder_hidden.size())
out,decoder_hidden = self.gru(decoder_input_embeded,decoder_hidden) ##### 开始attention ############
### 1. 计算attention weight
attn_weight = self.attn(decoder_hidden,encoder_outputs) #[batch_size,1,encoder_max_len]
### 2. 计算context vector
#encoder_ouputs :[batch_size,encoder_max_len,128*2]
context_vector = torch.bmm(attn_weight.unsqueeze(1),encoder_outputs).squeeze(1) #[batch_szie,128*2]
### 3. 计算 attention的结果
#[batch_size,128*2] #context_vector:[batch_size,128*2] --> 128*4
#attention_result = [batch_size,128*4] --->[batch_size,128*2]
attention_result = torch.tanh(self.fc_attn(torch.cat([context_vector,out.squeeze(1)],dim=-1)))
# attention_result = torch.tanh(torch.cat([context_vector,out.squeeze(1)],dim=-1))
#### attenion 结束 # print("decoder_hidden size:",decoder_hidden.size())
#out :【batch_size,1,hidden_size】 # out_squeezed = out.squeeze(dim=1) #去掉为1的维度
out_fc = F.log_softmax(self.fc(attention_result),dim=-1) #[bathc_size,vocab_size]
# print("out_fc:",out_fc.size())
return out_fc,decoder_hidden def evaluate(self,encoder_hidden,encoder_outputs): # 第一个时间步的输入的hidden_state
decoder_hidden = encoder_hidden # [1,batch_size,128*2]
# 第一个时间步的输入的input
batch_size = encoder_hidden.size(1)
decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device) # [batch_size,1]
# print("decoder_input:",decoder_input.size()) # 使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
decoder_outputs = torch.zeros([batch_size, config.chatbot_target_max_len, len(config.target_ws)]).to(
config.device) predict_result = []
for t in range(config.chatbot_target_max_len):
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden,encoder_outputs)
decoder_outputs[:, t, :] = decoder_output_t # 获取当前时间步的预测值
value, index = decoder_output_t.max(dim=-1)
predict_result.append(index.cpu().detach().numpy()) #[[batch],[batch]...]
decoder_input = index.unsqueeze(-1) # [batch_size,1]
# print("decoder_input:",decoder_input.size())
# predict_result.append(decoder_input)
#把结果转化为ndarray,每一行是一条预测结果
predict_result = np.array(predict_result).transpose()
return decoder_outputs, predict_result def evaluate_with_beam_search(self, encoder_hidden, encoder_outputs):
"""
使用beam search完成评估,只能输入一个句子,得到一个输出
:param encoder_hidden:
:param encoder_outputs:
:return:
"""
# 第一个时间步的输入的hidden_state
decoder_hidden = encoder_hidden # [1,batch_size,128*2]
# 第一个时间步的输入的input
batch_size = encoder_hidden.size(1)
assert batch_size == 1, "beam search的过程中,batch_size只能为1"
decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device) # [batch_size,1] prev_beam = Beam()
prev_beam.add(1, False, [decoder_input], decoder_input, decoder_hidden) while True:
cur_beam = Beam()
for prob, complete, seq_list, decoder_input, decoder_hidden in prev_beam:
if complete: # 有可能前一次已经到达eos了,但是概率不是最大的
cur_beam.add(prob, complete, seq_list, decoder_input, decoder_hidden)
else:
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden, encoder_outputs) value, index = torch.topk(decoder_output_t, config.beam_width)
# print("value index size:",value[0].size(),index[0].size())
for m, n in zip(value[0], index[0]):
# print("m,n size:",m.size(),n.size(),m,n)
cur_prob = prob * m.item()
decoder_input = torch.LongTensor([[n.item()]]).to(config.device)
cur_seq_list = seq_list + [decoder_input]
if n == config.target_ws.EOS:
cur_complete = True
else:
cur_complete = False
cur_beam.add(cur_prob, cur_complete, cur_seq_list, decoder_input, decoder_hidden) best_prob, best_complete, best_seq, _, _ = max(cur_beam)
if best_complete or len(best_seq) - 1 == config.chatbot_target_max_len: best_seq = [i.item() for i in best_seq]
if best_seq[0] == config.target_ws.SOS:
best_seq = best_seq[1:]
if best_seq[-1] == config.target_ws.EOS:
best_seq = best_seq[:-1]
return best_seq else:
prev_beam = cur_beam class Beam:
"""保存每一个时间步的数据""" def __init__(self):
self.heapq = list()
self.beam_width = config.beam_width def add(self, prob, complete, seq_list, decoder_input, decoder_hidden):
heapq.heappush(self.heapq, [prob, complete, seq_list, decoder_input, decoder_hidden])
# 保证最终只有一个beam width个结果
if len(self.heapq) > self.beam_width:
heapq.heappop(self.heapq) def __iter__(self):
for item in self.heapq:
yield item

  seq2seq.py

"""
完成seq2seq模型
"""
import torch.nn as nn
from chatbot.encoder import Encoder
from chatbot.decoder import Decoder class Seq2Seq(nn.Module):
def __init__(self):
super(Seq2Seq,self).__init__()
self.encoder = Encoder()
self.decoder = Decoder() def forward(self, input,input_len,target):
encoder_outputs,encoder_hidden = self.encoder(input,input_len)
decoder_outputs,decoder_hidden = self.decoder(encoder_hidden,target,encoder_outputs)
return decoder_outputs def evaluate(self,input,input_len):
encoder_outputs, encoder_hidden = self.encoder(input, input_len)
decoder_outputs, predict_result = self.decoder.evaluate(encoder_hidden,encoder_outputs)
return decoder_outputs,predict_result def evaluate_with_beam_search(self,input,input_len):
encoder_outputs, encoder_hidden = self.encoder(input, input_len)
best_seq = self.decoder.evaluate_with_beam_search(encoder_hidden, encoder_outputs)
return best_seq

  eval.py

"""
进行模型的评估
""" import torch
import torch.nn.functional as F
from chatbot.dataset import get_dataloader
from tqdm import tqdm
import config
import numpy as np
import pickle
from chatbot.seq2seq import Seq2Seq def eval():
model = Seq2Seq().to(config.device)
model.eval()
model.load_state_dict(torch.load("./models/model.pkl")) loss_list = []
data_loader = get_dataloader(train=False)
bar = tqdm(data_loader,total=len(data_loader),desc="当前进行评估")
with torch.no_grad():
for idx,(input,target,input_len,target_len) in enumerate(bar):
input = input.to(config.device)
target = target.to(config.device)
input_len = input_len.to(config.device) decoder_outputs,predict_result = model.evaluate(input,input_len) #[batch_Size,max_len,vocab_size]
loss = F.nll_loss(decoder_outputs.view(-1,len(config.target_ws)),target.view(-1),ignore_index=config.input_ws.PAD)
loss_list.append(loss.item())
bar.set_description("idx:{} loss:{:.6f}".format(idx,np.mean(loss_list)))
print("当前的平均损失为:",np.mean(loss_list)) def interface():
from chatbot.cut_sentence import cut
import config
#加载模型
model = Seq2Seq().to(config.device)
model.eval()
model.load_state_dict(torch.load("./models/model.pkl")) #准备待预测的数据
while True:
origin_input =input("me>>:")
# if "你是谁" in origin_input or "你叫什么" in origin_input:
# result = "我是小智。"
# elif "你好" in origin_input or "hello" in origin_input:
# result = "Hello"
# else:
_input = cut(origin_input, by_word=True)
input_len = torch.LongTensor([len(_input)]).to(config.device)
_input = torch.LongTensor([config.input_ws.transform(_input,max_len=config.chatbot_input_max_len)]).to(config.device) outputs,predict = model.evaluate(_input,input_len)
result = config.target_ws.inverse_transform(predict[0])
print("chatbot>>:",result) def interface_with_beamsearch():
from chatbot.cut_sentence import cut
import config
# 加载模型
model = Seq2Seq().to(config.device)
model.eval()
model.load_state_dict(torch.load("./models/model.pkl")) # 准备待预测的数据
while True:
origin_input = input("me>>:")
_input = cut(origin_input, by_word=True)
input_len = torch.LongTensor([len(_input)]).to(config.device)
_input = torch.LongTensor([config.input_ws.transform(_input, max_len=config.chatbot_input_max_len)]).to(
config.device) best_seq = model.evaluate_with_beam_search(_input, input_len)
result = config.target_ws.inverse_transform(best_seq)
print("chatbot>>:", result) if __name__ == '__main__':
# interface()
interface_with_beamsearch()

  

pytorch seq2seq闲聊机器人beam search返回结果的更多相关文章

  1. pytorch seq2seq闲聊机器人

    cut_sentence.py """ 实现句子的分词 注意点: 1. 实现单个字分词 2. 实现按照词语分词 2.1 加载词典 3. 使用停用词 "" ...

  2. pytorch seq2seq闲聊机器人加入attention机制

    attention.py """ 实现attention """ import torch import torch.nn as nn im ...

  3. 实现nlp文本生成中的beam search解码器

    自然语言处理任务,比如caption generation(图片描述文本生成).机器翻译中,都需要进行词或者字符序列的生成.常见于seq2seq模型或者RNNLM模型中. 这篇博文主要介绍文本生成解码 ...

  4. Beam Search快速理解及代码解析(下)

    Beam Search的问题 先解释一下什么要对Beam Search进行改进.因为Beam Search虽然比贪心强了不少,但还是会生成出空洞.重复.前后矛盾的文本.如果你有文本生成经验,一定对这些 ...

  5. Beam Search快速理解及代码解析

    目录 Beam Search快速理解及代码解析(上) Beam Search 贪心搜索 Beam Search Beam Search代码解析 准备初始输入 序列扩展 准备输出 总结 Beam Sea ...

  6. 【NLP】选择目标序列:贪心搜索和Beam search

    构建seq2seq模型,并训练完成后,我们只要将源句子输入进训练好的模型,执行一次前向传播就能得到目标句子,但是值得注意的是: seq2seq模型的decoder部分实际上相当于一个语言模型,相比于R ...

  7. Beam Search快速理解及代码解析(上)

    Beam Search 简单介绍一下在文本生成任务中常用的解码策略Beam Search(集束搜索). 生成式任务相比普通的分类.tagging等NLP任务会复杂不少.在生成的时候,模型的输出是一个时 ...

  8. Beam Search(集束搜索/束搜索)

    找遍百度也没有找到关于Beam Search的详细解释,只有一些比较泛泛的讲解,于是有了这篇博文. 首先给出wiki地址:http://en.wikipedia.org/wiki/Beam_searc ...

  9. 关于Beam Search

    Wiki定义:In computer science, beam search is a heuristic search algorithm that explores a graph by exp ...

随机推荐

  1. 基于STM32F030F4P9和STM32 CUBEMX 输出PWM波形

    STM32F030F4P9定时器功能比较丰富,在此记录项目中使用其自动输出PWM波形(频率:50HZ).CubeMX配置定时器如下图说明. 在此定时器基础时钟为48MHZ,配置中不做分频处理,预分频系 ...

  2. 从火车站车次公示栏来学Java读写锁

    Java多线程并发之读写锁 本文主要内容:读写锁的理论:通过生活中例子来理解读写锁:读写锁的代码演示:读写锁总结.通过理论(总结)-例子-代码-然后再次总结,这四个步骤来让大家对读写锁的深刻理解. 本 ...

  3. jQuery学习笔记01

    1.jQuery介绍 1.1什么是jQuery ? jQuery,顾名思义,也就是JavaScript和查询(Query),它就是辅助JavaScript开发的js类库. 1.2 jQuery核心思想 ...

  4. H - 蓬松的头发 HDU - 5504

    给你一个N个整数的序列. 你应该选择一些数字(至少一个),并使它们的乘积尽可能大. 它保证你在初始序列中选择的任何数的乘积的绝对值不会大于263−1. Input 在第一行有一个数字T(表示样例数). ...

  5. C#通用类库整理--字符串处理类

    在程序开发中通常需要将字符串转为自己想要的结果,以下三个类库主要实现: 1.GetStrArray(string str, char speater, bool toLower)  把字符串按照分隔符 ...

  6. Socket探索1-两种Socket服务端实现

    介绍 一次简单的Socket探索之旅,分别对Socket服务端的两种方式进行了测试和解析. CommonSocket 代码实现 实现一个简单的Socket服务,基本功能就是接收消息然后加上结束消息时间 ...

  7. MES Auto Logout

    如果您在车间使用MES,可能存在这种情况有人在仍然登录的情况下偶尔离开终端.如果一段时间不使用终端,我们是否可以让用户自动注销. 1 首先,我们有一条using语句: using System.Run ...

  8. 【Mongodb】聚合查询 && 固定集合

    概述 数据存储是为了可查询,统计.若数据只需存储,不需要查询,这种数据也没有多大价值 本篇介绍Mongodb 聚合查询(Aggregation) 固定集合(Capped Collections) 准备 ...

  9. go 锁和sync包

    一.什么是锁? sync.Mutex 是一个互斥锁,它的作用是守护在临界区入口来确保同一时间只能有一个线程进入临界区 在 sync 包中还有一个 RWMutex 锁:他能通过 RLock() 来允许同 ...

  10. c++动态数组的优点,创建和删除

    动态数组可以有两种使用方式: 1:不能预先知道数组的大小使用动态数组 传统数组(静态数组)是需要在程序运行前,就指定大小,比如说 int i = 10; int a[i]; 这种就是不合法的. 因为函 ...