tensorflow world language model
上文提到了pytorch里的world language model,那么怎么能不说tensorflow的实现呢,还是以tensorflow ptb的代码为例说说。
地址:
https://github.com/tensorflow/models/tree/master/tutorials/rnn/ptb
大概处理流程是,一大段文章,然后转成ids,然后根据batchsize切割成。batchsize * M
num_steps是一个sequence的长度
epoch_size 就是进行多少轮训练,算法就是一个batch内文本的长度,除以sequence的长度。一个文本都切成多少个sequence就训练多少轮。
用strided_slice切割也很有意思,接受两个参数,第一个参数是左上角的点,第二个参数是右下角的点,你感受下下面是怎么切割,的横坐标不变,永远是一个batchsize,然后纵坐标开始从左到右开始切割。非常直观。
epoch_size = (batch_len - 1) // num_steps
i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()
x = tf.strided_slice(data, [0, i * num_steps],
[batch_size, (i + 1) * num_steps])
x.set_shape([batch_size, num_steps])
y = tf.strided_slice(data, [0, i * num_steps + 1],
[batch_size, (i + 1) * num_steps + 1])
y.set_shape([batch_size, num_steps])
return x, y
训练的代码也比较直观
inputs = tf.nn.embedding_lookup(embedding, input_.input_data)
input_data的维度是batchsize * sequence,进行embding_lookup之后就是
batchsize * sequence*embsize
因为sequence上每一个词都有一个embding结果(每个词都去查表了)
下面的代码是核心:
对于num_steps也就是sequence上每一个词,进行循环。每一个time_step取出一个batchsize *embsize 和一个hidden作为输入,
然后输出一个batchsize *embsize 和一个hidden,这里把每一个时刻的输出都存到一个outputs 数组里面。所有outputs 的维度应该是:sequence*batchsize *embsize。
outputs = []
state = self._initial_state
with tf.variable_scope("RNN"):
for time_step in range(num_steps):
if time_step > 0: tf.get_variable_scope().reuse_variables()
(cell_output, state) = cell(inputs[:, time_step, :], state)
outputs.append(cell_output)
这边把outputs进行一个变换,变成output的维度是value *embsize二维矩阵 。其中value长度等于sequence*batchsize。
在有了output之后,就可以根据wx+b生成一个logits,最后根据这个logits去和这个target进行求loss,很显然,这个logits的维度是
value*vocabsize 。这里很坑爹。logits的第一维度是sequence*batchsize的乘积,我们用起来就不是很爽了,所以如果你想拿某个词的具体的logit的话,可以固定batchsize =1
output = tf.reshape(tf.stack(axis=1, values=outputs), [-1, size])
softmax_w = tf.get_variable(
"softmax_w", [size, vocab_size], dtype=data_type())
softmax_b = tf.get_variable("softmax_b", [vocab_size], dtype=data_type())
logits = tf.matmul(output, softmax_w) + softmax_b
loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example(
[logits],
[tf.reshape(input_.targets, [-1])],
[tf.ones([batch_size * num_steps], dtype=data_type())])
self._cost = cost = tf.reduce_sum(loss) / batch_size
self._final_state = state
tensorflow world language model的更多相关文章
- language model ——tensorflow 之RNN
代码结构 tf的代码看多了之后就知道其实官方代码的这个结构并不好: graph的构建和训练部分放在了一个文件中,至少也应该分开成model.py和train.py两个文件,model.py中只有一个P ...
- NLP问题特征表达基础 - 语言模型(Language Model)发展演化历程讨论
1. NLP问题简介 0x1:NLP问题都包括哪些内涵 人们对真实世界的感知被成为感知世界,而人们用语言表达出自己的感知视为文本数据.那么反过来,NLP,或者更精确地表达为文本挖掘,则是从文本数据出发 ...
- Sequence Models Week 1 Character level language model - Dinosaurus land
Character level language model - Dinosaurus land Welcome to Dinosaurus Island! 65 million years ago, ...
- Traditional Language Model
Traditional Language Model通常用于回答下述问题: How likely is a string of English words good English ? \(p_{LM ...
- [IR] Tolerant Retrieval & Spelling Correction & Language Model
Dictionary不一定是个list,它可以是多种形式. 放弃Hash的原因: 通常,tree是比较适合的结构. From: http://www.cnblogs.com/v-July-v/arch ...
- 用CNTK搞深度学习 (二) 训练基于RNN的自然语言模型 ( language model )
前一篇文章 用 CNTK 搞深度学习 (一) 入门 介绍了用CNTK构建简单前向神经网络的例子.现在假设读者已经懂得了使用CNTK的基本方法.现在我们做一个稍微复杂一点,也是自然语言挖掘中很火 ...
- A Neural Probabilistic Language Model
A Neural Probabilistic Language Model,这篇论文是Begio等人在2003年发表的,可以说是词表示的鼻祖.在这里给出简要的译文 A Neural Probabili ...
- 论文分享|《Universal Language Model Fine-tuning for Text Classificatio》
https://www.sohu.com/a/233269391_395209 本周我们要分享的论文是<Universal Language Model Fine-tuning for Text ...
- 将迁移学习用于文本分类 《 Universal Language Model Fine-tuning for Text Classification》
将迁移学习用于文本分类 < Universal Language Model Fine-tuning for Text Classification> 2018-07-27 20:07:4 ...
随机推荐
- (Detected problems with API compatibility(visit g.co/dev/appcompat for more info)
在applicaiton里面加载这么一段代码: private void closeAndroidPDialog(){ try { Class aClass = Class.forName(" ...
- 20190404 Oracle忘记登陆密码
记忆力不好,总是忘记Oracle账号的登陆密码 修改方式 Windows cmd 登陆修改后的密码即可
- Unity3D加密保护解决方案
精锐5加密锁支持Unity3D代码及资源保护,并提供授权方案 产品简介 可使用Virbox Protector加壳工具对Unity3D代码进行加密.Unity3D使用开源mono C#语法,代码会编译 ...
- java之数据库连接池-dbcp&c3p0&dbutils
介绍 因为数据库连接对象的创建比较消耗性能,所以可以在应用程序启动时就在内存中开辟一片空间(集合)存放多个数据库连接对象,后面需要连接时直接从该空间中取而不用新创建:使用完毕后归还连接(将连接重新放回 ...
- ORACLE——存储过程
存储过程procedure 被内容来自<oracle从入门到精通——明日科技>一书 存储过程是一种命名的PL/SQL程序快,存储过程被保存在数据库中,它不可以被SQL语句直接执行或调用,只 ...
- nodejs 癞子麻将
'use strict'; var _ = require('lodash'); var quick = require('quick-pomelo'); var P = quick.Promise; ...
- 【软件工程1916|W(福州大学)_助教博客】团队第一次作业成绩公示
题目 第一次作业 评分准则: 队名(最好能够体现项目内容,要求有亮点与个性):(1分) 拟作的团队项目描述:一句话(中英文不限):(1分) 队员风采:介绍每一名队员,包括成员性格.擅长的技术.编程的兴 ...
- sitecore开发入门Sitecore的CRUD操作 - 第一部分
在本文中,讨论如何使用Sitecore.Data.Items.Item并对这些项执行CRUD(创建,读取,更新和删除)操作.我还将介绍如何使用Glass和Fortis类库进行相同的操作,这些操作都是对 ...
- 关于django1.8版本的静态文件配置
环境:Python3.5.4,django1.8.1. 在页面使用js时,总是提示404找不到js文件. 于是,看看了settings文件 好像也没什么毛病.导入的方式也换了很多种,总是不行,于是只好 ...
- java基础hashmap
Iterator中hasNext(), next() 在Iterator类中,我们经常用到两个方法: hasNext(), next(),具体含义: next(), 是返回当前元素, 并指向下一个元 ...