char-rnn-tensorflow由飞飞弟子karpathy编写,展示了如何用tensorflow来搭建一个基本的RNN(LSTM)网络,并进行基于char的seq2seq进行训练。

数据读取部分

data文件夹下的input.txt为示例用的莎翁剧本,在数据读取阶段的preprocess函数中,将基于该文本文件生产词汇表文件vocal.pkl(记录词的索引)和data.npy(将训练用的文字转换成索引的文件)。

其中,self.vocab就是组织的字典文件,给出任意char能查询到它的index,如[(3:c),(25:y)....]。如果是第一次读取,Self.char(索引对应字符)被dump到vocab.pkl文件中。读取的文字素材内容也通过查字典文件的方式被

转成索引队列,存入tensor变量和data.npy文件中。

接下来确定每一批训练用的Batch的数据:

最终得到一个epoch需要训练的batch数量,训练数据x_batch和y_batch(y_batch[i]存储了x_batch[i+1]的char,并在结尾处循环),两者都是长度为num_batch的list,每一个item是一个batch需要训练的数据tensor(或标签数据),

即batch_size组数据(上图的第一个50),每一组数据有seq_length个word(上图的第二个50)。

model的建立

这里使用了rnn.BasicLSTMCell,也就是基本的LSTM来构建cell,隐层单元数为args.rnn_size(默认128个),一个cell中的layer层数为2.隐层单元的个数也就对应了训练出来的词的dense-vector的维度,隐层单元的矩阵类似于

word2vec中训练用的隐层矩阵。模型中的placeholder参数self.input_data和self.target的shape为[args.batch_size, args.seq_length],经过embedding查找后,被转化为稠密向量inputs(长度为batch_size的list,每一个item的

shape为[args.seq_length,args.rnn_size],如上所述,args.rnn_size即是dense-vector的维度,每一个word由dense-vector维度的词向量来表示)

关于embedding矩阵的构建,可以参考下图:

上例将embedding矩阵初始化为one-hot编码的对角矩阵,如果不进行初始化(就像char-rnn例子里面一样),则数据会在initial的时候被随机初始化,如下图:

最终,input会被压成一个长度为num_steps的列表,每个元素是[batch_size, input_size]的2-D维的tensor

loop函数,沿用模型的参数(w和b)来循环生成下一个词。训练的时候 inputs里面已经有完整的训练sample了,所以loop函数被设置为null,不使用,inference的时候这部分是缺失的,需要我们用loop来生成。

MultiRNNCell函数构造了一个时间步的多层rnn

legacy_seq2seq.rnn_decoder负责实现将其循环num_steps个时间步,这里的num_step就等价于seq_length。

这里附上rnn_decoder的伪代码

接下来就是手动求导 手动优化了:

训练完成后,Sample的流程总体上中规中矩,每次喂一个char给模型,吐出一个char,重复设定的num次。不过比较奇怪的是最后一步,在chars_size维的结果向量中选取最大概率的索引时,使用了一个奇怪的函数,

weighted_pick,这个函数的最终输出结果和随机数有关,这个随机性和系统的结果有什么关联,看不懂。经高手提醒,想明白了,这里没有直接使用softmax,而是在保证概率优先的前提下,加入了一些随机性,
让一个随机数落入weight组成的区间中去。

char-rnn-tensorflow源码解析及结构流程分析的更多相关文章

  1. Flink 源码解析 —— 项目结构一览

    Flink 源码项目结构一览 https://t.zsxq.com/MNfAYne 博客 1.Flink 从0到1学习 -- Apache Flink 介绍 2.Flink 从0到1学习 -- Mac ...

  2. Sentinel源码解析一(流程总览)

    引言 Sentinel作为ali开源的一款轻量级流控框架,主要以流量为切入点,从流量控制.熔断降级.系统负载保护等多个维度来帮助用户保护服务的稳定性.相比于Hystrix,Sentinel的设计更加简 ...

  3. Tensorflow源码解析1 -- 内核架构和源码结构

    1 主流深度学习框架对比 当今的软件开发基本都是分层化和模块化的,应用层开发会基于框架层.比如开发Linux Driver会基于Linux kernel,开发Android app会基于Android ...

  4. tensorflow源码解析系列文章索引

    文章索引 framework解析 resource allocator tensor op node kernel graph device function shape_inference 拾遗 c ...

  5. 顺序线性表 ---- ArrayList 源码解析及实现原理分析

    原创播客,如需转载请注明出处.原文地址:http://www.cnblogs.com/crawl/p/7738888.html ------------------------------------ ...

  6. Thinkphp6源码分析之解析,Thinkphp6路由,Thinkphp6路由源码解析,Thinkphp6请求流程解析,Thinkphp6源码

    Thinkphp6源码解析之分析 路由篇-请求流程 0x00 前言: 第一次写这么长的博客,所以可能排版啊,分析啊,什么的可能会比较乱.但是我大致的流程已经觉得是说的够清楚了.几乎是每行源码上都有注释 ...

  7. tensorflow源码解析之framework拾遗

    把framework中剩余的内容,按照文件名进行了简单解析.时间原因写的很仓促,算是占个坑,后面有了新的理解再来补充. allocation_description.proto 一个对单次内存分配结果 ...

  8. tensorflow源码解析之common_runtime-executor-上

    目录 核心概念 executor.h Executor NewLocalExecutor ExecutorBarrier executor.cc structs GraphView ExecutorI ...

  9. tensorflow源码解析之framework-allocator

    目录 什么是allocator 内存分配器的管理 内存分配追踪 其它结构 关系图 涉及的文件 迭代记录 1. 什么是allocator Allocator是所有内存分配器的基类,它定义了内存分配器需要 ...

随机推荐

  1. 【bzoj1552/3506】[Cerc2007]robotic sort splay翻转,区间最值

    [bzoj1552/3506][Cerc2007]robotic sort Description Input 输入共两行,第一行为一个整数N,N表示物品的个数,1<=N<=100000. ...

  2. 移动端:UI图px单位转换rem单位的计算方法

    简单说一下 em em 单位是相对于父元素字体大小来去定的.比方说: font-size:12px; 元素宽度是2em; 那么实际的宽度是 24px.(具体为什么,可以去查询资料,今天主讲rem) 简 ...

  3. Codeforces932D. Tree

    n<=400000个在线操作:树上插入一个某点权.父亲为某点的点:查询这样的最长点序列:序列的某个数必须是上一个数的祖先之一:序列的点权和不能超过x:序列的某个点的点权必须不小于上一个,且相邻两 ...

  4. 洛谷—— P1656 炸铁路

    P1656 炸铁路 题目描述 因为某国被某红色政权残酷的高压暴力统治.美国派出将军uim,对该国进行战略性措施,以解救涂炭的生灵. 该国有n个城市,这些城市以铁路相连.任意两个城市都可以通过铁路直接或 ...

  5. SVN 学习笔记-高级操作

    所谓高级操作,只是曲高和寡,其实都不怎么用的.但是关键时候,可能会很有用. 这个高级只是针对基本操作而言.有些操作可能也是比较基本的. 清除锁 有时候我们在操作的时候,可能系统崩溃了,或者SVN非正常 ...

  6. linux C 中的volatile使用

    一个定义为volatile的变量是说这变量可能会被意想不到地改变,这样,编译器就不会去假设这个变量的值了.精确地说就是,优化器在用到这个变量时必须每次都小心地重新读取这个变量的值,而不是使用保存在寄存 ...

  7. 开发:异常收集之 DB2建表相关问题

    第一次用DB2数据库,因为考虑到建表语句可能不一样,所以採用手动建表的办法.一个个字段去填.并勾选主键.最后发现创建失败.看了下系统生成的sql语句 sql语句例如以下: CREATE TABLE F ...

  8. 怎样免费设置QQ空间背景音乐

    怎样免费设置QQ空间背景音乐 1.打开QQ空间,点击 2. 3. 4.这里它要求我们输入歌曲的在线路径,并且必须是MP3格式的,这就简单了,我们仅仅要去网上找在线的MP3音乐就能够了.可是如今非常多提 ...

  9. AIX下RAC搭建 Oracle10G(五)安装oracle、建立监听

    AIX下RAC搭建系列 AIX下RAC搭建 Oracle10G(五)安装oracle.建立监听 环境 节点 节点1 节点2 小机型号 IBM P-series 630 IBM P-series 630 ...

  10. mysql最新版中文参考手册在线浏览

    MySQL是最流行的开放源码SQL数据库管理系统,具有快速.可靠和易于使用的特点.同时MySQL也是一种关联数据库管理系统,具有很高的响应速度和灵活性.又因为mysql拥有良好的连通性.速度和安全性, ...