#!/usr/bin/env python
# coding=utf-8 from keras.models import Sequential
from keras.layers import Activation, TimeDistributed, Dense, RepeatVector, recurrent
import numpy as np
import string
import random class CharacterTable(object): def __init__(self, maxlen):
self.chars = string.digits + '+ '
self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
self.indice_chars = dict((i, c) for i, c in enumerate(self.chars))
self.maxlen = maxlen def encode(self, strs, maxlen=None):
maxlen = maxlen if maxlen else self.maxlen
vec = np.zeros((maxlen, len(self.chars)))
for i, c in enumerate(strs):
vec[i, self.char_indices[c]] = 1
return vec def decode(self, vec, calc_argmax=True):
if calc_argmax:
vec = vec.argmax(axis=-1)
return ''.join(self.indice_chars[x] for x in vec) def gen_num():
nums = random.sample('', random.randint(1, 3))
return int(''.join(nums)) MAXLEN = 7 # 3+3+1
ctable = CharacterTable(MAXLEN) questions, expected = [], []
seen = set()
i = 0
while i < 50000:
a, b = gen_num(), gen_num()
key = tuple(sorted((a, b)))
if key in seen:
continue
seen.add(key)
q = '{}+{}'.format(a, b)
query = q + ' '*(7-len(q))
ans = str(a+b)
ans += ' ' * (4-len(ans)) questions.append(query)
expected.append(ans)
i += 1
print('total questions', len(questions)) X = np.zeros((len(questions), MAXLEN, len(ctable.chars)), dtype=np.bool)
y = np.zeros((len(questions), 4, len(ctable.chars)), dtype=np.bool) for i, sent in enumerate(questions):
X[i] = ctable.encode(sent) for i, sent in enumerate(expected):
y[i] = ctable.encode(sent, 4) model = Sequential()
model.add(recurrent.LSTM(128, input_shape=(7, len(ctable.chars))))
model.add(RepeatVector(4))
model.add(recurrent.LSTM(128, return_sequences=True))
model.add(recurrent.LSTM(128, return_sequences=True)) model.add(TimeDistributed(Dense(len(ctable.chars))))
model.add(Activation('softmax')) model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy']) model.fit(X, y, batch_size=64, nb_epoch=20, validation_split=0.02, verbose=2) # 测试看看
for i in range(10):
ind = np.random.randint(0, len(questions)-5)
x_test, y_test = X[ind:ind+5], y[ind:ind+5]
y_preds = model.predict_classes(x_test, verbose=0)
print('Q', ctable.decode(x_test[0]))
print('T', ctable.decode(y_test[0]))
print('Pred', ctable.decode(y_preds[0], calc_argmax=False)) json_string = model.to_json()
with open('rnn_add_model.json', 'wb') as fw:
fw.write(json_string)
model.save_weights('rnn_add_model.h5')

基本是模仿官网例子,精简了一点,训练约1h, 准确率99.6%

rnn实现三位数加法的训练的更多相关文章

  1. GDUFE-OJ 1203x的y次方的最后三位数 快速幂

    嘿嘿今天学了快速幂也~~ Problem Description: 求x的y次方的最后三位数 . Input: 一个两位数x和一个两位数y. Output: 输出x的y次方的后三位数. Sample ...

  2. 【python】题目:有1、2、3、4个数字,能组成多少个互不相同且无重复数字的三位数?都是多少?

    # encoding:utf-8 # p001_1234threeNums.py def threeNums(): '''题目:有1.2.3.4个数字,能组成多少个互不相同且无重复数字的三位数?都是多 ...

  3. 程序设计入门——C语言 第1周编程练习 1逆序的三位数(5分)

    第1周编程练习 查看帮助 返回   第1周编程练习题,直到课程结束之前随时可以来做.在自己的IDE或编辑器中完成作业后,将源代码的全部内容拷贝.粘贴到题目的代码区,就可以提交,然后可以查看在线编译和运 ...

  4. 题目:打印出所有的 "水仙花数 ",所谓 "水仙花数 "是指一个三位数,其各位数字立方和等于该数本身。例如:153是一个 "水仙花 数 ",因为153=1的三次方+5的三次方+3的三次方。

    题目:打印出所有的 "水仙花数 ",所谓 "水仙花数 "是指一个三位数,其各位数字立方和等于该数本身.例如:153是一个 "水仙花 数 ", ...

  5. C++判断对称三位数素数

    题目内容:判断一个数是否为对称三位数素数.所谓“对称”是指一个数,倒过来还是该数.例如,375不是对称数,因为倒过来变成了573. 输入描述:输入数据含有不多于50个的正整数(0<n<23 ...

  6. HDU_2035——求A^B的最后三位数

    Problem Description 求A^B的最后三位数表示的整数.说明:A^B的含义是“A的B次方”   Input 输入数据包含多个测试实例,每个实例占一行,由两个正整数A和B组成(1< ...

  7. 网易云课堂_程序设计入门-C语言_第一周:简单的计算程序_1逆序的三位数

    1 逆序的三位数(5分) 题目内容: 程序每次读入一个正三位数,然后输出逆序的数字.注意,当输入的数字含有结尾的0时,输出不应带有前导的0.比如输入700,输出应该是7. 输入格式: 每个测试是一个3 ...

  8. js求三位数的和

    例如输入508就输出5+0+8的和13: <!DOCTYPE html> <html lang="en"> <head> <meta ch ...

  9. Java求555 555的约数中最大的三位数。

    package org.llh.test; /** * 求555 555的约数中最大的三位数 * @author llh * */ public class Car { //整数j除以整数i(i≠0) ...

随机推荐

  1. 【NOI2014】【BZOJ3669】【UOJ#3】魔法森林

    我学会lct辣 原题: 为了得到书法大家的真传,小E同学下定决心去拜访住在魔法森林中的隐士.魔法森林可以被看成一个包含个N节点M条边的无向图,节点标号为 1…n1…n,边标号为1…m1…m.初始时小E ...

  2. 【liunx】Linux下的压缩和解压缩命令——jar

    原文链接:http://blog.chinaunix.net/uid-692788-id-2681136.html JAR包是Java中所特有一种压缩文档,其实大家就可以把它理解为.zip包.当然也是 ...

  3. nginx 优化(突破十万并发)

    一般来说nginx配置文件中对优化比较有作用的为以下几项: worker_processes 8; nginx进程数,建议按照cpu数目来指定,一般为它的倍数. worker_cpu_affinity ...

  4. 剑指offer-在数组中查找两个数,是的他们的和正好是S(一次性跑通)(时间复杂度还可以降低)

    /*对于一个递增的序列,存在2个数字的和相等,要想这2个数字的乘积最小,则这2个数字的距离最远*/ /*思想:j指向最后一个元素,然后i从前扫描看sum-a[j]在这个序列中吗?若不在j--*/ im ...

  5. 写一个小程序实现win系统定时锁屏

    貌似很久没写程序了,随便用C语言实现吧 #include<stdio.h> #include<stdlib.h> int main(){ system("rundll ...

  6. MySQL--批量插入导致自增跳号问题

    对于批量插入数据的操作,MySQL申请自增的策略为: 在批量插入语句执行过程中,申请策略: .第一次申请自增值时,会分配1个 .在N次申请自增值时,会分配上一次(第N-1次)的2倍. 测试Demo: ...

  7. tile38 一款开源的geo 数据库

    tile38 是基于golang 编写的geo 数据库,支持地理空间索引.实时地理围栏,同时也支持leader-flower 的部署模型 备注: 下边测试一个简单的地理围栏功能 环境准备 docker ...

  8. DevExpress的DateEdit控件正确显示日期的周名称

    DevExpress 的控件相当好看而且很好用,但 DateEdit 在是显示周名时,只能显示一个“星”字. 以下是解决方法,此解决方法不需修改其源码,所以免去了重新编译的必要,可直接使用其发布的标准 ...

  9. js里面的全局属性 全局对象 全局函数

    1)全局属性 Infinity   typeof Infinity        //number NaN typeof NaN           //number undefined       ...

  10. 持续集成--Jenkins--1

    持续集成之Jenkins安装部署   1.安装JDK Jenkins是Java编写的,所以需要先安装JDK,这里采用yum安装,如果对版本有需求,可以直接在Oracle官网下载JDK. [root@l ...