RNN(Recurrent Neural Networks)公式推导和实现

http://x-algo.cn/index.php/2016/04/25/rnn-recurrent-neural-networks-derivation-and-implementation/
2016-04-25 分类:Deep Learning / NLP / RNN 阅读(6997) 评论(7) 

本文主要参考wildml的博客所写,所有的代码都是python实现。没有使用任何深度学习的工具,公式推导虽然枯燥,但是推导一遍之后对RNN的理解会更加的深入。看本文之前建议对传统的神经网络的基本知识已经了解,如果不了解的可以看此文:『神经网络(Neural Network)实现』。

所有可执行代码:Code to follow along is on Github.

文章目录 [展开]

语言模型

熟悉NLP的应该会比较熟悉,就是将自然语言的一句话『概率化』。具体的,如果一个句子有m个词,那么这个句子生成的概率就是:

P(w1,...,wm)=∏mi=1P(wi∣w1,...,wi−1)P(w1,...,wm)=∏i=1mP(wi∣w1,...,wi−1)

其实就是假设下一次词生成的概率和只和句子前面的词有关,例如句子『He went to buy some chocolate』生成的概率可以表示为:  P(他喜欢吃巧克力) = P(他喜欢吃) * P(巧克力|他喜欢吃) 。

数据预处理

训练模型总需要语料,这里语料是来自google big query的reddit的评论数据,语料预处理会去掉一些低频词从而控制词典大小,低频词使用一个统一标识替换(这里是UNKNOWN_TOKEN),预处理之后每一个词都会使用一个唯一的编号替换;为了学出来哪些词常常作为句子开始和句子结束,引入SENTENCE_START和SENTENCE_END两个特殊字符。具体就看代码吧:

点击展开代码

 

网络结构

和传统的nn不同,但是也很好理解,rnn的网络结构如下图:

A recurrent neural network and the unfolding in time of the computation involved in its forward computation.

不同之处就在于rnn是一个『循环网络』,并且有『状态』的概念。

如上图,t表示的是状态, xtxt 表示的状态t的输入, stst 表示状态t时隐层的输出, otot 表示输出。特别的地方在于,隐层的输入有两个来源,一个是当前的 xtxt 输入、一个是上一个状态隐层的输出 st−1st−1 , W,U,VW,U,V 为参数。使用公式可以将上面结构表示为:

sty^t=tanh(Uxt+Wst−1)=softmax(Vst)st=tanh⁡(Uxt+Wst−1)y^t=softmax(Vst)

如果隐层节点个数为100,字典大小C=8000,参数的维度信息为:

xtotstUVW∈R8000∈R8000∈R100∈R100×8000∈R8000×100∈R100×100xt∈R8000ot∈R8000st∈R100U∈R100×8000V∈R8000×100W∈R100×100

初始化

参数的初始化有很多种方法,都初始化为0将会导致『symmetric calculations 』(我也不懂),如何初始化其实是和具体的激活函数有关系,我们这里使用的是tanh,一种推荐的方式是初始化为 [−1n√,1n√][−1n,1n] ,其中n是前一层接入的链接数。更多信息请点击查看更多

 
1
2
3
4
5
6
7
8
9
10
11
class RNNNumpy:
    
    def __init__(self, word_dim, hidden_dim=100, bptt_truncate=4):
        # Assign instance variables
        self.word_dim = word_dim
        self.hidden_dim = hidden_dim
        self.bptt_truncate = bptt_truncate
        # Randomly initialize the network parameters
        self.U = np.random.uniform(-np.sqrt(1./word_dim), np.sqrt(1./word_dim), (hidden_dim, word_dim))
        self.V = np.random.uniform(-np.sqrt(1./hidden_dim), np.sqrt(1./hidden_dim), (word_dim, hidden_dim))
        self.W = np.random.uniform(-np.sqrt(1./hidden_dim), np.sqrt(1./hidden_dim), (hidden_dim, hidden_dim))

前向传播

类似传统的nn的方法,计算几个矩阵乘法即可:

 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def forward_propagation(self, x):
    # The total number of time steps
    T = len(x)
    # During forward propagation we save all hidden states in s because need them later.
    # We add one additional element for the initial hidden, which we set to 0
    s = np.zeros((T + 1, self.hidden_dim))
    s[-1] = np.zeros(self.hidden_dim)
    # The outputs at each time step. Again, we save them for later.
    o = np.zeros((T, self.word_dim))
    # For each time step...
    for t in np.arange(T):
        # Note that we are indxing U by x[t]. This is the same as multiplying U with a one-hot vector.
        s[t] = np.tanh(self.U[:,x[t]] + self.W.dot(s[t-1]))
        o[t] = softmax(self.V.dot(s[t]))
    return [o, s]

预测函数可以写为:

 
1
2
3
4
def predict(self, x):
    # Perform forward propagation and return index of the highest score
    o, s = self.forward_propagation(x)
    return np.argmax(o, axis=1)

损失函数

类似nn方法,使用交叉熵作为损失函数,如果有N个样本,损失函数可以写为:

L(y,o)=−1N∑n∈NynlogonL(y,o)=−1N∑n∈Nynlog⁡on

下面两个函数用来计算损失:

 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def calculate_total_loss(self, x, y):
    L = 0
    # For each sentence...
    for i in np.arange(len(y)):
        o, s = self.forward_propagation(x[i])
        # We only care about our prediction of the "correct" words
        correct_word_predictions = o[np.arange(len(y[i])), y[i]]
        # Add to the loss based on how off we were
        L += -1 * np.sum(np.log(correct_word_predictions))
    return L
 
def calculate_loss(self, x, y):
    # Divide the total loss by the number of training examples
    N = np.sum((len(y_i) for y_i in y))
    return self.calculate_total_loss(x,y)/N

BPTT学习参数

BPTT( Backpropagation Through Time)是一种非常直观的方法,和传统的BP类似,只不过传播的路径是个『循环』,并且路径上的参数是共享的。

损失是交叉熵,损失可以表示为:

Et(yt,y^t)E(y,y^)=−ytlogy^t=∑tEt(yt,y^t)=−∑tytlogy^tEt(yt,y^t)=−ytlog⁡y^tE(y,y^)=∑tEt(yt,y^t)=−∑tytlog⁡y^t

其中 ytyt 是真实值, (^yt)(^yt) 是预估值,将误差展开可以用图表示为:

所以对所有误差求W的偏导数为:

∂E∂W=∑t∂Et∂W∂E∂W=∑t∂Et∂W

进一步可以将 EtEt 表示为:

∂E3∂V=∂E3∂y^3∂y^3∂V=∂E3∂y^3∂y^3∂z3∂z3∂V=(y^3−y3)⊗s3∂E3∂V=∂E3∂y^3∂y^3∂V=∂E3∂y^3∂y^3∂z3∂z3∂V=(y^3−y3)⊗s3

根据链式法则和RNN中W权值共享,可以得到:

∂E3∂W=∑k=03∂E3∂y^3∂y^3∂s3∂s3∂sk∂sk∂W∂E3∂W=∑k=03∂E3∂y^3∂y^3∂s3∂s3∂sk∂sk∂W

下图将这个过程表示的比较形象

BPTT更新梯度的代码:

 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def bptt(self, x, y):
    T = len(y)
    # Perform forward propagation
    o, s = self.forward_propagation(x)
    # We accumulate the gradients in these variables
    dLdU = np.zeros(self.U.shape)
    dLdV = np.zeros(self.V.shape)
    dLdW = np.zeros(self.W.shape)
    delta_o = o
    delta_o[np.arange(len(y)), y] -= 1.
    # For each output backwards...
    for t in np.arange(T)[::-1]:
        dLdV += np.outer(delta_o[t], s[t].T)
        # Initial delta calculation: dL/dz
        delta_t = self.V.T.dot(delta_o[t]) * (1 - (s[t] ** 2))
        # Backpropagation through time (for at most self.bptt_truncate steps)
        for bptt_step in np.arange(max(0, t-self.bptt_truncate), t+1)[::-1]:
            # print "Backpropagation step t=%d bptt step=%d " % (t, bptt_step)
            # Add to gradients at each previous step
            dLdW += np.outer(delta_t, s[bptt_step-1])              
            dLdU[:,x[bptt_step]] += delta_t
            # Update delta for next step dL/dz at t-1
            delta_t = self.W.T.dot(delta_t) * (1 - s[bptt_step-1] ** 2)
    return [dLdU, dLdV, dLdW]

梯度弥散现象

tanh和sigmoid函数和导数的取值返回如下图,可以看到导数取值是[0-1],用几次链式法则就会将梯度指数级别缩小,所以传播不了几层就会出现梯度非常弱。克服这个问题的LSTM是一种最近比较流行的解决方案。

Gradient Checking

梯度检验是非常有用的,检查的原理是一个点的『梯度』等于这个点的『斜率』,估算一个点的斜率可以通过求极限的方式:

∂L∂θ≈limh→0J(θ+h)−J(θ−h)2h∂L∂θ≈limh→0J(θ+h)−J(θ−h)2h

通过比较『斜率』和『梯度』的值,我们就可以判断梯度计算的是否有问题。需要注意的是这个检验成本还是很高的,因为我们的参数个数是百万量级的。

梯度检验的代码:

 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def gradient_check(self, x, y, h=0.001, error_threshold=0.01):
    # Calculate the gradients using backpropagation. We want to checker if these are correct.
    bptt_gradients = self.bptt(x, y)
    # List of all parameters we want to check.
    model_parameters = ['U', 'V', 'W']
    # Gradient check for each parameter
    for pidx, pname in enumerate(model_parameters):
        # Get the actual parameter value from the mode, e.g. model.W
        parameter = operator.attrgetter(pname)(self)
        print "Performing gradient check for parameter %s with size %d." % (pname, np.prod(parameter.shape))
        # Iterate over each element of the parameter matrix, e.g. (0,0), (0,1), ...
        it = np.nditer(parameter, flags=['multi_index'], op_flags=['readwrite'])
        while not it.finished:
            ix = it.multi_index
            # Save the original value so we can reset it later
            original_value = parameter[ix]
            # Estimate the gradient using (f(x+h) - f(x-h))/(2*h)
            parameter[ix] = original_value + h
            gradplus = self.calculate_total_loss([x],[y])
            parameter[ix] = original_value - h
            gradminus = self.calculate_total_loss([x],[y])
            estimated_gradient = (gradplus - gradminus)/(2*h)
            # Reset parameter to original value
            parameter[ix] = original_value
            # The gradient for this parameter calculated using backpropagation
            backprop_gradient = bptt_gradients[pidx][ix]
            # calculate The relative error: (|x - y|/(|x| + |y|))
            relative_error = np.abs(backprop_gradient - estimated_gradient)/(np.abs(backprop_gradient) + np.abs(estimated_gradient))
            # If the error is to large fail the gradient check
            if relative_error > error_threshold:
                print "Gradient Check ERROR: parameter=%s ix=%s" % (pname, ix)
                print "+h Loss: %f" % gradplus
                print "-h Loss: %f" % gradminus
                print "Estimated_gradient: %f" % estimated_gradient
                print "Backpropagation gradient: %f" % backprop_gradient
                print "Relative Error: %f" % relative_error
                return
            it.iternext()
        print "Gradient check for parameter %s passed." % (pname)

SGD实现

这个公式应该非常熟悉:

W=W−λΔWW=W−λΔW

其中 ΔWΔW 就是梯度,具体代码:

 
1
2
3
4
5
6
7
8
# Performs one step of SGD.
def numpy_sdg_step(self, x, y, learning_rate):
    # Calculate the gradients
    dLdU, dLdV, dLdW = self.bptt(x, y)
    # Change parameters according to gradients and learning rate
    self.U -= learning_rate * dLdU
    self.V -= learning_rate * dLdV
    self.W -= learning_rate * dLdW

生成文本

生成过程其实就是模型的应用过程,只需要反复执行预测函数即可:

 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def generate_sentence(model):
    # We start the sentence with the start token
    new_sentence = [word_to_index[sentence_start_token]]
    # Repeat until we get an end token
    while not new_sentence[-1] == word_to_index[sentence_end_token]:
        next_word_probs = model.forward_propagation(new_sentence)
        sampled_word = word_to_index[unknown_token]
        # We don't want to sample unknown words
        while sampled_word == word_to_index[unknown_token]:
            samples = np.random.multinomial(1, next_word_probs[-1])
            sampled_word = np.argmax(samples)
        new_sentence.append(sampled_word)
    sentence_str = [index_to_word[x] for x in new_sentence[1:-1]]
    return sentence_str
 
num_sentences = 10
senten_min_length = 7
 
for i in range(num_sentences):
    sent = []
    # We want long sentences, not sentences with one or two words
    while len(sent) < senten_min_length:
        sent = generate_sentence(model)
    print " ".join(sent)

参考文献

Recurrent Neural Networks Tutorial, Part 2 – Implementing a RNN with Python, Numpy and Theano

Recurrent Neural Networks Tutorial, Part 3 – Backpropagation Through Time and Vanishing Gradients

RNN(Recurrent Neural Networks)公式推导和实现的更多相关文章

  1. 转:RNN(Recurrent Neural Networks)

    RNN(Recurrent Neural Networks)公式推导和实现 http://x-algo.cn/index.php/2016/04/25/rnn-recurrent-neural-net ...

  2. 循环神经网络(RNN, Recurrent Neural Networks)介绍(转载)

    循环神经网络(RNN, Recurrent Neural Networks)介绍    这篇文章很多内容是参考:http://www.wildml.com/2015/09/recurrent-neur ...

  3. 循环神经网络(RNN, Recurrent Neural Networks)介绍

    原文地址: http://blog.csdn.net/heyongluoyao8/article/details/48636251# 循环神经网络(RNN, Recurrent Neural Netw ...

  4. 《转》循环神经网络(RNN, Recurrent Neural Networks)学习笔记:基础理论

    转自 http://blog.csdn.net/xingzhedai/article/details/53144126 更多参考:http://blog.csdn.net/mafeiyu80/arti ...

  5. 简述RNN Recurrent Neural Networks

    本文结构: 什么是 Recurrent Neural Networks ? Recurrent Neural Networks 的优点和应用? 训练 Recurrent Neural Networks ...

  6. 循环神经网络(RNN, Recurrent Neural Networks)——无非引入了环,解决时间序列问题

    摘自:http://blog.csdn.net/heyongluoyao8/article/details/48636251 不同于传统的FNNs(Feed-forward Neural Networ ...

  7. 循环神经网络(Recurrent Neural Networks, RNN)介绍

    目录 1 什么是RNNs 2 RNNs能干什么 2.1 语言模型与文本生成Language Modeling and Generating Text 2.2 机器翻译Machine Translati ...

  8. The Unreasonable Effectiveness of Recurrent Neural Networks (RNN)

    http://karpathy.github.io/2015/05/21/rnn-effectiveness/ There’s something magical about Recurrent Ne ...

  9. Attention and Augmented Recurrent Neural Networks

    Attention and Augmented Recurrent Neural Networks CHRIS OLAHGoogle Brain SHAN CARTERGoogle Brain Sep ...

随机推荐

  1. Spring_之注解事务 @Transactional

    spring 事务注解 默认遇到throw new RuntimeException("...");会回滚需要捕获的throw new Exception("...&qu ...

  2. How to add Leading Zeroes to a Number (Delphi Format)

    How to add Leading Zeroes to a Number (Delphi Format) Here's how convert (an integer) number to a st ...

  3. [译] Go 并发编程基础

    原文:Fundamentals of concurrent programming 译者:youngsterxyf 本文是一篇并发编程方面的入门文章,以Go语言编写示例代码,内容涵盖: 运行期并发线程 ...

  4. 看opengl 写代码(4) 画一个圆

    opengl 编程指南 P30 以下代码 是 用 直线 连起来 画一个圆. // circle.cpp : 定义控制台应用程序的入口点. // #include "stdafx.h" ...

  5. [C# 基础知识系列]专题八: 深入理解泛型(二)

    引言: 本专题主要是承接上一个专题要继续介绍泛型的其他内容,这里就不多说了,就直接进入本专题的内容的. 一.类型推断 在我们写泛型代码的时候经常有大量的"<"和"& ...

  6. StatCounter

    StatCounter provides free customisable hit counters, visitor tracking, web analytics and website sta ...

  7. Sqlite3+EF6踩的坑

    摘要 最近在用winform,有些数据需要本地存储,所以想到了使用sqlite这个文件数据库.在使用Nuget安装sqlite的时候,发现会将Ef也安装上了,所以想着使用EF进行数据的操作吧,所以这就 ...

  8. 在Brackets中使用Emmet

    当在Brackets中安装上Emmet插件后,就可以使用Emmet的语法来加速前端编写. 有关html ● 子关系> div>ul>li ● 相邻+ div+p+bq ● 上一级^ ...

  9. javascript:currentStyle和getComputedStyle的兼容写法

    currentStyle:获取计算后的样式,也叫当前样式.最终样式. 优点:可以获取元素的最终样式,包括浏览器的默认值,而不像style只能获取行间样式,所以更常用到. 注意:不能获取复合样式如bac ...

  10. Windows Phone本地数据库(SQLCE):11、使用LINQ查询数据库(翻译) (转)

    这是“windows phone mango本地数据库(sqlce)”系列短片文章的第十一篇. 为了让你开始在Windows Phone Mango中使用数据库,这一系列短片文章将覆盖所有你需要知道的 ...