随时间反向传播 (BackPropagation Through Time,BPTT)

符号注解:

  • \(K\):词汇表的大小
  • \(T\):句子的长度
  • \(H\):隐藏层单元数
  • \(E_t\):第t个时刻(第t个word)的损失函数,定义为交叉熵误差\(E_t=-y_t^Tlog(\hat{y}_t)\)
  • \(E\):一个句子的损失函数,由各个时刻(即每个word)的损失函数组成,\(E=\sum\limits_t^T E_t\)。

    注: 由于我们要推倒的是SGD算法, 更新梯度是相对于一个训练样例而言的, 因此我们一次只考虑一个句子的误差,而不是整个训练集的误差(对应BGD算法)
  • \(x_t\in\mathbb{R}^{K\times 1}\):第t个时刻RNN的输入,为one-hot vector,1表示一个单词的出现,0表示不出现
  • \(s_t\in\mathbb{R}^{H\times 1}\):第t个时刻RNN隐藏层的输入
  • \(h_t\in\mathbb{R}^{H\times 1}\):第t个时刻RNN隐藏层的输出
  • \(z_t\in\mathbb{R}^{K\times 1}\):输出层的汇集输入
  • \(\hat{y}_t\in\mathbb{R}^{K\times 1}\):输出层的输出,激活函数为softmax
  • \(y_t\in\mathbb{R}^{K\times 1}\):第t个时刻的监督信息,为一个one-hot vector
  • \(r_t=\hat{y}_t-y_t\):残差向量
  • \(W\in\mathbb{R}^{H\times K}\):从输入层到隐藏层的权值
  • \(U\in\mathbb{R}^{H\times H}\):隐藏层上一个时刻到当前时刻的权值
  • \(V\in\mathbb{R}^{K\times H}\):隐藏层到输出层的权值

他们之间的关系:

\[\left\{\begin{aligned}&s_t=Uh_{t-1}+Wx_t\\&h_t=\sigma(s_t)\\&z_t=Vh_t\\& \hat{y}_t=\mathrm{softmax}(z_t) \end{aligned}\right.
\]

其中,\(\sigma(\cdot)\)是sigmoid函数。由于\(x_t\)是one-hot向量,假设第\(j\)个词出现,则\(Wx_t\)相当于把\(W\)的第\(j\)列选出来,因此这一步是不用进行任何矩阵运算的,直接做下标操作即可,在matlab里就是\(W(:,x_t)\)。

BPTT与BP类似,是在时间上反传的梯度下降算法。RNN中,我们的目的是求得\(\frac{\partial E}{\partial U},\frac{\partial E}{\partial W},\frac{\partial E}{\partial V}\),根据这三个变化率来优化三个参数\(U,V,W\)

注意到\(\frac{\partial E}{\partial U}=\sum\limits_t \frac{\partial E_t}{\partial U}\),因此我们只要对每个时刻的损失函数求偏导数再加起来即可。

1.计算\(\frac{\partial E_t}{\partial V}\)

\[\begin{aligned}\frac{\partial E_t}{\partial V_{ij}}&=tr\bigg( \big( \frac{\partial E_t}{\partial z_t}\big)^T\cdot \frac{\partial z_t}{\partial V_{ij}}\bigg)\\&=tr\bigg((\hat{y}_t-y_t)^T\cdot\begin{bmatrix}0\\ \vdots \\ \frac{\partial z_{t}^{(i)}}{\partial V_{ij}}\\\vdots\\0\end{bmatrix}\bigg)\\&=r_t^{(i)} h_t^{(j)}\end{aligned}
\]

注:推导中用到了之前推导用到的结论。其中\(r_t^{(i)}=(\hat{y}_t-y_t)^{(i)}\)表示残差向量第i个分量,\(h_t^{(j)}\)表示\(h_t\)的第j个分量。

上述结果可以改写为:

\[\frac{\partial E_t}{\partial V}=(\hat{y}_t-y_t)\otimes h_t
\]

\[\frac{\partial E}{\partial V} = \sum_{k=0}^t (\hat{y}_k-y_k)\otimes h_k
\]

其中\(\otimes\)表示向量外积。

2.计算\(\frac{\partial E_t}{\partial U}\)

由于U是各个时刻共享的,所以t之前每个时刻U的变化都对\(E_t\)有贡献,反过来求偏导时,也要考虑之前每个时刻U对E的影响。我们以\(s_k\)为中间变量,应用链式法则:

\[\frac{\partial E_t}{\partial U} = \sum_{k=0}^t \frac{\partial s_k}{\partial U} \frac{\partial E_t}{\partial s_k}
\]

但由于\(\frac{\partial s_k}{\partial U}\)(分子向量,分母矩阵)以目前的数学发展水平是没办法求的,因此我们要求这个偏导,可以拆解为\(E_t\)对\(U_{ij}\)的偏导数:

\[\frac{\partial E_t}{\partial U_{ij}} = \sum_{k=0}^t tr[(\frac{\partial E_t}{\partial s_k})^T \frac{\partial s_k}{\partial U_{ij}}]= \sum_{k=0}^t tr[(\delta_k)^T\frac{\partial s_k}{\partial U_{ij}}]
\]

其中,\(\delta_k=\frac{\partial E_t}{\partial s_k}\),遵循

\[s_k\to h_k\to s_{k+1}\to ...\to E_t
\]

的传递关系,应用链式法则有:

\[\delta_k=\frac{\partial h_k}{\partial s_k}\frac{\partial s_{k+1}}{\partial h_k} \frac{\partial E_t}{\partial s_{k+1}}=diag(1-h_k\odot h_k)U^T\delta_{k+1}=(U^T\delta_{k+1})\odot (1-h_k\odot h_k)
\]

其中,\(\odot\)表示向量点乘。于是,我们得到了关于\(\delta\) 的递推关系式。由\(\delta_t\)出发,我们可以往前推出每一个\(\delta\),现在计算\(\delta_t\):

\begin{equation}\delta_t=\frac{\partial E_t}{\partial s_t}=\frac{\partial h_t}{\partial s_t}\frac{\partial z_t}{\partial h_t}\frac{\partial E_t}{\partial z_t}=diag(1-h_t\odot h_t)\cdot VT\cdot(\hat{y}_t-y_t)=(VT(\hat{y}t-y_t))\odot (1-h_t\odot h_t)\end{equation}

将\(\delta_0,...,\delta_t\)代入$ \frac{\partial E_t}{\partial U
{ij}} $有:

\[\frac{\partial E_t}{\partial U_{ij}} = \sum_{k=0}^t \delta_k^{(i)} h_{k-1}^{(j)}
\]

将上式写成矩阵形式:

\[\frac{\partial E_t}{\partial U} = \sum_{k=0}^t \delta_k \otimes h_{k-1}
\]

不失严谨性,定义\(h_{-1}\)为全0的向量。

3.计算\(\frac{\partial E_t}{\partial W}\)

按照上述思路,我们可以得到

\[\frac{\partial E_t}{\partial W} = \sum_{k=0}^t \delta_k \otimes x_{k}
\]

由于\(x_k\)是个one-hot vector,假设其第\(m\)个位置为1,那么我们在更新\(W\)时只需要更新\(W\)的第\(m\)列即可,计算\(\frac{\partial{E_t}}{\partial{W}}\)的伪代码如下:

delta_t = V.T.dot(residual[T]) * (1-h[T]**2)
for t from T to 0
dEdW[ :,x[t] ] += delta_t
#delta_t = W.T.dot(delta_t) * (1 - h[t-1]**2)
delta_t = U.T.dot(delta_t) * (1 - h[t-1]**2)

BPTT算法推导的更多相关文章

  1. Recurrent Neural Network系列3--理解RNN的BPTT算法和梯度消失

    作者:zhbzz2007 出处:http://www.cnblogs.com/zhbzz2007 欢迎转载,也请保留这段声明.谢谢! 这是RNN教程的第三部分. 在前面的教程中,我们从头实现了一个循环 ...

  2. 机器学习 —— 基础整理(八)循环神经网络的BPTT算法步骤整理;梯度消失与梯度爆炸

    网上有很多Simple RNN的BPTT(Backpropagation through time,随时间反向传播)算法推导.下面用自己的记号整理一下. 我之前有个习惯是用下标表示样本序号,这里不能再 ...

  3. BP神经网络模型及算法推导

    一,什么是BP "BP(Back Propagation)网络是1986年由Rumelhart和McCelland为首的科学家小组提出,是一种按误差逆传播算法训练的多层前馈网络,是目前应用最 ...

  4. 带你找到五一最省的旅游路线【dijkstra算法推导详解】

    前言 五一快到了,小张准备去旅游了! 查了查到各地的机票 因为今年被扣工资扣得很惨,小张手头不是很宽裕,必须精打细算.他想弄清去各个城市的最低开销. [嗯,不用考虑回来的开销.小张准备找警察叔叔说自己 ...

  5. 1.XGBOOST算法推导

    最近因为实习的缘故,所以开始复习各种算法推导~~~就先拿这个xgboost练练手吧. (参考原作者ppt 链接:https://pan.baidu.com/s/1MN2eR-4BMY-jA5SIm6W ...

  6. BP神经网络算法推导及代码实现笔记zz

    一. 前言: 作为AI入门小白,参考了一些文章,想记点笔记加深印象,发出来是给有需求的童鞋学习共勉,大神轻拍! [毒鸡汤]:算法这东西,读完之后的状态多半是 --> “我是谁,我在哪?” 没事的 ...

  7. SVD在推荐系统中的应用详解以及算法推导

    SVD在推荐系统中的应用详解以及算法推导     出处http://blog.csdn.net/zhongkejingwang/article/details/43083603 前面文章SVD原理及推 ...

  8. 从乘法求导法则到BPTT算法

    本文为手稿,旨在搞清楚为什么BPTT算法会多路反向求导,而不是一个感性的认识. 假设我们要对E3求导(上图中的L3),那么则有: 所以S2是W的函数,也就是说,我们不能说: 因为WS2 = WS2(w ...

  9. Deep Learning基础--随时间反向传播 (BackPropagation Through Time,BPTT)推导

    1. 随时间反向传播BPTT(BackPropagation Through Time, BPTT) RNN(循环神经网络)是一种具有长时记忆能力的神经网络模型,被广泛用于序列标注问题.一个典型的RN ...

随机推荐

  1. 有一字符串,包含n个字符。写一函数,将此字符串中从第m个字符开始的全部字符复制成为另一个字符串。

    [提交][状态][讨论版] 题目描述 有一字符串,包含n个字符.写一函数,将此字符串中从第m个字符开始的全部字符复制成为另一个字符串. 输入 数字n 一行字符串 数字m 输出 从m开始的子串 样例输入 ...

  2. rabbitMQ学习(二)

    一端发送,多端消费 发送端: import java.io.IOException; import com.rabbitmq.client.ConnectionFactory; import com. ...

  3. 关于js作用域链,以及闭包中的坑

    eg:链式作用域,想在外部读取blogName的值得方法 <script>var authorName="山边小溪";function doSomething(){   ...

  4. Codeigniter

    最近准备接手改进一个别人用Codeigniter写的项目,虽然之前也有用过CI,但是是完全按着自己的意思写的,没按CI的一些套路.用在公众的项目,最好还是按框架规范来,所以还是总结一下,免得以后别人再 ...

  5. javascript冒泡算法

    var arr = [10, 10, 3, 2, 5 , 4, 8, 3]; function reSort(arr) { var temp = 0; var len = arr.length; fo ...

  6. Badboy使用数据源Excel进行脚本参数化

    1.首先新建一个Excel,这里示例我写得非常简单,由两由数据组成,第一行为表头.见下图: 2.录制脚本,见上一篇,录制一个非常简单的搜狗查询 3.添加数据源,在Tools面板中找到Data Sour ...

  7. WAP端 经验记录2

    1. LightboxV2 插件 点击A 应该关闭弹层的效果,但是 SAMSUNG 手机上原生浏览器上,看上去不会关闭却跳转了,但当点击回退按钮的时候就会看见弹层已经消失(其实之前的关闭效果已经记录了 ...

  8. --查询nvarchar(max)的表和字段

    --查询nvarchar(max)的表和字段 select 'insert into #tempTabelInfo select '''+d.name+''', '''+a.name+''', max ...

  9. Excel表格解析

    //add by yangwenpei WGCW-144 使用Excel表格导入纸票记录 20161212 start /** * @param fileInputStream * @param co ...

  10. oracle 返回第一个不为空的列的值

    ) from emp; 作用是返回函数coalesce参数中第一个不为null的值.