LSTM推导

forward propagation

def lstm_cell_forward(xt, a_prev, c_prev, parameters):
"""
Implement a single forward step of the LSTM-cell as described in Figure (4) Arguments:
xt -- your input data at timestep "t", numpy array of shape (n_x, m).
a_prev -- Hidden state at timestep "t-1", numpy array of shape (n_a, m)
c_prev -- Memory state at timestep "t-1", numpy array of shape (n_a, m)
parameters -- python dictionary containing:
Wf -- Weight matrix of the forget gate, numpy array of shape (n_a, n_a + n_x)
bf -- Bias of the forget gate, numpy array of shape (n_a, 1)
Wi -- Weight matrix of the save gate, numpy array of shape (n_a, n_a + n_x)
bi -- Bias of the save gate, numpy array of shape (n_a, 1)
Wc -- Weight matrix of the first "tanh", numpy array of shape (n_a, n_a + n_x)
bc -- Bias of the first "tanh", numpy array of shape (n_a, 1)
Wo -- Weight matrix of the focus gate, numpy array of shape (n_a, n_a + n_x)
bo -- Bias of the focus gate, numpy array of shape (n_a, 1)
Wy -- Weight matrix relating the hidden-state to the output, numpy array of shape (n_y, n_a)
by -- Bias relating the hidden-state to the output, numpy array of shape (n_y, 1) Returns:
a_next -- next hidden state, of shape (n_a, m)
c_next -- next memory state, of shape (n_a, m)
yt_pred -- prediction at timestep "t", numpy array of shape (n_y, m)
cache -- tuple of values needed for the backward pass, contains (a_next, c_next, a_prev, c_prev, xt, parameters) Note: ft/it/ot stand for the forget/update/output gates, cct stands for the candidate value (c tilda),
c stands for the memory value
""" # Retrieve parameters from "parameters"
Wf = parameters["Wf"]
bf = parameters["bf"]
Wi = parameters["Wi"]
bi = parameters["bi"]
Wc = parameters["Wc"]
bc = parameters["bc"]
Wo = parameters["Wo"]
bo = parameters["bo"]
Wy = parameters["Wy"]
by = parameters["by"] # Retrieve dimensions from shapes of xt and Wy
n_x, m = xt.shape
n_y, n_a = Wy.shape # Concatenate a_prev and xt (≈3 lines)
concat = np.zeros((n_x+n_a,m))
concat[: n_a, :] = a_prev
concat[n_a :, :] = xt # Compute values for ft, it, cct, c_next, ot, a_next using the formulas given figure (4) (≈6 lines)
ft = sigmoid(np.dot(Wf,concat)+bf)
it = sigmoid(np.dot(Wi,concat)+bi)
cct = np.tanh(np.dot(Wc,concat)+bc)
c_next = ft*c_prev + it*cct
ot = sigmoid(np.dot(Wo,concat)+bo)
a_next = ot*np.tanh(c_next) # Compute prediction of the LSTM cell (≈1 line)
yt_pred = softmax(np.dot(Wy, a_next) + by) # store values needed for backward propagation in cache
cache = (a_next, c_next, a_prev, c_prev, ft, it, cct, ot, xt, parameters) return a_next, c_next, yt_pred, cache

back propagation

def lstm_cell_backward(da_next, dc_next, cache):
"""
Implement the backward pass for the LSTM-cell (single time-step). Arguments:
da_next -- Gradients of next hidden state, of shape (n_a, m)
dc_next -- Gradients of next cell state, of shape (n_a, m)
cache -- cache storing information from the forward pass Returns:
gradients -- python dictionary containing:
dxt -- Gradient of input data at time-step t, of shape (n_x, m)
da_prev -- Gradient w.r.t. the previous hidden state, numpy array of shape (n_a, m)
dc_prev -- Gradient w.r.t. the previous memory state, of shape (n_a, m, T_x)
dWf -- Gradient w.r.t. the weight matrix of the forget gate, numpy array of shape (n_a, n_a + n_x)
dWi -- Gradient w.r.t. the weight matrix of the input gate, numpy array of shape (n_a, n_a + n_x)
dWc -- Gradient w.r.t. the weight matrix of the memory gate, numpy array of shape (n_a, n_a + n_x)
dWo -- Gradient w.r.t. the weight matrix of the save gate, numpy array of shape (n_a, n_a + n_x)
dbf -- Gradient w.r.t. biases of the forget gate, of shape (n_a, 1)
dbi -- Gradient w.r.t. biases of the update gate, of shape (n_a, 1)
dbc -- Gradient w.r.t. biases of the memory gate, of shape (n_a, 1)
dbo -- Gradient w.r.t. biases of the save gate, of shape (n_a, 1)
""" # Retrieve information from "cache"
(a_next, c_next, a_prev, c_prev, ft, it, cct, ot, xt, parameters) = cache # Retrieve dimensions from xt's and a_next's shape (≈2 lines)
n_x, m = xt.shape
n_a, m = a_next.shape # Compute gates related derivatives, you can find their values can be found by looking carefully at equations (7) to (10) (≈4 lines)
dot = da_next * np.tanh(c_next) * ot * (1 - ot)
dcct = (dc_next * it + ot * (1 - np.square(np.tanh(c_next))) * it * da_next) * (1 - np.square(cct))
dit = (dc_next * cct + ot * (1 - np.square(np.tanh(c_next))) * cct * da_next) * it * (1 - it)
dft = (dc_next * c_prev + ot *(1 - np.square(np.tanh(c_next))) * c_prev * da_next) * ft * (1 - ft) # Compute parameters related derivatives. Use equations (11)-(14) (≈8 lines)
dWf = np.dot(dft,np.concatenate((a_prev, xt), axis=0).T)
dWi = np.dot(dit,np.concatenate((a_prev, xt), axis=0).T)
dWc = np.dot(dcct,np.concatenate((a_prev, xt), axis=0).T)
dWo = np.dot(dot,np.concatenate((a_prev, xt), axis=0).T)
dbf = np.sum(dft, axis=1 ,keepdims = True)
dbi = np.sum(dit, axis=1, keepdims = True)
dbc = np.sum(dcct, axis=1, keepdims = True)
dbo = np.sum(dot, axis=1, keepdims = True) # Compute derivatives w.r.t previous hidden state, previous memory state and input. Use equations (15)-(17). (≈3 lines)
da_prev = np.dot(parameters['Wf'][:,:n_a].T,dft)+np.dot(parameters['Wi'][:,:n_a].T,dit)+np.dot(parameters['Wc'][:,:n_a].T,dcct)+np.dot(parameters['Wo'][:,:n_a].T,dot)
dc_prev = dc_next*ft+ot*(1-np.square(np.tanh(c_next)))*ft*da_next
dxt = np.dot(parameters['Wf'][:,n_a:].T,dft)+np.dot(parameters['Wi'][:,n_a:].T,dit)+np.dot(parameters['Wc'][:,n_a:].T,dcct)+np.dot(parameters['Wo'][:,n_a:].T,dot)
# parameters['Wf'][:, :n_a].T 每一行的 第 0 到 n_a-1 列的数据取出来
# parameters['Wf'][:, n_a:].T 每一行的 第 n_a 到最后列的数据取出来 # Save gradients in dictionary
gradients = {"dxt": dxt, "da_prev": da_prev, "dc_prev": dc_prev, "dWf": dWf,"dbf": dbf, "dWi": dWi,"dbi": dbi,
"dWc": dWc,"dbc": dbc, "dWo": dWo,"dbo": dbo} return gradients

LSTM推导的更多相关文章

  1. 【Deep Learning】RNN LSTM 推导

    http://blog.csdn.net/Dark_Scope/article/details/47056361 http://blog.csdn.net/hongmaodaxia/article/d ...

  2. 循环神经(LSTM)网络学习总结

    摘要: 1.算法概述 2.算法要点与推导 3.算法特性及优缺点 4.注意事项 5.实现和具体例子 6.适用场合 内容: 1.算法概述 长短期记忆网络(Long Short Term Memory ne ...

  3. 程序猿 tensorflow 入门开发及人工智能实战

    tensorflow 中文文档: http://www.tensorfly.cn http://wiki.jikexueyuan.com/project/tensorflow-zh/ tensorfl ...

  4. 机器学习 —— 基础整理(八)循环神经网络的BPTT算法步骤整理;梯度消失与梯度爆炸

    网上有很多Simple RNN的BPTT(Backpropagation through time,随时间反向传播)算法推导.下面用自己的记号整理一下. 我之前有个习惯是用下标表示样本序号,这里不能再 ...

  5. LSTM简介以及数学推导(FULL BPTT)

    http://blog.csdn.net/a635661820/article/details/45390671 前段时间看了一些关于LSTM方面的论文,一直准备记录一下学习过程的,因为其他事儿,一直 ...

  6. 《神经网络的梯度推导与代码验证》之LSTM的前向传播和反向梯度推导

    前言 在本篇章,我们将专门针对LSTM这种网络结构进行前向传播介绍和反向梯度推导. 关于LSTM的梯度推导,这一块确实挺不好掌握,原因有: 一些经典的deep learning 教程,例如花书缺乏相关 ...

  7. lstm bptt推导

    深蓝 nlp 180429这个有详细的讲解

  8. GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现

    GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现 RNN GRU matlab codes RNN网络考虑到了具有时间数列的样本数据,但是RNN仍存在着一些问题,比如随着 ...

  9. RNN求解过程推导与实现

    RNN求解过程推导与实现 RNN LSTM BPTT matlab code opencv code BPTT,Back Propagation Through Time. 首先来看看怎么处理RNN. ...

  10. Theano:LSTM源码解析

    最难读的Theano代码 这份LSTM代码的作者,感觉和前面Tutorial代码作者不是同一个人.对于Theano.Python的手法使用得非常娴熟. 尤其是在两重并行设计上: ①LSTM各个门之间并 ...

随机推荐

  1. pbootcms对接微信扫码登录代码核心片段和步骤(前后端)

    首先需要在微信公众平台或开放平台中创建应用,并获取到AppID和AppSecret. 在pbootcms中创建一个自定义模板页面(例如:wechat_login.html),并在该页面中添加以下代码, ...

  2. Github疯传!谷歌师兄的LeetCode刷题笔记开源了!

    有小伙伴私聊我说刚开始刷LeetCode的时候,感到很吃力,刷题效率很低.我以前刷题的时候也遇到这个问题,直到后来看到这个谷歌师兄总结的刷题笔记,发现LeetCode刷题都是套路呀,掌握这些套路之后, ...

  3. 【C#/.NET】使用ASP.NET Core对象池

    Nuget Microsoft.Extensions.ObjectPool 使用对象池的好处 减少初始化/资源分配,提高性能.这一条与线程池同理,有些对象的初始化或资源分配耗时长,复用这些对象减少初始 ...

  4. Dapr在Java中的实践 之 环境准备

    Dapr简介 Dapr (Distributed Application Runtime)是一个可移植的.事件驱动的运行时,它使任何开发人员都可以轻松地构建运行在云和边缘上的弹性.无状态和有状态的应用 ...

  5. 【python爬虫实战】用python爬取爱奇艺电视剧十大榜单的全部数据!

    目录 一.爬取目标 二.讲解代码 三.查看结果 四.视频演示 五.附完整源码 一.爬取目标 本次爬取的目标是,爱奇艺电视剧类目下的10个榜单:电视剧风云榜-爱奇艺风云榜 ​ 可以看到,这10个榜单包含 ...

  6. ChatGPT教我用200行代码写一个简版Vue框架 - OpenTiny

    AI 是未来最好的老师 最近,我正在准备一份关于 Vue 基础的学习材料.期间我突发奇想:能否利用现在热门的 ChatGPT 帮我创建学习内容?其实 Vue 本身不难学,特别是基础用法,但是,如果你想 ...

  7. CMU15445 (Fall 2020) 数据库系统 Project#2 - B+ Tree 详解(上篇)

    前言 考虑到 B+ 树较为复杂,CMU15-445 将 B+ 树实验拆成了两部分,这篇博客将介绍 Checkpoint#1 部分的实现过程,搭配教材 <DataBase System Conce ...

  8. 【VS Code 与 Qt6】运用事件过滤器批量操作子级组件

    如果某个派生自 QObject 的类重写 eventFilter 方法,那它就成了事件过滤器(Event Filter).该方法的声明如下: virtual bool eventFilter(QObj ...

  9. .Net7矢量化的性能优化

    前言 矢量化是性能优化的重要技术,也是寄托硬件层面的优化技术.本篇来看下. 概括 一:矢量化支持的问题: 矢量化的System.Runtime.Intrinsics.X86.Sse2.MoveMask ...

  10. C++面试八股文:用过std::set/std::map吗?

    某日二师兄参加XXX科技公司的C++工程师开发岗位第27面: 面试官:用过std::set/std::map吗? 二师兄:用过. 面试官:能介绍一下二者吗? 二师兄:std::set是一个有序的集合, ...