word2vec学习 spark版
参考资料:
http://ir.dlut.edu.cn/NewsShow.aspx?ID=291
http://www.douban.com/note/298095260/
word2vec是NLP领域的重要算法,它的功能是将word用K维的dense vector来表达,训练集是语料库,不含标点,以空格断句。因此可以看作是种特征处理方法。
主要优点:
- 加法操作。
- 高效。单机可处理1小时2千万词。
google的开源版本比较权威,地址( http://word2vec.googlecode.com/svn/trunk/ ),不过我以spark版本学习的。
I.背景知识
Distributed representation,word的特征表达方式,通过训练将每个词映射成 K 维实数向量(K 一般为模型中的超参数),通过词之间的距离(比如 cosine 相似度、欧氏距离等)来判断它们之间的语义相似度。
语言模型:n-gram等。
II.模型
0.word window构成context,对于一个单词i,以$u_i$表示,它作为别的单词的context时用$v_i$表示(也即它作为context的表示是不同的)。只有word window内的word才被认为是context,并且是顺序无关的。
1.概率模型为\[ P=\sum lot p(u_i) ,\]i表示位置(或单词),也即各单词出现概率的累积函数。
2.以skip gram为例(CBOW条件概率反过来),则位置i的单词出现概率为
\[ p(u_i)=\sum_{-c\leq j\leq c,j\neq 0} p(v_{i+j}|u_{i}) \]
表示位置i只和其context有关。
3.条件概率$p(v_{i+j}|u_i)$ 通过softmax实现K维向量到概率的转化表达。
III.优化
最开始使用神经网络,后来用层次softmax等来降低时间复杂度。还用了很多trick,比如ExpTable。
a) 删除隐藏层
b) 使用Hierarchical softmax 或negative sampling
c) 去除小于minCount的词
d)预先计算ExpTable
e) 根据一下公式算出每个词被选出的概率,如果选出来则不予更新。此方法可以节省时间而且可以提高非频繁词的准确度。
\[ prob(w)=1-\large(\sqrt{\frac{t}{f(w)}}+\frac{t}{f(w)}\large) \] 其中$t$为设定好的阈值,$f(w)$ 为$w$出现的频率。
f) 选取邻近词的窗口大小不固定。有利于更加偏重于离自己近的词进行更新。
g) 多线程,无需考虑互斥。
IV.spark源码分析
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/ package org.apache.spark.mllib.feature import java.lang.{Iterable => JavaIterable} import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer /**
* Entry in vocabulary
*/
private case class VocabWord(
var word: String,
var cn: Int,
var point: Array[Int],
var code: Array[Int],
var codeLen:Int
) /**
* :: Experimental ::
* Word2Vec creates vector representation of words in a text corpus.
* The algorithm first constructs a vocabulary from the corpus
* and then learns vector representation of words in the vocabulary.
* The vector representation can be used as features in
* natural language processing and machine learning algorithms.
*
* We used skip-gram model in our implementation and hierarchical softmax
* method to train the model. The variable names in the implementation
* matches the original C implementation.
*
* For original C implementation, see https://code.google.com/p/word2vec/
* For research papers, see
* Efficient Estimation of Word Representations in Vector Space
* and
* Distributed Representations of Words and Phrases and their Compositionality.
*/
@Experimental
class Word2VectorEX extends Serializable with Logging { private var vectorSize =
private var startingAlpha = 0.025
private var numPartitions =
private var numIterations =
private var seed = Utils.random.nextLong() /**
* Sets vector size (default: 100).
*/
def setVectorSize(vectorSize: Int): this.type = {
this.vectorSize = vectorSize
this
} /**
* Sets initial learning rate (default: 0.025).
*/
def setLearningRate(learningRate: Double): this.type = {
this.startingAlpha = learningRate
this
} /**
* Sets number of partitions (default: 1). Use a small number for accuracy.
*/
def setNumPartitions(numPartitions: Int): this.type = {
require(numPartitions > , s"numPartitions must be greater than 0 but got $numPartitions")
this.numPartitions = numPartitions
this
} /**
* Sets number of iterations (default: 1), which should be smaller than or equal to number of
* partitions.
*/
def setNumIterations(numIterations: Int): this.type = {
this.numIterations = numIterations
this
} /**
* Sets random seed (default: a random long integer).
*/
def setSeed(seed: Long): this.type = {
this.seed = seed
this
} private val EXP_TABLE_SIZE =
private val MAX_EXP =
private val MAX_CODE_LENGTH =
private val MAX_SENTENCE_LENGTH = /** context words from [-window, window] */
private val window = //context 范围限定 /** minimum frequency to consider a vocabulary word */
private val minCount = //过滤单词阈值 private var trainWordsCount = //语料库总共词量(计重复出现)
private var vocabSize = //词表内单词总数
private var vocab: Array[VocabWord] = null //词表
private var vocabHash = mutable.HashMap.empty[String, Int] //词表反查索引 private def learnVocab(words: RDD[String]): Unit = { //构造词表,统计更新上面四个量
vocab = words.map(w => (w, ))
.reduceByKey(_ + _)
.map(x => VocabWord(
x._1,
x._2,
new Array[Int](MAX_CODE_LENGTH),
new Array[Int](MAX_CODE_LENGTH),
))
.filter(_.cn >= minCount)
.collect()
.sortWith((a, b) => a.cn > b.cn) vocabSize = vocab.length
var a =
while (a < vocabSize) {
vocabHash += vocab(a).word -> a
trainWordsCount += vocab(a).cn
a +=
}
logInfo("trainWordsCount = " + trainWordsCount)
} private def createExpTable(): Array[Float] = { //指数运算查表
val expTable = new Array[Float](EXP_TABLE_SIZE)
var i =
while (i < EXP_TABLE_SIZE) {
val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP)
expTable(i) = (tmp / (tmp + 1.0)).toFloat
i +=
}
expTable
} private def createBinaryTree(): Unit = {
val count = new Array[Long](vocabSize * + )
val binary = new Array[Int](vocabSize * + )
val parentNode = new Array[Int](vocabSize * + )
val code = new Array[Int](MAX_CODE_LENGTH)
val point = new Array[Int](MAX_CODE_LENGTH)
var a =
while (a < vocabSize) {
count(a) = vocab(a).cn
a +=
}
while (a < * vocabSize) {
count(a) = 1e9.toInt
a +=
}
var pos1 = vocabSize -
var pos2 = vocabSize var min1i =
var min2i = a =
while (a < vocabSize - ) {
if (pos1 >= ) {
if (count(pos1) < count(pos2)) {
min1i = pos1
pos1 -=
} else {
min1i = pos2
pos2 +=
}
} else {
min1i = pos2
pos2 +=
}
if (pos1 >= ) {
if (count(pos1) < count(pos2)) {
min2i = pos1
pos1 -=
} else {
min2i = pos2
pos2 +=
}
} else {
min2i = pos2
pos2 +=
}
count(vocabSize + a) = count(min1i) + count(min2i)
parentNode(min1i) = vocabSize + a
parentNode(min2i) = vocabSize + a
binary(min2i) =
a +=
}
// Now assign binary code to each vocabulary word
var i =
a =
while (a < vocabSize) {
var b = a
i =
while (b != vocabSize * - ) {
code(i) = binary(b)
point(i) = b
i +=
b = parentNode(b)
}
vocab(a).codeLen = i
vocab(a).point() = vocabSize -
b =
while (b < i) {
vocab(a).code(i - b - ) = code(b)
vocab(a).point(i - b) = point(b) - vocabSize
b +=
}
a +=
}
} /**
* Computes the vector representation of each word in vocabulary.
* @param dataset an RDD of words
* @return a Word2VecModel
*/
def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VectorModel = { val words = dataset.flatMap(x => x) //拉成词序列,句话断点通过Iterable来表征 learnVocab(words) //学习词库 createBinaryTree() val sc = dataset.context val expTable = sc.broadcast(createExpTable())
val bcVocab = sc.broadcast(vocab)
val bcVocabHash = sc.broadcast(vocabHash) val sentences: RDD[Array[Int]] = words.mapPartitions { iter => //按句子划分,单词以Int表征
new Iterator[Array[Int]] {
def hasNext: Boolean = iter.hasNext def next(): Array[Int] = {
var sentence = new ArrayBuffer[Int]
var sentenceLength =
while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) {
val word = bcVocabHash.value.get(iter.next())
word match {
case Some(w) =>
sentence += w
sentenceLength +=
case None =>
}
}
sentence.toArray
}
}
} //Hierarchical Softmax
val newSentences = sentences.repartition(numPartitions).cache()
val initRandom = new XORShiftRandom(seed)
val syn0Global =
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
val syn1Global = new Array[Float](vocabSize * vectorSize)
var alpha = startingAlpha
for (k <- to numIterations) {
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
val random = new XORShiftRandom(seed ^ ((idx + ) << ) ^ ((-k - ) << )) //随机梯度下降
val syn0Modify = new Array[Int](vocabSize)
val syn1Modify = new Array[Int](vocabSize)
val model = iter.foldLeft((syn0Global, syn1Global, , )) {
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
var wc = wordCount
if (wordCount - lastWordCount > ) {
lwc = wordCount
// TODO: discount by iteration?
alpha =
startingAlpha * ( - numPartitions * wordCount.toDouble / (trainWordsCount + ))
if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
}
wc += sentence.size
var pos =
while (pos < sentence.size) {
val word = sentence(pos)
val b = random.nextInt(window)
// Train Skip-gram
var a = b
while (a < window * + - b) {
if (a != window) {
val c = pos - window + a
if (c >= && c < sentence.size) {
val lastWord = sentence(c)
val l1 = lastWord * vectorSize
val neu1e = new Array[Float](vectorSize)
// Hierarchical softmax
var d =
while (d < bcVocab.value(word).codeLen) {
val inner = bcVocab.value(word).point(d)
val l2 = inner * vectorSize
// Propagate hidden -> output
var f = blas.sdot(vectorSize, syn0, l1, , syn1, l2, )
if (f > -MAX_EXP && f < MAX_EXP) {
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
f = expTable.value(ind)
val g = (( - bcVocab.value(word).code(d) - f) * alpha).toFloat
blas.saxpy(vectorSize, g, syn1, l2, , neu1e, , )
blas.saxpy(vectorSize, g, syn0, l1, , syn1, l2, )
syn1Modify(inner) +=
}
d +=
}
blas.saxpy(vectorSize, 1.0f, neu1e, , , syn0, l1, )
syn0Modify(lastWord) +=
}
}
a +=
}
pos +=
}
(syn0, syn1, lwc, wc)
}
val syn0Local = model._1
val syn1Local = model._2
// Only output modified vectors.
Iterator.tabulate(vocabSize) { index =>
if (syn0Modify(index) > ) {
Some((index, syn0Local.slice(index * vectorSize, (index + ) * vectorSize)))
} else {
None
}
}.flatten ++ Iterator.tabulate(vocabSize) { index =>
if (syn1Modify(index) > ) {
Some((index + vocabSize, syn1Local.slice(index * vectorSize, (index + ) * vectorSize)))
} else {
None
}
}.flatten
}
val synAgg = partial.reduceByKey { case (v1, v2) =>
blas.saxpy(vectorSize, 1.0f, v2, , v1, )
v1
}.collect()
var i =
while (i < synAgg.length) {
val index = synAgg(i)._1
if (index < vocabSize) {
Array.copy(synAgg(i)._2, , syn0Global, index * vectorSize, vectorSize)
} else {
Array.copy(synAgg(i)._2, , syn1Global, (index - vocabSize) * vectorSize, vectorSize)
}
i +=
}
}
newSentences.unpersist() val word2VecMap = mutable.HashMap.empty[String, Array[Float]]
var i =
while (i < vocabSize) {
val word = bcVocab.value(i).word
val vector = new Array[Float](vectorSize)
Array.copy(syn0Global, i * vectorSize, vector, , vectorSize)
word2VecMap += word -> vector
i +=
} new Word2VectorModel(word2VecMap.toMap)
} /**
* Computes the vector representation of each word in vocabulary (Java version).
* @param dataset a JavaRDD of words
* @return a Word2VecModel
*/
def fit[S <: JavaIterable[String]](dataset: JavaRDD[S]): Word2VectorModel = {
fit(dataset.rdd.map(_.asScala))
} } /**
* :: Experimental ::
* Word2Vec model
*/
@Experimental
class Word2VectorModel private[mllib] (
private val model: Map[String, Array[Float]]) extends Serializable { private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
require(v1.length == v2.length, "Vectors should have the same length")
val n = v1.length
val norm1 = blas.snrm2(n, v1, )
val norm2 = blas.snrm2(n, v2, )
if (norm1 == || norm2 == ) return 0.0
blas.sdot(n, v1, , v2,) / norm1 / norm2
} /**
* Transforms a word to its vector representation
* @param word a word
* @return vector representation of word
*/
def transform(word: String): Vector = {
model.get(word) match {
case Some(vec) =>
Vectors.dense(vec.map(_.toDouble))
case None =>
throw new IllegalStateException(s"$word not in vocabulary")
}
} /**
* Find synonyms of a word
* @param word a word
* @param num number of synonyms to find
* @return array of (word, similarity)
*/
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
val vector = transform(word)
findSynonyms(vector,num)
} /**
* Find synonyms of the vector representation of a word
* @param vector vector representation of a word
* @param num number of synonyms to find
* @return array of (word, cosineSimilarity)
*/
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
require(num > , "Number of similar words should > 0")
// TODO: optimize top-k
val fVector = vector.toArray.map(_.toFloat)
model.mapValues(vec => cosineSimilarity(fVector, vec))
.toSeq
.sortBy(- _._2)
.take(num + )
.tail
.toArray
} def getModel(): Map[String, Array[Float]] = {
model
} }
word2vec学习 spark版的更多相关文章
- 学习Spark——那些让你精疲力尽的坑
这一个月我都干了些什么-- 工作上,还是一如既往的写bug并不亦乐乎的修bug.学习上,最近看了一些非专业书籍,时常在公众号(JackieZheng)上写点小感悟,我刚稍稍瞄了下,最近五篇居然都跟技术 ...
- 5分钟学习spark streaming之 轻松在浏览器运行和修改Word Counts
方案一:根据官方实例,下载预编译好的版本,执行以下步骤: nc -lk 9999 作为实时数据源 ./bin/run-example org.apache.spark.examples.sql.str ...
- Learning ROS for Robotics Programming - Second Edition(《ROS机器人编程学习-第二版》)
Learning ROS for Robotics Programming - Second Edition <ROS机器人编程学习-第二版> ----Your one-stop guid ...
- 学习Spark——环境搭建(Mac版)
大数据情结 还记得上次跳槽期间,与很多猎头都有聊过,其中有一个猎头告诉我,整个IT跳槽都比较频繁,但是相对来说,做大数据的比较"懒"一些,不太愿意动.后来在一篇文中中也证实了这一观 ...
- 大数据学习——spark笔记
变量的定义 val a: Int = 1 var b = 2 方法和函数 区别:函数可以作为参数传递给方法 方法: def test(arg: Int): Int=>Int ={ 方法体 } v ...
- 小白学习Spark系列四:RDD踩坑总结(scala+spark2.1 sql常用方法)
初次尝试用 Spark+scala 完成项目的重构,由于两者之前都没接触过,所以边学边用的过程大多艰难.首先面临的是如何快速上手,然后是代码调优.性能调优.本章主要记录自己在项目中遇到的问题以及解决方 ...
- 小白学习Spark系列三:RDD常用方法总结
上一节简单介绍了Spark的基本原理以及如何调用spark进行打包一个独立应用,那么这节我们来学习下在spark中如何编程,同样先抛出以下几个问题. Spark支持的数据集,如何理解? Spark编程 ...
- 小白学习Spark系列一:Spark简介
由于最近在工作中刚接触到scala和Spark,并且作为python中毒者,爬行过程很是艰难,所以这一系列分为几个部分记录下学习<Spark快速大数据分析>的知识点以及自己在工程中遇到的小 ...
- word2vec学习总结
目录 1.简介 2.从统计语言模型开始 2.1序列概率模型 2.2 N元统计模型 3.深度序列模型 3.1神经概率模型 3.2 one-hot向量表示法 3.3 word2vec 3.4word2ve ...
随机推荐
- 九度oj 题目1108:堆栈的使用
题目描述: 堆栈是一种基本的数据结构.堆栈具有两种基本操作方式,push 和 pop.Push一个值会将其压入栈顶,而 pop 则会将栈顶的值弹出.现在我们就来验证一下堆栈的使用. 输入: 对于每组测 ...
- 九度oj 题目1139:最大子矩阵
题目描述: 已知矩阵的大小定义为矩阵中所有元素的和.给定一个矩阵,你的任务是找到最大的非空(大小至少是1 * 1)子矩阵. 比如,如下4 * 4的矩阵 0 -2 -7 0 9 2 -6 2 -4 1 ...
- jsessionid 所引起的404问题和解决方法
问题: 在SpringMvc使用RedirectView或者"redirect:"前缀来做重定向时,Spring MVC最后会调用: response.sendRedirect(r ...
- BZOJ2245 [SDOI2011]工作安排 【费用流】
题目 你的公司接到了一批订单.订单要求你的公司提供n类产品,产品被编号为1~n,其中第i类产品共需要Ci件.公司共有m名员工,员工被编号为1~m员工能够制造的产品种类有所区别.一件产品必须完整地由一名 ...
- noip2017爆炸记——题解&总结&反省(普及组+提高组)
相关链接: noip2018总结 noip2017是我见过的有史以来最坑爹的一场考试了. 今年北京市考点有一个是我们学校,我还恰好被分到了自己学校(还是自己天天上课的那个教室),于是我同时报了普及提高 ...
- Ionic 如何把左上角的按钮去掉?
代码实现: <ion-header > <ion-toolbar> <ion-buttons start> <a href="#"> ...
- MyEclipse6.5增加对Tomcat7的支持
MyEclipse6.5增加对Tomcat7的支持 最近在研究Servlet3.0,它是JavaEE6.0规范中的一部分 而Servlet3.0对服务器是有要求的,比如Tomcat7+(而Tomcat ...
- msp430项目编程53
msp430综合项目---扩展项目三53 1.电路工作原理 2.代码(显示部分) 3.代码(功能实现) 4.项目总结
- 系统进程的Watchdog
编写者:李文栋 /rayleeya http://rayleeya.iteye.com/blog/1963408 3.1 Watchdog简介 对于像笔者这样没玩过硬件的纯软程序员来说,第一次看到这个 ...
- git status检测不到文件变化
SourceTree(Git)无法检测新增文件的解决方法 有时候使用git管理软件SourceTree会遇到往项目里新增了文件,软件却没有任何反应的问题,这多发生在git合并出错而只能重新git的情况 ...