LSTM推导
LSTM推导
forward propagation

def lstm_cell_forward(xt, a_prev, c_prev, parameters):
"""
Implement a single forward step of the LSTM-cell as described in Figure (4)
Arguments:
xt -- your input data at timestep "t", numpy array of shape (n_x, m).
a_prev -- Hidden state at timestep "t-1", numpy array of shape (n_a, m)
c_prev -- Memory state at timestep "t-1", numpy array of shape (n_a, m)
parameters -- python dictionary containing:
Wf -- Weight matrix of the forget gate, numpy array of shape (n_a, n_a + n_x)
bf -- Bias of the forget gate, numpy array of shape (n_a, 1)
Wi -- Weight matrix of the save gate, numpy array of shape (n_a, n_a + n_x)
bi -- Bias of the save gate, numpy array of shape (n_a, 1)
Wc -- Weight matrix of the first "tanh", numpy array of shape (n_a, n_a + n_x)
bc -- Bias of the first "tanh", numpy array of shape (n_a, 1)
Wo -- Weight matrix of the focus gate, numpy array of shape (n_a, n_a + n_x)
bo -- Bias of the focus gate, numpy array of shape (n_a, 1)
Wy -- Weight matrix relating the hidden-state to the output, numpy array of shape (n_y, n_a)
by -- Bias relating the hidden-state to the output, numpy array of shape (n_y, 1)
Returns:
a_next -- next hidden state, of shape (n_a, m)
c_next -- next memory state, of shape (n_a, m)
yt_pred -- prediction at timestep "t", numpy array of shape (n_y, m)
cache -- tuple of values needed for the backward pass, contains (a_next, c_next, a_prev, c_prev, xt, parameters)
Note: ft/it/ot stand for the forget/update/output gates, cct stands for the candidate value (c tilda),
c stands for the memory value
"""
# Retrieve parameters from "parameters"
Wf = parameters["Wf"]
bf = parameters["bf"]
Wi = parameters["Wi"]
bi = parameters["bi"]
Wc = parameters["Wc"]
bc = parameters["bc"]
Wo = parameters["Wo"]
bo = parameters["bo"]
Wy = parameters["Wy"]
by = parameters["by"]
# Retrieve dimensions from shapes of xt and Wy
n_x, m = xt.shape
n_y, n_a = Wy.shape
# Concatenate a_prev and xt (≈3 lines)
concat = np.zeros((n_x+n_a,m))
concat[: n_a, :] = a_prev
concat[n_a :, :] = xt
# Compute values for ft, it, cct, c_next, ot, a_next using the formulas given figure (4) (≈6 lines)
ft = sigmoid(np.dot(Wf,concat)+bf)
it = sigmoid(np.dot(Wi,concat)+bi)
cct = np.tanh(np.dot(Wc,concat)+bc)
c_next = ft*c_prev + it*cct
ot = sigmoid(np.dot(Wo,concat)+bo)
a_next = ot*np.tanh(c_next)
# Compute prediction of the LSTM cell (≈1 line)
yt_pred = softmax(np.dot(Wy, a_next) + by)
# store values needed for backward propagation in cache
cache = (a_next, c_next, a_prev, c_prev, ft, it, cct, ot, xt, parameters)
return a_next, c_next, yt_pred, cache
back propagation

def lstm_cell_backward(da_next, dc_next, cache):
"""
Implement the backward pass for the LSTM-cell (single time-step).
Arguments:
da_next -- Gradients of next hidden state, of shape (n_a, m)
dc_next -- Gradients of next cell state, of shape (n_a, m)
cache -- cache storing information from the forward pass
Returns:
gradients -- python dictionary containing:
dxt -- Gradient of input data at time-step t, of shape (n_x, m)
da_prev -- Gradient w.r.t. the previous hidden state, numpy array of shape (n_a, m)
dc_prev -- Gradient w.r.t. the previous memory state, of shape (n_a, m, T_x)
dWf -- Gradient w.r.t. the weight matrix of the forget gate, numpy array of shape (n_a, n_a + n_x)
dWi -- Gradient w.r.t. the weight matrix of the input gate, numpy array of shape (n_a, n_a + n_x)
dWc -- Gradient w.r.t. the weight matrix of the memory gate, numpy array of shape (n_a, n_a + n_x)
dWo -- Gradient w.r.t. the weight matrix of the save gate, numpy array of shape (n_a, n_a + n_x)
dbf -- Gradient w.r.t. biases of the forget gate, of shape (n_a, 1)
dbi -- Gradient w.r.t. biases of the update gate, of shape (n_a, 1)
dbc -- Gradient w.r.t. biases of the memory gate, of shape (n_a, 1)
dbo -- Gradient w.r.t. biases of the save gate, of shape (n_a, 1)
"""
# Retrieve information from "cache"
(a_next, c_next, a_prev, c_prev, ft, it, cct, ot, xt, parameters) = cache
# Retrieve dimensions from xt's and a_next's shape (≈2 lines)
n_x, m = xt.shape
n_a, m = a_next.shape
# Compute gates related derivatives, you can find their values can be found by looking carefully at equations (7) to (10) (≈4 lines)
dot = da_next * np.tanh(c_next) * ot * (1 - ot)
dcct = (dc_next * it + ot * (1 - np.square(np.tanh(c_next))) * it * da_next) * (1 - np.square(cct))
dit = (dc_next * cct + ot * (1 - np.square(np.tanh(c_next))) * cct * da_next) * it * (1 - it)
dft = (dc_next * c_prev + ot *(1 - np.square(np.tanh(c_next))) * c_prev * da_next) * ft * (1 - ft)
# Compute parameters related derivatives. Use equations (11)-(14) (≈8 lines)
dWf = np.dot(dft,np.concatenate((a_prev, xt), axis=0).T)
dWi = np.dot(dit,np.concatenate((a_prev, xt), axis=0).T)
dWc = np.dot(dcct,np.concatenate((a_prev, xt), axis=0).T)
dWo = np.dot(dot,np.concatenate((a_prev, xt), axis=0).T)
dbf = np.sum(dft, axis=1 ,keepdims = True)
dbi = np.sum(dit, axis=1, keepdims = True)
dbc = np.sum(dcct, axis=1, keepdims = True)
dbo = np.sum(dot, axis=1, keepdims = True)
# Compute derivatives w.r.t previous hidden state, previous memory state and input. Use equations (15)-(17). (≈3 lines)
da_prev = np.dot(parameters['Wf'][:,:n_a].T,dft)+np.dot(parameters['Wi'][:,:n_a].T,dit)+np.dot(parameters['Wc'][:,:n_a].T,dcct)+np.dot(parameters['Wo'][:,:n_a].T,dot)
dc_prev = dc_next*ft+ot*(1-np.square(np.tanh(c_next)))*ft*da_next
dxt = np.dot(parameters['Wf'][:,n_a:].T,dft)+np.dot(parameters['Wi'][:,n_a:].T,dit)+np.dot(parameters['Wc'][:,n_a:].T,dcct)+np.dot(parameters['Wo'][:,n_a:].T,dot)
# parameters['Wf'][:, :n_a].T 每一行的 第 0 到 n_a-1 列的数据取出来
# parameters['Wf'][:, n_a:].T 每一行的 第 n_a 到最后列的数据取出来
# Save gradients in dictionary
gradients = {"dxt": dxt, "da_prev": da_prev, "dc_prev": dc_prev, "dWf": dWf,"dbf": dbf, "dWi": dWi,"dbi": dbi,
"dWc": dWc,"dbc": dbc, "dWo": dWo,"dbo": dbo}
return gradients
LSTM推导的更多相关文章
- 【Deep Learning】RNN LSTM 推导
http://blog.csdn.net/Dark_Scope/article/details/47056361 http://blog.csdn.net/hongmaodaxia/article/d ...
- 循环神经(LSTM)网络学习总结
摘要: 1.算法概述 2.算法要点与推导 3.算法特性及优缺点 4.注意事项 5.实现和具体例子 6.适用场合 内容: 1.算法概述 长短期记忆网络(Long Short Term Memory ne ...
- 程序猿 tensorflow 入门开发及人工智能实战
tensorflow 中文文档: http://www.tensorfly.cn http://wiki.jikexueyuan.com/project/tensorflow-zh/ tensorfl ...
- 机器学习 —— 基础整理(八)循环神经网络的BPTT算法步骤整理;梯度消失与梯度爆炸
网上有很多Simple RNN的BPTT(Backpropagation through time,随时间反向传播)算法推导.下面用自己的记号整理一下. 我之前有个习惯是用下标表示样本序号,这里不能再 ...
- LSTM简介以及数学推导(FULL BPTT)
http://blog.csdn.net/a635661820/article/details/45390671 前段时间看了一些关于LSTM方面的论文,一直准备记录一下学习过程的,因为其他事儿,一直 ...
- 《神经网络的梯度推导与代码验证》之LSTM的前向传播和反向梯度推导
前言 在本篇章,我们将专门针对LSTM这种网络结构进行前向传播介绍和反向梯度推导. 关于LSTM的梯度推导,这一块确实挺不好掌握,原因有: 一些经典的deep learning 教程,例如花书缺乏相关 ...
- lstm bptt推导
深蓝 nlp 180429这个有详细的讲解
- GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现
GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现 RNN GRU matlab codes RNN网络考虑到了具有时间数列的样本数据,但是RNN仍存在着一些问题,比如随着 ...
- RNN求解过程推导与实现
RNN求解过程推导与实现 RNN LSTM BPTT matlab code opencv code BPTT,Back Propagation Through Time. 首先来看看怎么处理RNN. ...
- Theano:LSTM源码解析
最难读的Theano代码 这份LSTM代码的作者,感觉和前面Tutorial代码作者不是同一个人.对于Theano.Python的手法使用得非常娴熟. 尤其是在两重并行设计上: ①LSTM各个门之间并 ...
随机推荐
- ChatGPT 推出 iOS 应用,支持语音输入,使用体验如何?
最近,OpenAI 宣布推出官方 iOS 应用,允许用户随时随地访问其高人气 AI 聊天机器人,此举也打破了近几个月内苹果 App Store 上充斥似是而非的山寨服务的窘境. 该应用程序是 Chat ...
- springboot 静态资源导入
1.根据源码可以看到需要去webjars官网下载jquery的依赖 <dependency> <groupId>org.webjars</groupId> < ...
- hvv蓝初面试常见漏洞问题(上)
1.SQL注入 漏洞成因: 可控变量 变量会带入数据库查询 变量不存在过滤或者变量过滤不严格 注入流程 判断是否有注入点 order by 判断字段数量 union select 报错查看注入点 使用 ...
- 使用openresty替换线上nginx网关之openresty安装细节
背景 线上跑了多年的一个网关业务,随着部门的拆分,逐渐有了一个痛点.该网关业务主要处理app端请求,app端发起的请求,采用http协议,post方法,content-type采用applicatio ...
- 使用c#实现23种常见的设计模式
使用c#实现23种常见的设计模式 设计模式通常分为三个主要类别: 创建型模式 结构型模式 行为型模式. 这些模式是用于解决常见的对象导向设计问题的最佳实践. 以下是23种常见的设计模式并且提供c#代码 ...
- 【python爬虫实战】用python爬取爱奇艺电视剧十大榜单的全部数据!
目录 一.爬取目标 二.讲解代码 三.查看结果 四.视频演示 五.附完整源码 一.爬取目标 本次爬取的目标是,爱奇艺电视剧类目下的10个榜单:电视剧风云榜-爱奇艺风云榜 可以看到,这10个榜单包含 ...
- 第一章 : Linux入门
1. 概述 2. Linux 和 Windows 区别 3. Centos 下载地址 网易镜像:http://mirrors.163.com/centos/7/isos ...
- 记一次Native memory leak排查过程
1 问题现象 路由计算服务是路由系统的核心服务,负责运单路由计划的计算以及实操与计划的匹配.在运维过程中,发现在长期不重启的情况下,有TP99缓慢爬坡的现象.此外,在每周例行调度的试算过程中,能明显看 ...
- gRPC vs. HTTP:网络通信协议的对比
概述 gRPC 和 HTTP 是两种常见的网络通信协议,用于在客户端和服务器之间进行通信.它们具有不同的特点和适用场景,下面对它们进行详细比较. HTTP(Hypertext Transfer Pro ...
- 查看C语言程序对应的汇编代码
在终端输入 gcc -S main.c 命令的意思是 编译不汇编 mian.c 可以换成想要汇编的C语言程序 然后生成 main.s 使用文本编辑器查看即可