源代码网址: https://github.com/tensorflow/tensorflow/blob/r1.2/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
简书上有一篇此代码的详解,图文并茂,可直接看这篇详解: http://www.jianshu.com/p/f682066f0586 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Basic word2vec example.""" from __future__ import absolute_import
from __future__ import division
from __future__ import print_function import collections
import math
import os
import random
import zipfile import numpy as np
from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf # Step 1: Download the data.
url = 'http://mattmahoney.net/dc/' def maybe_download(filename, expected_bytes):
"""Download a file if not present, and make sure it's the right size."""
if not os.path.exists(filename):
filename, _ = urllib.request.urlretrieve(url + filename, filename)
statinfo = os.stat(filename)
if statinfo.st_size == expected_bytes:
print('Found and verified', filename)
else:
print(statinfo.st_size)
raise Exception(
'Failed to verify ' + filename + '. Can you get to it with a browser?')
return filename filename = maybe_download('text8.zip', 31344016) # Read the data into a list of strings.
def read_data(filename):
"""Extract the first file enclosed in a zip file as a list of words."""
with zipfile.ZipFile(filename) as f:
data = tf.compat.as_str(f.read(f.namelist()[0])).split()
return data vocabulary = read_data(filename)
print('Data size', len(vocabulary)) # Step 2: Build the dictionary and replace rare words with UNK token.
vocabulary_size = 50000

'''
input:
words - the original word list
n_words - the number of used words

output:
data - a list with the same length of input words
every element in the list is the value of the corresponding word in dictionary
or the position in count or dictionary
count - a matrix with n_words rows and two columns,
the first column corresponds to the word,
the second column corresponds to its frequency in input words
the first row in count is ['UNK', *]
the other rows are in descending order of the sencond column
dictionary - key-value map, key is the word, value is its position in count or dictionary
reversed_dictionary - reverse the key-value in dictionary
'''

def build_dataset(words, n_words):
"""Process raw inputs into a dataset."""
count = [['UNK', -1]]
count.extend(collections.Counter(words).most_common(n_words - 1))
dictionary = dict()
for word, _ in count:
dictionary[word] = len(dictionary)
data = list()
unk_count = 0
for word in words:
if word in dictionary:
index = dictionary[word]
else:
index = 0 # dictionary['UNK']
unk_count += 1
data.append(index)
count[0][1] = unk_count
reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
return data, count, dictionary, reversed_dictionary data, count, dictionary, reverse_dictionary = build_dataset(vocabulary,
vocabulary_size)
del vocabulary # Hint to reduce memory.
print('Most common words (+UNK)', count[:5])
print('Sample data', data[:10], [reverse_dictionary[i] for i in data[:10]]) data_index = 0

'''
convert data to batch and labels
the values in batch and labels are the positions of the corresponding words
'''

# Step 3: Function to generate a training batch for the skip-gram model.
def generate_batch(batch_size, num_skips, skip_window):
global data_index
assert batch_size % num_skips == 0
assert num_skips <= 2 * skip_window
batch = np.ndarray(shape=(batch_size), dtype=np.int32)
labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
span = 2 * skip_window + 1 # [ skip_window target skip_window ]
buffer = collections.deque(maxlen=span)
for _ in range(span):
buffer.append(data[data_index])
data_index = (data_index + 1) % len(data)
for i in range(batch_size // num_skips):
target = skip_window # target label at the center of the buffer
targets_to_avoid = [skip_window]
for j in range(num_skips):
while target in targets_to_avoid:
target = random.randint(0, span - 1)
targets_to_avoid.append(target)
batch[i * num_skips + j] = buffer[skip_window]
labels[i * num_skips + j, 0] = buffer[target]
buffer.append(data[data_index])
data_index = (data_index + 1) % len(data)
# Backtrack a little bit to avoid skipping words in the end of a batch
data_index = (data_index + len(data) - span) % len(data)
return batch, labels batch, labels = generate_batch(batch_size=8, num_skips=2, skip_window=1)
for i in range(8):
print(batch[i], reverse_dictionary[batch[i]],
'->', labels[i, 0], reverse_dictionary[labels[i, 0]]) # Step 4: Build and train a skip-gram model. batch_size = 128
embedding_size = 128 # Dimension of the embedding vector.
skip_window = 1 # How many words to consider left and right.
num_skips = 2 # How many times to reuse an input to generate a label. # We pick a random validation set to sample nearest neighbors. Here we limit the
# validation samples to the words that have a low numeric ID, which by
# construction are also the most frequent.
valid_size = 16 # Random set of words to evaluate similarity on.
valid_window = 100 # Only pick dev samples in the head of the distribution.
valid_examples = np.random.choice(valid_window, valid_size, replace=False)
num_sampled = 64 # Number of negative examples to sample. graph = tf.Graph() with graph.as_default(): # Input data.
train_inputs = tf.placeholder(tf.int32, shape=[batch_size])
train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1])
valid_dataset = tf.constant(valid_examples, dtype=tf.int32) # Ops and variables pinned to the CPU because of missing GPU implementation
with tf.device('/cpu:0'):

'''
Generate initial embeddings using random values
the row of the embeddings is same as vocabulary size
the column of the embeddings is the dimension of the embedding vector
each row of the embedding corresponds to the word in count or dictionary with the same row id
The below embed is the embeddings of train_inputs
'''

    # Look up embeddings for inputs.
embeddings = tf.Variable(
tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
embed = tf.nn.embedding_lookup(embeddings, train_inputs) # Construct the variables for the NCE loss
nce_weights = tf.Variable(
tf.truncated_normal([vocabulary_size, embedding_size],
stddev=1.0 / math.sqrt(embedding_size)))
nce_biases = tf.Variable(tf.zeros([vocabulary_size])) # Compute the average NCE loss for the batch.
# tf.nce_loss automatically draws a new sample of the negative labels each
# time we evaluate the loss.
loss = tf.reduce_mean(
tf.nn.nce_loss(weights=nce_weights,
biases=nce_biases,
labels=train_labels,
inputs=embed,
num_sampled=num_sampled,
num_classes=vocabulary_size)) # Construct the SGD optimizer using a learning rate of 1.0.
optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(loss) # Compute the cosine similarity between minibatch examples and all embeddings.
norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True))
normalized_embeddings = embeddings / norm
valid_embeddings = tf.nn.embedding_lookup(
normalized_embeddings, valid_dataset)
similarity = tf.matmul(
valid_embeddings, normalized_embeddings, transpose_b=True) # Add variable initializer.
init = tf.global_variables_initializer() # Step 5: Begin training.
num_steps = 100001 with tf.Session(graph=graph) as session:
# We must initialize all variables before we use them.
init.run()
print('Initialized') average_loss = 0
for step in xrange(num_steps):
batch_inputs, batch_labels = generate_batch(
batch_size, num_skips, skip_window)
feed_dict = {train_inputs: batch_inputs, train_labels: batch_labels} # We perform one update step by evaluating the optimizer op (including it
# in the list of returned values for session.run()
_, loss_val = session.run([optimizer, loss], feed_dict=feed_dict)
average_loss += loss_val if step % 2000 == 0:
if step > 0:
average_loss /= 2000
# The average loss is an estimate of the loss over the last 2000 batches.
print('Average loss at step ', step, ': ', average_loss)
average_loss = 0 # Note that this is expensive (~20% slowdown if computed every 500 steps)
if step % 10000 == 0:
sim = similarity.eval()
for i in xrange(valid_size):
valid_word = reverse_dictionary[valid_examples[i]]
top_k = 8 # number of nearest neighbors
nearest = (-sim[i, :]).argsort()[1:top_k + 1]
log_str = 'Nearest to %s:' % valid_word
for k in xrange(top_k):
close_word = reverse_dictionary[nearest[k]]
log_str = '%s %s,' % (log_str, close_word)
print(log_str)
final_embeddings = normalized_embeddings.eval() # Step 6: Visualize the embeddings. def plot_with_labels(low_dim_embs, labels, filename='tsne.png'):
assert low_dim_embs.shape[0] >= len(labels), 'More labels than embeddings'
plt.figure(figsize=(18, 18)) # in inches
for i, label in enumerate(labels):
x, y = low_dim_embs[i, :]
plt.scatter(x, y)
plt.annotate(label,
xy=(x, y),
xytext=(5, 2),
textcoords='offset points',
ha='right',
va='bottom') plt.savefig(filename) try:
# pylint: disable=g-import-not-at-top
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000)
plot_only = 500
low_dim_embs = tsne.fit_transform(final_embeddings[:plot_only, :])
labels = [reverse_dictionary[i] for i in xrange(plot_only)]
plot_with_labels(low_dim_embs, labels) except ImportError:
print('Please install sklearn, matplotlib, and scipy to show embeddings.')

tensorflow之word2vec_basic代码研究的更多相关文章

  1. dedecms代码研究二

    dedecms代码研究(2)从index开始现在继续,今天讲的主要是dedecms的入口代码.先打开index.PHP看看里面是什么吧.打开根目录下的index.php嗯映入眼帘的是一个if语句.检查 ...

  2. Ningx代码研究.

    概述 研究计划 参与人员 研究文档 学习emiller的文章 熟悉nginx的基本数据结构 nginx 代码的目录结构 nginx简单的数据类型的表示 nginx字符串的数据类型的表示 内存分配相关 ...

  3. 一段markdown编辑器代码研究

    一段markdown编辑器代码研究 说明 代码在 https://github.com/dukeofharen/markdown-editor 之所以选择这个来分析是一方面是因为它的代码结构比较简单, ...

  4. [转载]iOS6新特征:UICollectionView官方使用示例代码研究

    原文地址:iOS6新特征:UICollectionView官方使用示例代码研究作者:浪友dans 注:这里是iOS6新特征汇总贴链接 iOS6新特征:参考资料和示例汇总 这个链接可以学习到UIColl ...

  5. *DataSet序列化,这段代码研究

    DataSet序列化,这段代码研究研究.学习学习. using System; using System.Collections.Generic; using System.Linq; using S ...

  6. CWMP开源代码研究——git代码工程

    原创作品,转载请注明出处,严禁非法转载.如有错误,请留言! email:40879506@qq.com 声明:本系列涉及的开源程序代码学习和研究,严禁用于商业目的. 如有任何问题,欢迎和我交流.(企鹅 ...

  7. 如何使用TensorFlow Hub和代码示例

    任何深度学习框架,为了获得成功,必须提供一系列最先进的模型,以及在流行和广泛接受的数据集上训练的权重,即与训练模型. TensorFlow现在已经提出了一个更好的框架,称为TensorFlow Hub ...

  8. CWMP开源代码研究5——CWMP程序设计思想

    声明:本文涉及的开源程序代码学习和研究,严禁用于商业目的. 如有任何问题,欢迎和我交流.(企鹅号:408797506) 本文介绍自己用过的ACS,其中包括开源版(提供下载包)和商业版(仅提供安装包下载 ...

  9. CWMP开源代码研究2——easycwmp安装和学习

    声明:本文是对开源程序代码学习和研究,严禁用于商业目的. 如有任何问题,欢迎和我交流.(企鹅号:408797506) 本文所有笔记和代码可以到csdn下载:http://download.csdn.n ...

随机推荐

  1. VC.【转】采用_beginthread/_beginthreadex函数创建多线程

    https://blog.csdn.net/cbnotes/article/details/8331632 还可以看这个网址的内容:[多线程]VC6使用_beginthread开启多线程的方法-技术宅 ...

  2. Gradle实现编译差异

    今天开发组长问了这么一个问题,如何实现通过gradle编译动态设置代码里的一些值.可能这么说不太明白,下面说依稀具体需求. 开发中有两个服务器:一个用于测试版本.一个用于线上版本发布,这两个服务器地址 ...

  3. 力扣(LeetCode)832. 翻转图像

    给定一个二进制矩阵 A,我们想先水平翻转图像,然后反转图像并返回结果. 水平翻转图片就是将图片的每一行都进行翻转,即逆序.例如,水平翻转 [1, 1, 0] 的结果是 [0, 1, 1]. 反转图片的 ...

  4. 学习笔记4-pathon的range()函数和list()函数

    使用python的人都知道range()函数很方便,今天再用到他的时候发现了很多以前看到过但是忘记的细节.这里记录一下range(),复习下list的slide,最后分析一个好玩儿的冒泡程序. 这里记 ...

  5. AD10中创建材料清单(BOM表)

    材料清单可以用来作为元件的采购清单,同时也可以用于查看PCB中的元件封装信息是否正确. 操作: Reports----->Bill of Materials

  6. github下载文件和文件夹

    1.建议安装的插件.Octo mate在你的github的单个文件页面会出现download的下载按钮. 2.octo tree 右侧多一个github readme.md的一个菜单. 3.下载文件夹 ...

  7. lua --- Module

    首先需要明白,一般情况下,我们的定义的lua模块的文件与模块名(其实就是table的名字)是一致的,当然,不一致代码也是可以编译的(亲测),之所以这样,本人认为是为了实际项目中管理的方便.以下是定义模 ...

  8. spring cloud: Hystrix(四):feign类似于hystrix的断容器功能:fallback

    spring cloud: Hystrix(四):feign使用hystrix @FeignClient支持回退的概念:fallback方法,这里有点类似于:@HystrixCommand(fallb ...

  9. 最简单的网络图片的爬取 --Pyhon网络爬虫与信息获取

    1.本次要爬取的图片的url http://www.nxl123.cn/static/imgs/php.jpg 2.代码部分 import requestsimport osurl = "h ...

  10. LeetCode--367--有效的完全平方数

    问题描述: 给定一个正整数 num,编写一个函数,如果 num 是一个完全平方数,则返回 True,否则返回 False. 说明:不要使用任何内置的库函数,如  sqrt. 示例 1: 输入:16 输 ...