LSTM推导

forward propagation

def lstm_cell_forward(xt, a_prev, c_prev, parameters):
"""
Implement a single forward step of the LSTM-cell as described in Figure (4) Arguments:
xt -- your input data at timestep "t", numpy array of shape (n_x, m).
a_prev -- Hidden state at timestep "t-1", numpy array of shape (n_a, m)
c_prev -- Memory state at timestep "t-1", numpy array of shape (n_a, m)
parameters -- python dictionary containing:
Wf -- Weight matrix of the forget gate, numpy array of shape (n_a, n_a + n_x)
bf -- Bias of the forget gate, numpy array of shape (n_a, 1)
Wi -- Weight matrix of the save gate, numpy array of shape (n_a, n_a + n_x)
bi -- Bias of the save gate, numpy array of shape (n_a, 1)
Wc -- Weight matrix of the first "tanh", numpy array of shape (n_a, n_a + n_x)
bc -- Bias of the first "tanh", numpy array of shape (n_a, 1)
Wo -- Weight matrix of the focus gate, numpy array of shape (n_a, n_a + n_x)
bo -- Bias of the focus gate, numpy array of shape (n_a, 1)
Wy -- Weight matrix relating the hidden-state to the output, numpy array of shape (n_y, n_a)
by -- Bias relating the hidden-state to the output, numpy array of shape (n_y, 1) Returns:
a_next -- next hidden state, of shape (n_a, m)
c_next -- next memory state, of shape (n_a, m)
yt_pred -- prediction at timestep "t", numpy array of shape (n_y, m)
cache -- tuple of values needed for the backward pass, contains (a_next, c_next, a_prev, c_prev, xt, parameters) Note: ft/it/ot stand for the forget/update/output gates, cct stands for the candidate value (c tilda),
c stands for the memory value
""" # Retrieve parameters from "parameters"
Wf = parameters["Wf"]
bf = parameters["bf"]
Wi = parameters["Wi"]
bi = parameters["bi"]
Wc = parameters["Wc"]
bc = parameters["bc"]
Wo = parameters["Wo"]
bo = parameters["bo"]
Wy = parameters["Wy"]
by = parameters["by"] # Retrieve dimensions from shapes of xt and Wy
n_x, m = xt.shape
n_y, n_a = Wy.shape # Concatenate a_prev and xt (≈3 lines)
concat = np.zeros((n_x+n_a,m))
concat[: n_a, :] = a_prev
concat[n_a :, :] = xt # Compute values for ft, it, cct, c_next, ot, a_next using the formulas given figure (4) (≈6 lines)
ft = sigmoid(np.dot(Wf,concat)+bf)
it = sigmoid(np.dot(Wi,concat)+bi)
cct = np.tanh(np.dot(Wc,concat)+bc)
c_next = ft*c_prev + it*cct
ot = sigmoid(np.dot(Wo,concat)+bo)
a_next = ot*np.tanh(c_next) # Compute prediction of the LSTM cell (≈1 line)
yt_pred = softmax(np.dot(Wy, a_next) + by) # store values needed for backward propagation in cache
cache = (a_next, c_next, a_prev, c_prev, ft, it, cct, ot, xt, parameters) return a_next, c_next, yt_pred, cache

back propagation

def lstm_cell_backward(da_next, dc_next, cache):
"""
Implement the backward pass for the LSTM-cell (single time-step). Arguments:
da_next -- Gradients of next hidden state, of shape (n_a, m)
dc_next -- Gradients of next cell state, of shape (n_a, m)
cache -- cache storing information from the forward pass Returns:
gradients -- python dictionary containing:
dxt -- Gradient of input data at time-step t, of shape (n_x, m)
da_prev -- Gradient w.r.t. the previous hidden state, numpy array of shape (n_a, m)
dc_prev -- Gradient w.r.t. the previous memory state, of shape (n_a, m, T_x)
dWf -- Gradient w.r.t. the weight matrix of the forget gate, numpy array of shape (n_a, n_a + n_x)
dWi -- Gradient w.r.t. the weight matrix of the input gate, numpy array of shape (n_a, n_a + n_x)
dWc -- Gradient w.r.t. the weight matrix of the memory gate, numpy array of shape (n_a, n_a + n_x)
dWo -- Gradient w.r.t. the weight matrix of the save gate, numpy array of shape (n_a, n_a + n_x)
dbf -- Gradient w.r.t. biases of the forget gate, of shape (n_a, 1)
dbi -- Gradient w.r.t. biases of the update gate, of shape (n_a, 1)
dbc -- Gradient w.r.t. biases of the memory gate, of shape (n_a, 1)
dbo -- Gradient w.r.t. biases of the save gate, of shape (n_a, 1)
""" # Retrieve information from "cache"
(a_next, c_next, a_prev, c_prev, ft, it, cct, ot, xt, parameters) = cache # Retrieve dimensions from xt's and a_next's shape (≈2 lines)
n_x, m = xt.shape
n_a, m = a_next.shape # Compute gates related derivatives, you can find their values can be found by looking carefully at equations (7) to (10) (≈4 lines)
dot = da_next * np.tanh(c_next) * ot * (1 - ot)
dcct = (dc_next * it + ot * (1 - np.square(np.tanh(c_next))) * it * da_next) * (1 - np.square(cct))
dit = (dc_next * cct + ot * (1 - np.square(np.tanh(c_next))) * cct * da_next) * it * (1 - it)
dft = (dc_next * c_prev + ot *(1 - np.square(np.tanh(c_next))) * c_prev * da_next) * ft * (1 - ft) # Compute parameters related derivatives. Use equations (11)-(14) (≈8 lines)
dWf = np.dot(dft,np.concatenate((a_prev, xt), axis=0).T)
dWi = np.dot(dit,np.concatenate((a_prev, xt), axis=0).T)
dWc = np.dot(dcct,np.concatenate((a_prev, xt), axis=0).T)
dWo = np.dot(dot,np.concatenate((a_prev, xt), axis=0).T)
dbf = np.sum(dft, axis=1 ,keepdims = True)
dbi = np.sum(dit, axis=1, keepdims = True)
dbc = np.sum(dcct, axis=1, keepdims = True)
dbo = np.sum(dot, axis=1, keepdims = True) # Compute derivatives w.r.t previous hidden state, previous memory state and input. Use equations (15)-(17). (≈3 lines)
da_prev = np.dot(parameters['Wf'][:,:n_a].T,dft)+np.dot(parameters['Wi'][:,:n_a].T,dit)+np.dot(parameters['Wc'][:,:n_a].T,dcct)+np.dot(parameters['Wo'][:,:n_a].T,dot)
dc_prev = dc_next*ft+ot*(1-np.square(np.tanh(c_next)))*ft*da_next
dxt = np.dot(parameters['Wf'][:,n_a:].T,dft)+np.dot(parameters['Wi'][:,n_a:].T,dit)+np.dot(parameters['Wc'][:,n_a:].T,dcct)+np.dot(parameters['Wo'][:,n_a:].T,dot)
# parameters['Wf'][:, :n_a].T 每一行的 第 0 到 n_a-1 列的数据取出来
# parameters['Wf'][:, n_a:].T 每一行的 第 n_a 到最后列的数据取出来 # Save gradients in dictionary
gradients = {"dxt": dxt, "da_prev": da_prev, "dc_prev": dc_prev, "dWf": dWf,"dbf": dbf, "dWi": dWi,"dbi": dbi,
"dWc": dWc,"dbc": dbc, "dWo": dWo,"dbo": dbo} return gradients

LSTM推导的更多相关文章

  1. 【Deep Learning】RNN LSTM 推导

    http://blog.csdn.net/Dark_Scope/article/details/47056361 http://blog.csdn.net/hongmaodaxia/article/d ...

  2. 循环神经(LSTM)网络学习总结

    摘要: 1.算法概述 2.算法要点与推导 3.算法特性及优缺点 4.注意事项 5.实现和具体例子 6.适用场合 内容: 1.算法概述 长短期记忆网络(Long Short Term Memory ne ...

  3. 程序猿 tensorflow 入门开发及人工智能实战

    tensorflow 中文文档: http://www.tensorfly.cn http://wiki.jikexueyuan.com/project/tensorflow-zh/ tensorfl ...

  4. 机器学习 —— 基础整理(八)循环神经网络的BPTT算法步骤整理;梯度消失与梯度爆炸

    网上有很多Simple RNN的BPTT(Backpropagation through time,随时间反向传播)算法推导.下面用自己的记号整理一下. 我之前有个习惯是用下标表示样本序号,这里不能再 ...

  5. LSTM简介以及数学推导(FULL BPTT)

    http://blog.csdn.net/a635661820/article/details/45390671 前段时间看了一些关于LSTM方面的论文,一直准备记录一下学习过程的,因为其他事儿,一直 ...

  6. 《神经网络的梯度推导与代码验证》之LSTM的前向传播和反向梯度推导

    前言 在本篇章,我们将专门针对LSTM这种网络结构进行前向传播介绍和反向梯度推导. 关于LSTM的梯度推导,这一块确实挺不好掌握,原因有: 一些经典的deep learning 教程,例如花书缺乏相关 ...

  7. lstm bptt推导

    深蓝 nlp 180429这个有详细的讲解

  8. GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现

    GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现 RNN GRU matlab codes RNN网络考虑到了具有时间数列的样本数据,但是RNN仍存在着一些问题,比如随着 ...

  9. RNN求解过程推导与实现

    RNN求解过程推导与实现 RNN LSTM BPTT matlab code opencv code BPTT,Back Propagation Through Time. 首先来看看怎么处理RNN. ...

  10. Theano:LSTM源码解析

    最难读的Theano代码 这份LSTM代码的作者,感觉和前面Tutorial代码作者不是同一个人.对于Theano.Python的手法使用得非常娴熟. 尤其是在两重并行设计上: ①LSTM各个门之间并 ...

随机推荐

  1. Java并发(九)----线程join、interrupt

    1.join 方法详解 1.1 为什么需要 join? 下面的代码执行,打印 r 是什么? static int r = 0; public static void main(String[] arg ...

  2. elment UI + EasyExcel 实现 导入

    前端组件 <hd-flex> <el-dialog v-model="isUploadDialog" width="50%" lock-scr ...

  3. 统信UOS系统开发笔记(三):从Qt源码编译安装之编译安装Qt5.12.8

    前言   上一篇,是使用Qt提供的安装包安装的,有些场景需要使用到自己编译的Qt,所以本篇如何在统信UOS系统上编译Qt5.12.8源码.   统信UOS系统版本   系统版本:   Qt源码下载   ...

  4. 我在 vscode 插件里接入了 ChatGPT,解决了代码变量命名的难题

    lowcode 插件 已经迭代了差不多3年.作为我的生产力工具,平常一些不需要动脑的搬砖活基本上都是用 lowcode 去完成,比如管理脚手架,生成 CURD 页面,根据接口文档生成 TS 类型,生成 ...

  5. 尚医通day13【预约挂号】(内附源码)

    页面预览 预约挂号 根据预约周期,展示可预约日期,根据有号.无号.约满等状态展示不同颜色,以示区分 可预约最后一个日期为即将放号日期 选择一个日期展示当天可预约列表 预约确认 第01章-预约挂号 接口 ...

  6. MySql InnoDB 存储引擎表优化

    一.InnoDB 表存储优化 1.OPTIMIZE TABLE 适时的使用 OPTIMIZE TABLE 语句来重组表,压缩浪费的表空间.这是在其它优化技术不可用的情况下最直接的方法.OPTIMIZE ...

  7. 前端vue自定义简单实用下拉筛选 下拉菜单

    前端vue自定义简单实用下拉筛选 下拉菜单, 下载完整代码请访问: https://ext.dcloud.net.cn/plugin?id=13020 效果图如下:     #### 使用方法 ``` ...

  8. Pinot2的无人机任务和数据处理实践

    目录 1. 引言 2. 技术原理及概念 2.1 基本概念解释 2.2 技术原理介绍 2.3 相关技术比较 3. 实现步骤与流程 4. 应用示例与代码实现讲解 4.1 应用场景介绍 4.2 应用实例分析 ...

  9. 好的,以下是我为您拟定的自然语言处理(NLP)领域的100篇热门博客文章标题,以逻辑清晰、结构紧凑、简单易懂的

    目录 1. 引言 2. 技术原理及概念 3. 实现步骤与流程 4. 应用示例与代码实现讲解 1. 机器翻译 2. 文本分类 3. 情感分析 5. 优化与改进 6. 结论与展望 好的,以下是我为您拟定的 ...

  10. Rust 过程宏 proc-macro 是个啥

    定义一个 procedural macro 新建一个 lib 类型的 crate: cargo new hello-macro --lib procedural macros 只能在 proc-mac ...