RNN 的 BP —— Back Propagation Through Time.

参考:零基础入门深度学习(5) - 循环神经网络知乎

 1   def backward(self, sensitivity_array,
activator):
'''
实现BPTT算法
'''
self.calc_delta(sensitivity_array, activator)
self.calc_gradient()
def calc_delta(self, sensitivity_array, activator):
self.delta_list = [] # 用来保存各个时刻的误差项
for i in range(self.times):
self.delta_list.append(np.zeros(
(self.state_width, 1)))
self.delta_list.append(sensitivity_array)
# 迭代计算每个时刻的误差项
for k in range(self.times - 1, 0, -1):
self.calc_delta_k(k, activator)
def calc_delta_k(self, k, activator):
'''
根据k+1时刻的delta计算k时刻的delta
'''
state = self.state_list[k+1].copy()
element_wise_op(self.state_list[k+1],
activator.backward)
self.delta_list[k] = np.dot(
np.dot(self.delta_list[k+1].T, self.W),
np.diag(state[:,0])).T
def calc_gradient(self):
self.gradient_list = [] # 保存各个时刻的权重梯度
for t in range(self.times + 1):
self.gradient_list.append(np.zeros(
(self.state_width, self.state_width)))
for t in range(self.times, 0, -1):
self.calc_gradient_t(t)
# 实际的梯度是各个时刻梯度之和
self.gradient = reduce(
lambda a, b: a + b, self.gradient_list,
self.gradient_list[0]) # [0]被初始化为0且没有被修改过
def calc_gradient_t(self, t):
'''
计算每个时刻t权重的梯度
'''
gradient = np.dot(self.delta_list[t],
self.state_list[t-1].T)
self.gradient_list[t] = gradient
 class RNN2(RNN1):
# 定义 Sigmoid 激活函数
def activate(self, x):
return 1 / (1 + np.exp(-x)) # 定义 Softmax 变换函数
def transform(self, x):
safe_exp = np.exp(x - np.max(x))
return safe_exp / np.sum(safe_exp) def bptt(self, x, y):
x, y, n = np.asarray(x), np.asarray(y), len(y)
# 获得各个输出,同时计算好各个 State
o = self.run(x)
# 照着公式敲即可 ( σ'ω')σ
dis = o - y
dv = dis.T.dot(self._states[:-1])
du = np.zeros_like(self._u)
dw = np.zeros_like(self._w)
for t in range(n-1, -1, -1):
st = self._states[t]
ds = self._v.T.dot(dis[t]) * st * (1 - st)
# 这里额外设定了最多往回看 10 步
for bptt_step in range(t, max(-1, t-10), -1):
du += np.outer(ds, x[bptt_step])
dw += np.outer(ds, self._states[bptt_step-1])
st = self._states[bptt_step-1]
ds = self._w.T.dot(ds) * st * (1 - st)
return du, dv, dw def loss(self, x, y):
o = self.run(x)
return np.sum(
-y * np.log(np.maximum(o, 1e-12)) -
(1 - y) * np.log(np.maximum(1 - o, 1e-12))
)

BPTT的更多相关文章

  1. BPTT算法推导

    随时间反向传播 (BackPropagation Through Time,BPTT) 符号注解: \(K\):词汇表的大小 \(T\):句子的长度 \(H\):隐藏层单元数 \(E_t\):第t个时 ...

  2. RNN 入门教程 Part 3 – 介绍 BPTT 算法和梯度消失问题

    转载 - Recurrent Neural Networks Tutorial, Part 3 – Backpropagation Through Time and Vanishing Gradien ...

  3. Recurrent Neural Network系列3--理解RNN的BPTT算法和梯度消失

    作者:zhbzz2007 出处:http://www.cnblogs.com/zhbzz2007 欢迎转载,也请保留这段声明.谢谢! 这是RNN教程的第三部分. 在前面的教程中,我们从头实现了一个循环 ...

  4. 机器学习 —— 基础整理(八)循环神经网络的BPTT算法步骤整理;梯度消失与梯度爆炸

    网上有很多Simple RNN的BPTT(Backpropagation through time,随时间反向传播)算法推导.下面用自己的记号整理一下. 我之前有个习惯是用下标表示样本序号,这里不能再 ...

  5. BPTT for multiple layers

    单层rnn的bptt: 每一个时间点的误差进行反向传播,然后将delta求和,更新本层weight. 多层时: 1.时间1:T 分层计算activation. 2.时间T:1 利用本时间点的误差,分层 ...

  6. 循环神经网络-极其详细的推导BPTT

    首先明确一下,本文需要对RNN有一定的了解,而且本文只针对标准的网络结构,旨在彻底搞清楚反向传播和BPTT. 反向传播形象描述 什么是反向传播?传播的是什么?传播的是误差,根据误差进行调整. 举个例子 ...

  7. LSTM简介以及数学推导(FULL BPTT)

    http://blog.csdn.net/a635661820/article/details/45390671 前段时间看了一些关于LSTM方面的论文,一直准备记录一下学习过程的,因为其他事儿,一直 ...

  8. Deep Learning基础--随时间反向传播 (BackPropagation Through Time,BPTT)推导

    1. 随时间反向传播BPTT(BackPropagation Through Time, BPTT) RNN(循环神经网络)是一种具有长时记忆能力的神经网络模型,被广泛用于序列标注问题.一个典型的RN ...

  9. Backpropagation Through Time (BPTT) 梯度消失与梯度爆炸

    Backpropagation Through Time (BPTT) 梯度消失与梯度爆炸 下面的图显示的是RNN的结果以及数据前向流动方向 假设有 \[ \begin{split} h_t & ...

随机推荐

  1. 【ABAP系列】SAP ABAP DATA - COMMON PART

    公众号:SAP Technical 本文作者:matinal 原文出处:http://www.cnblogs.com/SAPmatinal/ 原文链接:[ABAP系列]SAP ABAP DATA - ...

  2. 机器学习笔记——模型调参利器 GridSearchCV(网格搜索)参数的说明

    GridSearchCV,它存在的意义就是自动调参,只要把参数输进去,就能给出最优化的结果和参数.但是这个方法适合于小数据集,一旦数据的量级上去了,很难得出结果.这个时候就是需要动脑筋了.数据量比较大 ...

  3. 华三F100系列防火墙 、华为USG6300系列防火 GRE 隧道配置

    GRE概述: 通用路由封装(GRE: Generic Routing Encapsulation)是通用路由封装协议,可以对某些网络层协议的数据报进行封装,使这些被封装的数据报能够在IPV4网络中传输 ...

  4. ambari 快速安装部署

    OS:Linux CPU消耗大,要准备5G以上,不然集群启动不了(我自己给它配了8G,启动整个集群是没问题,要用的话估计不够) 一.准备工作: 1.关闭防火墙:sudo ufw disable/ 2. ...

  5. Mysql安装后在服务里找不到和服务启动不起来的解决方法

    一,在安装完Mysql数据库后,发现在控制面板->管理->服务中找不到Mysql的服务启动 解决方法如下:开启命令行,按照如下步骤即可: 1.进入到mysql的安装包,在bin里执行:my ...

  6. 【C/C++开发】【VS开发】win32位与x64位下各类型长度对比

    64 位的优点:64 位的应用程序可以直接访问 4EB 的内存和文件大小最大达到4 EB(2 的 63 次幂):可以访问大型数据库.本文介绍的是64位下C语言开发程序注意事项. 1. 32 位和 64 ...

  7. [转帖]如何备份及恢复Linux文件权限

    如何备份及恢复Linux文件权限   http://embeddedlinux.org.cn/emb-linux/entry-level/201604/10-5337.html 三年前我就干过 chm ...

  8. c++学习笔记之引用

    引用是 C++ 的新增内容,在实际开发中会经常使用:C++ 用的引用就如同C语言的指针一样重要,但它比指针更加方便和易用,有时候甚至是不可或缺的. 同指针一样,引用能够减少数据的拷贝,提高数据的传递效 ...

  9. VUE-挂载点-实例成员-数据-过滤器-文本指令-事件指令-属性指令-表单指令-01

    目录 路飞项目 vue vue 导读 vue 的优势 渐进式框架 引入 vue 实例成员 - 挂载点 el js 对象(字典)补充 实例成员 - 数据 data 实例成员 - 过滤器 filters ...

  10. Web 开发和数据科学家仍是 Python 开发的两大主力

    由于 Python 2 即将退役,使用 Python 3 的开发者大约为 90%,Python 2 的使用量正在迅速减少.而去年仍有 1/4 的人使用 Python 2. Web 开发和数据科学家仍是 ...