在循环内加的teacher forcing机制,这种为目标确定的时候,可以这样加。

目标不确定,需要在循环外加。

decoder.py 中的修改

"""
实现解码器
"""
import torch.nn as nn
import config
import torch
import torch.nn.functional as F
import numpy as np
import random class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__() self.embedding = nn.Embedding(num_embeddings=len(config.ns),
embedding_dim=50,
padding_idx=config.ns.PAD) # 需要的hidden_state形状:[1,batch_size,64]
self.gru = nn.GRU(input_size=50,
hidden_size=64,
num_layers=1,
bidirectional=False,
batch_first=True,
dropout=0) # 假如encoder的hidden_size=64,num_layer=1 encoder_hidden :[2,batch_sizee,64] self.fc = nn.Linear(64, len(config.ns)) def forward(self, encoder_hidden,target): # 第一个时间步的输入的hidden_state
decoder_hidden = encoder_hidden # [1,batch_size,encoder_hidden_size]
# 第一个时间步的输入的input
batch_size = encoder_hidden.size(1)
decoder_input = torch.LongTensor([[config.ns.SOS]] * batch_size).to(config.device) # [batch_size,1]
# print("decoder_input:",decoder_input.size()) # 使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
decoder_outputs = torch.zeros([batch_size, config.max_len, len(config.ns)]).to(config.device) for t in range(config.max_len):
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
decoder_outputs[:, t, :] = decoder_output_t # 获取当前时间步的预测值
value, index = decoder_output_t.max(dim=-1)
if random.randint(0,100) >70: #teacher forcing机制
decoder_input = target[:,t].unsqueeze(-1)
else:
decoder_input = index.unsqueeze(-1) # [batch_size,1]
# print("decoder_input:",decoder_input.size())
return decoder_outputs, decoder_hidden def forward_step(self, decoder_input, decoder_hidden):
'''
计算一个时间步的结果
:param decoder_input: [batch_size,1]
:param decoder_hidden: [batch_size,encoder_hidden_size]
:return:
''' decoder_input_embeded = self.embedding(decoder_input)
# print("decoder_input_embeded:",decoder_input_embeded.size()) out, decoder_hidden = self.gru(decoder_input_embeded, decoder_hidden) # out :【batch_size,1,hidden_size】 out_squeezed = out.squeeze(dim=1) # 去掉为1的维度
out_fc = F.log_softmax(self.fc(out_squeezed), dim=-1) # [bathc_size,vocab_size]
# out_fc.unsqueeze_(dim=1) #[bathc_size,1,vocab_size]
# print("out_fc:",out_fc.size())
return out_fc, decoder_hidden def evaluate(self, encoder_hidden): # 第一个时间步的输入的hidden_state
decoder_hidden = encoder_hidden # [1,batch_size,encoder_hidden_size]
# 第一个时间步的输入的input
batch_size = encoder_hidden.size(1)
decoder_input = torch.LongTensor([[config.ns.SOS]] * batch_size).to(config.device) # [batch_size,1]
# print("decoder_input:",decoder_input.size()) # 使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
decoder_outputs = torch.zeros([batch_size, config.max_len, len(config.ns)]).to(config.device) decoder_predict = [] # [[],[],[]] #123456 ,targe:123456EOS,predict:123456EOS123
for t in range(config.max_len):
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
decoder_outputs[:, t, :] = decoder_output_t # 获取当前时间步的预测值
value, index = decoder_output_t.max(dim=-1)
decoder_input = index.unsqueeze(-1) # [batch_size,1]
# print("decoder_input:",decoder_input.size())
decoder_predict.append(index.cpu().detach().numpy()) # 返回预测值
decoder_predict = np.array(decoder_predict).transpose() # [batch_size,max_len]
return decoder_outputs, decoder_predict

  seq2seq.py

"""
完成seq2seq模型
"""
import torch.nn as nn
from encoder import Encoder
from decoder import Decoder class Seq2Seq(nn.Module):
def __init__(self):
super(Seq2Seq, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder() def forward(self, input, input_len,target):
encoder_outputs, encoder_hidden = self.encoder(input, input_len)
decoder_outputs, decoder_hidden = self.decoder(encoder_hidden,target)
return decoder_outputs def evaluate(self, input, input_len):
encoder_outputs, encoder_hidden = self.encoder(input, input_len)
decoder_outputs, decoder_predict = self.decoder.evaluate(encoder_hidden)
return decoder_outputs, decoder_predict

  train.py

"""
进行模型的训练
"""
import torch
import torch.nn.functional as F
from seq2seq import Seq2Seq
from torch.optim import Adam
from dataset import get_dataloader
from tqdm import tqdm
import config
import numpy as np
import pickle
from matplotlib import pyplot as plt
from eval import eval
import os model = Seq2Seq().to(config.device)
optimizer = Adam(model.parameters()) if os.path.exists("./models/model.pkl"):
model.load_state_dict(torch.load("./models/model.pkl"))
optimizer.load_state_dict(torch.load("./models/optimizer.pkl")) loss_list = [] def train(epoch):
data_loader = get_dataloader(train=True)
bar = tqdm(data_loader, total=len(data_loader)) for idx, (input, target, input_len, target_len) in enumerate(bar):
input = input.to(config.device)
target = target.to(config.device)
input_len = input_len.to(config.device)
optimizer.zero_grad()
decoder_outputs = model(input, input_len,target) # [batch_Size,max_len,vocab_size]
predict = decoder_outputs.view(-1, len(config.ns))
target = target.view(-1)
loss = F.nll_loss(predict, target, ignore_index=config.ns.PAD)
loss.backward()
optimizer.step()
loss_list.append(loss.item())
bar.set_description("epoch:{} idx:{} loss:{:.6f}".format(epoch, idx, np.mean(loss_list))) if idx % 100 == 0:
torch.save(model.state_dict(), "./models/model.pkl")
torch.save(optimizer.state_dict(), "./models/optimizer.pkl")
pickle.dump(loss_list, open("./models/loss_list.pkl", "wb")) if __name__ == '__main__':
for i in range(5):
train(i)
eval() plt.figure(figsize=(50, 8))
plt.plot(range(len(loss_list)), loss_list)
plt.show()

  

pytorch seq2seq模型中加入teacher_forcing机制的更多相关文章

  1. pytorch seq2seq闲聊机器人加入attention机制

    attention.py """ 实现attention """ import torch import torch.nn as nn im ...

  2. pytorch seq2seq模型示例

    以下代码可以让你更加熟悉seq2seq模型机制 """ test """ import numpy as np import torch i ...

  3. 分布式系统读写模型中的Quorum机制

    分布式系统的设计中会涉及到许多的协议.机制用来解决可靠性问题.数据一致性问题等,Quorum 机制就是其中的一种.我们通过分布式系统中的读写模型来简单介绍它. 分布式系统中的读写模型 分布式系统是由多 ...

  4. pytorch seq2seq模型训练测试

    num_sequence.py """ 数字序列化方法 """ class NumSequence: """ ...

  5. Seq2Seq模型与注意力机制

    Seq2Seq模型 基本原理 核心思想:将一个作为输入的序列映射为一个作为输出的序列 编码输入 解码输出 解码第一步,解码器进入编码器的最终状态,生成第一个输出 以后解码器读入上一步的输出,生成当前步 ...

  6. 深度学习之seq2seq模型以及Attention机制

    RNN,LSTM,seq2seq等模型广泛用于自然语言处理以及回归预测,本期详解seq2seq模型以及attention机制的原理以及在回归预测方向的运用. 1. seq2seq模型介绍 seq2se ...

  7. 注意力机制和Seq2seq模型

    注意力机制 在"编码器-解码器(seq2seq)"⼀节⾥,解码器在各个时间步依赖相同的背景变量(context vector)来获取输⼊序列信息.当编码器为循环神经⽹络时,背景变量 ...

  8. L11注意力机制和Seq2seq模型

    注意力机制 在"编码器-解码器(seq2seq)"⼀节⾥,解码器在各个时间步依赖相同的背景变量(context vector)来获取输⼊序列信息.当编码器为循环神经⽹络时,背景变量 ...

  9. Deep Learning基础--理解LSTM/RNN中的Attention机制

    导读 目前采用编码器-解码器 (Encode-Decode) 结构的模型非常热门,是因为它在许多领域较其他的传统模型方法都取得了更好的结果.这种结构的模型通常将输入序列编码成一个固定长度的向量表示,对 ...

随机推荐

  1. Functor、Applicative 和 Monad

    Functor.Applicative 和 Monad 是函数式编程语言中三个非常重要的概念,尤其是 Monad. 说明:本文中的主要代码为 Haskell 语言,它是一门纯函数式的编程语言. 一.结 ...

  2. SQL 实战(五)

    一. 将所有to_date为9999-01-01的全部更新为NULL,且 from_date更新为2001-01-01.CREATE TABLE IF NOT EXISTS titles_test ( ...

  3. Python——五分钟理解函数式编程与闭包

    本文始发于个人公众号:TechFlow,原创不易,求个关注 今天是Python专题的第9篇文章,我们来聊聊Python的函数式编程与闭包. 函数式编程 函数式编程这个概念我们可能或多或少都听说过,刚听 ...

  4. Vue+Mock.js模拟登录和表格的增删改查

    有三类人不适合此篇文章: "喜欢站在道德制高点的圣母婊" -- 适合去教堂 "无理取闹的键盘侠" -- 国际新闻版块欢迎你去 "有一定基础但又喜欢逼逼 ...

  5. HTTP协议经典面试题整理及答案详解

    无论你是Java.PHP开发者,还是运维人员,只要从事互联网行业,面试时都可能被问到HTTP协议相关知识.历时多天的呕心沥血,为你总结了HTTP协议的经典面试题.由于涉及内容比较繁杂不方便记忆,建议收 ...

  6. Jmeter使用Websocket插件测试SingalR,外加还有阿里云PTS的Jmeter原生测试爬坑日志。

    题外话:距离我的上一篇博客已经过去7年多了,我实在是个不务正业的程序员,遇到测试方面的东西总想分享一下,因为可用的资料实在太少了(包括国外的资料). 本人不喜欢授人以鱼,所以不会直接给出问题和解决方案 ...

  7. AJ学IOS(36)UI之手势事件旋转_缩放_拖拽

    AJ分享,必须精品 效果 完成一个图片的捏合缩放,拖拽,旋转动作. 设计思路 拖拽: 首先是最简单的拖拽 //拖拽 -(void)panTest { UIPanGestureRecognizer *p ...

  8. 03-css3中的3D转换

    一.CSS3-3D转换 1.3D 特点:近大远小,物体和面遮挡不可见 1.1三维坐标系 x 轴:水平向右 -- x 轴右边是正值,左边是负值 y 轴:垂直向下 -- y 轴下面是正值,上面是负值 z ...

  9. Redis之事务操作

    1.Redis事务的概念: Redis 事务的本质是一组命令的集合.事务支持一次执行多个命令,一个事务中所有命令都会被序列化.在事务执行过程,会按照顺序串行化执行队列中的命令,其他客户端提交的命令请求 ...

  10. 【Java】Operator 运算符/操作符

    Operator 运算符/操作符 什么是操作符? 一个表示特定的数学或逻辑操作的符号 算术运算符 加 + 减 - 乘 * 除 / 取模 % 前置自运算 ++ a .--b 后置自运算 a++ .b-- ...