decoder.py

"""
实现解码器
"""
import heapq import torch.nn as nn
import config
import torch
import torch.nn.functional as F
import numpy as np
import random
from chatbot.attention import Attention class Decoder(nn.Module):
def __init__(self):
super(Decoder,self).__init__() self.embedding = nn.Embedding(num_embeddings=len(config.target_ws),
embedding_dim=config.chatbot_decoder_embedding_dim,
padding_idx=config.target_ws.PAD) #需要的hidden_state形状:[1,batch_size,64]
self.gru = nn.GRU(input_size=config.chatbot_decoder_embedding_dim,
hidden_size=config.chatbot_decoder_hidden_size,
num_layers=config.chatbot_decoder_number_layer,
bidirectional=False,
batch_first=True,
dropout=config.chatbot_decoder_dropout) #假如encoder的hidden_size=64,num_layer=1 encoder_hidden :[2,batch_sizee,64] self.fc = nn.Linear(config.chatbot_decoder_hidden_size,len(config.target_ws))
self.attn = Attention(method="general")
self.fc_attn = nn.Linear(config.chatbot_decoder_hidden_size * 2, config.chatbot_decoder_hidden_size, bias=False) def forward(self, encoder_hidden,target,encoder_outputs):
# print("target size:",target.size())
#第一个时间步的输入的hidden_state
decoder_hidden = encoder_hidden #[1,batch_size,128*2]
#第一个时间步的输入的input
batch_size = encoder_hidden.size(1)
decoder_input = torch.LongTensor([[config.target_ws.SOS]]*batch_size).to(config.device) #[batch_size,1]
# print("decoder_input:",decoder_input.size()) #使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
decoder_outputs = torch.zeros([batch_size,config.chatbot_target_max_len,len(config.target_ws)]).to(config.device) if random.random()>0.5: #teacher_forcing机制 for t in range(config.chatbot_target_max_len):
decoder_output_t,decoder_hidden = self.forward_step(decoder_input,decoder_hidden,encoder_outputs)
decoder_outputs[:,t,:] = decoder_output_t #获取当前时间步的预测值
value,index = decoder_output_t.max(dim=-1)
decoder_input = index.unsqueeze(-1) #[batch_size,1]
# print("decoder_input:",decoder_input.size())
else:
for t in range(config.chatbot_target_max_len):
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden,encoder_outputs)
decoder_outputs[:, t, :] = decoder_output_t
#把真实值作为下一步的输入
decoder_input = target[:,t].unsqueeze(-1)
# print("decoder_input size:",decoder_input.size())
return decoder_outputs,decoder_hidden def forward_step(self,decoder_input,decoder_hidden,encoder_outputs):
'''
计算一个时间步的结果
:param decoder_input: [batch_size,1]
:param decoder_hidden: [1,batch_size,128*2]
:return:
''' decoder_input_embeded = self.embedding(decoder_input)
# print("decoder_input_embeded:",decoder_input_embeded.size()) #out:[batch_size,1,128*2]
#decoder_hidden :[1,bathc_size,128*2]
# print(decoder_hidden.size())
out,decoder_hidden = self.gru(decoder_input_embeded,decoder_hidden) ##### 开始attention ############
### 1. 计算attention weight
attn_weight = self.attn(decoder_hidden,encoder_outputs) #[batch_size,1,encoder_max_len]
### 2. 计算context vector
#encoder_ouputs :[batch_size,encoder_max_len,128*2]
context_vector = torch.bmm(attn_weight.unsqueeze(1),encoder_outputs).squeeze(1) #[batch_szie,128*2]
### 3. 计算 attention的结果
#[batch_size,128*2] #context_vector:[batch_size,128*2] --> 128*4
#attention_result = [batch_size,128*4] --->[batch_size,128*2]
attention_result = torch.tanh(self.fc_attn(torch.cat([context_vector,out.squeeze(1)],dim=-1)))
# attention_result = torch.tanh(torch.cat([context_vector,out.squeeze(1)],dim=-1))
#### attenion 结束 # print("decoder_hidden size:",decoder_hidden.size())
#out :【batch_size,1,hidden_size】 # out_squeezed = out.squeeze(dim=1) #去掉为1的维度
out_fc = F.log_softmax(self.fc(attention_result),dim=-1) #[bathc_size,vocab_size]
# print("out_fc:",out_fc.size())
return out_fc,decoder_hidden def evaluate(self,encoder_hidden,encoder_outputs): # 第一个时间步的输入的hidden_state
decoder_hidden = encoder_hidden # [1,batch_size,128*2]
# 第一个时间步的输入的input
batch_size = encoder_hidden.size(1)
decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device) # [batch_size,1]
# print("decoder_input:",decoder_input.size()) # 使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
decoder_outputs = torch.zeros([batch_size, config.chatbot_target_max_len, len(config.target_ws)]).to(
config.device) predict_result = []
for t in range(config.chatbot_target_max_len):
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden,encoder_outputs)
decoder_outputs[:, t, :] = decoder_output_t # 获取当前时间步的预测值
value, index = decoder_output_t.max(dim=-1)
predict_result.append(index.cpu().detach().numpy()) #[[batch],[batch]...]
decoder_input = index.unsqueeze(-1) # [batch_size,1]
# print("decoder_input:",decoder_input.size())
# predict_result.append(decoder_input)
#把结果转化为ndarray,每一行是一条预测结果
predict_result = np.array(predict_result).transpose()
return decoder_outputs, predict_result def evaluate_with_beam_search(self, encoder_hidden, encoder_outputs):
"""
使用beam search完成评估,只能输入一个句子,得到一个输出
:param encoder_hidden:
:param encoder_outputs:
:return:
"""
# 第一个时间步的输入的hidden_state
decoder_hidden = encoder_hidden # [1,batch_size,128*2]
# 第一个时间步的输入的input
batch_size = encoder_hidden.size(1)
assert batch_size == 1, "beam search的过程中,batch_size只能为1"
decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device) # [batch_size,1] prev_beam = Beam()
prev_beam.add(1, False, [decoder_input], decoder_input, decoder_hidden) while True:
cur_beam = Beam()
for prob, complete, seq_list, decoder_input, decoder_hidden in prev_beam:
if complete: # 有可能前一次已经到达eos了,但是概率不是最大的
cur_beam.add(prob, complete, seq_list, decoder_input, decoder_hidden)
else:
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden, encoder_outputs) value, index = torch.topk(decoder_output_t, config.beam_width)
# print("value index size:",value[0].size(),index[0].size())
for m, n in zip(value[0], index[0]):
# print("m,n size:",m.size(),n.size(),m,n)
cur_prob = prob * m.item()
decoder_input = torch.LongTensor([[n.item()]]).to(config.device)
cur_seq_list = seq_list + [decoder_input]
if n == config.target_ws.EOS:
cur_complete = True
else:
cur_complete = False
cur_beam.add(cur_prob, cur_complete, cur_seq_list, decoder_input, decoder_hidden) best_prob, best_complete, best_seq, _, _ = max(cur_beam)
if best_complete or len(best_seq) - 1 == config.chatbot_target_max_len: best_seq = [i.item() for i in best_seq]
if best_seq[0] == config.target_ws.SOS:
best_seq = best_seq[1:]
if best_seq[-1] == config.target_ws.EOS:
best_seq = best_seq[:-1]
return best_seq else:
prev_beam = cur_beam class Beam:
"""保存每一个时间步的数据""" def __init__(self):
self.heapq = list()
self.beam_width = config.beam_width def add(self, prob, complete, seq_list, decoder_input, decoder_hidden):
heapq.heappush(self.heapq, [prob, complete, seq_list, decoder_input, decoder_hidden])
# 保证最终只有一个beam width个结果
if len(self.heapq) > self.beam_width:
heapq.heappop(self.heapq) def __iter__(self):
for item in self.heapq:
yield item

  seq2seq.py

"""
完成seq2seq模型
"""
import torch.nn as nn
from chatbot.encoder import Encoder
from chatbot.decoder import Decoder class Seq2Seq(nn.Module):
def __init__(self):
super(Seq2Seq,self).__init__()
self.encoder = Encoder()
self.decoder = Decoder() def forward(self, input,input_len,target):
encoder_outputs,encoder_hidden = self.encoder(input,input_len)
decoder_outputs,decoder_hidden = self.decoder(encoder_hidden,target,encoder_outputs)
return decoder_outputs def evaluate(self,input,input_len):
encoder_outputs, encoder_hidden = self.encoder(input, input_len)
decoder_outputs, predict_result = self.decoder.evaluate(encoder_hidden,encoder_outputs)
return decoder_outputs,predict_result def evaluate_with_beam_search(self,input,input_len):
encoder_outputs, encoder_hidden = self.encoder(input, input_len)
best_seq = self.decoder.evaluate_with_beam_search(encoder_hidden, encoder_outputs)
return best_seq

  eval.py

"""
进行模型的评估
""" import torch
import torch.nn.functional as F
from chatbot.dataset import get_dataloader
from tqdm import tqdm
import config
import numpy as np
import pickle
from chatbot.seq2seq import Seq2Seq def eval():
model = Seq2Seq().to(config.device)
model.eval()
model.load_state_dict(torch.load("./models/model.pkl")) loss_list = []
data_loader = get_dataloader(train=False)
bar = tqdm(data_loader,total=len(data_loader),desc="当前进行评估")
with torch.no_grad():
for idx,(input,target,input_len,target_len) in enumerate(bar):
input = input.to(config.device)
target = target.to(config.device)
input_len = input_len.to(config.device) decoder_outputs,predict_result = model.evaluate(input,input_len) #[batch_Size,max_len,vocab_size]
loss = F.nll_loss(decoder_outputs.view(-1,len(config.target_ws)),target.view(-1),ignore_index=config.input_ws.PAD)
loss_list.append(loss.item())
bar.set_description("idx:{} loss:{:.6f}".format(idx,np.mean(loss_list)))
print("当前的平均损失为:",np.mean(loss_list)) def interface():
from chatbot.cut_sentence import cut
import config
#加载模型
model = Seq2Seq().to(config.device)
model.eval()
model.load_state_dict(torch.load("./models/model.pkl")) #准备待预测的数据
while True:
origin_input =input("me>>:")
# if "你是谁" in origin_input or "你叫什么" in origin_input:
# result = "我是小智。"
# elif "你好" in origin_input or "hello" in origin_input:
# result = "Hello"
# else:
_input = cut(origin_input, by_word=True)
input_len = torch.LongTensor([len(_input)]).to(config.device)
_input = torch.LongTensor([config.input_ws.transform(_input,max_len=config.chatbot_input_max_len)]).to(config.device) outputs,predict = model.evaluate(_input,input_len)
result = config.target_ws.inverse_transform(predict[0])
print("chatbot>>:",result) def interface_with_beamsearch():
from chatbot.cut_sentence import cut
import config
# 加载模型
model = Seq2Seq().to(config.device)
model.eval()
model.load_state_dict(torch.load("./models/model.pkl")) # 准备待预测的数据
while True:
origin_input = input("me>>:")
_input = cut(origin_input, by_word=True)
input_len = torch.LongTensor([len(_input)]).to(config.device)
_input = torch.LongTensor([config.input_ws.transform(_input, max_len=config.chatbot_input_max_len)]).to(
config.device) best_seq = model.evaluate_with_beam_search(_input, input_len)
result = config.target_ws.inverse_transform(best_seq)
print("chatbot>>:", result) if __name__ == '__main__':
# interface()
interface_with_beamsearch()

  

pytorch seq2seq闲聊机器人beam search返回结果的更多相关文章

  1. pytorch seq2seq闲聊机器人

    cut_sentence.py """ 实现句子的分词 注意点: 1. 实现单个字分词 2. 实现按照词语分词 2.1 加载词典 3. 使用停用词 "" ...

  2. pytorch seq2seq闲聊机器人加入attention机制

    attention.py """ 实现attention """ import torch import torch.nn as nn im ...

  3. 实现nlp文本生成中的beam search解码器

    自然语言处理任务,比如caption generation(图片描述文本生成).机器翻译中,都需要进行词或者字符序列的生成.常见于seq2seq模型或者RNNLM模型中. 这篇博文主要介绍文本生成解码 ...

  4. Beam Search快速理解及代码解析(下)

    Beam Search的问题 先解释一下什么要对Beam Search进行改进.因为Beam Search虽然比贪心强了不少,但还是会生成出空洞.重复.前后矛盾的文本.如果你有文本生成经验,一定对这些 ...

  5. Beam Search快速理解及代码解析

    目录 Beam Search快速理解及代码解析(上) Beam Search 贪心搜索 Beam Search Beam Search代码解析 准备初始输入 序列扩展 准备输出 总结 Beam Sea ...

  6. 【NLP】选择目标序列:贪心搜索和Beam search

    构建seq2seq模型,并训练完成后,我们只要将源句子输入进训练好的模型,执行一次前向传播就能得到目标句子,但是值得注意的是: seq2seq模型的decoder部分实际上相当于一个语言模型,相比于R ...

  7. Beam Search快速理解及代码解析(上)

    Beam Search 简单介绍一下在文本生成任务中常用的解码策略Beam Search(集束搜索). 生成式任务相比普通的分类.tagging等NLP任务会复杂不少.在生成的时候,模型的输出是一个时 ...

  8. Beam Search(集束搜索/束搜索)

    找遍百度也没有找到关于Beam Search的详细解释,只有一些比较泛泛的讲解,于是有了这篇博文. 首先给出wiki地址:http://en.wikipedia.org/wiki/Beam_searc ...

  9. 关于Beam Search

    Wiki定义:In computer science, beam search is a heuristic search algorithm that explores a graph by exp ...

随机推荐

  1. POJ - 3255 SPFA+邻接表求次短路径

    题意:给出m条边 , n个顶点,u [ i ]到v [ i ] 的距离w [ i ],求除了最短路的那条最短的边的长度. 思路:之前有做过相似的题,使用迪杰斯特拉算法求单源最短路径,并且记录路径,枚举 ...

  2. Web Scraper 性能测试 (-_-)

    刚在研究 Python 爬虫的时候,看到了个小白工具,叫 Web Scraper,于是来测试下好不好用. Web Scraper 是什么? 它是一个谷歌浏览器的插件, 用于批量抓取网页信息, 主要特点 ...

  3. 【强烈推荐】适合Flutter初学者的完整项目

    简介 Flutter Fly是什么?Flutter Fly是一款开源的Flutter 项目,非常适合初学者进行学习.App内集成了160+Flutter基础控件的详细介绍及用法,内容来源于:http: ...

  4. Python python对象 deque

    # deque对象 ''' class collections.deque([ iterable [,maxlen ] ] ) 返回一个从左到右(使用append())初始化的新deque对象,其中包 ...

  5. ArrayList中的Iterator详解

    每个实现Iterable接口的类必须提供一个iterator方法,返回一个Iterator对象,ArrayList也不例外 public Iterator<E> iterator() { ...

  6. 《JAVA与模式》之责任链模式 【转载】

    转载自java_my_life的博客 原文地址:http://www.cnblogs.com/java-my-life/archive/2012/05/28/2516865.html 在阎宏博士的&l ...

  7. Unity 游戏框架搭建 2019 (二十五) 类的第一个作用 与 Obselete 属性

    在上一篇我们整理到了第七个示例,我们今天再接着往下整理.我们来看第八个示例: #if UNITY_EDITOR using UnityEditor; #endif using UnityEngine; ...

  8. 新安装的eclipse配置好了环境变量后,打开还是出现A Java runtime environment错误

    新安装的eclipse配置好了环境变量后,打开还是出现如下图的A Java runtime environment错误; 解决方法: 第一步: Windows环境下:把C:\Users\你的用户名 目 ...

  9. 安装一个KVM服务器

                                                              安装一个KVM服务器     案例1:安装一个KVM服务器 案例2:KVM平台构建及 ...

  10. Linux基础:Day02

    Linux文件管理 创建 touch vim/vi echo重定向 touch 管理:atime mtime ctime touch 文件名  //如果文件不存在就创建文件 touch -a -t [ ...