import sys
import codecs
import tensorflow as tf # 1.参数设置。
# 读取checkpoint的路径。9000表示是训练程序在第9000步保存的checkpoint。
CHECKPOINT_PATH = "F:\\temp\\seq2seq_ckpt-9000" # 模型参数。必须与训练时的模型参数保持一致。
HIDDEN_SIZE = 1024 # LSTM的隐藏层规模。
NUM_LAYERS = 2 # 深层循环神经网络中LSTM结构的层数。
SRC_VOCAB_SIZE = 10000 # 源语言词汇表大小。
TRG_VOCAB_SIZE = 4000 # 目标语言词汇表大小。
SHARE_EMB_AND_SOFTMAX = True # 在Softmax层和词向量层之间共享参数。 # 词汇表文件
SRC_VOCAB = "F:\\TensorFlowGoogle\\201806-github\\TensorFlowGoogleCode\\Chapter09\\en.vocab"
TRG_VOCAB = "F:\\TensorFlowGoogle\\201806-github\\TensorFlowGoogleCode\\Chapter09\\zh.vocab" # 词汇表中<sos>和<eos>的ID。在解码过程中需要用<sos>作为第一步的输入,并将检查
# 是否是<eos>,因此需要知道这两个符号的ID。
SOS_ID = 1
EOS_ID = 2
# 2.定义NMT模型和解码步骤。
# 定义NMTModel类来描述模型。
class NMTModel(object):
# 在模型的初始化函数中定义模型要用到的变量。
def __init__(self):
# 定义编码器和解码器所使用的LSTM结构。
self.enc_cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)for _ in range(NUM_LAYERS)])
self.dec_cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE) for _ in range(NUM_LAYERS)]) # 为源语言和目标语言分别定义词向量。
self.src_embedding = tf.get_variable("src_emb", [SRC_VOCAB_SIZE, HIDDEN_SIZE])
self.trg_embedding = tf.get_variable("trg_emb", [TRG_VOCAB_SIZE, HIDDEN_SIZE]) # 定义softmax层的变量
if SHARE_EMB_AND_SOFTMAX:
self.softmax_weight = tf.transpose(self.trg_embedding)
else:
self.softmax_weight = tf.get_variable("weight", [HIDDEN_SIZE, TRG_VOCAB_SIZE])
self.softmax_bias = tf.get_variable("softmax_bias", [TRG_VOCAB_SIZE]) def inference(self, src_input):
# 虽然输入只有一个句子,但因为dynamic_rnn要求输入是batch的形式,因此这里
# 将输入句子整理为大小为1的batch。
src_size = tf.convert_to_tensor([len(src_input)], dtype=tf.int32)
src_input = tf.convert_to_tensor([src_input], dtype=tf.int32)
src_emb = tf.nn.embedding_lookup(self.src_embedding, src_input) # 使用dynamic_rnn构造编码器。这一步与训练时相同。
with tf.variable_scope("encoder"):
enc_outputs, enc_state = tf.nn.dynamic_rnn(self.enc_cell, src_emb, src_size, dtype=tf.float32) # 设置解码的最大步数。这是为了避免在极端情况出现无限循环的问题。
MAX_DEC_LEN=100 with tf.variable_scope("decoder/rnn/multi_rnn_cell"):
# 使用一个变长的TensorArray来存储生成的句子。
init_array = tf.TensorArray(dtype=tf.int32, size=0,dynamic_size=True, clear_after_read=False)
# 填入第一个单词<sos>作为解码器的输入。
init_array = init_array.write(0, SOS_ID)
# 构建初始的循环状态。循环状态包含循环神经网络的隐藏状态,保存生成句子的
# TensorArray,以及记录解码步数的一个整数step。
init_loop_var = (enc_state, init_array, 0) # tf.while_loop的循环条件:
# 循环直到解码器输出<eos>,或者达到最大步数为止。
def continue_loop_condition(state, trg_ids, step):
return tf.reduce_all(tf.logical_and(tf.not_equal(trg_ids.read(step), EOS_ID),tf.less(step, MAX_DEC_LEN-1))) def loop_body(state, trg_ids, step):
# 读取最后一步输出的单词,并读取其词向量。
trg_input = [trg_ids.read(step)]
trg_emb = tf.nn.embedding_lookup(self.trg_embedding,trg_input)
# 这里不使用dynamic_rnn,而是直接调用dec_cell向前计算一步。
dec_outputs, next_state = self.dec_cell.call(state=state, inputs=trg_emb)
# 计算每个可能的输出单词对应的logit,并选取logit值最大的单词作为
# 这一步的而输出。
output = tf.reshape(dec_outputs, [-1, HIDDEN_SIZE])
logits = (tf.matmul(output, self.softmax_weight)+ self.softmax_bias)
next_id = tf.argmax(logits, axis=1, output_type=tf.int32)
# 将这一步输出的单词写入循环状态的trg_ids中。
trg_ids = trg_ids.write(step+1, next_id[0])
return next_state, trg_ids, step+1 # 执行tf.while_loop,返回最终状态。
state, trg_ids, step = tf.while_loop(continue_loop_condition, loop_body, init_loop_var)
return trg_ids.stack()
def main():
# 定义训练用的循环神经网络模型。
with tf.variable_scope("nmt_model", reuse=None):
model = NMTModel() # 定义个测试句子。
test_en_text = "This is a test . <eos>"
print(test_en_text) # 根据英文词汇表,将测试句子转为单词ID。
with codecs.open(SRC_VOCAB, "r", "utf-8") as f_vocab:
src_vocab = [w.strip() for w in f_vocab.readlines()]
src_id_dict = dict((src_vocab[x], x) for x in range(len(src_vocab)))
test_en_ids = [(src_id_dict[token] if token in src_id_dict else src_id_dict['<unk>'])
for token in test_en_text.split()]
print(test_en_ids) # 建立解码所需的计算图。
output_op = model.inference(test_en_ids)
sess = tf.Session()
saver = tf.train.Saver()
saver.restore(sess, CHECKPOINT_PATH) # 读取翻译结果。
output_ids = sess.run(output_op)
print(output_ids) # 根据中文词汇表,将翻译结果转换为中文文字。
with codecs.open(TRG_VOCAB, "r", "utf-8") as f_vocab:
trg_vocab = [w.strip() for w in f_vocab.readlines()]
output_text = ''.join([trg_vocab[x] for x in output_ids]) # 输出翻译结果。
print(output_text.encode('utf8').decode(sys.stdout.encoding))
sess.close() if __name__ == "__main__":
main()

吴裕雄--天生自然 pythonTensorFlow自然语言处理:Seq2Seq模型--测试的更多相关文章

  1. 吴裕雄--天生自然 pythonTensorFlow自然语言处理:Seq2Seq模型--训练

    import tensorflow as tf # 1.参数设置. # 假设输入数据已经用9.2.1小节中的方法转换成了单词编号的格式. SRC_TRAIN_DATA = "F:\\Tens ...

  2. 吴裕雄--天生自然 pythonTensorFlow自然语言处理:Attention模型--训练

    import tensorflow as tf # 1.参数设置. # 假设输入数据已经转换成了单词编号的格式. SRC_TRAIN_DATA = "F:\\TensorFlowGoogle ...

  3. 吴裕雄--天生自然 pythonTensorFlow自然语言处理:Attention模型--测试

    import sys import codecs import tensorflow as tf # 1.参数设置. # 读取checkpoint的路径.9000表示是训练程序在第9000步保存的ch ...

  4. 吴裕雄--天生自然 pythonTensorFlow自然语言处理:PTB 语言模型

    import numpy as np import tensorflow as tf # 1.设置参数. TRAIN_DATA = "F:\TensorFlowGoogle\\201806- ...

  5. 吴裕雄--天生自然 pythonTensorFlow自然语言处理:文本数据预处理--生成训练文件

    import sys import codecs # 1. 参数设置 MODE = "PTB_TRAIN" # 将MODE设置为"PTB_TRAIN", &qu ...

  6. 吴裕雄--天生自然 pythonTensorFlow自然语言处理:交叉熵损失函数

    import tensorflow as tf # 1. sparse_softmax_cross_entropy_with_logits样例. # 假设词汇表的大小为3, 语料包含两个单词" ...

  7. 吴裕雄--天生自然 pythonTensorFlow图形数据处理:循环神经网络预测正弦函数

    import numpy as np import tensorflow as tf import matplotlib.pyplot as plt # 定义RNN的参数. HIDDEN_SIZE = ...

  8. 吴裕雄--天生自然 pythonTensorFlow图形数据处理:数据集高层操作

    import tempfile import tensorflow as tf # 1. 列举输入文件. # 输入数据生成的训练和测试数据. train_files = tf.train.match_ ...

  9. 吴裕雄--天生自然 pythonTensorFlow图形数据处理:数据集基本使用方法

    import tempfile import tensorflow as tf # 1. 从数组创建数据集. input_data = [1, 2, 3, 5, 8] dataset = tf.dat ...

随机推荐

  1. PHP实现简单的双色球机选号码

    <?php header('Content-Type: text/html; charset=utf-8'); //PHP实现双色球机选号码 $red = range(1, 33);//初次设定 ...

  2. 使用 prototype 定义方法和属性

    除了可以在类的构造器方法中定义方法和属性外,也可以使用 prototype 定义方法和属性.每个类都有这个属性,该属性是一个静态属性,因此无需实例化,只需使用类引用该属性即可. 1.1 使用 prot ...

  3. python counter、闭包、generator、解数学方程、异常

    1.counter 2.闭包 3.generator 4.解数学方程 5.异常 1.python库——counter from collections import Counter breakfast ...

  4. ORACLE增删改查以及case when的基本用法

    1.创建table create table test01( id int not null primary key, name ) not null, gender ) not null, age ...

  5. Html 常见标签,css基础样式,css的浮动和清流,浏览器的兼容

    1.html模板<!DOCTYPE html><html><head> <meta charset="UTF-8"> <tit ...

  6. php优惠券生成-去重

    记录一次优惠券生成-去重 方法一 /** * 生成批量礼品消费券 */ public function giftCardAddOp() { //接收get值 $num = $_GET['gift_nu ...

  7. CSS3新特性—过渡、转换

    过渡 转换 2D转换 2D转换包括四个方面:位移,缩放,旋转,倾斜 位移[让元素移动位置] transform: translate(100px,100px); 备注: 1. 如果只设置一个值,那么代 ...

  8. 吴裕雄--天生自然 JAVASCRIPT开发学习: JSON

    <!DOCTYPE html> <html> <head> <meta charset="utf-8"> <title> ...

  9. vscode 集成git bash, mingw, mintty 的terminal

    设置 右上角打开json文件的设置 输入以下代码: "terminal.external.windowsExec": "D:\\Program Files\\Git\\b ...

  10. MySQL--InnoDB 体系架构

    InnoDB 体系架构 后台线程 Master Thread Master Thread 是一个非常核心的后台线程,主要负责将缓冲池中的数据异步刷新到磁盘,保证数据的一致性,包括脏页的刷新.合并插入缓 ...