attention.py

"""
实现attention
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import config class Attention(nn.Module):
def __init__(self,method="general"):
super(Attention,self).__init__()
assert method in ["dot","general","concat"],"attention method error"
self.method = method
if method == "general":
self.W = nn.Linear(config.chatbot_encoder_hidden_size*2,config.chatbot_encoder_hidden_size*2,bias=False) if method == "concat":
self.W = nn.Linear(config.chatbot_decoder_hidden_size*4,config.chatbot_decoder_hidden_size*2,bias=False)
self.V = nn.Linear(config.chatbot_decoder_hidden_size*2,1,bias=False) def forward(self,decoder_hidden,encoder_outputs):
if self.method == "dot":
return self.dot_score(decoder_hidden,encoder_outputs) elif self.method == "general":
return self.general_socre(decoder_hidden,encoder_outputs) elif self.method == "concat":
return self.concat_socre(decoder_hidden,encoder_outputs) def dot_score(self,decoder_hidden,encoder_outputs):
"""H_t^T * H_s
:param decoder_hidden:[1,batch_size,128*2] --->[batch_size,128*2,1]
:param encoder_outputs:[batch_size,encoder_max_len,128*2] --->[batch_size,encoder_max_len,128*2]
:return:attention_weight:[batch_size,encoder_max_len]
"""
decoder_hidden_viewed = decoder_hidden.squeeze(0).unsqueeze(-1) #[batch_size,128*2,1]
attention_weight = torch.bmm(encoder_outputs,decoder_hidden_viewed).squeeze(-1)
return F.softmax(attention_weight,dim=-1) def general_socre(self,decoder_hidden,encoder_outputs):
"""
H_t^T *W* H_s
:param decoder_hidden:[1,batch_size,128*2]-->[batch_size,decode_hidden_size] *[decoder_hidden_size,encoder_hidden_size]--->[batch_size,encoder_hidden_size]
:param encoder_outputs:[batch_size,encoder_max_len,128*2]
:return:[batch_size,encoder_max_len]
"""
decoder_hidden_processed =self.W(decoder_hidden.squeeze(0)).unsqueeze(-1) #[batch_size,encoder_hidden_size*2,1]
attention_weight = torch.bmm(encoder_outputs, decoder_hidden_processed).squeeze(-1)
return F.softmax(attention_weight, dim=-1) def concat_socre(self,decoder_hidden,encoder_outputs):
"""
V*tanh(W[H_t,H_s])
:param decoder_hidden:[1,batch_size,128*2]
:param encoder_outputs:[batch_size,encoder_max_len,128*2]
:return:[batch_size,encoder_max_len]
"""
#1. decoder_hidden:[batch_size,128*2] ----> [batch_size,encoder_max_len,128*2]
# encoder_max_len 个[batch_size,128*2] -->[encoder_max_len,bathc_size,128*2] -->transpose--->[]
encoder_max_len = encoder_outputs.size(1)
batch_size = encoder_outputs.size(0)
decoder_hidden_repeated = decoder_hidden.squeeze(0).repeat(encoder_max_len,1,1).transpose(0,1) #[batch_size,max_len,128*2]
h_cated = torch.cat([decoder_hidden_repeated,encoder_outputs],dim=-1).view([batch_size*encoder_max_len,-1]) #[batch_size*max_len,128*4]
attention_weight = self.V(F.tanh(self.W(h_cated))).view([batch_size,encoder_max_len]) #[batch_size*max_len,1]
return F.softmax(attention_weight,dim=-1)

  decoder.py

"""
实现解码器
"""
import torch.nn as nn
import config
import torch
import torch.nn.functional as F
import numpy as np
import random
from chatbot.attention import Attention class Decoder(nn.Module):
def __init__(self):
super(Decoder,self).__init__() self.embedding = nn.Embedding(num_embeddings=len(config.target_ws),
embedding_dim=config.chatbot_decoder_embedding_dim,
padding_idx=config.target_ws.PAD) #需要的hidden_state形状:[1,batch_size,64]
self.gru = nn.GRU(input_size=config.chatbot_decoder_embedding_dim,
hidden_size=config.chatbot_decoder_hidden_size,
num_layers=config.chatbot_decoder_number_layer,
bidirectional=False,
batch_first=True,
dropout=config.chatbot_decoder_dropout) #假如encoder的hidden_size=64,num_layer=1 encoder_hidden :[2,batch_sizee,64] self.fc = nn.Linear(config.chatbot_decoder_hidden_size,len(config.target_ws))
self.attn = Attention(method="general")
self.fc_attn = nn.Linear(config.chatbot_decoder_hidden_size * 2, config.chatbot_decoder_hidden_size, bias=False) def forward(self, encoder_hidden,target,encoder_outputs):
# print("target size:",target.size())
#第一个时间步的输入的hidden_state
decoder_hidden = encoder_hidden #[1,batch_size,128*2]
#第一个时间步的输入的input
batch_size = encoder_hidden.size(1)
decoder_input = torch.LongTensor([[config.target_ws.SOS]]*batch_size).to(config.device) #[batch_size,1]
# print("decoder_input:",decoder_input.size()) #使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
decoder_outputs = torch.zeros([batch_size,config.chatbot_target_max_len,len(config.target_ws)]).to(config.device) if random.random()>0.5: #teacher_forcing机制 for t in range(config.chatbot_target_max_len):
decoder_output_t,decoder_hidden = self.forward_step(decoder_input,decoder_hidden,encoder_outputs)
decoder_outputs[:,t,:] = decoder_output_t #获取当前时间步的预测值
value,index = decoder_output_t.max(dim=-1)
decoder_input = index.unsqueeze(-1) #[batch_size,1]
# print("decoder_input:",decoder_input.size())
else:
for t in range(config.chatbot_target_max_len):
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden,encoder_outputs)
decoder_outputs[:, t, :] = decoder_output_t
#把真实值作为下一步的输入
decoder_input = target[:,t].unsqueeze(-1)
# print("decoder_input size:",decoder_input.size())
return decoder_outputs,decoder_hidden def forward_step(self,decoder_input,decoder_hidden,encoder_outputs):
'''
计算一个时间步的结果
:param decoder_input: [batch_size,1]
:param decoder_hidden: [1,batch_size,128*2]
:return:
''' decoder_input_embeded = self.embedding(decoder_input)
# print("decoder_input_embeded:",decoder_input_embeded.size()) #out:[batch_size,1,128*2]
#decoder_hidden :[1,bathc_size,128*2]
# print(decoder_hidden.size())
out,decoder_hidden = self.gru(decoder_input_embeded,decoder_hidden) ##### 开始attention ############
### 1. 计算attention weight
attn_weight = self.attn(decoder_hidden,encoder_outputs) #[batch_size,1,encoder_max_len]
### 2. 计算context vector
#encoder_ouputs :[batch_size,encoder_max_len,128*2]
context_vector = torch.bmm(attn_weight.unsqueeze(1),encoder_outputs).squeeze(1) #[batch_szie,128*2]
### 3. 计算 attention的结果
#[batch_size,128*2] #context_vector:[batch_size,128*2] --> 128*4
#attention_result = [batch_size,128*4] --->[batch_size,128*2]
attention_result = torch.tanh(self.fc_attn(torch.cat([context_vector,out.squeeze(1)],dim=-1)))
# attention_result = torch.tanh(torch.cat([context_vector,out.squeeze(1)],dim=-1))
#### attenion 结束 # print("decoder_hidden size:",decoder_hidden.size())
#out :【batch_size,1,hidden_size】 # out_squeezed = out.squeeze(dim=1) #去掉为1的维度
out_fc = F.log_softmax(self.fc(attention_result),dim=-1) #[bathc_size,vocab_size]
# print("out_fc:",out_fc.size())
return out_fc,decoder_hidden def evaluate(self,encoder_hidden,encoder_outputs): # 第一个时间步的输入的hidden_state
decoder_hidden = encoder_hidden # [1,batch_size,128*2]
# 第一个时间步的输入的input
batch_size = encoder_hidden.size(1)
decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device) # [batch_size,1]
# print("decoder_input:",decoder_input.size()) # 使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
decoder_outputs = torch.zeros([batch_size, config.chatbot_target_max_len, len(config.target_ws)]).to(
config.device) predict_result = []
for t in range(config.chatbot_target_max_len):
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden,encoder_outputs)
decoder_outputs[:, t, :] = decoder_output_t # 获取当前时间步的预测值
value, index = decoder_output_t.max(dim=-1)
predict_result.append(index.cpu().detach().numpy()) #[[batch],[batch]...]
decoder_input = index.unsqueeze(-1) # [batch_size,1]
# print("decoder_input:",decoder_input.size())
# predict_result.append(decoder_input)
#把结果转化为ndarray,每一行是一条预测结果
predict_result = np.array(predict_result).transpose()
return decoder_outputs, predict_result

  seq2seq.py

"""
完成seq2seq模型
"""
import torch.nn as nn
from chatbot.encoder import Encoder
from chatbot.decoder import Decoder class Seq2Seq(nn.Module):
def __init__(self):
super(Seq2Seq,self).__init__()
self.encoder = Encoder()
self.decoder = Decoder() def forward(self, input,input_len,target):
encoder_outputs,encoder_hidden = self.encoder(input,input_len)
decoder_outputs,decoder_hidden = self.decoder(encoder_hidden,target,encoder_outputs)
return decoder_outputs def evaluate(self,input,input_len):
encoder_outputs, encoder_hidden = self.encoder(input, input_len)
decoder_outputs, predict_result = self.decoder.evaluate(encoder_hidden,encoder_outputs)
return decoder_outputs,predict_result

  

pytorch seq2seq闲聊机器人加入attention机制的更多相关文章

  1. pytorch seq2seq闲聊机器人beam search返回结果

    decoder.py """ 实现解码器 """ import heapq import torch.nn as nn import con ...

  2. pytorch seq2seq闲聊机器人

    cut_sentence.py """ 实现句子的分词 注意点: 1. 实现单个字分词 2. 实现按照词语分词 2.1 加载词典 3. 使用停用词 "" ...

  3. pytorch seq2seq模型中加入teacher_forcing机制

    在循环内加的teacher forcing机制,这种为目标确定的时候,可以这样加. 目标不确定,需要在循环外加. decoder.py 中的修改 """ 实现解码器 &q ...

  4. 深度学习之seq2seq模型以及Attention机制

    RNN,LSTM,seq2seq等模型广泛用于自然语言处理以及回归预测,本期详解seq2seq模型以及attention机制的原理以及在回归预测方向的运用. 1. seq2seq模型介绍 seq2se ...

  5. pytorch笔记:09)Attention机制

    刚从图像处理的hole中攀爬出来,刚走一步竟掉到了另一个hole(fire in the hole*▽*) 1.RNN中的attentionpytorch官方教程:https://pytorch.or ...

  6. DL4NLP —— seq2seq+attention机制的应用:文档自动摘要(Automatic Text Summarization)

    两周以前读了些文档自动摘要的论文,并针对其中两篇( [2] 和 [3] )做了presentation.下面把相关内容简单整理一下. 文本自动摘要(Automatic Text Summarizati ...

  7. 完全图解RNN、RNN变体、Seq2Seq、Attention机制

    完全图解RNN.RNN变体.Seq2Seq.Attention机制 本文主要是利用图片的形式,详细地介绍了经典的RNN.RNN几个重要变体,以及Seq2Seq模型.Attention机制.希望这篇文章 ...

  8. 深度学习中的序列模型演变及学习笔记(含RNN/LSTM/GRU/Seq2Seq/Attention机制)

    [说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![认真看图][认真看图] [补充说明]深度学习中的序列模型已经广泛应用于自然语言处理(例如机器翻 ...

  9. Multimodal —— 看图说话(Image Caption)任务的论文笔记(三)引入视觉哨兵的自适应attention机制

    在此前的两篇博客中所介绍的两个论文,分别介绍了encoder-decoder框架以及引入attention之后在Image Caption任务上的应用. 这篇博客所介绍的文章所考虑的是生成captio ...

随机推荐

  1. 关于getchar的疑惑

    最近做了一道题,我的代码有片段是这样的 while(scanf("%d",&n)) { if(n==0&&getchar()=='\n') break; . ...

  2. Java工程师常用Linux命令

    本文所列的Linux常用命令包含:文件相关(目录操作,内容查看,查找与比较,压缩与解压),进程管理,网络操作,系统管理,性能监测与优化,Java常用工具多个方面概述. 文件目录基本操作 ls 命令用来 ...

  3. Python第十二章-多进程和多线程01-多进程

    多进程和多线程 一.进程 1.1 进程的引入 现实生活中,有很多的场景中的事情是同时进行的,比如开车的时候 手和脚共同来驾驶汽车,再比如唱歌跳舞也是同时进行的:试想,如果把唱歌和跳舞这2件事情分开依次 ...

  4. [codevs1227]草地排水<Dinic网络流最大流>

    题目链接:http://codevs.cn/problem/1993/ https://www.luogu.org/problemnew/show/P2740 之前一直都没去管网络流这算法,但是老师最 ...

  5. python opencv 图片缺陷检测(讲解直方图以及相关系数对比法)

    一.利用直方图的方式进行批量的图片缺陷检测(方法简单) 二.步骤(完整代码见最后) 2.1灰度转换(将原图和要检测对比的图分开灰度化) 灰度化的作用是因为后面的直方图比较需要以像素256为基准进行相关 ...

  6. 如何编写优雅的异步代码 — CompletableFuture

    前言 在我们的意识里,同步执行的程序都比较符合人们的思维方式,而异步的东西通常都不好处理.在异步计算的情况下,以回调表示的动作往往会分散在代码中,也可能相互嵌套在内部,如果需要处理其中一个步骤中可能发 ...

  7. 物体的三维识别与6D位姿估计:PPF系列论文介绍(三)

    作者:仲夏夜之星 Date:2020-04-08 来源:物体的三维识别与6D位姿估计:PPF系列论文介绍(三) 文章“A Method for 6D Pose Estimation of Free-F ...

  8. 《Three.js 入门指南》3.1.1 - 基本几何形状 - 球体(SphereGeometry)

    3.1 基本几何形状 球体(SphereGeometry) 构造函数: THREE.SphereGeometry(radius, segmentsWidth, segmentsHeight, phiS ...

  9. 1030 Travel Plan (30分)(dijkstra 具有多种决定因素)

    A traveler's map gives the distances between cities along the highways, together with the cost of ea ...

  10. javaweb添加学生信息

    连接数据库已经进行判断 要求: 1登录账号:要求由6到12位字母.数字.下划线组成,只有字母可以开头:(1分) 2登录密码:要求显示“• ”或“*”表示输入位数,密码要求八位以上字母.数字组成.(1分 ...