『TensotFlow』RNN/LSTM古诗生成
往期RNN相关工程实践文章
『TensotFlow』RNN中文文本_下_暨研究生开学感想
张量分析
预处理结果是二维数据,相当于batch条一维数据,每个数据对应一首诗,每个字是一个scalar;
embedding之后,将每个字映射为一个rnn_size大小的向量,数据变为三维;
经过递归神经网络,输出维度不变;
将之调整为二维数据,这里面第二维度(即每一行)对应一个字;
全连接映射,将每一个字由rnnsize的向量映射为一个长度为总字数的向量,这样方便计算loss,实际计算loss时,会将label(二维向量,一行一首诗,字为scalar)拉伸为一维向量,每行只有一个字scalar,one_hot之后和此时数据正好对应,方便计算
SoftMax不改变张量形状,只是将结果以概率分布的形式输出
工程分析
代码见Github
1、文件简介
LSTM_model.py
:LSTM网络模型,提供了end_points接口,被其他部分调用
poetry_porcess.py
:数据读取、预处理部分,会返回打包好的batch,被main调用
gen_poetry.py
:古诗生成程序,拥有可选的风格参数,被main调用
main.py
:主函数,既可以调用前两个程序获取预处理数据并使用LSTM网络进行训练,也可以调用gen_poetry.py生成古诗
2、调用指令
在main.py
最后有如下指令,
if __name__ == "__main__":
words,poetry_vector,to_num,x_batches,y_batches = poetry_porcess.poetry_process()
# train(words, poetry_vector, x_batches, y_batches)
# gen_poetry(words, to_num)
generate(words_, to_num_, style_words="狂沙将军战燕然,大漠孤烟黄河骑。")
此时实际上处于生成模式,对于最后的三行, train:表示训练 gen_poetry:表示根据首字符生成 generate:表示根据首句和风格句生成古诗
训练时注释掉后两行,保留train行,
if __name__ == "__main__":
words,poetry_vector,to_num,x_batches,y_batches = poetry_porcess.poetry_process()
train(words, poetry_vector, x_batches, y_batches)
# gen_poetry(words, to_num)
# generate(words_, to_num_, style_words="狂沙将军战燕然,大漠孤烟黄河骑。")
生成时不需要修改,但是
generate(words_, to_num_, style_words="狂沙将军战燕然,大漠孤烟黄河骑。")
可以替换style_word为任何你想要的风格句,注意最好使用7言或者5言,因为这句会大概率影响到你生成的古诗的句子长度(不绝对),这只是风格提取,你可以输入任意长度;在运行了脚本后,屏幕会提示输入起始句,输入的句子一般5或者7个字,这个由于会拿来直接做首句(由结果示范可以看到),输入长度不宜过长。
对于上面的两种情况,修改完成后运行脚本即可,
python main.py
即可显示结果
3、结果示范
head:床前明月光 + style:黄沙百战金甲: 床前明月光辉,魏武征夫血絮红。
数步崩云复遗主,缟衣东,帝京举,玉轮还满出书初。
秋秋惨惨垂杨柳,梦断黄莺欲断肠。
花凋柳映阮家几,屋前病,歇马空留门。
当年皆月林,独往深山有素。 head:少小离家老大回 + style:山雨欲来风满楼: 少小离家老大回,四壁百月弄鸦飞。
扫香花间春风地,隔天倾似烂桃香。
近来谁伴清明日,两株愁味在罗帏。
仍通西疾空何处,轧轧凉吹日方明。 head:少小离家老大回 + style:铁马冰河入梦来: 少小离家老大回,化空千里便成丝。
官抛十里同牛颔,莫碍风光雪片云。
饮水远涛飞汉地,云连城户翠微低。
一树铁门万象耸,白云三尺各关高。
同言东甸西游子,谁道承阳要旧忧。 少小离家老大回,含颦玉烛拂楼台。
初齐去府芙蓉死,细缓行云向国天
RNN结构补充
原网络结构如下,实际上不需要像下面这样写了,不过当时费了好大事,所以保留一下原来版本的代码,
with tf.variable_scope('placeholder'):
input_vec = tf.placeholder(tf.int32,[None,None])
output_targets = tf.placeholder(tf.int32,[None,None]) def rnn_network(rnn_size=128,num_layers=2):
def lstm_cell():
l_cell = tf.contrib.rnn.BasicLSTMCell(rnn_size,state_is_tuple=True,reuse=tf.get_variable_scope().reuse)
return l_cell
cell = tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(num_layers)])
initial_state = cell.zero_state(batch_size, tf.float32) # 初始化LSTM网络节点,参数为尺寸
with tf.variable_scope('LSTM'):
with tf.variable_scope('embedding'):
E = tf.get_variable('embedding',[len(words_list) + 1,rnn_size])
input_embedding = tf.nn.embedding_lookup(E,input_vec)
output_embedding, last_state = tf.nn.dynamic_rnn(cell, input_embedding, initial_state=initial_state,scope='lstm')
output = tf.reshape(output_embedding,[-1,rnn_size])
with tf.variable_scope('output'):
W = tf.get_variable('W', [rnn_size,len(words_list)+1])
b = tf.get_variable('b', [len(words_list)+1])
logits = tf.matmul(output,W) + b
probs = tf.nn.softmax(logits)
return logits, last_state, probs, cell, initial_state
另外,直接使用tf.nn.rnn_cell而不是用tf.contrib.rnn也可以。
『TensotFlow』RNN/LSTM古诗生成的更多相关文章
- 『TensotFlow』RNN中文文本_下_暨研究生开学感想
承前 接上节代码『TensotFlow』RNN中文文本_上, import numpy as np import tensorflow as tf from collections import Co ...
- 『TensotFlow』RNN中文文本_上
中文文字预处理流程 文本处理 读取+去除特殊符号 按照字段长度排序 辅助数据结构生成 生成 {字符:出现次数} 字典 生成按出现次数排序好的字符list 生成 {字符:序号} 字典 生成序号list ...
- 『计算机视觉』Mask-RCNN_锚框生成
Github地址:Mask_RCNN 『计算机视觉』Mask-RCNN_论文学习 『计算机视觉』Mask-RCNN_项目文档翻译 『计算机视觉』Mask-RCNN_推断网络其一:总览 『计算机视觉』M ...
- 『cs231n』RNN之理解LSTM网络
概述 LSTM是RNN的增强版,1.RNN能完成的工作LSTM也都能胜任且有更好的效果:2.LSTM解决了RNN梯度消失或爆炸的问题,进而可以具有比RNN更为长时的记忆能力.LSTM网络比较复杂,而恰 ...
- 『TensotFlow』转置卷积
网上解释 作者:张萌链接:https://www.zhihu.com/question/43609045/answer/120266511来源:知乎著作权归作者所有.商业转载请联系作者获得授权,非商业 ...
- 『计算机视觉』Mask-RCNN_推断网络其三:RPN锚框处理和Proposal生成
一.RPN锚框信息生成 上文的最后,我们生成了用于计算锚框信息的特征(源代码在inference模式中不进行锚框生成,而是外部生成好feed进网络,training模式下在向前传播时直接生成锚框,不过 ...
- 『PyTorch』第十弹_循环神经网络
RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...
- 『TensorFlow』专题汇总
TensorFlow:官方文档 TensorFlow:项目地址 本篇列出文章对于全零新手不太合适,可以尝试TensorFlow入门系列博客,搭配其他资料进行学习. Keras使用tf.Session训 ...
- 『计算机视觉』Mask-RCNN_从服装关键点检测看KeyPoints分支
下图Github地址:Mask_RCNN Mask_RCNN_KeyPoints『计算机视觉』Mask-RCNN_论文学习『计算机视觉』Mask-RCNN_项目文档翻译『计算机视觉』Mas ...
随机推荐
- pythonl类继承例子
#coding=utf-8 class Person(object): def __init__(self,name,age): self.name=name sel ...
- Linux基础命令---split
split 将一个大文件切割成较小的文件,默认情况下每1000行就会切割一次.分割后的文件,默认以xaa.xab.xac等命名.用户亦可以指定名字的前缀,例如指定前缀test,那么分割后的文件是tes ...
- AI+教育落地,百度大脑如何让校园更智能?
人工智能作为影响社会底层技术革命逐渐向传统行业渗透,“AI+”已经替代“互联网+”成为创业创新的新引擎,出人意料的是,在AI在教育业的率先落地并且相当火爆. 现在,人工智能教育已成为从业者心目中的“教 ...
- OpenCV-跟我一起学数字图像处理之拉普拉斯算子
https://www.cnblogs.com/german-iris/p/4840647.html Laplace算子和Sobel算子一样,属于空间锐化滤波操作.起本质与前面的Spatial Fil ...
- Elasticsearch 疑难解惑
Elasticsearch是如何实现Master选举的? Elasticsearch的选主是ZenDiscovery模块负责的,主要包含Ping(节点之间通过这个RPC来发现彼此)和Unicast(单 ...
- Redis的两种持久化方式-快照持久化(RDB)和AOF持久化
Redis为了内部数据的安全考虑,会把本身的数据以文件形式保存到硬盘中一份,在服务器重启之后会自动把硬盘的数据恢复到内存(redis)的里边,数据保存到硬盘的过程就称为“持久化”效果. redis有两 ...
- 项目梳理5——修改已生成.nuspec文件
xxxx.nuspec格式如下 <?xml version="1.0"?> <package > <metadata> <id>$i ...
- 如何解决Visual Studio2010 编译时提示系统找不到指定文件问题
前一段时间,开始使用vs2010编写程序,可是在编译的时候总是报错,提示系统找不到指定文件,导致无法正常运行程序,花了好久时间终于找到原因解决,是因为常规的输出目录 要与链接的常规的输出文件要相对应. ...
- 【TCP/IP详解 卷一:协议】第四章 ARP:地址解析协议 以及其他部分的一些知识
4.1 引言 数据链路 如以太网(Ethernet) 或者 令牌环网 都有自己的寻址机制(一般为 48 bit 的地址). 一个网络(数据链路层) 可以同时被多个不同的网络使用.比如,一组使用TCP/ ...
- UVa 1602 网格动物(回溯)
https://vjudge.net/problem/UVA-1602 题意:计算n连通块不同形态的个数. 思路: 实在是不知道该怎么做好,感觉判重实在是太麻烦了. 判重就是判断所有格子位置是否都相同 ...