导读

本文讨论了深层神经网络训练困难的原因以及如何使用Highway Networks去解决深层神经网络训练的困难,并且在pytorch上实现了Highway Networks。

一 、Highway Networks 与 Deep Networks 的关系

深层神经网络相比于浅层神经网络具有更好的效果,在很多方面都已经取得了很好的效果,特别是在图像处理方面已经取得了很大的突破,然而,伴随着深度的增加,深层神经网络存在的问题也就越大,像大家所熟知的梯度消失问题,这也就造成了训练深层神经网络困难的难题。2015年由Rupesh Kumar Srivastava等人受到LSTM门机制的启发提出的网络结构(Highway Networks)很好的解决了训练深层神经网络的难题,Highway Networks 允许信息高速无阻碍的通过深层神经网络的各层,这样有效的减缓了梯度的问题,使深层神经网络不在仅仅具有浅层神经网络的效果。

二、Deep Networks 梯度消失/爆炸(vanishing and exploding gradient)问题

我们先来看一下简单的深层神经网络(仅仅几个隐层)

先把各个层的公式写出来

我们对W1求导:

W = W - lr * g(t)

以上公式仅仅是四个隐层的情况,当隐层的数量达到数十层甚至是数百层的情况下,一层一层的反向传播回去,当权值 < 1的时候,反向传播到某一层之后权值近乎不变,相当于输入x的映射,例如,g(t) =〖0.9〗^100已经是很小很小了,这就造成了只有前面几层能够正常的反向传播,后面的那些隐层仅仅相当于输入x的权重的映射,权重不进行更新。反过来,当权值 > 1的时候,会造成梯度爆炸,同样是仅仅前面的几层能更改正常学习,后面的隐层会变得很大。

三、Highway Networks Formula

  • Notation

    (.) 操作代表的是矩阵按位相乘

    sigmoid函数:

  • Highway Networks formula

    对于我们普通的神经网络,用非线性激活函数H将输入的x转换成y,公式1忽略了bias。但是,H不仅仅局限于激活函数,也采用其他的形式,像convolutional和recurrent。

    对于Highway Networks神经网络,增加了两个非线性转换层,一个是 T(transform gate) 和一个是 C(carry gate),通俗来讲,T表示输入信息经过convolutional或者是recurrent的信息被转换的部分,C表示的是原始输入信息x保留的部分 ,其中 T=sigmoid(wx + b)

    为了计算方便,这里定义了 C = 1 - T

    需要注意的是x,y, H, T的维度必须一致,要想保证其维度一致,可以采用sub-sampling或者zero-padding策略,也可以使用普通的线性层改变维度,使其一致。

    几个公式相比,公式3要比公式1灵活的多,可以考虑一下特殊的情况,T= 0的时候,y = x,原始输入信息全部保留,不做任何的改变,T = 1的时候,Y = H,原始信息全部转换,不在保留原始信息,仅仅相当于一个普通的神经网络。

四、Highway BiLSTM Networks

  • Highway BiLSTM Networks Structure Diagram

    下图是 Highway BiLSTM Networks 结构图:

    input:代表输入的词向量

    B:在本任务代表bidirection lstm,代表公式(2)中的 H

    T:代表公式(2)中的 T,是Highway Networks中的transform gate

    C:代表公式(2)中的 C,是Highway Networks中的carry gate

    Layer = n,代表Highway Networks中的第n层

    Highway:框出来的代表一层Highway Networks

    在这个结构图中,Highway Networks第 n - 1 层的输出作为第n层的输入

  • Highway BiLSTM Networks Demo

    pytorch搭建神经网络一般需要继承nn.Module这个类,然后实现里面的forward()函数,搭建Highway BiLSTM Networks写了两个类,并使用nn.ModuleList将两个类联系起来:

      class HBiLSTM(nn.Module):
    def __init__(self, args):
    super(HBiLSTM, self).__init__()
    ......
    def forward(self, x):
    # 实现Highway BiLSTM Networks的公式
    ......
      class HBiLSTM_model(nn.Module):
    def __init__(self, args):
    super(HBiLSTM_model, self).__init__()
    ......
    # args.layer_num_highway 代表Highway BiLSTM Networks有几层
    self.highway = nn.ModuleList([HBiLSTM(args) for _ in range(args.layer_num_highway)])
    ......
    def forward(self, x):
    ......
    # 调用HBiLSTM类的forward()函数
    for current_layer in self.highway:
    x, self.hidden = current_layer(x, self.hidden)

    HBiLSTM类的forward()函数里面我们实现Highway BiLSTM Networks的的公式

    首先我们先来计算H,上文已经说过,H可以是卷积或者是LSTM,在这里,normal_fc就是我们需要的H

       x, hidden = self.bilstm(x, hidden)
    # torch.transpose是转置操作
    normal_fc = torch.transpose(x, 0, 1)

    上文提及,x,y,H,T的维度必须保持一致,并且提供了两种策略,这里我们使用一个普通的Linear去转换维度

      source_x = source_x.contiguous()
    information_source = source_x.view(source_x.size(0) * source_x.size(1), source_x.size(2))
    information_source = self.gate_layer(information_source)
    information_source = information_source.view(source_x.size(0), source_x.size(1), information_source.size(1))

    也可以采用zero-padding的策略保证维度一致

      # you also can choose the strategy that zero-padding
    zeros = torch.zeros(source_x.size(0), source_x.size(1), carry_layer.size(2) - source_x.size(2))
    source_x = Variable(torch.cat((zeros, source_x.data), 2))

    维度一致之后我们就可以根据我们的公式来写代码了:

      # transformation gate layer in the formula is T
    transformation_layer = F.sigmoid(information_source)
    # carry gate layer in the formula is C
    carry_layer = 1 - transformation_layer
    # formula Y = H * T + x * C
    allow_transformation = torch.mul(normal_fc, transformation_layer)
    allow_carry = torch.mul(information_source, carry_layer)
    information_flow = torch.add(allow_transformation, allow_carry)

    最后的information_flow就是我们的输出,但是,还需要经过转换维度保证维度一致。

    更多的请参考Github: Highway Networks implement in pytorch

五、Highway BiLSTM Networks 实验结果

本次实验任务是使用Highway BiLSTM Networks 完成情感分类任务(一句话的态度分成积极或者是消极),数据来源于Twitter情感分类数据集,以下是数据集中的各个标签的句子个数:

下图是本次实验任务在2-class数据集中的测试结果。图中1-300在Highway BiLSTM Networks中表示Layer = 1,BiLSTM 隐层的维度是300维。

实验结果:从图中可以看出,简单的多层双向LSTM并没有带来情感分析性能的提升,尤其是是到了10层之后,效果有不如随机的猜测。当用上Highway Networks之后,虽然性能也在逐步的下降,但是下降的幅度有了明显的改善。

References

Highway Networks Pytorch的更多相关文章

  1. 基于pytorch实现HighWay Networks之Highway Networks详解

    (一)简述---承接上文---基于pytorch实现HighWay Networks之Train Deep Networks 上文已经介绍过Highway Netwotrks提出的目的就是解决深层神经 ...

  2. 基于pytorch实现HighWay Networks之Train Deep Networks

    (一)Highway Networks 与 Deep Networks 的关系 理论实践表明神经网络的深度是至关重要的,深层神经网络在很多方面都已经取得了很好的效果,例如,在1000-class Im ...

  3. Highway Networks

    一 .Highway Networks 与 Deep Networks 的关系 深层神经网络相比于浅层神经网络具有更好的效果,在很多方面都已经取得了很好的效果,特别是在图像处理方面已经取得了很大的突破 ...

  4. Highway Networks(高速路神经网络)

    Rupesh Kumar Srivastava (邮箱:RUPESH@IDSIA.CH)Klaus Greff (邮箱:KLAUS@IDSIA.CH)J¨ urgen Schmidhuber (邮箱: ...

  5. Paper | Highway Networks

    目录 1. 网络结构 2. 分析 解决的问题:在当时,人们认为 提高深度 是 提高精度 的法宝.但是网络训练也变得很困难.本文旨在解决深度网络训练难的问题,本质是解决梯度问题. 提出的网络:本文提出的 ...

  6. 【论文笔记】Training Very Deep Networks - Highway Networks

    目标: 怎么训练很深的神经网络 然而过深的神经网络会造成各种问题,梯度消失之类的,导致很难训练 作者利用了类似LSTM的方法,通过增加gate来控制transform前和transform后的数据的比 ...

  7. (转)Awesome PyTorch List

    Awesome-Pytorch-list 2018-08-10 09:25:16 This blog is copied from: https://github.com/Epsilon-Lee/Aw ...

  8. DenseNet算法详解——思路就是highway,DneseNet在训练时十分消耗内存

    论文笔记:Densely Connected Convolutional Networks(DenseNet模型详解) 2017年09月28日 11:58:49 阅读数:1814 [ 转载自http: ...

  9. Residual Networks <2015 ICCV, ImageNet 图像分类Top1>

    本文介绍一下2015 ImageNet中分类任务的冠军——MSRA何凯明团队的Residual Networks.实际上,MSRA是今年Imagenet的大赢家,不单在分类任务,MSRA还用resid ...

随机推荐

  1. 线性代数-矩阵-【3】矩阵加减 C和C++实现

    点击这里可以跳转至 [1]矩阵汇总:http://www.cnblogs.com/HongYi-Liang/p/7287369.html [2]矩阵生成:http://www.cnblogs.com/ ...

  2. 7.7 WPF后台代码绑定如果是属性,必须指定一下数据上下文才能实现,而函数(click)就不用

    如: private bool _IsExportWithImage; /// <summary> /// 是否选择导出曲线图 /// </summary> public bo ...

  3. Day1 - 服务器硬件基础

    1.1 关于运维人员 1.1.1 运维的职责 1.保证服务器7*24小时 运行 2.保证数据不能丢 3.提高用户的体验(网站打开的速度) 1.1.2 运维原则 简单.易用.高效  === 简单.粗暴 ...

  4. Java GC 日志详解

    详见:http://blog.yemou.net/article/query/info/tytfjhfascvhzxcyt105 java GC日志可以通过 +PrintGCDetails开启 以Pa ...

  5. 【DDD】领域驱动设计实践 —— 限界上下文识别

    本文从战略层面街上DDD中关于限界上下文的相关知识,并以ECO系统为例子,介绍如何识别上下文.限界上下文(Bounded Context)定义了每个模型的应用范围,在每个Bounded Context ...

  6. 使用SVG基本操作API

    前面的话 本文将详细介绍SVG基本操作API,并使用这些API操作实例效果 基础API 在javascript中,可以使用一些基本的API来对SVG进行操作 [NS地址] 因为SVG定义在其自身的命令 ...

  7. JavaScript学习日志(六):事件

    这篇随笔,深恶痛绝,敲到快结束的时候,凌晨00:19,突然闪退,也不知道是Mac的原因还是chrome的原因,重新打开的时候,以为自动保存有效果,心想没关系,结果他么的只保存了四分之一,WTF?!!! ...

  8. 阿里云Maven地址

    GFW 呵呵呵 下载几个jar要几个小时.....伤透了 直接替换国内阿里云的maven镜像地址  速度嗖嗖嗖的.... 配置 修改maven根目录下的conf文件夹中的setting.xml文件,内 ...

  9. 历上最强的音乐播放器(jetA…

    原文地址:历上最强的音乐播放器(jetAudio-8.0.5.320-Plus-VX-完全汉化版)下载作者:盖世天星 历上最强的音乐播放器(jetAudio-8.0.5.320-Plus-VX-完全汉 ...

  10. MySQLzip archive版本(5.7.19)安装教程

    1.  从官网下载zip archive版本http://dev.mysql.com/downloads/mysql/ 2. 解压缩至相应目录,并配置环境变量(将*\bin添加进path中): 3. ...