之前我们介绍了Recurrent neural network (RNN) 的原理:

http://blog.csdn.net/matrix_space/article/details/53374040

http://blog.csdn.net/matrix_space/article/details/53376870  

这里,我们构建一个简单的RNN网络,激励函数我们用sigmoid 函数,利用这个网络,我们来测试二进制数的运算。网络重复模块的表达式是:

ht=σ(Wh⋅ht−1+Wi⋅Xt)
ot=σ(Wo⋅ht)
e=12(yt−ot)2
import copy, numpy as np

np.random.seed(0)

# compute sigmoid nonlinearity
# 定义sigmoid 函数
def sigmoid (x):
output = 1 / (1+np.exp(-x))
return output # convert output to sigmoid function to its derivative
# 定义sigmoid 函数的导数
def sigmoid_output_to_derivative(output):
return output*(1-output) # training dataset generation
# 生成训练集
int2binary = {}
binary_dim = 8
max_number = pow (2, binary_dim)
binary = np.unpackbits(np.array([range(max_number)], dtype=np.uint8).T, axis=1)
# np.unpackbits 是将一个uint8的数组元素都转换成0-1二进制形式,这里max_number 是256, 
# binary 里面一共存了0-255 共 256 个二进制数 for i in range(max_number):
int2binary[i]=binary[i] # input parameters
alpha = 0.1
input_dim = 2
hidden_dim = 16
output_dim = 1 # weight 的初始化
synapse_0 = 2 * np.random.random((input_dim, hidden_dim))-1
synapse_1 = 2 * np.random.random((hidden_dim, output_dim))-1
synapse_h = 2 * np.random.random((hidden_dim, hidden_dim)) -1 synapse_0_update = np.zeros_like(synapse_0)
synapse_1_update = np.zeros_like(synapse_1)
synapse_h_update = np.zeros_like(synapse_h) for j in range(10000):
# 生成一个0-128之间的随机数
# 获取这个数的二进制序列
a_int = np.random.randint(max_number/2)
a = int2binary [a_int]    b_int = np.random.randint(max_number/2)
b = int2binary [b_int] c_int = a_int + b_int
c = int2binary [c_int] d = np.zeros_like(c) overallError = 0 layer_2_deltas = list ()
layer_1_values = list ()
layer_1_values.append(np.zeros(hidden_dim)) # moving along the positions in the binary encoding
for position in range(binary_dim): # generate input and output
X = np.array([[a[binary_dim-position-1], b[binary_dim-position-1]]])
y = np.array([[c[binary_dim-position-1]]]).T      # 计算重复模块的隐含层的输入和输出
layer_1 = sigmoid(np.dot(X, synapse_0) + np.dot(layer_1_values[-1], synapse_h)) layer_2 = sigmoid(np.dot(layer_1, synapse_1))      # BP
layer_2_error = y-layer_2
layer_2_deltas.append((layer_2_error)*sigmoid_output_to_derivative(layer_2))
overallError += np.abs(layer_2_error[0]) d[binary_dim-position-1] = np.round(layer_2[0][0]) layer_1_values.append(copy.deepcopy(layer_1)) future_layer_1_delta = np.zeros(hidden_dim) for position in range (binary_dim): X = np.array([[a[position], b[position]]])
layer_1 = layer_1_values [-position-1]
pre_layer_1 = layer_1_values[-position-2] layer_2_delta = layer_2_deltas[-position-1] layer_1_delta = (future_layer_1_delta.dot(synapse_h.T) + layer_2_delta.dot(
synapse_1.T)) * sigmoid_output_to_derivative(layer_1) # weight update
synapse_1_update += np.atleast_2d(layer_1).T.dot(layer_2_delta)
synapse_h_update += np.atleast_2d(pre_layer_1).T.dot(layer_1_delta)
synapse_0_update += X.T.dot(layer_1_delta) future_layer_1_delta = layer_1_delta synapse_0 += synapse_0_update * alpha
synapse_1 += synapse_1_update * alpha
synapse_h += synapse_h_update * alpha synapse_0_update *= 0
synapse_1_update *= 0
synapse_h_update *= 0 # print out progress
if (j % 500 == 0):
print ("Error: ", str(overallError))
print ("Pred:", str(d))
print ("True:", str(c))
out = 0
for index, x in enumerate(reversed(d)):
out += x*pow(2, index)
print (str(a_int) + "+" + str(b_int) + "=" + str(out))
print ("---------------")

运行结果:

('Error: ', '[ 3.45638663]')
('Pred:', '[0 0 0 0 0 0 0 1]')
('True:', '[0 1 0 0 0 1 0 1]')
9+60=1
---------------
('Error: ', '[ 4.02253884]')
('Pred:', '[0 1 1 0 1 0 1 1]')
('True:', '[1 0 0 0 0 0 0 1]')
112+17=107
---------------
('Error: ', '[ 3.63389116]')
('Pred:', '[1 1 1 1 1 1 1 1]')
('True:', '[0 0 1 1 1 1 1 1]')
28+35=255
---------------
('Error: ', '[ 3.99234598]')
('Pred:', '[1 1 0 1 1 0 1 0]')
('True:', '[1 0 1 1 0 0 1 1]')
78+101=218
---------------
('Error: ', '[ 3.91366595]')
('Pred:', '[0 1 0 0 1 0 0 0]')
('True:', '[1 0 1 0 0 0 0 0]')
116+44=72
---------------
('Error: ', '[ 3.65154804]')
('Pred:', '[1 1 0 1 1 0 1 0]')
('True:', '[1 1 0 1 1 1 1 0]')
122+100=218
---------------
('Error: ', '[ 3.72191702]')
('Pred:', '[1 1 0 1 1 1 1 1]')
('True:', '[0 1 0 0 1 1 0 1]')
4+73=223
---------------
('Error: ', '[ 3.35048888]')
('Pred:', '[1 0 0 1 1 0 0 1]')
('True:', '[1 0 0 1 0 0 0 1]')
76+69=153
---------------
('Error: ', '[ 3.5852713]')
('Pred:', '[0 0 0 0 1 0 0 0]')
('True:', '[0 1 0 1 0 0 1 0]')
71+11=8
---------------
('Error: ', '[ 2.43239777]')
('Pred:', '[0 1 1 0 1 0 1 1]')
('True:', '[0 1 1 0 1 0 1 1]')
72+35=107
---------------
('Error: ', '[ 2.53352328]')
('Pred:', '[1 0 1 0 0 0 1 0]')
('True:', '[1 1 0 0 0 0 1 0]')
81+113=162
---------------
('Error: ', '[ 1.87382863]')
('Pred:', '[0 1 1 0 0 0 1 0]')
('True:', '[0 1 1 0 0 0 1 0]')
21+77=98
---------------
('Error: ', '[ 0.57691441]')
('Pred:', '[0 1 0 1 0 0 0 1]')
('True:', '[0 1 0 1 0 0 0 1]')
81+0=81
---------------
('Error: ', '[ 0.75100965]')
('Pred:', '[0 0 1 1 1 1 0 0]')
('True:', '[0 0 1 1 1 1 0 0]')
49+11=60
---------------
('Error: ', '[ 1.42589952]')
('Pred:', '[1 0 0 0 0 0 0 1]')
('True:', '[1 0 0 0 0 0 0 1]')
4+125=129
---------------
('Error: ', '[ 0.6594703]')
('Pred:', '[0 1 1 0 1 1 0 0]')
('True:', '[0 1 1 0 1 1 0 0]')
80+28=108
---------------
('Error: ', '[ 0.47477457]')
('Pred:', '[0 0 1 1 1 0 0 0]')
('True:', '[0 0 1 1 1 0 0 0]')
39+17=56
---------------
('Error: ', '[ 0.7200904]')
('Pred:', '[1 0 1 0 1 0 0 0]')
('True:', '[1 0 1 0 1 0 0 0]')
123+45=168
---------------
('Error: ', '[ 0.21595037]')
('Pred:', '[0 0 0 0 1 1 1 0]')
('True:', '[0 0 0 0 1 1 1 0]')
11+3=14
---------------
('Error: ', '[ 0.52112049]')
('Pred:', '[1 0 1 0 1 0 1 1]')
('True:', '[1 0 1 0 1 0 1 1]')
71+100=171
---------------

参考来源:

https://github.com/llSourcell/recurrent_neural_net_demo

机器学习: Python with Recurrent Neural Network的更多相关文章

  1. Recurrent Neural Network系列2--利用Python,Theano实现RNN

    作者:zhbzz2007 出处:http://www.cnblogs.com/zhbzz2007 欢迎转载,也请保留这段声明.谢谢! 本文翻译自 RECURRENT NEURAL NETWORKS T ...

  2. Recurrent Neural Network系列4--利用Python,Theano实现GRU或LSTM

    yi作者:zhbzz2007 出处:http://www.cnblogs.com/zhbzz2007 欢迎转载,也请保留这段声明.谢谢! 本文翻译自 RECURRENT NEURAL NETWORK ...

  3. Recurrent Neural Network系列1--RNN(循环神经网络)概述

    作者:zhbzz2007 出处:http://www.cnblogs.com/zhbzz2007 欢迎转载,也请保留这段声明.谢谢! 本文翻译自 RECURRENT NEURAL NETWORKS T ...

  4. 课程五(Sequence Models),第一 周(Recurrent Neural Networks) —— 1.Programming assignments:Building a recurrent neural network - step by step

    Building your Recurrent Neural Network - Step by Step Welcome to Course 5's first assignment! In thi ...

  5. Recurrent Neural Network(递归神经网络)

    递归神经网络(RNN),是两种人工神经网络的总称,一种是时间递归神经网络(recurrent neural network),另一种是结构递归神经网络(recursive neural network ...

  6. Sequence Models Week 1 Building a recurrent neural network - step by step

    Building your Recurrent Neural Network - Step by Step Welcome to Course 5's first assignment! In thi ...

  7. Recurrent Neural Network(循环神经网络)

    Reference:   Alex Graves的[Supervised Sequence Labelling with RecurrentNeural Networks] Alex是RNN最著名变种 ...

  8. Recurrent Neural Network系列3--理解RNN的BPTT算法和梯度消失

    作者:zhbzz2007 出处:http://www.cnblogs.com/zhbzz2007 欢迎转载,也请保留这段声明.谢谢! 这是RNN教程的第三部分. 在前面的教程中,我们从头实现了一个循环 ...

  9. 循环神经网络(Recurrent Neural Network,RNN)

    为什么使用序列模型(sequence model)?标准的全连接神经网络(fully connected neural network)处理序列会有两个问题:1)全连接神经网络输入层和输出层长度固定, ...

随机推荐

  1. PJSIP开源库详解

    PJSIP是一个包含了SIP.SDP.RTP.RTCP.STUN.ICE等协议实现的开源库.它把基于信令协议SIP的多媒体框架和NAT穿透功能整合成高层次.抽象的多媒体通信API,这套API能够很容易 ...

  2. word中公式的排版及标题列表

    1.首先建好你的标题,如标题1,标题2等等,你能够依次改变它们的字体,段落等格式,新建格式例如以下图所看到的 红圈处即建立新的格式,你能够建立不论什么你想要的格式,非常方便: 2.当你建立好了多个标题 ...

  3. JAVA Concurrent包 中的并发集合类

    我们平时写程序需要经常用到集合类,比如ArrayList.HashMap等,但是这些集合不能够实现并发运行机制,这样在服务器上运行时就会非常的消耗资源和浪费时间,并且对这些集合进行迭代的过程中不能进行 ...

  4. linux下FAT32格式u盘只读的问题及解决方法

    以下是网上看到的解决办法:http://blog.csdn.net/heqiuya/article/details/7870554 其实是掉电保护,之前挂在的SD变成了制度文件,只需要将SD卡重新挂载 ...

  5. [Git] How to rename your remote branch

    Rename your local foo branch with bar: git branch -m foo bar Remember this will add the new branch w ...

  6. C语言学习笔记:12_变量的存储方式和生存期

    /* * 12_变量的存储方式和生存期.c * * Created on: 2015年7月5日 * Author: zhong */ #include <stdio.h> #include ...

  7. php中读取文件内容的几种方法。(file_get_contents:将文件内容读入一个字符串)

    php中读取文件内容的几种方法.(file_get_contents:将文件内容读入一个字符串) 一.总结 php中读取文件内容的几种方法(file_get_contents:将文件内容读入一个字符串 ...

  8. Mysql用户本机登陆不成功的解决

    mysql新建一个用户,本机不能登陆,但是远程能够登陆,不知什么原因,最后查阅 http://blog.itpub.net/12679300/viewspace-1453490/ 这篇文章得以解决,进 ...

  9. [Ramda] Refactor to Point Free Functions with Ramda using compose and converge

    In this lesson we'll take some existing code and refactor it using some functions from the Ramda lib ...

  10. DOM常用的四大对象是什么?

    DOM常用的四大对象是什么? 一.总结 一句话总结: 1.关注结构,关注主干 2.从主干处着手的话,可以发现dom就是四个东西,document(文档),element,attribute,event ...