以下代码可以让你更加熟悉seq2seq模型机制

"""
test
"""
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable # 创建字典
seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']]
char_arr = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz']
num_dict = {n:i for i,n in enumerate(char_arr)} # 网络参数
n_step = 5
n_hidden = 128
n_class = len(num_dict)
batch_size = len(seq_data) # 准备数据
def make_batch(seq_data):
input_batch, output_batch, target_batch =[], [], [] for seq in seq_data:
for i in range(2):
seq[i] = seq[i] + 'P' * (n_step-len(seq[i]))
input = [num_dict[n] for n in seq[0]]
ouput = [num_dict[n] for n in ('S'+ seq[1])]
target = [num_dict[n] for n in (seq[1]) + 'E'] input_batch.append(np.eye(n_class)[input])
output_batch.append(np.eye(n_class)[ouput])
target_batch.append(target) return Variable(torch.Tensor(input_batch)), Variable(torch.Tensor(output_batch)), Variable(torch.LongTensor(target_batch)) input_batch, output_batch, target_batch = make_batch(seq_data) # 创建网络
class Seq2Seq(nn.Module):
"""
要点:
1.该网络包含一个encoder和一个decoder,使用的RNN的结构相同,最后使用全连接接预测结果
2.RNN网络结构要熟知
3.seq2seq的精髓:encoder层生成的参数作为decoder层的输入
"""
def __init__(self):
super().__init__()
# 此处的input_size是每一个节点可接纳的状态,hidden_size是隐藏节点的维度
self.enc = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)
self.dec = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)
self.fc = nn.Linear(n_hidden, n_class) def forward(self, enc_input, enc_hidden, dec_input):
# RNN要求输入:(seq_len, batch_size, n_class),这里需要转置一下
enc_input = enc_input.transpose(0,1)
dec_input = dec_input.transpose(0,1)
_, enc_states = self.enc(enc_input, enc_hidden)
outputs, _ = self.dec(dec_input, enc_states)
pred = self.fc(outputs) return pred # training
model = Seq2Seq()
loss_fun = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001) for epoch in range(5000):
hidden = Variable(torch.zeros(1, batch_size, n_hidden)) optimizer.zero_grad()
pred = model(input_batch, hidden, output_batch)
pred = pred.transpose(0, 1)
loss = 0
for i in range(len(seq_data)):
temp = pred[i]
tar = target_batch[i]
loss += loss_fun(pred[i], target_batch[i])
if (epoch + 1) % 1000 == 0:
print('Epoch: %d Cost: %f' % (epoch + 1, loss))
loss.backward()
optimizer.step() # 测试
def translate(word):
input_batch, output_batch, _ = make_batch([[word, 'P' * len(word)]])
# hidden 形状 (1, 1, n_class)
hidden = Variable(torch.zeros(1, 1, n_hidden))
# output 形状(6,1, n_class)
output = model(input_batch, hidden, output_batch)
predict = output.data.max(2, keepdim=True)[1]
decoded = [char_arr[i] for i in predict]
end = decoded.index('E')
translated = ''.join(decoded[:end]) return translated.replace('P', '') print('girl ->', translate('girl'))

参考:https://blog.csdn.net/weixin_43632501/article/details/98525673

pytorch seq2seq模型示例的更多相关文章

  1. pytorch seq2seq模型中加入teacher_forcing机制

    在循环内加的teacher forcing机制,这种为目标确定的时候,可以这样加. 目标不确定,需要在循环外加. decoder.py 中的修改 """ 实现解码器 &q ...

  2. pytorch seq2seq模型训练测试

    num_sequence.py """ 数字序列化方法 """ class NumSequence: """ ...

  3. PyTorch专栏(六): 混合前端的seq2seq模型部署

    欢迎关注磐创博客资源汇总站: http://docs.panchuang.net/ 欢迎关注PyTorch官方中文教程站: http://pytorch.panchuang.net/ 专栏目录: 第一 ...

  4. 混合前端seq2seq模型部署

    混合前端seq2seq模型部署 本文介绍,如何将seq2seq模型转换为PyTorch可用的前端混合Torch脚本.要转换的模型来自于聊天机器人教程Chatbot tutorial. 1.混合前端 在 ...

  5. 时间序列深度学习:seq2seq 模型预测太阳黑子

    目录 时间序列深度学习:seq2seq 模型预测太阳黑子 学习路线 商业中的时间序列深度学习 商业中应用时间序列深度学习 深度学习时间序列预测:使用 keras 预测太阳黑子 递归神经网络 设置.预处 ...

  6. seq2seq模型详解及对比(CNN,RNN,Transformer)

    一,概述 在自然语言生成的任务中,大部分是基于seq2seq模型实现的(除此之外,还有语言模型,GAN等也能做文本生成),例如生成式对话,机器翻译,文本摘要等等,seq2seq模型是由encoder, ...

  7. PyTorch 自动微分示例

    PyTorch 自动微分示例 autograd 包是 PyTorch 中所有神经网络的核心.首先简要地介绍,然后训练第一个神经网络.autograd 软件包为 Tensors 上的所有算子提供自动微分 ...

  8. pytorch生成对抗示例

    pytorch生成对抗示例 本文对ML(机器学习)模型的安全漏洞的认识,并将深入了解对抗性机器学习的热门话题.图像添加难以察觉的扰动会导致模型性能大不相同.通过图像分类器上的示例探讨该主题.使用第一种 ...

  9. [炼丹术]使用Pytorch搭建模型的步骤及教程

    使用Pytorch搭建模型的步骤及教程 我们知道,模型有一个特定的生命周期,了解这个为数据集建模和理解 PyTorch API 提供了指导方向.我们可以根据生命周期的每一个步骤进行设计和优化,同时更加 ...

随机推荐

  1. System.Windows.Forms.Control.CheckForIllegalCrossThreadCalls = false

    多线程程序中,新创建的线程不能访问UI线程创建的窗口控件,这时如果想要访问窗口的控件,发现无法对其控制. 这时可将窗口构造函数中的CheckForIllegalCrossThreadCalls设置为f ...

  2. C#中的System.Type和System.RuntimeType之间的区别

    string str = string.Empty; Type strType = str.GetType(); Type strTypeType = strType.GetType(); strTy ...

  3. 25个特殊操作符(special operator)

    1. CLHS (Common-Lisp-Hyper-Spec) http://www.lispworks.com/documentation/HyperSpec/Body/03_ababa.htm ...

  4. PIE SDK加载WMS服务数据

    1.   功能简介 WMS服务,WMS是OGC标准中比较简单也是比较重要的标准之一.它全称是“Web Map Service”(网络地图服务):利用具有地理空间位置信息的数据制作地图.其中将地图定义为 ...

  5. Spring.yml配置文件读取字符串出现错误

    今天遇到一个诡异的问题,在配置文件中配置了一个值为字符串的属性,但是在用@Value注入时发现注入的值不是我配置的值,而且在全文都没有找到匹配的值 之后研究了好久,发现yml文件会把0开头的数组进行8 ...

  6. 我用Bash编写了一个扫雷游戏

    我在编程教学方面不是专家,但当我想更好掌握某一样东西时,会试着找出让自己乐在其中的方法.比方说,当我想在 shell 编程方面更进一步时,我决定用 Bash 编写一个扫雷游戏来加以练习. 我在编程教学 ...

  7. Spark调用Linux命令实现解压和压缩功能

    一.应用场景 在Spark程序中调用Linux命令,实现一些程序难以实现的功能,例如:发送模拟邮件.文件打包或解压等等 二.代码实现 package big.data.analyse.linux im ...

  8. ETHINK组件取值手册

    Ethink组件取值手册 一.取值 Sql查询配置中取值方式:所有可以对外过滤的组件都可以用id.output取值 就是取组件setOutput()里输出的值 ,具体分为以下两种: 1)$p{OBJ_ ...

  9. itextpdf5生成document生成pdf的简单dome

    package dbzx.pdf; import java.io.FileNotFoundException; import java.io.FileOutputStream; import org. ...

  10. Nginx 高级配置-实现多域名HTTPS

    Nginx 高级配置-实现多域名HTTPS 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 一.Nginx支持基于单个IP实现多域名的功能 Nginx支持基于单个IP实现多域名的功能 ...