http://blog.csdn.net/zjm750617105/article/details/51321889

本文是初学keras这两天来,自己仿照addition_rnn.py,写的一个实例,数据处理稍微有些不同,但是准确性相比addition_rnn.py 差一点,下面直接贴代码,
解释和注释都在代码里边。

  1. <span style="font-family: Arial, Helvetica, sans-serif;">#coding:utf-8</span>
    1. from keras.models import  Sequential
    2. from keras.layers.recurrent import LSTM
    3. from utils import  log
    4. from numpy import random
    5. import numpy as np
    6. from  keras.layers.core import RepeatVector, TimeDistributedDense, Activation
    7. '''''
    8. 先用lstm实现一个计算加法的keras版本, 根据addition_rnn.py改写
    9. size: 500
    10. 10次: test_acu = 0.3050  base_acu= 0.3600
    11. 30次: rest_acu = 0.3300  base_acu= 0.4250
    12. size: 50000
    13. 10次: test_acu: loss: 0.4749 - acc: 0.8502 - val_loss: 0.4601 - val_acc: 0.8539
    14. base_acu: loss: 0.3707 - acc: 0.9008 - val_loss: 0.3327 - val_acc: 0.9135
    15. 20次: test_acu: loss: 0.1536 - acc: 0.9505 - val_loss: 0.1314 - val_acc: 0.9584
    16. base_acu: loss: 0.0538 - acc: 0.9891 - val_loss: 0.0454 - val_acc: 0.9919
    17. 30次: test_acu: loss: 0.0671 - acc: 0.9809 - val_loss: 0.0728 - val_acc: 0.9766
    18. base_acu: loss: 0.0139 - acc: 0.9980 - val_loss: 0.0502 - val_acc: 0.9839
    19. '''
    20. log = log()
    21. #defination the global variable
    22. training_size = 50000
    23. hidden_size = 128
    24. batch_size = 128
    25. layers = 1
    26. maxlen = 7
    27. single_digit = 3
    28. def generate_data():
    29. log.info("generate the questions and answers")
    30. questions = []
    31. expected = []
    32. seen = set()
    33. while len(seen) < training_size:
    34. num1 = random.randint(1, 999) #generate a num [1,999]
    35. num2 = random.randint(1, 999)
    36. #用set来存储又有排序,来保证只有不同数据和结果
    37. key  = tuple(sorted((num1,num2)))
    38. if key in seen:
    39. continue
    40. seen.add(key)
    41. q = '{}+{}'.format(num1,num2)
    42. query = q + ' ' * (maxlen - len(q))
    43. ans = str(num1 + num2)
    44. ans = ans + ' ' * (single_digit + 1 - len(ans))
    45. questions.append(query)
    46. expected.append(ans)
    47. return questions, expected
    48. class CharacterTable():
    49. '''''
    50. encode: 将一个str转化为一个n维数组
    51. decode: 将一个n为数组转化为一个str
    52. 输入输出分别为
    53. character_table =  [' ', '+', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    54. 如果一个question = [' 123+23']
    55. 那个改question对应的数组就是(7,12):
    56. 同样expected最大是一个四位数[' 146']:
    57. 那么ans对应的数组就是[4,12]
    58. '''
    59. def __init__(self, chars, maxlen):
    60. self.chars = sorted(set(chars))
    61. '''''
    62. >>> b = [(c, i) for i, c in enumerate(a)]
    63. >>> dict(b)
    64. {' ': 0, '+': 1, '1': 3, '0': 2, '3': 5, '2': 4, '5': 7, '4': 6, '7': 9, '6': 8, '9': 11, '8': 10}
    65. 得出的结果是无序的,但是下面这种方式得出的结果是有序的
    66. '''
    67. self.char_index = dict((c, i) for i, c in enumerate(self.chars))
    68. self.index_char = dict((i, c) for i, c in enumerate(self.chars))
    69. self.maxlen = maxlen
    70. def encode(self, C, maxlen):
    71. X = np.zeros((maxlen, len(self.chars)))
    72. for i, c in enumerate(C):
    73. X[i, self.char_index[c]] = 1
    74. return X
    75. def decode(self, X, calc_argmax=True):
    76. if calc_argmax:
    77. X = X.argmax(axis=-1)
    78. return ''.join(self.index_char[x] for x in X)
    79. chars = '0123456789 +'
    80. character_table = CharacterTable(chars,len(chars))
    81. questions , expected = generate_data()
    82. log.info('Vectorization...') #失量化
    83. inputs = np.zeros((len(questions), maxlen, len(chars))) #(5000, 7, 12)
    84. labels = np.zeros((len(expected), single_digit+1, len(chars))) #(5000, 4, 12)
    85. log.info("encoding the questions and get inputs")
    86. for i, sentence in enumerate(questions):
    87. inputs[i] = character_table.encode(sentence, maxlen=len(sentence))
    88. #print("questions is ", questions[0])
    89. #print("X is ", inputs[0])
    90. log.info("encoding the expected and get labels")
    91. for i, sentence in enumerate(expected):
    92. labels[i] = character_table.encode(sentence, maxlen=len(sentence))
    93. #print("expected is ", expected[0])
    94. #print("y is ", labels[0])
    95. log.info("total inputs is %s"%str(inputs.shape))
    96. log.info("total labels is %s"%str(labels.shape))
    97. log.info("build model")
    98. model = Sequential()
    99. '''''
    100. LSTM(output_dim, init='glorot_uniform', inner_init='orthogonal',
    101. forget_bias_init='one', activation='tanh',
    102. inner_activation='hard_sigmoid',
    103. W_regularizer=None, U_regularizer=None, b_regularizer=None,
    104. dropout_W=0., dropout_U=0., **kwargs)
    105. output_dim: 输出层的维数,或者可以用output_shape
    106. init:
    107. uniform(scale=0.05) :均匀分布,最常用的。Scale就是均匀分布的每个数据在-scale~scale之间。此处就是-0.05~0.05。scale默认值是0.05;
    108. lecun_uniform:是在LeCun在98年发表的论文中基于uniform的一种方法。区别就是lecun_uniform的scale=sqrt(3/f_in)。f_in就是待初始化权值矩阵的行。
    109. normal:正态分布(高斯分布)。
    110. Identity :用于2维方阵,返回一个单位阵.
    111. Orthogonal:用于2维方阵,返回一个正交矩阵. lstm默认
    112. Zero:产生一个全0矩阵。
    113. glorot_normal:基于normal分布,normal的默认 sigma^2=scale=0.05,而此处sigma^2=scale=sqrt(2 / (f_in+ f_out)),其中,f_in和f_out是待初始化矩阵的行和列。
    114. glorot_uniform:基于uniform分布,uniform的默认scale=0.05,而此处scale=sqrt( 6 / (f_in +f_out)) ,其中,f_in和f_out是待初始化矩阵的行和列。
    115. W_regularizer , b_regularizer  and activity_regularizer:
    116. 官方文档: http://keras.io/regularizers/
    117. from keras.regularizers import l2, activity_l2
    118. model.add(Dense(64, input_dim=64, W_regularizer=l2(0.01), activity_regularizer=activity_l2(0.01)))
    119. 加入规则项主要是为了在小样本数据下过拟合现象的发生,我们都知道,一半在训练过程中解决过拟合现象的方法主要中两种,一种是加入规则项(权值衰减), 第二种是加大数据量
    120. 很显然,加大数据量一般是不容易的,而加入规则项则比较容易,所以在发生过拟合的情况下,我们一般都采用加入规则项来解决这个问题.
    121. '''
    122. model.add(LSTM(hidden_size, input_shape=(maxlen, len(chars)))) #(7,12) 输入层
    123. '''''
    124. keras.layers.core.RepeatVector(n)
    125. 把1维的输入重复n次。假设输入维度为(nb_samples, dim),那么输出shape就是(nb_samples, n, dim)
    126. inputshape: 任意。当把这层作为某个模型的第一层时,需要用到该参数(元组,不包含样本轴)。
    127. outputshape:(nb_samples,nb_input_units)
    128. '''
    129. model.add(RepeatVector(single_digit + 1))
    130. #表示有多少个隐含层
    131. for _ in range(layers):
    132. model.add(LSTM(hidden_size, return_sequences=True))
    133. '''''
    134. TimeDistributedDense:
    135. 官方文档:http://keras.io/layers/core/#timedistributeddense
    136. keras.layers.core.TimeDistributedDense(output_dim,init='glorot_uniform', activation='linear', weights=None
    137. W_regularizer=None, b_regularizer=None, activity_regularizer=None, W_constraint=None, b_constraint=None,
    138. input_dim=None, input_length=None)
    139. 这是一个基于时间维度的全连接层。主要就是用来构建RNN(递归神经网络)的,但是在构建RNN时需要设置return_sequences=True。
    140. for example:
    141. # input shape: (nb_samples, timesteps,10)
    142. model.add(LSTM(5, return_sequences=True, input_dim=10)) # output shape: (nb_samples, timesteps, 5)
    143. model.add(TimeDistributedDense(15)) # output shape:(nb_samples, timesteps, 15)
    144. W_constraint:
    145. from keras.constraints import maxnorm
    146. model.add(Dense(64, W_constraint =maxnorm(2))) #限制权值的各个参数不能大于2
    147. '''
    148. model.add(TimeDistributedDense(len(chars)))
    149. model.add(Activation('softmax'))
    150. '''''
    151. 关于目标函数和优化函数,参考另外一片博文: http://blog.csdn.net/zjm750617105/article/details/51321915
    152. '''
    153. model.compile(loss='categorical_crossentropy',
    154. optimizer='adam',
    155. metrics=['accuracy'])
    156. # Train the model each generation and show predictions against the validation dataset
    157. for iteration in range(1, 3):
    158. print()
    159. print('-' * 50)
    160. print('Iteration', iteration)
    161. model.fit(inputs, labels, batch_size=batch_size, nb_epoch=2,
    162. validation_split = 0.1)
    163. # Select 10 samples from the validation set at random so we can visualize errors
    164. model.get_config()

详细解读简单的lstm的实例的更多相关文章

  1. Paxos协议超级详细解释+简单实例

    转载自:  https://blog.csdn.net/cnh294141800/article/details/53768464 Paxos协议超级详细解释+简单实例   Basic-Paxos算法 ...

  2. MemCache超详细解读

    MemCache是什么 MemCache是一个自由.源码开放.高性能.分布式的分布式内存对象缓存系统,用于动态Web应用以减轻数据库的负载.它通过在内存中缓存数据和对象来减少读取数据库的次数,从而提高 ...

  3. MemCache超详细解读 图

    http://www.cnblogs.com/xrq730/p/4948707.html   MemCache是什么 MemCache是一个自由.源码开放.高性能.分布式的分布式内存对象缓存系统,用于 ...

  4. MemCache详细解读

    MemCache是什么 MemCache是一个自由.源码开放.高性能.分布式的分布式内存对象缓存系统,用于动态Web应用以减轻数据库的负载.它通过在内存中缓存数据和对象来减少读取数据库的次数,从而提高 ...

  5. 【Python】【Web.py】详细解读Python的web.py框架下的application.py模块

    详细解读Python的web.py框架下的application.py模块   这篇文章主要介绍了Python的web.py框架下的application.py模块,作者深入分析了web.py的源码, ...

  6. SpringMVC 原理 - 设计原理、启动过程、请求处理详细解读

    SpringMVC 原理 - 设计原理.启动过程.请求处理详细解读 目录 一. 设计原理 二. 启动过程 三. 请求处理 一. 设计原理 Servlet 规范 SpringMVC 是基于 Servle ...

  7. NLP突破性成果 BERT 模型详细解读 bert参数微调

    https://zhuanlan.zhihu.com/p/46997268 NLP突破性成果 BERT 模型详细解读 章鱼小丸子 不懂算法的产品经理不是好的程序员 ​关注她 82 人赞了该文章 Goo ...

  8. springmvc 项目完整示例01 需求与数据库表设计 简单的springmvc应用实例 web项目

    一个简单的用户登录系统 用户有账号密码,登录ip,登录时间 打开登录页面,输入用户名密码 登录日志,可以记录登陆的时间,登陆的ip 成功登陆了的话,就更新用户的最后登入时间和ip,同时记录一条登录记录 ...

  9. Android BLE蓝牙详细解读

    代码地址如下:http://www.demodashi.com/demo/15062.html 随着物联网时代的到来,越来越多的智能硬件设备开始流行起来,比如智能手环.心率检测仪.以及各式各样的智能家 ...

随机推荐

  1. Jmeter自定义编写Java代码调用socket通信

    一.前言 最近需要测试一款手机游戏的性能,找不到啥录制脚本的工具,然后,另外想办法.性能测试实际上就是对服务器的承载能力的测试,和各种类型的手机客户端没有啥多大关系,手机再好,服务器负载不了,也不能够 ...

  2. python及扩展程序安装

    安装 从官方网站下载python程序,我下载的是python-3.3.2.msi 然后下载python扩展程序,我下载的是pywin32-218.win32-py3.3.exe 最后下载wmi插件,我 ...

  3. Memory access Tracing/Profiling

    https://mahmoudhatem.wordpress.com/2017/03/22/workaround-for-linux-perf-probes-issue-for-oracle-trac ...

  4. Snmp学习总结系列——开篇

    进入公司以来,一直参与到公司的产品研发工作当中去,在产品研发中有一个监控远程服务器CPU使用率,内存使用情况,硬盘的需求,技术总监提出了使用Snmp协议作为远程监控的技术解决方案,头一次听说Snmp这 ...

  5. 在使用SQLServer时忘记sa账号密码解决办法

    先以windows 身份验证方式登录SQLServer数据库,如下图所示: 打开查询分析器,运行如下代码: sp_password Null,'新密码','sa' 即可把原来的密码修改成新密码 例如: ...

  6. 关于Maven项目build时出现No compiler is provided in this environment的处理(转)

    本文转自https://blog.csdn.net/lslk9898/article/details/73836745 近日有同事遇到在编译Maven项目时出现[ERROR] No compiler ...

  7. Android项目更换开发环境时出现的 java.lang.VerifyError 异常解决办法

    from://http://blog.csdn.net/wudiwo/article/details/7548451 项目是从同事的电脑上直接拷贝过来的,项目里面的jar包是在项目跟下libs里面存放 ...

  8. skb的两个函数pskb_copy和skb_copy

    转自:http://blog.csdn.net/farmwang/article/details/54235252 skb的两个函数pskb_copy和skb_copy 前者仅仅是将sk_buff的结 ...

  9. Unity5中新的Shader体系简析

    一.Unity5中新的Shader体系简析 Unity5和之前的书写模式有了一定的改变.Unity5时代的Shader Reference官方文档也进一步地变得丰满. 主要需要了解到的是,在原来的Un ...

  10. 多个类定义attr属性重复的问题:Attribute "xxx" has already been defined

    有时候做自定义控件时就会遇到命名冲突,改变有冲突的名字自然是最直接有效的方式,但是感觉很傻.我搜了下别人的解决方案,觉得很值得借鉴.就是把重名的属性,独立出来写一下,然后在定义时直接写属性名字即可. ...