Tensorflow --BeamSearch
github:https://github.com/zle1992/Seq2Seq-Chatbot
1、 注意在infer阶段,需要需要reuse,
2、If you are using the BeamSearchDecoder with a cell wrapped in AttentionWrapper, then you must ensure that:
- The encoder output has been tiled to
beam_widthviatf.contrib.seq2seq.tile_batch(NOTtf.tile). - The
batch_sizeargument passed to thezero_statemethod of this wrapper is equal totrue_batch_size * beam_width. - The initial state created with
zero_stateabove contains acell_statevalue containing properly tiled final state from the encoder.
import tensorflow as tf
from tensorflow.python.layers.core import Dense BEAM_WIDTH = 5
BATCH_SIZE = 128 # INPUTS
X = tf.placeholder(tf.int32, [BATCH_SIZE, None])
Y = tf.placeholder(tf.int32, [BATCH_SIZE, None])
X_seq_len = tf.placeholder(tf.int32, [BATCH_SIZE])
Y_seq_len = tf.placeholder(tf.int32, [BATCH_SIZE]) # ENCODER
encoder_out, encoder_state = tf.nn.dynamic_rnn(
cell = tf.nn.rnn_cell.BasicLSTMCell(128),
inputs = tf.contrib.layers.embed_sequence(X, 10000, 128),
sequence_length = X_seq_len,
dtype = tf.float32) # DECODER COMPONENTS
Y_vocab_size = 10000
decoder_embedding = tf.Variable(tf.random_uniform([Y_vocab_size, 128], -1.0, 1.0))
projection_layer = Dense(Y_vocab_size) # ATTENTION (TRAINING)
with tf.variable_scope('shared_attention_mechanism'):
attention_mechanism = tf.contrib.seq2seq.LuongAttention(
num_units = 128,
memory = encoder_out,
memory_sequence_length = X_seq_len) decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
cell = tf.nn.rnn_cell.BasicLSTMCell(128),
attention_mechanism = attention_mechanism,
attention_layer_size = 128) # DECODER (TRAINING)
training_helper = tf.contrib.seq2seq.TrainingHelper(
inputs = tf.nn.embedding_lookup(decoder_embedding, Y),
sequence_length = Y_seq_len,
time_major = False)
training_decoder = tf.contrib.seq2seq.BasicDecoder(
cell = decoder_cell,
helper = training_helper,
initial_state = decoder_cell.zero_state(BATCH_SIZE,tf.float32).clone(cell_state=encoder_state),
output_layer = projection_layer)
with tf.variable_scope('decode_with_shared_attention'):
training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
decoder = training_decoder,
impute_finished = True,
maximum_iterations = tf.reduce_max(Y_seq_len))
training_logits = training_decoder_output.rnn_output # BEAM SEARCH TILE
encoder_out = tf.contrib.seq2seq.tile_batch(encoder_out, multiplier=BEAM_WIDTH)
X_seq_len = tf.contrib.seq2seq.tile_batch(X_seq_len, multiplier=BEAM_WIDTH)
encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=BEAM_WIDTH) # ATTENTION (PREDICTING)
with tf.variable_scope('shared_attention_mechanism', reuse=True):
attention_mechanism = tf.contrib.seq2seq.LuongAttention(
num_units = 128,
memory = encoder_out,
memory_sequence_length = X_seq_len) decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
cell = tf.nn.rnn_cell.BasicLSTMCell(128),
attention_mechanism = attention_mechanism,
attention_layer_size = 128) # DECODER (PREDICTING)
predicting_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
cell = decoder_cell,
embedding = decoder_embedding,
start_tokens = tf.tile(tf.constant([1], dtype=tf.int32), [BATCH_SIZE]),
end_token = 2,
initial_state = decoder_cell.zero_state(BATCH_SIZE * BEAM_WIDTH,tf.float32).clone(cell_state=encoder_state),
beam_width = BEAM_WIDTH,
output_layer = projection_layer,
length_penalty_weight = 0.0)
with tf.variable_scope('decode_with_shared_attention', reuse=True):
predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
decoder = predicting_decoder,
impute_finished = False,
maximum_iterations = 2 * tf.reduce_max(Y_seq_len))
predicting_logits = predicting_decoder_output.predicted_ids[:, :, 0] print('successful')
参考:
https://gist.github.com/higepon/eb81ba0f6663a57ff1908442ce753084
https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/BeamSearchDecoder
https://github.com/tensorflow/nmt#beam-search
Tensorflow --BeamSearch的更多相关文章
- tensorflow 笔记13:了解机器翻译,google NMT,Attention
一.关于Attention,关于NMT 未完待续... 以google 的 nmt 代码引入 探讨下端到端: 项目地址:https://github.com/tensorflow/nmt 机器翻译算是 ...
- Effective Tensorflow[转]
Effective TensorFlow Table of Contents TensorFlow Basics Understanding static and dynamic shapes Sco ...
- Tensorflow 官方版教程中文版
2015年11月9日,Google发布人工智能系统TensorFlow并宣布开源,同日,极客学院组织在线TensorFlow中文文档翻译.一个月后,30章文档全部翻译校对完成,上线并提供电子书下载,该 ...
- tensorflow学习笔记二:入门基础
TensorFlow用张量这种数据结构来表示所有的数据.用一阶张量来表示向量,如:v = [1.2, 2.3, 3.5] ,如二阶张量表示矩阵,如:m = [[1, 2, 3], [4, 5, 6], ...
- 用Tensorflow让神经网络自动创造音乐
#————————————————————————本文禁止转载,禁止用于各类讲座及ppt中,违者必究————————————————————————# 前几天看到一个有意思的分享,大意是讲如何用Ten ...
- tensorflow 一些好的blog链接和tensorflow gpu版本安装
pading :SAME,VALID 区别 http://blog.csdn.net/mao_xiao_feng/article/details/53444333 tensorflow实现的各种算法 ...
- tensorflow中的基本概念
本文是在阅读官方文档后的一些个人理解. 官方文档地址:https://www.tensorflow.org/versions/r0.12/get_started/basic_usage.html#ba ...
- kubernetes&tensorflow
谷歌内部--Borg Google Brain跑在数十万台机器上 谷歌电商商品分类深度学习模型跑在1000+台机器上 谷歌外部--Kubernetes(https://github.com/kuber ...
- tensorflow学习
tensorflow安装时遇到gcc: error trying to exec 'as': execvp: No such file or directory. 截止到2016年11月13号,源码编 ...
随机推荐
- asp.net 页面生命周期事件详细
(1)请求页面:页请求发生在页生命周期开始之前. (2)开始:在开始阶段,将设置页属性,如Request和Response.在此阶段,页还将确定请求是回发请求还是新请求,并设置IsPostBack属性 ...
- viewport的故事(一)
部分翻译 自原文 https://www.quirksmode.org/mobile/viewports.html 概念:设备像素和CSS像素 设备像素可以通过 screen.width/he ...
- oldboy es和logstash
logstash: input:https://www.elastic.co/guide/en/logstash/current/input-plugins.html input { file { p ...
- 团队项目-课程MS需求分析心得
我们的课程管理项目需求讲道理其实应该是比较简单的,但是在经过几次和老师讨论过后,项目需求已经多得让人脑门疼,后来继续跟老师聊,老师嘴上说着减减减,但是每次讨论下来需求还是会变得更多,以致于个人已经不再 ...
- java.text.DateFormat 线程不安全问题
java.text下的 DateFormat 是线程不安全的: 建议1: 1.使用threadLocal包装DateFormat(太复杂,不推荐) 2.使用org.apache.commons.lan ...
- 我了解到的新知识之----如何使用Python获取最新外汇汇率信息
这个需求本来是来源于公司同事工作中需求,用户需要使用数据分析工具Power BI抓取多页的中国银行官网上当天的外汇数据.但是没能研究出来. 我就开始在网络上找关于使用python来抓取当天汇率的案例分 ...
- Oracle JDK 1.8 openJDK11 定制化JDK
小结: 1. https://mp.weixin.qq.com/s/4rkgisFRJxokXZ4lyFXujw 京东JDK在大数据平台的探索与研究 臧琳 亿级流量网站架构 3月11日
- kubernetes in action - Volumes
Volume解决Kubernetes的存储的问题 对于Pod使用的存储,抽象为volume,volume伴随着Pod的创建而创建,消失而同时消失,不能单独的创建 这样的好处,是存储的塑胶不会因为某个c ...
- js基础--数据类型
1,数据类型 Number:包括小数与整数,负数,NaN ,Infinity无限大String字符串:‘abc’Boolean布尔值:true or falsenull 空undefined 未定义 ...
- react将字符串转义成html语句
在使用reactjs库的时候,会遇到将一段html的字符串,然后要将它插入页面中以html的形式展现,然而直接插入的话页面显示的就是这段字符串,而不会进行转义,可以用以下方法插入,便可以html的形式 ...