Tensorflow-hub[例子解析2]
3 基于文本词向量的例子
3.1 创建Module
可以从Tensorflow-hub[例子解析1].中看出,hub相对之前减少了更多的工作量。
首先,假设有词向量文本文件
token1 1.0 2.0 3.0 4.0 5.0
token2 2.0 3.0 4.0 5.0 6.0
该例子就是通过读取该文件去生成TF-Hub Module,可以使用如下命令:
python export.py --embedding_file=/tmp/embedding.txt --export_path=/tmp/module
下面就是export.py的源码,通过跟踪代码中以序号进行注释的部分,可以得知Module的操作过程。
# 惯例导入需要的模块
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import shutil
import sys
import tempfile
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
FLAGS = None
EMBEDDINGS_VAR_NAME = "embeddings"
def parse_line(line):
"""该函数是为了解析./tmp/embedding.txt文件的每一行
Args:
line: (str) One line of the text embedding file.
Returns:
A token string and its embedding vector in floats.
"""
columns = line.split()
token = columns.pop(0)
values = [float(column) for column in columns]
return token, values
def load(file_path, parse_line_fn):
"""该函数是为了将/tmp/embedding.txt解析为numpy对象,并保存在内存中.
Args:
file_path: Path to the text embedding file.
parse_line_fn: callback function to parse each file line.
Returns:
A tuple of (list of vocabulary tokens, numpy matrix of embedding vectors).
Raises:
ValueError: if the data in the sstable is inconsistent.
"""
vocabulary = []
embeddings = []
embeddings_dim = None
for line in tf.gfile.GFile(file_path):
token, embedding = parse_line_fn(line)
if not embeddings_dim:
embeddings_dim = len(embedding)
elif embeddings_dim != len(embedding):
raise ValueError(
"Inconsistent embedding dimension detected, %d != %d for token %s",
embeddings_dim, len(embedding), token)
vocabulary.append(token)
embeddings.append(embedding)
return vocabulary, np.array(embeddings)
''' 该函数展示了如何使用Module '''
def make_module_spec(vocabulary_file, vocab_size, embeddings_dim,
num_oov_buckets, preprocess_text):
"""Makes a module spec to simply perform token to embedding lookups.
Input of this module is a 1-D list of string tokens. For T tokens input and
an M dimensional embedding table, the lookup result is a [T, M] shaped Tensor.
Args:
vocabulary_file: Text file where each line is a key in the vocabulary.
vocab_size: The number of tokens contained in the vocabulary.
embeddings_dim: The embedding dimension.
num_oov_buckets: The number of out-of-vocabulary buckets.
preprocess_text: Whether to preprocess the input tensor by removing
punctuation and splitting on spaces.
Returns:
A module spec object used for constructing a TF-Hub module.
"""
''' 1 - 先创建函数module_fn:
通过tf.placeholder作为输入端占位符并构建整个graph;
调用hub.add_signature()执行类似注册操作'''
def module_fn():
"""Spec function for a token embedding module."""
tokens = tf.placeholder(shape=[None], dtype=tf.string, name="tokens")
embeddings_var = tf.get_variable(
initializer=tf.zeros([vocab_size + num_oov_buckets, embeddings_dim]),
name=EMBEDDINGS_VAR_NAME,
dtype=tf.float32)
lookup_table = tf.contrib.lookup.index_table_from_file(
vocabulary_file=vocabulary_file,
num_oov_buckets=num_oov_buckets,
)
ids = lookup_table.lookup(tokens)
combined_embedding = tf.nn.embedding_lookup(params=embeddings_var, ids=ids)
hub.add_signature("default", {"tokens": tokens},
{"default": combined_embedding})
''' 1 - 这个函数如上面的module_fn是互斥的:
通过tf.placeholder作为输入端占位符并构建整个graph;
调用hub.add_signature()执行类似注册操作 '''
def module_fn_with_preprocessing():
"""Spec function for a full-text embedding module with preprocessing."""
sentences = tf.placeholder(shape=[None], dtype=tf.string, name="sentences")
# Perform a minimalistic text preprocessing by removing punctuation and
# splitting on spaces.
normalized_sentences = tf.regex_replace(
input=sentences, pattern=r"\pP", rewrite="")
tokens = tf.string_split(normalized_sentences, " ")
# In case some of the input sentences are empty before or after
# normalization, we will end up with empty rows. We do however want to
# return embedding for every row, so we have to fill in the empty rows with
# a default.
tokens, _ = tf.sparse_fill_empty_rows(tokens, "")
# In case all of the input sentences are empty before or after
# normalization, we will end up with a SparseTensor with shape [?, 0]. After
# filling in the empty rows we must ensure the shape is set properly to
# [?, 1].
tokens = tf.sparse_reset_shape(tokens)
embeddings_var = tf.get_variable(
initializer=tf.zeros([vocab_size + num_oov_buckets, embeddings_dim]),
name=EMBEDDINGS_VAR_NAME,
dtype=tf.float32)
lookup_table = tf.contrib.lookup.index_table_from_file(
vocabulary_file=vocabulary_file,
num_oov_buckets=num_oov_buckets,
)
sparse_ids = tf.SparseTensor(
indices=tokens.indices,
values=lookup_table.lookup(tokens.values),
dense_shape=tokens.dense_shape)
combined_embedding = tf.nn.embedding_lookup_sparse(
params=embeddings_var,
sp_ids=sparse_ids,
sp_weights=None,
combiner="sqrtn")
hub.add_signature("default", {"sentences": sentences},
{"default": combined_embedding})
''' 2 - 通过调用hub.create_module_spec()创建ModuleSpec对象 '''
if preprocess_text:
return hub.create_module_spec(module_fn_with_preprocessing)
else:
return hub.create_module_spec(module_fn)
def export(export_path, vocabulary, embeddings, num_oov_buckets,
preprocess_text):
"""Exports a TF-Hub module that performs embedding lookups.
Args:
export_path: Location to export the module.
vocabulary: List of the N tokens in the vocabulary.
embeddings: Numpy array of shape [N+K,M] the first N rows are the
M dimensional embeddings for the respective tokens and the next K
rows are for the K out-of-vocabulary buckets.
num_oov_buckets: How many out-of-vocabulary buckets to add.
preprocess_text: Whether to preprocess the input tensor by removing
punctuation and splitting on spaces.
"""
# Write temporary vocab file for module construction.
tmpdir = tempfile.mkdtemp()
vocabulary_file = os.path.join(tmpdir, "tokens.txt")
with tf.gfile.GFile(vocabulary_file, "w") as f:
f.write("\n".join(vocabulary))
vocab_size = len(vocabulary)
embeddings_dim = embeddings.shape[1]
spec = make_module_spec(vocabulary_file, vocab_size, embeddings_dim,
num_oov_buckets, preprocess_text)
try:
''' 3 - 建立tf.Graph(),并使用hub.Module(spec)进行如y=f(x)的操作'''
with tf.Graph().as_default():
m = hub.Module(spec)
# The embeddings may be very large (e.g., larger than the 2GB serialized
# Tensor limit). To avoid having them frozen as constant Tensors in the
# graph we instead assign them through the placeholders and feed_dict
# mechanism.
p_embeddings = tf.placeholder(tf.float32)
load_embeddings = tf.assign(m.variable_map[EMBEDDINGS_VAR_NAME],
p_embeddings)
''' 4 - 建立Session(),进行初始化,训练,迭代等正常操作;最后通过调用module.export(export_path,sess)导出Module'''
with tf.Session() as sess:
sess.run([load_embeddings], feed_dict={p_embeddings: embeddings})
m.export(export_path, sess)
finally:
shutil.rmtree(tmpdir)
def maybe_append_oov_vectors(embeddings, num_oov_buckets):
"""Adds zero vectors for oov buckets if num_oov_buckets > 0.
Since we are assigning zero vectors, adding more that one oov bucket is only
meaningful if we perform fine-tuning.
Args:
embeddings: Embeddings to extend.
num_oov_buckets: Number of OOV buckets in the extended embedding.
"""
num_embeddings = np.shape(embeddings)[0]
embedding_dim = np.shape(embeddings)[1]
embeddings.resize(
[num_embeddings + num_oov_buckets, embedding_dim], refcheck=False)
def export_module_from_file(embedding_file, export_path, parse_line_fn,
num_oov_buckets, preprocess_text):
# Load pretrained embeddings into memory.
vocabulary, embeddings = load(embedding_file, parse_line_fn)
# Add OOV buckets if num_oov_buckets > 0.
maybe_append_oov_vectors(embeddings, num_oov_buckets)
# Export the embedding vectors into a TF-Hub module.
export(export_path, vocabulary, embeddings, num_oov_buckets, preprocess_text)
def main(_):
export_module_from_file(FLAGS.embedding_file, FLAGS.export_path, parse_line,
FLAGS.num_oov_buckets, FLAGS.preprocess_text)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--embedding_file",
type=str,
default=None,
help="Path to file with embeddings.")
parser.add_argument(
"--export_path",
type=str,
default=None,
help="Where to export the module.")
parser.add_argument(
"--preprocess_text",
type=bool,
default=False,
help="Whether to preprocess the input tensor by removing punctuation and "
"splitting on spaces. Use this if input is a dense tensor of untokenized "
"sentences.")
parser.add_argument(
"--num_oov_buckets",
type=int,
default="1",
help="How many OOV buckets to add.")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
从上面创建的例子可以看出,该操作过程与Tensorflow-hub[例子解析1].相似
3.2 使用Module
下面就是使用创建好的Module的代码,这里用了几个test进行测试,通过跟踪下面的序号的注释,可以看出使用也是相当简单
# 导入所需要的模块
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import export
_MOCK_EMBEDDING = "\n".join(
["cat 1.11 2.56 3.45", "dog 1 2 3", "mouse 0.5 0.1 0.6"])
class ExportTokenEmbeddingTest(tf.test.TestCase):
def setUp(self):
self._embedding_file_path = os.path.join(self.get_temp_dir(),
"mock_embedding_file.txt")
with tf.gfile.GFile(self._embedding_file_path, mode="w") as f:
f.write(_MOCK_EMBEDDING)
def testEmbeddingLoaded(self):
vocabulary, embeddings = export.load(self._embedding_file_path,
export.parse_line)
self.assertEqual((3,), np.shape(vocabulary))
self.assertEqual((3, 3), np.shape(embeddings))
def testExportTokenEmbeddingModule(self):
''' 1 - 先调用生成Module的代码,生成一个Module'''
export.export_module_from_file(
embedding_file=self._embedding_file_path,
export_path=self.get_temp_dir(),
parse_line_fn=export.parse_line,
num_oov_buckets=1,
preprocess_text=False)
''' 2 - 创建一个tf.Graph():
调用hub.Module装载Module;
创建tf.Session()进行初始化,和如y=f(x)进行计算得到结果'''
with tf.Graph().as_default():
hub_module = hub.Module(self.get_temp_dir())
tokens = tf.constant(["cat", "lizard", "dog"])
embeddings = hub_module(tokens)
with tf.Session() as session:
session.run(tf.tables_initializer())
session.run(tf.global_variables_initializer())
self.assertAllClose(
session.run(embeddings),
[[1.11, 2.56, 3.45], [0.0, 0.0, 0.0], [1.0, 2.0, 3.0]])
def testExportFulltextEmbeddingModule(self):
''' 1 - 先调用生成Module的代码,生成一个Module'''
export.export_module_from_file(
embedding_file=self._embedding_file_path,
export_path=self.get_temp_dir(),
parse_line_fn=export.parse_line,
num_oov_buckets=1,
preprocess_text=True)
''' 2 - 创建一个tf.Graph():
调用hub.Module装载Module;
创建tf.Session()进行初始化,和如y=f(x)进行计算得到结果'''
with tf.Graph().as_default():
hub_module = hub.Module(self.get_temp_dir())
tokens = tf.constant(["cat", "cat cat", "lizard. dog", "cat? dog", ""])
embeddings = hub_module(tokens)
with tf.Session() as session:
session.run(tf.tables_initializer())
session.run(tf.global_variables_initializer())
self.assertAllClose(
session.run(embeddings),
[[1.11, 2.56, 3.45], [1.57, 3.62, 4.88], [0.70, 1.41, 2.12],
[1.49, 3.22, 4.56], [0.0, 0.0, 0.0]],
rtol=0.02)
if __name__ == "__main__":
tf.test.main()
Tensorflow-hub[例子解析2]的更多相关文章
- Tensorflow-hub[例子解析1]
0. 引言 Tensorflow于1.7之后推出了tensorflow hub,其是一个适合于迁移学习的部分,主要通过将tensorflow的训练好的模型进行模块划分,并可以再次加以利用.不过介于推出 ...
- Poco库网络模块例子解析1-------字典查询
Poco的网络模块在Poco::Net名字空间下定义 下面是字典例子解析 #include "Poco/Net/StreamSocket.h" //流式套接字 #include ...
- 如何使用TensorFlow Hub和代码示例
任何深度学习框架,为了获得成功,必须提供一系列最先进的模型,以及在流行和广泛接受的数据集上训练的权重,即与训练模型. TensorFlow现在已经提出了一个更好的框架,称为TensorFlow Hub ...
- Java字节码例子解析
举个简单的例子: public class Hello { public static void main(String[] args) { String string1 = ...
- Tensorflow之MNIST解析
要说2017年什么技术最火爆,无疑是google领衔的深度学习开源框架Tensorflow.本文简述一下深度学习的入门例子MNIST. 深度学习简单介绍 首先要简单区别几个概念:人工智能,机器学习,深 ...
- tensorflow源码解析之distributed_runtime
本篇主要介绍TF的分布式运行时的基本概念.为了对TF的分布式运行机制有一个大致的了解,我们先结合/tensorflow/core/protobuf中的文件给出对TF分布式集群的初步理解,然后介绍/te ...
- Poco C++库网络模块例子解析2-------HttpServer
//下面程序取自 Poco 库的Net模块例子----HTTPServer 下面开始解析代码 #include "Poco/Net/HTTPServer.h" //继承自TCPSe ...
- Tensorflow ActiveFunction激活函数解析
Active Function 激活函数 原创文章,请勿转载哦~!! 觉得有用的话,欢迎一起讨论相互学习~Follow Me Tensorflow提供了多种激活函数,在CNN中,人们主要是用tf.nn ...
- tensorflow finuetuning 例子
最近研究了下如何使用tensorflow进行finetuning,相比于caffe,tensorflow的finetuning麻烦一些,记录如下: 1.原理 finetuning原理很简单,利用一个在 ...
随机推荐
- JS之用ES6 Promise解决回调地狱(这里以小程序为例)
首先 写一个请求的方法,如: /** * 银行窗口 * 你需要给我提供相关的相关参数我帮你提交到服务器上 * 我会给你一个等待区的编号给你 你去等待区等待,我处理完成会去等待区通知你 * @param ...
- VSCode中怎么改变文件夹的图标
昨天更新了VSCode后我的文件夹图标莫名其妙的没有了,变成了下图这样 看着真的让我难受的头皮发麻,本来打代码就头发少,难道非要让我变成秃头,不可能不可能,所以我找了找怎么解决 来,各位看官上眼 如图 ...
- 【代码笔记】Web-HTML-表格
一,效果图. 二,代码. <!DOCTYPE html> <html> <head> <meta charset="utf-8"> ...
- docker 搭建maven 私服
# 搜索镜像 docker search nexus; #拉取nexus镜像docker pull sonatype/nexus; #运行 -id 创建守护式容器--privileged=true 授 ...
- 《Inside C#》笔记(十五) 非托管代码 上
为了保证向后兼容性,C#和.NET可以通过非托管的方式运行旧代码.非托管代码是指没有被.NET运行时管控的代码.非托管代码主要包括:平台调用服务(PlatformInvocation Services ...
- Fiddler抓包使用教程-模拟低速网络环境
转载请标明出处:http://blog.csdn.net/zhaoyanjun6/article/details/73467267 本文出自[赵彦军的博客] 在无线测试中,网络测试是必不可少的环节,通 ...
- matlab练习程序(FAST特征点检测)
算法思想:如果一个像素与它邻域的像素差别较大(过亮或过暗) , 那它更可能是角点. 算法步骤: 1.上图所示,一个以像素p为中心,半径为3的圆上,有16个像素点(p1.p2.....p16). 2.定 ...
- maven管理项目的特点
Maven介绍 我们在开发项目的过程中,会使用一些开源框架.第三方的工具等等,这些都是以jar包的方式被项目所引用,并且有些jar包还会依赖其他的jar包,我们同样需要添加到项目中,所有这些相关的ja ...
- vue-cli在控制台创建vue项目时乱码的问题
新装的win10系统,使用vue-cli在控制台创建项目时出现乱码,请问如何处理? 解决: 打开cmd,在控制台输入CHCP 65001,按回车键即可将编码格式设成utf-8,再创建就不会乱码了. 执 ...
- 【PAT】B1080 MOOC期终成绩(25 分)
还是c++好用,三部分输入直接用相同的方法, 用map映射保存学生在结构体数组中的下标. 结构体保存学生信息,其中期末成绩直接初始化为-1, 注意四舍五入 此题还算简单 #include<ios ...