GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现
GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现
RNN网络考虑到了具有时间数列的样本数据,但是RNN仍存在着一些问题,比如随着时间的推移,RNN单元就失去了对很久之前信息的保存和处理的能力,而且存在着gradient vanishing问题。
所以有些特殊类型的RNN网络相继被提出,比如LSTM(long short term memory)和GRU(gated recurrent unit)(Chao,et al. 2014).这里我主要推导一下GRU参数的迭代过程
GRU单元结构如下图所示

数据流过程如下

其中
表示Hadamard积,即对应元素乘积;下标表示节点的index,上标表示时刻;
表示隐层到输出层的参数矩阵,
分别是隐层和输出层的节点个数;
分别表示输入和上一时刻隐层到更新门z的连接矩阵,
表示输入数据的维度;
分别表示输入和上一时刻隐层到重置门r的连接矩阵;
分别表示输入和上一时刻的隐层到待选状态
的连接矩阵。
针对于时刻t,使用链式求导法则,计算参数矩阵的梯度,其中E是代价函数,首先计算对隐层输出的梯度,因为隐层输出牵涉到多个时刻

所以

其中
分别是对应激活函数的线性和部分
现在对参数计算梯度

令

则

将上面的式子矢量化(行向量)表示:


那接下来使用matlab来实现一个小例子,看看GRU的效果,同样是二进制相加的问题
- function error= GRUtest( )
- % 初始化训练数据
- uNum=16;%单元个数
- maxInt=2^uNum;
- % 初始化网络结构
- xdim=2;
- ydim=1;
- hdim=16;
- eta=0.1;
- %初始化网络参数
- Wy=rand(hdim,ydim)*2-1;
- Wr=rand(xdim,hdim)*2-1;
- Ur=rand(hdim,hdim)*2-1;
- W =rand(xdim,hdim)*2-1;
- U =rand(hdim,hdim)*2-1;
- Wz=rand(xdim,hdim)*2-1;
- Uz=rand(hdim,hdim)*2-1;
- rvalues=zeros(uNum+1,hdim);
- zvalues=zeros(uNum+1,hdim);
- hbarvalues=zeros(uNum,hdim);
- hvalues = zeros(uNum,hdim);
- yvalues=zeros(uNum,ydim);
- for p=1:10000
- aInt=randi(maxInt/2);
- bInt=randi(maxInt/2);
- cInt=aInt+bInt;
- at=dec2bin(aInt)-'0';
- bt=dec2bin(bInt)-'0';
- ct=dec2bin(cInt)-'0';
- a=zeros(1,uNum);
- b=zeros(1,uNum);
- c=zeros(1,uNum);
- a(1:size(at,2))=at(end:-1:1);
- b(1:size(bt,2))=bt(end:-1:1);
- c(1:size(ct,2))=ct(end:-1:1);
- xvalues=[a;b]';
- d=c';
- % 前向计算
- rvalues(1,:)=sigmoid(xvalues(1,:)*Wr);
- hbarvalues(1,:)=outTanh(xvalues(1,:)*W);
- zvalues(1,:)=sigmoid(xvalues(1,:)*Wz);
- hvalues(1,:)=zvalues(1,:).*hbarvalues(1,:);
- yvalues(1,:)=sigmoid(hvalues(1,:)*Wy);
- for t=2:uNum
- rvalues(t,:)=sigmoid(xvalues(t,:)*Wr+hvalues(t-1,:)*Ur);
- hbarvalues(t,:)=outTanh(xvalues(t,:)*W+(rvalues(t,:).*hvalues(t-1,:))*U);
- zvalues(t,:)=sigmoid(xvalues(t,:)*Wz+hvalues(t-1,:)*Uz);
- hvalues(t,:)=(1-zvalues(t,:)).*hvalues(t-1,:)+zvalues(t,:).*hbarvalues(t,:);
- yvalues(t,:)=sigmoid(hvalues(t,:)*Wy);
- end
- % 误差反向传播
- delta_r_next=zeros(1,hdim);
- delta_z_next=zeros(1,hdim);
- delta_h_next=zeros(1,hdim);
- delta_next=zeros(1,hdim);
- dWy=zeros(hdim,ydim);
- dWr=zeros(xdim,hdim);
- dUr=zeros(hdim,hdim);
- dW=zeros(xdim,hdim);
- dU=zeros(hdim,hdim);
- dWz=zeros(xdim,hdim);
- dUz=zeros(hdim,hdim);
- for t=uNum:-1:2
- delta_y=(yvalues(t,:)-d(t,:)).*diffsigmoid(yvalues(t,:));
- delta_h=delta_y*Wy'+delta_z_next*Uz'+delta_next*U'.*rvalues(t+1,:)+delta_r_next*Ur'+delta_h_next.*(1-zvalues(t+1,:));
- delta_z=delta_h.*(hbarvalues(t,:)-hvalues(t-1,:)).*diffsigmoid(zvalues(t,:));
- delta =delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:));
- delta_r=hvalues(t-1,:).*((delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:)))*U').*diffsigmoid(rvalues(t,:));
- dWy=dWy+hvalues(t,:)'*delta_y;
- dWz=dWz+xvalues(t,:)'*delta_z;
- dUz=dUz+hvalues(t-1,:)'*delta_z;
- dW =dW+xvalues(t,:)'*delta;
- dU =dU+(rvalues(t,:).*hvalues(t-1,:))'*delta ;
- dWr=dWr+xvalues(t,:)'*delta_r;
- dUr=dUr+hvalues(t-1,:)'*delta_r;
- delta_r_next=delta_r;
- delta_z_next=delta_z;
- delta_h_next=delta_h;
- delta_next =delta;
- end
- t=1;
- delta_y=(yvalues(t,:)-d(t,:)).*diffsigmoid(yvalues(t,:));
- delta_h=delta_y*Wy'+delta_z_next*Uz'+delta_next*U'.*rvalues(t+1,:)+delta_r_next*Ur'+delta_h_next.*(1-zvalues(t+1,:));
- delta_z=delta_h.*(hbarvalues(t,:)-0).*diffsigmoid(zvalues(t,:));
- delta =delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:));
- delta_r=0.*((delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:)))*U').*diffsigmoid(rvalues(t,:));
- dWy=dWy+hvalues(t,:)'*delta_y;
- dWz=dWz+xvalues(t,:)'*delta_z;
- dW =dW+xvalues(t,:)'*delta;
- dWr=dWr+xvalues(t,:)'*delta_r;
- Wy = Wy-eta*dWy;
- Wr = Wr-eta*dWr;
- Ur = Ur-eta*dUr;
- W = W -eta*dW;
- U = U-eta*dU;
- Wz = Wz-eta*dWz;
- Uz = Uz-eta*dUz;
- error = (norm(yvalues-d,2))/2.0;
- %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
- if mod(p,500)==0
- fprintf('******************第%s次迭代****************\n',int2str(p));
- yvalues=round(yvalues(end:-1:1));
- y=bin2dec(int2str(yvalues'));
- fprintf('y=%d\n',y);
- fprintf('c=%d\n',cInt);
- fprintf('样本误差:e=%f\n',error);
- end
- end
- end
- function f=sigmoid(x)
- f=1./(1+exp(-x));
- end
- function fd = diffsigmoid(f)
- fd=f.*(1-f);
- end
- function g=outTanh(x)
- g=1-2./(1+exp(2*x));
- end
- function gd=diffoutTanh(g)
- gd=1-g.^2;
- end
部分实验结果

GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现的更多相关文章
- Gated Recurrent Unit (GRU)
Gated Recurrent Unit (GRU) Outline Backgr ...
- Gated Recurrent Unit (GRU)公式简介
update gate $z_t$: defines how much of the previous memory to keep around. \[z_t = \sigma ( W^z x_t+ ...
- pytorch_SRU(Simple Recurrent Unit)
导读 本文讨论了最新爆款论文(Training RNNs as Fast as CNNs)提出的LSTM变种SRU(Simple Recurrent Unit),以及基于pytorch实现了SRU,并 ...
- Simple Recurrent Unit,单循环单元
SRU(Simple Recurrent Unit),单循环单元 src/nnet/nnet-recurrent.h 使用Tanh作为非线性单元 SRU不保留内部状态 训练时,每个训练序列以零向量开始 ...
- php网页,想弹出对话框, 消息框 简单代码
php网页,想弹出对话框, 消息框 简单代码 <?php echo "<script language=\"JavaScript\">alert(\&q ...
- C# 客服端上传文件与服务器器端接收 (简单代码)
简单代码: /*服务器端接收写入 可以实现断点续传*/ public string ConnectUpload(string newfilename,string filepath,byte[] fi ...
- Redis:安装、配置、操作和简单代码实例(C语言Client端)
Redis:安装.配置.操作和简单代码实例(C语言Client端) - hj19870806的专栏 - 博客频道 - CSDN.NET Redis:安装.配置.操作和简单代码实例(C语言Client端 ...
- 1 go 开发环境搭建与简单代码实现
什么是go语言 go是一门并发支持,垃圾回收的编译型 系统编程语言,旨在创造一门具有静态编译语言的高性能和动态语言的高效开发之间拥有一个良好平衡点 的一门编程语言. go有什么优点? 自动垃圾回收机制 ...
- 使用WinSCP进行简单代码文件同步
前言传输协议FTPFTPSSFTPSCP为什么使用WinSCP?CMD的FTP命令FileZillaPuTTYrsyncSublime的SFTP插件WinSCPWinSCP进行简单代码文件同步总结备注 ...
随机推荐
- PHP 位运算(&, |, ^, ~, <<, >>)及 PHP错误级别报告设置(error_reporting) 详解
位运算符允许对整型数中指定的位进行求值和操作. 位运算符 例子 名称 结果 $a & $b And(按位与) 将把 $a 和 $b 中都为 1 的位设为 1. $a | $b Or(按位或) ...
- 常用ubuntu命令
解压缩.7z sudo apt-get install p7zip-full 7z x PACKAGE.7z 查看图片 eog A.png 关闭打开触摸板(触点) sudo rmmod psmouse ...
- 从sum()求和引发的思考
sum()求和是一个非常简单的函数,以前我的写法是这样,我想大部分和我一样刚开始学习JS的同学写出来的也会是这样. function sum() { var total=null; for(var i ...
- iOS 强制退出程序APP代码
1.先po代码 UIAlertView* alert = [[UIAlertView alloc] initWithTitle:self.exitapplication message:@" ...
- 报错:1130-host ... is not allowed to connect to this MySql server
报错:1130-host ... is not allowed to connect to this MySql server 解决方法: 1. 改表法. 可能是你的帐号不允许从远程登陆,只能在l ...
- MySQL 5.7 Replication 相关新功能说明
背景: MySQL5.7在主从复制上面相对之前版本多了一些新特性,包括多源复制.基于组提交的并行复制.在线修改Replication Filter.GTID增强.半同步复制增强等.因为都是和复制相关, ...
- yii框架安装
YII安装: 下载最版本http://www.framework.com 下载高级的->yii with advanced APPlication template 解压至访问目录下 ...
- ModelMapper 中高级使用 java
ModelMapper 是一个java对象自动映射的第三方架包,用起来很方便,配合阿里的frstjson可以极大简化后台代码. 但是ModelMapper 中文使用说明很少,官网http://mode ...
- 顺序查找SequentialSearch
#include <stdio.h>int SequentialSearch(int *a,int n,int x);int main(void){ //num代表查找的数 int num ...
- Ubuntu安装Mysqlcluster集群
可参考:http://xuwensong.elastos.org/2014/01/13/ubuntu-%E4%B8%8Bmysql-cluster%E5%AE%89%E8%A3%85%E5%92%8C ...