代码:

    def forward(self, x):
        '''
        根据式1-式6进行前向计算
        '''
        self.times += 1
        # 遗忘门
        fg = self.calc_gate(x, self.Wfx, self.Wfh,
                            self.bf, self.gate_activator)
        self.f_list.append(fg)
        # 输入门
        ig = self.calc_gate(x, self.Wix, self.Wih,
                            self.bi, self.gate_activator)
        self.i_list.append(ig)
        # 输出门
        og = self.calc_gate(x, self.Wox, self.Woh,
                            self.bo, self.gate_activator)
        self.o_list.append(og)
        # 即时状态
        ct = self.calc_gate(x, self.Wcx, self.Wch,
                            self.bc, self.output_activator)
        self.ct_list.append(ct)
        # 单元状态
        c = fg * self.c_list[self.times - 1] + ig * ct
        self.c_list.append(c)
        # 输出
        h = og * self.output_activator.forward(c)
        self.h_list.append(h)

    def calc_gate(self, x, Wx, Wh, b, activator):
        '''
        计算门
        '''
        h = self.h_list[self.times - 1]  # 上次的LSTM输出
        net = np.dot(Wh, h) + np.dot(Wx, x) + b
        gate = activator.forward(net)
        return gate

    def calc_delta_k(self, k):
        '''
        根据k时刻的delta_h,计算k时刻的delta_f、
        delta_i、delta_o、delta_ct,以及k-1时刻的delta_h
        '''
        # 获得k时刻前向计算的值
        ig = self.i_list[k]
        og = self.o_list[k]
        fg = self.f_list[k]
        ct = self.ct_list[k]
        c = self.c_list[k]
        c_prev = self.c_list[k - 1]
        tanh_c = self.output_activator.forward(c)
        delta_k = self.delta_h_list[k]

        # 根据式9计算delta_o
        delta_o = (delta_k * tanh_c *
                   self.gate_activator.backward(og))
        delta_f = (delta_k * og *
                   (1 - tanh_c * tanh_c) * c_prev *
                   self.gate_activator.backward(fg))
        delta_i = (delta_k * og *
                   (1 - tanh_c * tanh_c) * ct *
                   self.gate_activator.backward(ig))
        delta_ct = (delta_k * og *
                    (1 - tanh_c * tanh_c) * ig *
                    self.output_activator.backward(ct))
        delta_h_prev = (
                np.dot(delta_o.transpose(), self.Woh) +
                np.dot(delta_i.transpose(), self.Wih) +
                np.dot(delta_f.transpose(), self.Wfh) +
                np.dot(delta_ct.transpose(), self.Wch)
        ).transpose()

        # 保存全部delta值
        self.delta_h_list[k - 1] = delta_h_prev
        self.delta_f_list[k] = delta_f
        self.delta_i_list[k] = delta_i
        self.delta_o_list[k] = delta_o
        self.delta_ct_list[k] = delta_ct

    def calc_gradient_t(self, t):
        '''
        计算每个时刻t权重的梯度
        '''
        h_prev = self.h_list[t - 1].transpose()
        Wfh_grad = np.dot(self.delta_f_list[t], h_prev)
        bf_grad = self.delta_f_list[t]
        Wih_grad = np.dot(self.delta_i_list[t], h_prev)
        bi_grad = self.delta_f_list[t]
        Woh_grad = np.dot(self.delta_o_list[t], h_prev)
        bo_grad = self.delta_f_list[t]
        Wch_grad = np.dot(self.delta_ct_list[t], h_prev)
        bc_grad = self.delta_ct_list[t]
        return Wfh_grad, bf_grad, Wih_grad, bi_grad, \
               Woh_grad, bo_grad, Wch_grad, bc_grad

    def calc_gradient(self, x):
        # 初始化遗忘门权重梯度矩阵和偏置项
        self.Wfh_grad, self.Wfx_grad, self.bf_grad = (
            self.init_weight_gradient_mat())
        # 初始化输入门权重梯度矩阵和偏置项
        self.Wih_grad, self.Wix_grad, self.bi_grad = (
            self.init_weight_gradient_mat())
        # 初始化输出门权重梯度矩阵和偏置项
        self.Woh_grad, self.Wox_grad, self.bo_grad = (
            self.init_weight_gradient_mat())
        # 初始化单元状态权重梯度矩阵和偏置项
        self.Wch_grad, self.Wcx_grad, self.bc_grad = (
            self.init_weight_gradient_mat())

        # 计算对上一次输出h的权重梯度
        for t in range(self.times, 0, -1):
            # 计算各个时刻的梯度
            (Wfh_grad, bf_grad,
             Wih_grad, bi_grad,
             Woh_grad, bo_grad,
             Wch_grad, bc_grad) = (
                self.calc_gradient_t(t))
            # 实际梯度是各时刻梯度之和
            self.Wfh_grad += Wfh_grad
            self.bf_grad += bf_grad
            self.Wih_grad += Wih_grad
            self.bi_grad += bi_grad
            self.Woh_grad += Woh_grad
            self.bo_grad += bo_grad
            self.Wch_grad += Wch_grad
            self.bc_grad += bc_grad

        # 计算对本次输入x的权重梯度
        xt = x.transpose()
        self.Wfx_grad = np.dot(self.delta_f_list[-1], xt)
        self.Wix_grad = np.dot(self.delta_i_list[-1], xt)
        self.Wox_grad = np.dot(self.delta_o_list[-1], xt)
        self.Wcx_grad = np.dot(self.delta_ct_list[-1], xt)

参考:

https://zybuluo.com/hanbingtao/note/581764

https://www.cnblogs.com/ratels/p/11416515.html

零基础入门深度学习(6) - 长短时记忆网络(LSTM)的更多相关文章

  1. (转)零基础入门深度学习(6) - 长短时记忆网络(LSTM)

    无论即将到来的是大数据时代还是人工智能时代,亦或是传统行业使用人工智能在云上处理大数据的时代,作为一个有理想有追求的程序员,不懂深度学习(Deep Learning)这个超热的技术,会不会感觉马上就o ...

  2. C#区块链零基础入门,学习路线图 转

    C#区块链零基础入门,学习路线图 一.1分钟短视频<区块链100问>了解区块链基本概念 http://tech.sina.com.cn/zt_d/blockchain_100/ 二.C#区 ...

  3. 长短时记忆网络(LSTM)

    长短时记忆网络 循环神经网络很难训练的原因导致它的实际应用中很处理长距离的依赖.本文将介绍改进后的循环神经网络:长短时记忆网络(Long Short Term Memory Network, LSTM ...

  4. 【零基础学深度学习】动手学深度学习2.0--tensorboard可视化工具简单使用

    1 引言 老师让我将线性回归训练得出的loss值进行可视化,于是我使用了tensorboard将其应用到Pytorch中,用于Pytorch的可视化. 2 环境安装 本教程代码环境依赖: python ...

  5. 长短时记忆网络LSTM和条件随机场crf

    LSTM 原理 CRF 原理 给定一组输入随机变量条件下另一组输出随机变量的条件概率分布模型.假设输出随机变量构成马尔科夫随机场(概率无向图模型)在标注问题应用中,简化成线性链条件随机场,对数线性判别 ...

  6. 机器学习与Tensorflow(5)——循环神经网络、长短时记忆网络

    1.循环神经网络的标准模型 前馈神经网络能够用来建立数据之间的映射关系,但是不能用来分析过去信号的时间依赖关系,而且要求输入样本的长度固定 循环神经网络是一种在前馈神经网络中增加了分亏链接的神经网络, ...

  7. 函数:我的地盘听我的 - 零基础入门学习Python019

    函数:我的地盘听我的 让编程改变世界 Change the world by program 函数与过程 在小甲鱼另一个实践性超强的编程视频教学<零基础入门学习Delphi>中,我们谈到了 ...

  8. 【Python教程】《零基础入门学习Python》(小甲鱼)

    [Python教程]<零基础入门学习Python>(小甲鱼) 讲解通俗易懂,诙谐. 哈哈哈. https://www.bilibili.com/video/av27789609

  9. 《零基础入门学习Python》【第一版】视频课后答案第001讲

    测试题答案: 0. Python 是什么类型的语言? Python是脚本语言 脚本语言(Scripting language)是电脑编程语言,因此也能让开发者藉以编写出让电脑听命行事的程序.以简单的方 ...

随机推荐

  1. 【Python】蟒蛇绘制(三种方式+import用法)

    第一种方式不会出现函数重名问题,而第二种会.可以用第三种解决问题 方式一: #pythondraw.py import turtle #引用 绘制(海龟)库 turtle.setup(650,350, ...

  2. Http接口安全设计

    1.  完全开放 2. 基本验证 appid(企业唯一标识)+args(请求参数)->sign(摘要). 3. 时效控制 appid+args+timestamp(时间戳)->sign. ...

  3. 理解javaBean

    1:什么是JavaBean 组件?使用JavaBean 组件有什么优点?答案:现在软件开发都已经转向了基于组件的开发,目前具备代表性的组件技术有微软的COM.COM+,有Sun 的JavaBean 和 ...

  4. win api + ffmpeg 播放 mp3 音乐

    暂时记录,还有很多需要改善的地方一直没弄好,比如不知道怎么对上正确的播放速度. 有一些多余的代码,不用在意. #include <iostream> #include <fstrea ...

  5. DFA 简易正则表达式匹配

    一个只能匹配非常简单的(字母 . + *)共 4 种状态的正则表达式语法的自动机(注意,仅限 DFA,没考虑 NFA): 好久之前写的了,记得有个 bug 一直没解决... #include < ...

  6. 基于语音识别、音文同步、图像OCR的字幕解决方案HtwMedia介绍

    背景介绍 俗话说,“好记性不如乱笔头”,这充分说明了文字归档的重要性.如今随着微信.抖音等移动端app的使用越来越广,人们生产音.视频内容也越来越便捷.而相比语音和视频而言,文字具有易存档.易检索.易 ...

  7. Loppinha, the boy who likes sopinha Gym - 101875E (dp,记忆化搜索)

    https://vjudge.net/contest/299302#problem/E 题意:给出一个01 0101串,然后能量计算是连续的1就按1, 2, 3的能量加起来.然后给出起始的能量,求最少 ...

  8. 【StarUML】组件图

    架构设计中可视化地表达各个组件之间依赖关系以及组件的接口调用情况 1.元素 a.组件 b.接口 b1.组件暴露接口 暴露接口,需要先画一个接口 然后建立组件和接口的联系,这里是暴露接口,那么这个连线就 ...

  9. C++ 宏定义创建(销毁)单例

    Util.h: #define CREATE_SINGLETON_POINTER(CLASS,INSTANCE,MUTEX) if (NULL == INSTANCE) \ { \ MUTEX.loc ...

  10. power-plan如何定

    Power-Plan或者说PG如何打,这是一个仁者见仁智者见智的问题,没有一个标准的答案,因为有各种各样的影响因素.本文将列举一些可能的影响因素: 1.和design  相关 1) Utilizati ...