词向量嵌入需要高效率处理大规模文本语料库。word2vec。简单方式,词送入独热编码(one-hot encoding)学习系统,长度为词汇表长度的向量,词语对应位置元素为1,其余元素为0。向量维数很高,无法刻画不同词语的语义关联。共生关系(co-occurrence)表示单词,解决语义关联,遍历大规模文本语料库,统计每个单词一定距离范围内的周围词汇,用附近词汇规范化数量表示每个词语。类似语境中词语语义相似。用PCA或类似方法降维出现向量(occurrence vector),得到更稠密表示。性能好,追踪所有词汇共生矩阵,宽度、高度为词汇表长度。2013年,Mikolov、Tomas等提出上下文计算词表示方法,《Efficient estimation of word representations in vector space》(arXiv preprint arXiv:1301.3781(2013))。skip-gram模型,从随机表示开始,依据当前词语预测上下文词语简单分类器,误差通过分类器权值和词表示传播,对两者调整减少预测误差。大规模语料库训练模型表示赂量逼近压缩后共生向量。

数据集, 英文维基百科转储文件包含所有页面完整修订历史,当前页面版本100GB,https://dumps.wikimedia.org/backup-index.html。

下载转储文件,提取页面词语。统计词语出现次数,构建常见词汇表。用词汇表对提取页面编码。逐行读取文件,结果立即写入磁盘。在不同步骤间保存检查点,避免程序崩溃重来。

__iter__遍历词语索引列表页面。encode获取字符串词语词汇索引。decode依据词汇索引返回字符串词语。_read_pages从维基百科转储文件(压缩XML)提取单词,保存到页面文件,每个页面一行空格分隔的单词。bz2模块open函数读取文件。中间结果压缩处理。正则表达式捕捉任意连续字母序列或单独特殊字母。_build_vocabulary统计页面文件单词数,出现频率高词语写入文件。独热编码需要词汇表。词汇表索引编码。移除拼写错误、极不常见词语,词汇表只包含vocabulary_size - 1个最常见词语。所有不在词汇表词语<unk>标记,未出现单词词向量。

动态形成训练样本,组织到大批数据,分类器不占大量内存。skip-gram模型预测当前词语的上下文词语。遍历文本,当前词语数据,周围词语目标,创建训练样本。上下文尺寸R,每个单词生成2R样本,当前词左右各R个词。语义上下文,距离近重要,尽量少创建远上下文词语训练样本,范围[1,D=10]随机选择词上下文尺寸。依据skip-gram模型形成训练对。Numpy数组生成数值流批数据。

初始,单词表示为随机向量。分类器根据中层表示预测上下文单词当前表示。传播误差,微调权值、输入单词表示。MomentumOptimizer 模型优化,智能不足,效率高。

分类器是模型核心。噪声对比估计损失(noisecontrastive estimation loss)性能优异。softmax分类器建模。tf.nn.nce_loss 新随机向量负样本(对比样本),近似softmax分类器。

训练模型结束,最终词向量写入文件。维基百科语料库子集,普通CPU训练5小时,得到NumPy数组嵌入表示。完整语料库: https://dumps.wikimedia.org/enwiki/20160501/enwiki-20160501-pages-meta-current.xml.bz2 。AttrDict类等价Python dict,键可属性访问。

import bz2
import collections
import os
import re
from lxml import etree
from helpers import download
class Wikipedia:
    TOKEN_REGEX = re.compile(r'[A-Za-z]+|[!?.:,()]')
    def __init__(self, url, cache_dir, vocabulary_size=10000):
        self._cache_dir = os.path.expanduser(cache_dir)
        self._pages_path = os.path.join(self._cache_dir, 'pages.bz2')
        self._vocabulary_path = os.path.join(self._cache_dir, 'vocabulary.bz2')
        if not os.path.isfile(self._pages_path):
            print('Read pages')
            self._read_pages(url)
        if not os.path.isfile(self._vocabulary_path):
            print('Build vocabulary')
            self._build_vocabulary(vocabulary_size)
        with bz2.open(self._vocabulary_path, 'rt') as vocabulary:
            print('Read vocabulary')
            self._vocabulary = [x.strip() for x in vocabulary]
        self._indices = {x: i for i, x in enumerate(self._vocabulary)}
    def __iter__(self):
        with bz2.open(self._pages_path, 'rt') as pages:
            for page in pages:
                words = page.strip().split()
                words = [self.encode(x) for x in words]
                yield words
    @property
    def vocabulary_size(self):
        return len(self._vocabulary)
    def encode(self, word):
        return self._indices.get(word, 0)
    def decode(self, index):
        return self._vocabulary[index]
    def _read_pages(self, url):
        wikipedia_path = download(url, self._cache_dir)
        with bz2.open(wikipedia_path) as wikipedia, \
                bz2.open(self._pages_path, 'wt') as pages:
            for _, element in etree.iterparse(wikipedia, tag='{*}page'):
                if element.find('./{*}redirect') is not None:
                    continue
                page = element.findtext('./{*}revision/{*}text')
                words = self._tokenize(page)
                pages.write(' '.join(words) + '\n')
                element.clear()
    def _build_vocabulary(self, vocabulary_size):
        counter = collections.Counter()
        with bz2.open(self._pages_path, 'rt') as pages:
            for page in pages:
                words = page.strip().split()
                counter.update(words)
        common = ['<unk>'] + counter.most_common(vocabulary_size - 1)
        common = [x[0] for x in common]
        with bz2.open(self._vocabulary_path, 'wt') as vocabulary:
            for word in common:
                vocabulary.write(word + '\n')
    @classmethod
    def _tokenize(cls, page):
        words = cls.TOKEN_REGEX.findall(page)
        words = [x.lower() for x in words]
        return words

import tensorflow as tf
import numpy as np
from helpers import lazy_property
class EmbeddingModel:
    def __init__(self, data, target, params):
        self.data = data
        self.target = target
        self.params = params
        self.embeddings
        self.cost
        self.optimize
    @lazy_property
    def embeddings(self):
        initial = tf.random_uniform(
            [self.params.vocabulary_size, self.params.embedding_size],
            -1.0, 1.0)
        return tf.Variable(initial)
    @lazy_property
    def optimize(self):
        optimizer = tf.train.MomentumOptimizer(
            self.params.learning_rate, self.params.momentum)
        return optimizer.minimize(self.cost)
    @lazy_property
    def cost(self):
        embedded = tf.nn.embedding_lookup(self.embeddings, self.data)
        weight = tf.Variable(tf.truncated_normal(
            [self.params.vocabulary_size, self.params.embedding_size],
            stddev=1.0 / self.params.embedding_size ** 0.5))
        bias = tf.Variable(tf.zeros([self.params.vocabulary_size]))
        target = tf.expand_dims(self.target, 1)
        return tf.reduce_mean(tf.nn.nce_loss(
            weight, bias, embedded, target,
            self.params.contrastive_examples,
            self.params.vocabulary_size))

import collections
import tensorflow as tf
import numpy as np
from batched import batched
from EmbeddingModel import EmbeddingModel
from skipgrams import skipgrams
from Wikipedia import Wikipedia
from helpers import AttrDict
WIKI_DOWNLOAD_DIR = './wikipedia'
params = AttrDict(
    vocabulary_size=10000,
    max_context=10,
    embedding_size=200,
    contrastive_examples=100,
    learning_rate=0.5,
    momentum=0.5,
    batch_size=1000,
)
data = tf.placeholder(tf.int32, [None])
target = tf.placeholder(tf.int32, [None])
model = EmbeddingModel(data, target, params)
corpus = Wikipedia(
    'https://dumps.wikimedia.org/enwiki/20160501/'
    'enwiki-20160501-pages-meta-current1.xml-p000000010p000030303.bz2',
    WIKI_DOWNLOAD_DIR,
    params.vocabulary_size)
examples = skipgrams(corpus, params.max_context)
batches = batched(examples, params.batch_size)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
average = collections.deque(maxlen=100)
for index, batch in enumerate(batches):
    feed_dict = {data: batch[0], target: batch[1]}
    cost, _ = sess.run([model.cost, model.optimize], feed_dict)
    average.append(cost)
    print('{}: {:5.1f}'.format(index + 1, sum(average) / len(average)))
    if index > 100000:
        break
embeddings = sess.run(model.embeddings)
np.save(WIKI_DOWNLOAD_DIR + '/embeddings.npy', embeddings)

参考资料:
《面向机器智能的TensorFlow实践》

欢迎加我微信交流:qingxingfengzi
我的微信公众号:qingxingfengzigz
我老婆张幸清的微信公众号:qingqingfeifangz

学习笔记TF018:词向量、维基百科语料库训练词向量模型的更多相关文章

  1. 广师大学习笔记之文本统计(jieba库好玩的词云)

    1.jieba库,介绍如下: (1) jieba 库的分词原理是利用一个中文词库,将待分词的内容与分词词库进行比对,通过图结构和动态规划方法找到最大概率的词组:除此之外,jieba 库还提供了增加自定 ...

  2. 学习笔记之TCP/IP协议分层与OSI參考模型

    1.协议的分层      ISO在制定标准化OSI之前,对网络体系结构相关的问题进行了充分的讨论, 终于提出了作为通信协议设计指标的OSI參考模型.这一模型将通信协议中必要 的功能分成了7层.通过这些 ...

  3. 学习笔记(22)- plato-训练端到端的模型

    原始文档 Train an end-to-end model To get started we can train a very simple model using Ludwig (feel fr ...

  4. tensorflow学习笔记(三十四):Saver(保存与加载模型)

    Savertensorflow 中的 Saver 对象是用于 参数保存和恢复的.如何使用呢? 这里介绍了一些基本的用法. 官网中给出了这么一个例子: v1 = tf.Variable(..., nam ...

  5. 开源共享一个训练好的中文词向量(语料是维基百科的内容,大概1G多一点)

    使用gensim的word2vec训练了一个词向量. 语料是1G多的维基百科,感觉词向量的质量还不错,共享出来,希望对大家有用. 下载地址是: http://pan.baidu.com/s/1boPm ...

  6. 基于51单片机IIC通信的PCF8591学习笔记

    引言 PCF8591 是单电源,低功耗8 位CMOS 数据采集器件,具有4 个模拟输入.一个输出和一个串行I2C 总线接口.3 个地址引脚A0.A1 和A2 用于编程硬件地址,允许将最多8 个器件连接 ...

  7. thinkphp学习笔记7—多层MVC

    原文:thinkphp学习笔记7-多层MVC ThinkPHP支持多层设计. 1.模型层Model 使用多层目录结构和命名规范来设计多层的model,例如在项目设计中如果需要区分数据层,逻辑层,服务层 ...

  8. AKKA学习笔记

    AKKA学习笔记总结 01. AKKA 1. 介绍: Akka基于Actor模型,提供了一个用于构建可扩展的(Scalable).弹性的(Resilient).快速响应的(Responsive)应用程 ...

  9. DNN模型训练词向量原理

    转自:https://blog.csdn.net/fendouaini/article/details/79821852 1 词向量 在NLP里,最细的粒度是词语,由词语再组成句子,段落,文章.所以处 ...

随机推荐

  1. 解决failed to push some refs to

    由于github我使用了dev和feature分支,团队合作合并到dev,个人开发都是feature....今天在本地feature中git pull origin dev 出现 在使用git 对源代 ...

  2. (转)什么是P问题、NP问题和NPC问题

    这或许是众多OIer最大的误区之一.    你会经常看到网上出现"这怎么做,这不是NP问题吗"."这个只有搜了,这已经被证明是NP问题了"之类的话.你要知道,大 ...

  3. SQL Server 死锁概念和分析

    锁的概念 锁是什么 锁是数据库中在并发操作情形下保护资源的机制.通常(具体要看锁兼容性)只有锁的拥有者才能对被锁的资源进行操作,从而保证数据一致性. 锁的概念可分为几部分 锁资源(锁住什么) 锁模式( ...

  4. php写流程管理

    流程控制即某个人发起一个流程,通过一层一层审核,通过后,完成整个流程,若有一层审核未通过,中断整个流程.即结束! 比如请假流程: 某一员工发起一个请假流程,那么这个流程的节点人员即他的上级,上上级,上 ...

  5. LVS+Keepalived实现DBProxy的高可用

    背景 在上一篇文章美团点评DBProxy读写分离使用说明实现了读写分离,但在最后提了二个问题:一是代理不管MySQL主从的复制状态,二是DBProxy本身是一个单点的存在.对于第一个可以通过自己定义的 ...

  6. Kafka协议兼容性改进

    在Kafka 0.10.2.0之前,Kafka服务器端和客户端版本之间的兼容性是"单向"的,即高版本的broker可以处理低版本client的请求.反过来,低版本的broker不能 ...

  7. 【算法系列学习三】[kuangbin带你飞]专题二 搜索进阶 之 A-Eight 反向bfs打表和康拓展开

    [kuangbin带你飞]专题二 搜索进阶 之 A-Eight 这是一道经典的八数码问题.首先,简单介绍一下八数码问题: 八数码问题也称为九宫问题.在3×3的棋盘,摆有八个棋子,每个棋子上标有1至8的 ...

  8. jquery获取文件名称

    $("#fileupload").on("change",function(){ var filePath=$(this).val(); if(filePath ...

  9. xmlplus 组件设计系列之十 - 网格(DataGrid)

    这一章我们要实现是一个网格组件,该组件除了最基本的数据展示功能外,还提供排序以及数据过滤功能. 数据源 为了测试我们即将编写好网格组件,我们采用如下格式的数据源.此数据源包含两部分的内容,分别是表头数 ...

  10. Java爬虫(一)利用GET和POST发送请求,获取服务器返回信息

    本人所使用软件 eclipse fiddle UC浏览器 分析请求信息 以知乎(https://www.zhihu.com)为例,模拟登陆请求,获取登陆后首页,首先就是分析请求信息. 用UC浏览器F1 ...