pytorch seq2seq闲聊机器人beam search返回结果
decoder.py
"""
实现解码器
"""
import heapq import torch.nn as nn
import config
import torch
import torch.nn.functional as F
import numpy as np
import random
from chatbot.attention import Attention class Decoder(nn.Module):
def __init__(self):
super(Decoder,self).__init__() self.embedding = nn.Embedding(num_embeddings=len(config.target_ws),
embedding_dim=config.chatbot_decoder_embedding_dim,
padding_idx=config.target_ws.PAD) #需要的hidden_state形状:[1,batch_size,64]
self.gru = nn.GRU(input_size=config.chatbot_decoder_embedding_dim,
hidden_size=config.chatbot_decoder_hidden_size,
num_layers=config.chatbot_decoder_number_layer,
bidirectional=False,
batch_first=True,
dropout=config.chatbot_decoder_dropout) #假如encoder的hidden_size=64,num_layer=1 encoder_hidden :[2,batch_sizee,64] self.fc = nn.Linear(config.chatbot_decoder_hidden_size,len(config.target_ws))
self.attn = Attention(method="general")
self.fc_attn = nn.Linear(config.chatbot_decoder_hidden_size * 2, config.chatbot_decoder_hidden_size, bias=False) def forward(self, encoder_hidden,target,encoder_outputs):
# print("target size:",target.size())
#第一个时间步的输入的hidden_state
decoder_hidden = encoder_hidden #[1,batch_size,128*2]
#第一个时间步的输入的input
batch_size = encoder_hidden.size(1)
decoder_input = torch.LongTensor([[config.target_ws.SOS]]*batch_size).to(config.device) #[batch_size,1]
# print("decoder_input:",decoder_input.size()) #使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
decoder_outputs = torch.zeros([batch_size,config.chatbot_target_max_len,len(config.target_ws)]).to(config.device) if random.random()>0.5: #teacher_forcing机制 for t in range(config.chatbot_target_max_len):
decoder_output_t,decoder_hidden = self.forward_step(decoder_input,decoder_hidden,encoder_outputs)
decoder_outputs[:,t,:] = decoder_output_t #获取当前时间步的预测值
value,index = decoder_output_t.max(dim=-1)
decoder_input = index.unsqueeze(-1) #[batch_size,1]
# print("decoder_input:",decoder_input.size())
else:
for t in range(config.chatbot_target_max_len):
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden,encoder_outputs)
decoder_outputs[:, t, :] = decoder_output_t
#把真实值作为下一步的输入
decoder_input = target[:,t].unsqueeze(-1)
# print("decoder_input size:",decoder_input.size())
return decoder_outputs,decoder_hidden def forward_step(self,decoder_input,decoder_hidden,encoder_outputs):
'''
计算一个时间步的结果
:param decoder_input: [batch_size,1]
:param decoder_hidden: [1,batch_size,128*2]
:return:
''' decoder_input_embeded = self.embedding(decoder_input)
# print("decoder_input_embeded:",decoder_input_embeded.size()) #out:[batch_size,1,128*2]
#decoder_hidden :[1,bathc_size,128*2]
# print(decoder_hidden.size())
out,decoder_hidden = self.gru(decoder_input_embeded,decoder_hidden) ##### 开始attention ############
### 1. 计算attention weight
attn_weight = self.attn(decoder_hidden,encoder_outputs) #[batch_size,1,encoder_max_len]
### 2. 计算context vector
#encoder_ouputs :[batch_size,encoder_max_len,128*2]
context_vector = torch.bmm(attn_weight.unsqueeze(1),encoder_outputs).squeeze(1) #[batch_szie,128*2]
### 3. 计算 attention的结果
#[batch_size,128*2] #context_vector:[batch_size,128*2] --> 128*4
#attention_result = [batch_size,128*4] --->[batch_size,128*2]
attention_result = torch.tanh(self.fc_attn(torch.cat([context_vector,out.squeeze(1)],dim=-1)))
# attention_result = torch.tanh(torch.cat([context_vector,out.squeeze(1)],dim=-1))
#### attenion 结束 # print("decoder_hidden size:",decoder_hidden.size())
#out :【batch_size,1,hidden_size】 # out_squeezed = out.squeeze(dim=1) #去掉为1的维度
out_fc = F.log_softmax(self.fc(attention_result),dim=-1) #[bathc_size,vocab_size]
# print("out_fc:",out_fc.size())
return out_fc,decoder_hidden def evaluate(self,encoder_hidden,encoder_outputs): # 第一个时间步的输入的hidden_state
decoder_hidden = encoder_hidden # [1,batch_size,128*2]
# 第一个时间步的输入的input
batch_size = encoder_hidden.size(1)
decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device) # [batch_size,1]
# print("decoder_input:",decoder_input.size()) # 使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
decoder_outputs = torch.zeros([batch_size, config.chatbot_target_max_len, len(config.target_ws)]).to(
config.device) predict_result = []
for t in range(config.chatbot_target_max_len):
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden,encoder_outputs)
decoder_outputs[:, t, :] = decoder_output_t # 获取当前时间步的预测值
value, index = decoder_output_t.max(dim=-1)
predict_result.append(index.cpu().detach().numpy()) #[[batch],[batch]...]
decoder_input = index.unsqueeze(-1) # [batch_size,1]
# print("decoder_input:",decoder_input.size())
# predict_result.append(decoder_input)
#把结果转化为ndarray,每一行是一条预测结果
predict_result = np.array(predict_result).transpose()
return decoder_outputs, predict_result def evaluate_with_beam_search(self, encoder_hidden, encoder_outputs):
"""
使用beam search完成评估,只能输入一个句子,得到一个输出
:param encoder_hidden:
:param encoder_outputs:
:return:
"""
# 第一个时间步的输入的hidden_state
decoder_hidden = encoder_hidden # [1,batch_size,128*2]
# 第一个时间步的输入的input
batch_size = encoder_hidden.size(1)
assert batch_size == 1, "beam search的过程中,batch_size只能为1"
decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device) # [batch_size,1] prev_beam = Beam()
prev_beam.add(1, False, [decoder_input], decoder_input, decoder_hidden) while True:
cur_beam = Beam()
for prob, complete, seq_list, decoder_input, decoder_hidden in prev_beam:
if complete: # 有可能前一次已经到达eos了,但是概率不是最大的
cur_beam.add(prob, complete, seq_list, decoder_input, decoder_hidden)
else:
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden, encoder_outputs) value, index = torch.topk(decoder_output_t, config.beam_width)
# print("value index size:",value[0].size(),index[0].size())
for m, n in zip(value[0], index[0]):
# print("m,n size:",m.size(),n.size(),m,n)
cur_prob = prob * m.item()
decoder_input = torch.LongTensor([[n.item()]]).to(config.device)
cur_seq_list = seq_list + [decoder_input]
if n == config.target_ws.EOS:
cur_complete = True
else:
cur_complete = False
cur_beam.add(cur_prob, cur_complete, cur_seq_list, decoder_input, decoder_hidden) best_prob, best_complete, best_seq, _, _ = max(cur_beam)
if best_complete or len(best_seq) - 1 == config.chatbot_target_max_len: best_seq = [i.item() for i in best_seq]
if best_seq[0] == config.target_ws.SOS:
best_seq = best_seq[1:]
if best_seq[-1] == config.target_ws.EOS:
best_seq = best_seq[:-1]
return best_seq else:
prev_beam = cur_beam class Beam:
"""保存每一个时间步的数据""" def __init__(self):
self.heapq = list()
self.beam_width = config.beam_width def add(self, prob, complete, seq_list, decoder_input, decoder_hidden):
heapq.heappush(self.heapq, [prob, complete, seq_list, decoder_input, decoder_hidden])
# 保证最终只有一个beam width个结果
if len(self.heapq) > self.beam_width:
heapq.heappop(self.heapq) def __iter__(self):
for item in self.heapq:
yield item
seq2seq.py
"""
完成seq2seq模型
"""
import torch.nn as nn
from chatbot.encoder import Encoder
from chatbot.decoder import Decoder class Seq2Seq(nn.Module):
def __init__(self):
super(Seq2Seq,self).__init__()
self.encoder = Encoder()
self.decoder = Decoder() def forward(self, input,input_len,target):
encoder_outputs,encoder_hidden = self.encoder(input,input_len)
decoder_outputs,decoder_hidden = self.decoder(encoder_hidden,target,encoder_outputs)
return decoder_outputs def evaluate(self,input,input_len):
encoder_outputs, encoder_hidden = self.encoder(input, input_len)
decoder_outputs, predict_result = self.decoder.evaluate(encoder_hidden,encoder_outputs)
return decoder_outputs,predict_result def evaluate_with_beam_search(self,input,input_len):
encoder_outputs, encoder_hidden = self.encoder(input, input_len)
best_seq = self.decoder.evaluate_with_beam_search(encoder_hidden, encoder_outputs)
return best_seq
eval.py
"""
进行模型的评估
""" import torch
import torch.nn.functional as F
from chatbot.dataset import get_dataloader
from tqdm import tqdm
import config
import numpy as np
import pickle
from chatbot.seq2seq import Seq2Seq def eval():
model = Seq2Seq().to(config.device)
model.eval()
model.load_state_dict(torch.load("./models/model.pkl")) loss_list = []
data_loader = get_dataloader(train=False)
bar = tqdm(data_loader,total=len(data_loader),desc="当前进行评估")
with torch.no_grad():
for idx,(input,target,input_len,target_len) in enumerate(bar):
input = input.to(config.device)
target = target.to(config.device)
input_len = input_len.to(config.device) decoder_outputs,predict_result = model.evaluate(input,input_len) #[batch_Size,max_len,vocab_size]
loss = F.nll_loss(decoder_outputs.view(-1,len(config.target_ws)),target.view(-1),ignore_index=config.input_ws.PAD)
loss_list.append(loss.item())
bar.set_description("idx:{} loss:{:.6f}".format(idx,np.mean(loss_list)))
print("当前的平均损失为:",np.mean(loss_list)) def interface():
from chatbot.cut_sentence import cut
import config
#加载模型
model = Seq2Seq().to(config.device)
model.eval()
model.load_state_dict(torch.load("./models/model.pkl")) #准备待预测的数据
while True:
origin_input =input("me>>:")
# if "你是谁" in origin_input or "你叫什么" in origin_input:
# result = "我是小智。"
# elif "你好" in origin_input or "hello" in origin_input:
# result = "Hello"
# else:
_input = cut(origin_input, by_word=True)
input_len = torch.LongTensor([len(_input)]).to(config.device)
_input = torch.LongTensor([config.input_ws.transform(_input,max_len=config.chatbot_input_max_len)]).to(config.device) outputs,predict = model.evaluate(_input,input_len)
result = config.target_ws.inverse_transform(predict[0])
print("chatbot>>:",result) def interface_with_beamsearch():
from chatbot.cut_sentence import cut
import config
# 加载模型
model = Seq2Seq().to(config.device)
model.eval()
model.load_state_dict(torch.load("./models/model.pkl")) # 准备待预测的数据
while True:
origin_input = input("me>>:")
_input = cut(origin_input, by_word=True)
input_len = torch.LongTensor([len(_input)]).to(config.device)
_input = torch.LongTensor([config.input_ws.transform(_input, max_len=config.chatbot_input_max_len)]).to(
config.device) best_seq = model.evaluate_with_beam_search(_input, input_len)
result = config.target_ws.inverse_transform(best_seq)
print("chatbot>>:", result) if __name__ == '__main__':
# interface()
interface_with_beamsearch()
pytorch seq2seq闲聊机器人beam search返回结果的更多相关文章
- pytorch seq2seq闲聊机器人
cut_sentence.py """ 实现句子的分词 注意点: 1. 实现单个字分词 2. 实现按照词语分词 2.1 加载词典 3. 使用停用词 "" ...
- pytorch seq2seq闲聊机器人加入attention机制
attention.py """ 实现attention """ import torch import torch.nn as nn im ...
- 实现nlp文本生成中的beam search解码器
自然语言处理任务,比如caption generation(图片描述文本生成).机器翻译中,都需要进行词或者字符序列的生成.常见于seq2seq模型或者RNNLM模型中. 这篇博文主要介绍文本生成解码 ...
- Beam Search快速理解及代码解析(下)
Beam Search的问题 先解释一下什么要对Beam Search进行改进.因为Beam Search虽然比贪心强了不少,但还是会生成出空洞.重复.前后矛盾的文本.如果你有文本生成经验,一定对这些 ...
- Beam Search快速理解及代码解析
目录 Beam Search快速理解及代码解析(上) Beam Search 贪心搜索 Beam Search Beam Search代码解析 准备初始输入 序列扩展 准备输出 总结 Beam Sea ...
- 【NLP】选择目标序列:贪心搜索和Beam search
构建seq2seq模型,并训练完成后,我们只要将源句子输入进训练好的模型,执行一次前向传播就能得到目标句子,但是值得注意的是: seq2seq模型的decoder部分实际上相当于一个语言模型,相比于R ...
- Beam Search快速理解及代码解析(上)
Beam Search 简单介绍一下在文本生成任务中常用的解码策略Beam Search(集束搜索). 生成式任务相比普通的分类.tagging等NLP任务会复杂不少.在生成的时候,模型的输出是一个时 ...
- Beam Search(集束搜索/束搜索)
找遍百度也没有找到关于Beam Search的详细解释,只有一些比较泛泛的讲解,于是有了这篇博文. 首先给出wiki地址:http://en.wikipedia.org/wiki/Beam_searc ...
- 关于Beam Search
Wiki定义:In computer science, beam search is a heuristic search algorithm that explores a graph by exp ...
随机推荐
- ||,&&短路规则测试
短路规则:a||b中若a为真,则直接判断整个表达式为真,不再判断b是真或假, a&&b中若a为假,则直接判断整个表达式为假,不再单独判断b是真或假. 想要测试这个规则的话,可以将 ...
- 面向对象(OO)第一阶段学习总结
前言:对OO本阶段作业情况说明 本阶段一共完成三次作业,第一次主要是在主方法里面进行编程,也就是和之前C差不多,而随着学习的深入,慢慢了解到面向对象与面向过程的区别.作业的难度也在慢慢增大,后两次都用 ...
- 关于 IDEA 启动 springboot 项目异常 - Disconnected from the target VM, address: '127.0.0.1:59770', transport: 'socket'
关于 IDEA 启动 springboot 项目异常 - Disconnected from the target VM, address: '127.0.0.1:59770', transport: ...
- Python面向对象之异常处理
1:什么是异常 异常就是在我们的程序在运行过程中由于某种错误而引发Python抛出的错误: 异常就是程序运行时发生错误的信号(在程序出现错误时,则会产生一个异常,若程序没有处理它,则会抛出该异常,程序 ...
- Oracle 数据库创建导入
Oracle 数据库创建导入 由 Alma 创建, 最后一次修改 2018-06-04 14:37:50 在本章教程中,将教大家如何在Oracle 中创建导入数据库. 注意:本教程中的有些命令您可能并 ...
- echarts以地图形式显示中国疫情情况实现点击省份下钻
首先要导入对应的包.下钻用到各个省份的json文件等内容导入之后进行相关的操作. 首先是从数据库中读取相应的数据文件.通过list方式.只有在ser出转化为json文件.在jsp页面通过ajax来进行 ...
- django rest framework用户认证
django rest framework用户认证 进入rest framework的Apiview @classmethod def as_view(cls, **initkwargs): &quo ...
- linux中忘记了mysql的root用户的密码怎么办?
1.vim /etc/my.cnf skip-grant-tables #取消此行的注释 2.重启mysql systemctl restart mysqld 3.mysql 登陆mysql mys ...
- asap异步执行实现原理
目录 为什么分析asap asap概述 asap源码解析-Node版 参考 1.为什么分析asap 在之前的文章 async和await是如何实现异步编程? 中的浅谈Promise如何实现异步执行小节 ...
- Markdown语法详解-cnblog
博客的重要性 博客,英文名为Blog,它的正式名称为网络日记. 为什么要写博客? 需要总结和思考.有时候我们一直在赶路,却忘了放慢脚步 提升文笔组织能力 提升学习总结能力 提升逻辑思维能力 帮助他人, ...