深度学习之 rnn 台词生成
深度学习之 rnn 台词生成
写一个台词生成的程序,用 pytorch 写的。
import os
def load_data(path):
with open(path, 'r', encoding="utf-8") as f:
data = f.read()
return data
text = load_data('./moes_tavern_lines.txt')[81:]
train_count = int(len(text) * 0.6)
val_count = int(len(text) * 0.2)
test_count = int(len(text) * 0.2)
train_text = text[:train_count]
val_text = text[train_count: train_count + val_count]
test_text = text[train_count + val_count:]
view_sentence_range = (0, 10)
import numpy as np
print("data set State")
print("Roughly the number of unique words: {}".format(len({word: None for word in text.split()})))
scenes = text.split("\n\n")
print("number of scenes: {}".format(len(scenes)))
sentence_count_scene = [scene.count('\n') for scene in scenes]
print('Average number for sentences in each scene: {}'.format(np.average(sentence_count_scene)))
sentences = [sentence for scene in scenes for sentence in scene.split('\n')]
print("Number for lines: {}".format(len(sentences)))
word_count_sentence = [len(sentence.split()) for sentence in sentences]
print('Average number for words in each line: {}'.format(np.average(word_count_sentence)))
print()
print('The sentences {} to {}:'.format(*view_sentence_range))
print('\n'.join(text.split('\n')[view_sentence_range[0]:view_sentence_range[1]]))
def token_lookup():
return {
'.': '||Period||',
',': '||Comma||',
'"': '||Quotation_Mark||',
';': '||Semicolon||',
'!': '||Exclamation_mark||',
'?': '||Question_mark||',
'(': '||Left_Parentheses||',
')': '||Right_Parentheses||',
'--': '||Dash||',
'\n': '||Return||',
}
import os
import torch
class Dictionary(object):
def __init__(self):
self.word2idx = {}
self.idx2word = []
def add_word(self, word):
if word not in self.word2idx:
self.idx2word.append(word)
self.word2idx[word] = len(self.idx2word) - 1
return self.word2idx[word]
def __len__(self):
return len(self.idx2word)
class Corpus(object):
def __init__(self, train, val, test):
self.dictionary = Dictionary()
self.train = self.tokenize(train)
self.valid = self.tokenize(val)
self.test = self.tokenize(test)
def tokenize(self, text):
words = text.split()
tokens = len(words)
token = 0
ids = torch.LongTensor(tokens)
for i, word in enumerate(words):
self.dictionary.add_word(word)
ids[i] = self.dictionary.word2idx[word]
return ids
import numpy as np
import torch
i_dict = token_lookup()
def create_data(text):
vocab_to_int = {}
int_to_vocab = {}
new_text = ""
for t in text:
if t in token_lookup():
new_text += " {} ".format(i_dict[t])
else:
new_text += t
return new_text
import torch
import torch.nn as nn
from torch.autograd import Variable
# 模型 RNN
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size, n_layers=1):
super(RNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.n_layers = n_layers
self.drop = nn.Dropout(0.5)
self.encoder = nn.Embedding(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size, n_layers)
self.decoder = nn.Linear(hidden_size, output_size)
def forward(self, input, hidden):
input = self.encoder(input)
output, hidden = self.gru(input, hidden)
output = self.drop(output)
decoded = self.decoder(output.view(output.size(0) * output.size(1), output.size(2)))
return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden
def init_hidden(self, batch_size):
return Variable(torch.zeros(self.n_layers, batch_size, self.hidden_size))
# batch 化
def batchify(data, bsz):
# Work out how cleanly we can divide the dataset into bsz parts.
nbatch = data.size(0) // bsz
# Trim off any extra elements that wouldn't cleanly fit (remainders).
data = data.narrow(0, 0, nbatch * bsz)
# Evenly divide the data across the bsz batches.
data = data.view(bsz, -1).t().contiguous()
return data
n_epochs = 3500
print_every = 500
plot_every = 10
hidden_size = 100
n_layers = 1
lr = 0.005
chunk_len = 10
batch_size = 20
val_batch_size = 10
# 数据生成
train_data = create_data(train_text)
test_data = create_data(test_text)
val_data = create_data(val_text)
corpus = Corpus(train_data, val_data, test_data)
train_source = batchify(corpus.train, batch_size)
test_source = batchify(corpus.test, batch_size)
val_source = batchify(corpus.valid, batch_size)
n_tokens = len(corpus.dictionary)
# 模型
model = RNN(n_tokens, hidden_size, n_tokens, n_layers)
# 优化器
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# 损失函数
criterion = nn.CrossEntropyLoss()
#
def get_batch(source, i , evaluation = False):
seq_len = min(chunk_len, len(source) - 1 - i)
data = Variable(source[i:i+seq_len], volatile=evaluation)
target = Variable(source[i+1:i+1+seq_len].view(-1))
return data,target
def repackage_hidden(h):
if type(h) == Variable:
return Variable(h.data)
else:
return tuple(repackage_hidden(v) for v in h)
# 训练
def train():
model.train()
total_loss = 0
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(batch_size)
for batch, i in enumerate(range(0, train_source.size(0) - 1, chunk_len)):
data, targets = get_batch(train_source, i)
hidden = repackage_hidden(hidden)
optimizer.zero_grad()
output, hidden = model(data, hidden)
loss = criterion(output.view(-1, ntokens), targets)
loss.backward()
optimizer.step()
total_loss += loss.data
if batch % 10 == 0:
print('epoch {}/{} {}'.format(epoch, batch, loss.data))
# 验证
def evaluate(data_source):
model.eval()
total_loss = 0
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(batch_size)
for i in range(0, data_source.size(0) - 1, chunk_len):
data, targets = get_batch(data_source, i, evaluation=True)
output, hidden = model(data, hidden)
output_flat = output.view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).data
hidden = repackage_hidden(hidden)
return total_loss[0] / len(data_source)
import time, math
# 开始训练
for epoch in range(1, n_epochs + 1):
train()
val_loss = evaluate(val_source)
print("epoch {} {} {}".format(epoch, val_loss, math.exp(val_loss)))
# 生成一段短语
def gen(n_words):
model.eval()
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(1)
input = Variable(torch.rand(1, 1).mul(ntokens).long(), volatile=True)
words = []
for i in range(n_words):
output, hidden = model(input, hidden)
word_weights = output.squeeze().data.exp().cpu()
word_idx = torch.multinomial(word_weights, 1)[0]
input.data.fill_(word_idx)
word = corpus.dictionary.idx2word[word_idx]
isOk = False
for w,s in i_dict.items():
if s == word:
isOk = True
words.append(w)
break
if not isOk:
words.append(word)
return words
words = gen(1000)
print(" ".join(words))
总结
rnn 总是参数不怎么对,耐心调整即可。
深度学习之 rnn 台词生成的更多相关文章
- 惊不惊喜, 用深度学习 把设计图 自动生成HTML代码 !
如何用前端页面原型生成对应的代码一直是我们关注的问题,本文作者根据 pix2code 等论文构建了一个强大的前端代码生成模型,并详细解释了如何利用 LSTM 与 CNN 将设计原型编写为 HTML 和 ...
- 深度学习-CNN+RNN笔记
以下叙述只是简单的叙述,CNN+RNN(LSTM,GRU)的应用相关文章还很多,而且研究的方向不仅仅是下文提到的1. CNN 特征提取,用于RNN语句生成图片标注.2. RNN特征提取用于CNN内容分 ...
- [深度学习]理解RNN, GRU, LSTM 网络
Recurrent Neural Networks(RNN) 人类并不是每时每刻都从一片空白的大脑开始他们的思考.在你阅读这篇文章时候,你都是基于自己已经拥有的对先前所见词的理解来推断当前词的真实含义 ...
- 用深度学习技术FCN自动生成口红
1 这个是什么? 基于全卷积神经网络(FCN)的自动生成口红Python程序. 图1 FCN生成口红的效果(注:此两张人脸图来自人脸公开数据库LFW) 2 怎么使用了? 首 ...
- 4.keras实现-->生成式深度学习之用GAN生成图像
生成式对抗网络(GAN,generative adversarial network)由Goodfellow等人于2014年提出,它可以替代VAE来学习图像的潜在空间.它能够迫使生成图像与真实图像在统 ...
- 【深度学习】RNN | GRU | LSTM
目录: 1.RNN 2.GRU 3.LSTM 一.RNN 1.RNN结构图如下所示: 其中: $a^{(t)} = \boldsymbol{W}h^{t-1} + \boldsymbol{W}_{e} ...
- 机器学习(Machine Learning)&深度学习(Deep Learning)资料【转】
转自:机器学习(Machine Learning)&深度学习(Deep Learning)资料 <Brief History of Machine Learning> 介绍:这是一 ...
- 机器学习(Machine Learning)与深度学习(Deep Learning)资料汇总
<Brief History of Machine Learning> 介绍:这是一篇介绍机器学习历史的文章,介绍很全面,从感知机.神经网络.决策树.SVM.Adaboost到随机森林.D ...
- CNCC2017中的深度学习与跨媒体智能
CNCC2017中的深度学习与跨媒体智能 转载请注明作者:梦里茶 目录 机器学习与跨媒体智能 传统方法与深度学习 图像分割 小数据集下的深度学习 语音前沿技术 生成模型 基于贝叶斯的视觉信息编解码 珠 ...
随机推荐
- 【经验随笔】MYSQL表加锁升级导致数据库访问失败
背景:有一次定位问题发现,在同一个session连接中对MYSQL部分表加锁,导致其它未加锁的表不能访问. 用Spring管理MYSQL数据连接,在多线程访问数据库的情况下容易出问题.一个线程中对部分 ...
- Android TV 电视调试和遥控器事件监听
Android TV 真机调试 要进行Android TV开发免不了要进行真机调试. 1.确定电视盒子和开发机器在同一局域网中 2.打开电视盒子的adb允许调试开关 3.进入adb所在文件夹进行adb ...
- NancyFX 第五章 Nancy 路由
在Nancy中,最为神奇的莫过于路由了,定义路由模块是构成Nancy应用的骨架.在Nancy中定义路由,和在 ASP.NET MVC那些类似的框架中有着非常大的区别. 以 ASP.NET MVC 为例 ...
- 11.C++-临时对象分析
首先来参考以下代码: #include <stdio.h> class Test { int mi; public: Test(int i) { mi = i; } Test() { Te ...
- 华为/华三交换机snmp配置
snmp-agent /使能snmp服务/snmp-agent local-engineid 000007DB7F000001000049DD /系统自 ...
- git将文件托管到github上遇到的问题
先来一问题描述: 执行:$ git push -u origin master 结果Warning: Permanently added the RSA host key for IP address ...
- 20165230 2017-2018-2 《Java程序设计》第2周学习总结
20165230 2017-2018-2 <Java程序设计>第2周学习总结 教材学习内容总结 本周学习了JAVA中的数据类型.数组.运算符.表达式和语句,与C语言很类似,二者也有区别. ...
- 20165230 学习基础和C语言基础调查
20165230 学习基础和C语言基础调查 技能学习经验 我擅长弹钢琴.小时候我曾上过很多兴趣班,比如钢琴.跳舞.书法.绘画等等,唯一坚持至今的只有钢琴.仔细一算学习钢琴至今已有12年,不能说已经精通 ...
- Gauge----自动化测试工具--使用
开始吧 1 下载安装gauge(根据官网教程 http://getgauge.io/documentation/user/current/)测试:gauge -v step01 磁盘上新建一个空目录- ...
- spring boot高性能实现二维码扫码登录(上)——单服务器版
前言 目前网页的主流登录方式是通过手机扫码二维码登录.我看了网上很多关于扫码登录博客后,发现基本思路大致是:打开网页,生成uuid,然后长连接请求后端并等待登录认证相应结果,而后端每个几百毫秒会循环查 ...