github:https://github.com/zle1992/Seq2Seq-Chatbot

1、 注意在infer阶段,需要需要reuse,

2、If you are using the BeamSearchDecoder with a cell wrapped in AttentionWrapper, then you must ensure that:

  • The encoder output has been tiled to beam_width via tf.contrib.seq2seq.tile_batch (NOT tf.tile).
  • The batch_size argument passed to the zero_state method of this wrapper is equal to true_batch_size * beam_width.
  • The initial state created with zero_state above contains a cell_state value containing properly tiled final state from the encoder.
 import tensorflow as tf
from tensorflow.python.layers.core import Dense BEAM_WIDTH = 5
BATCH_SIZE = 128 # INPUTS
X = tf.placeholder(tf.int32, [BATCH_SIZE, None])
Y = tf.placeholder(tf.int32, [BATCH_SIZE, None])
X_seq_len = tf.placeholder(tf.int32, [BATCH_SIZE])
Y_seq_len = tf.placeholder(tf.int32, [BATCH_SIZE]) # ENCODER
encoder_out, encoder_state = tf.nn.dynamic_rnn(
cell = tf.nn.rnn_cell.BasicLSTMCell(128),
inputs = tf.contrib.layers.embed_sequence(X, 10000, 128),
sequence_length = X_seq_len,
dtype = tf.float32) # DECODER COMPONENTS
Y_vocab_size = 10000
decoder_embedding = tf.Variable(tf.random_uniform([Y_vocab_size, 128], -1.0, 1.0))
projection_layer = Dense(Y_vocab_size) # ATTENTION (TRAINING)
with tf.variable_scope('shared_attention_mechanism'):
attention_mechanism = tf.contrib.seq2seq.LuongAttention(
num_units = 128,
memory = encoder_out,
memory_sequence_length = X_seq_len) decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
cell = tf.nn.rnn_cell.BasicLSTMCell(128),
attention_mechanism = attention_mechanism,
attention_layer_size = 128) # DECODER (TRAINING)
training_helper = tf.contrib.seq2seq.TrainingHelper(
inputs = tf.nn.embedding_lookup(decoder_embedding, Y),
sequence_length = Y_seq_len,
time_major = False)
training_decoder = tf.contrib.seq2seq.BasicDecoder(
cell = decoder_cell,
helper = training_helper,
initial_state = decoder_cell.zero_state(BATCH_SIZE,tf.float32).clone(cell_state=encoder_state),
output_layer = projection_layer)
with tf.variable_scope('decode_with_shared_attention'):
training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
decoder = training_decoder,
impute_finished = True,
maximum_iterations = tf.reduce_max(Y_seq_len))
training_logits = training_decoder_output.rnn_output # BEAM SEARCH TILE
encoder_out = tf.contrib.seq2seq.tile_batch(encoder_out, multiplier=BEAM_WIDTH)
X_seq_len = tf.contrib.seq2seq.tile_batch(X_seq_len, multiplier=BEAM_WIDTH)
encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=
BEAM_WIDTH) # ATTENTION (PREDICTING)
with tf.variable_scope('shared_attention_mechanism', reuse=True):
attention_mechanism = tf.contrib.seq2seq.LuongAttention(
num_units = 128,
memory = encoder_out,
memory_sequence_length = X_seq_len) decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
cell = tf.nn.rnn_cell.BasicLSTMCell(128),
attention_mechanism = attention_mechanism,
attention_layer_size = 128) # DECODER (PREDICTING)
predicting_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
cell = decoder_cell,
embedding = decoder_embedding,
start_tokens = tf.tile(tf.constant([1], dtype=tf.int32), [BATCH_SIZE]),
end_token = 2,
initial_state = decoder_cell.zero_state(BATCH_SIZE * BEAM_WIDTH,tf.float32).clone(cell_state=encoder_state),
beam_width = BEAM_WIDTH,
output_layer = projection_layer,
length_penalty_weight = 0.0)
with tf.variable_scope('decode_with_shared_attention', reuse=True):
predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
decoder = predicting_decoder,
impute_finished = False,
maximum_iterations = 2 * tf.reduce_max(Y_seq_len))
predicting_logits = predicting_decoder_output.predicted_ids[:, :, 0] print('successful')

参考:

https://gist.github.com/higepon/eb81ba0f6663a57ff1908442ce753084

https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/BeamSearchDecoder

https://github.com/tensorflow/nmt#beam-search

Tensorflow --BeamSearch的更多相关文章

  1. tensorflow 笔记13:了解机器翻译,google NMT,Attention

    一.关于Attention,关于NMT 未完待续... 以google 的 nmt 代码引入 探讨下端到端: 项目地址:https://github.com/tensorflow/nmt 机器翻译算是 ...

  2. Effective Tensorflow[转]

    Effective TensorFlow Table of Contents TensorFlow Basics Understanding static and dynamic shapes Sco ...

  3. Tensorflow 官方版教程中文版

    2015年11月9日,Google发布人工智能系统TensorFlow并宣布开源,同日,极客学院组织在线TensorFlow中文文档翻译.一个月后,30章文档全部翻译校对完成,上线并提供电子书下载,该 ...

  4. tensorflow学习笔记二:入门基础

    TensorFlow用张量这种数据结构来表示所有的数据.用一阶张量来表示向量,如:v = [1.2, 2.3, 3.5] ,如二阶张量表示矩阵,如:m = [[1, 2, 3], [4, 5, 6], ...

  5. 用Tensorflow让神经网络自动创造音乐

    #————————————————————————本文禁止转载,禁止用于各类讲座及ppt中,违者必究————————————————————————# 前几天看到一个有意思的分享,大意是讲如何用Ten ...

  6. tensorflow 一些好的blog链接和tensorflow gpu版本安装

    pading :SAME,VALID 区别  http://blog.csdn.net/mao_xiao_feng/article/details/53444333 tensorflow实现的各种算法 ...

  7. tensorflow中的基本概念

    本文是在阅读官方文档后的一些个人理解. 官方文档地址:https://www.tensorflow.org/versions/r0.12/get_started/basic_usage.html#ba ...

  8. kubernetes&tensorflow

    谷歌内部--Borg Google Brain跑在数十万台机器上 谷歌电商商品分类深度学习模型跑在1000+台机器上 谷歌外部--Kubernetes(https://github.com/kuber ...

  9. tensorflow学习

    tensorflow安装时遇到gcc: error trying to exec 'as': execvp: No such file or directory. 截止到2016年11月13号,源码编 ...

随机推荐

  1. 2017-2018 ACM-ICPC Nordic Collegiate Programming Contest (NCPC 2017)

    A. Airport Coffee 设$f_i$表示考虑前$i$个咖啡厅,且在$i$处买咖啡的最小时间,通过单调队列优化转移. 时间复杂度$O(n)$. #include<cstdio> ...

  2. python学习:条件语句if、else

    条件语句: 1.if...else...; 2.if...elif...esle 举例: 1.if...else... “age_of_princal = 56   guess_age = int(i ...

  3. Tornado-Ajax

    介绍 AJAX = Asynchronous JavaScript and XML(异步的 JavaScript 和 XML).AJAX 不是新的编程语言,而是一种使用现有标准的新方法.AJAX是在不 ...

  4. react_app 项目开发 (5)_前后端分离_后台管理系统_开始

    项目描述 技术选型 react API 接口 接口文档,url,请求方式,参数类型, 根据文档描述的方法,进行 postman 测试,看是否能够得到理想的结果 collections - 创建文件取项 ...

  5. Spring 依赖注入中 Field 注入的有害性

    大致分为:Field 注入.构造注入.setter 注入 其中 Field 注入被认为有害的: 1. 违反了单一原则 当一个 class 中有多个依赖时,如果仅仅使用 Field 注入,则看不出有很多 ...

  6. Viterbi algorithm

    HMM(隐马尔可夫模型)是用来描述隐含未知参数的统计模型,是一个关于时序的概率模型,它描述了一个由隐藏的马尔可夫链生成状态序列,再由状态序列生成观测序列的过程.其中,状态之间的转换以及观测序列和状态序 ...

  7. LeetCode 538 Convert BST to Greater Tree 解题报告

    题目要求 Given a Binary Search Tree (BST), convert it to a Greater Tree such that every key of the origi ...

  8. WSL(Windows Subsystem for Linux)笔记一安装与使用

    1.安装linux子系统 很简单直接在启动或关闭windows功能 中选择“适用于linux的windows子系统”,确定安装后重启即可,安装还是比较快的只用了几分钟. 也可以直接使用shell命令行 ...

  9. netstat lsof ps 常用场景

    1.netstat 命令可以帮助检查本机的网络状况实战应用1:公司内部的一个老服务运行在192.168.1.1:50060上,服务将于一周之后停用,再在要查一下本机上有没有进行在调用该服务.[root ...

  10. Java注解--笔记

    @Override标签的作用@Override是伪代码,所以是可写可不写的.它表示方法重写,写上会给我们带来好处. 1.可以当注释用,方便阅读. 2.告诉阅读你代码的人,这是方法的复写. 3.编译器可 ...