使用tensorflow实现了简单的rnn网络用来学习加法运算。

tensorflow 版本:1.1

import tensorflow as tf
from tensorflow.contrib import rnn class RNN():
def __init__(self, input_dim , hidden_dim , step_num , class_num,learning_rate):
# # tf Graph input
self.x = tf.placeholder("float", [None, step_num, input_dim])
self.y = tf.placeholder("float", [None, class_num])
# Define weights
weights = {
'out': tf.Variable(tf.random_normal([hidden_dim, hidden_dim])),
'sigout':tf.Variable(tf.random_normal([hidden_dim , class_num]))
}
biases = {
'out': tf.Variable(tf.random_normal([hidden_dim])),
'sigout':tf.Variable(tf.random_normal([class_num]))
}
# Unstack to get a list of 'step_num' tensors of shape (batch_size, input_dim)
x_unstack = tf.unstack(self.x, step_num, 1)
# Define a lstm cell with tensorflow
lstm_cell = rnn.BasicLSTMCell(hidden_dim, forget_bias=1.0) # Get lstm cell output
outputs, states = rnn.static_rnn(lstm_cell, x_unstack, dtype=tf.float32) # Linear activation, using rnn inner loop last output
pred = tf.matmul(outputs[-1], weights['out']) + biases['out']
sigmodout = tf.matmul(tf.nn.sigmoid(pred , 'relu') , weights['sigout']) + biases['sigout']
self.predict = tf.round(sigmodout)
correct_pred = tf.equal(self.predict , self.y)
self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
cost = tf.reduce_sum(tf.abs(tf.subtract(sigmodout , self.y)))
self.loss = cost
tf.summary.scalar('acc' , self.accuracy)
tf.summary.scalar('cost' , cost)
self.merge_all = tf.summary.merge_all()
import random
import numpy as np
class DataProcess():
def __init__(self):
self.point = 0
self.max_integer = 10000
self.max_len = len(self.filter(bin(self.max_integer))) + 1
def nextBatch(self,batch_size , fake = False ,is_test = False):
x_batch = []
y_batch = []
t_batch = []
for i in range( 0 , batch_size):
a = random.randint(0,self.max_integer)
b = random.randint(0,self.max_integer)
if fake:
a = 1
b = 1
c = a + b
abin = self.process(a)
bbin = self.process(b)
cbin = self.process(c)
xa = np.array(abin)
xb = np.array(bbin)
x_ = np.concatenate((xa , xb),axis= 0)
y_ = np.array(cbin)
x_batch.append(x_.reshape(2,self.max_len))
y_batch.append(cbin)
if is_test:
temp = []
temp.append(a)
temp.append(b)
t_batch.append(temp)
# if not is_test:
return np.array(x_batch) , np.array(y_batch),np.array(t_batch)
# else:
# return np.array(x_batch) , np.array(y_batch) def process(self , str):
bstr = bin(str)
bstr = self.filter(bstr)
bstr = self.completion(bstr)
return bstr[::-1] def filter(self , bstr):
return bstr.replace('0b' , '') def completion(self , bstr):
lst = []
for num in list(bstr):
lst.append(int(num))
length = len(bstr)
for i in range(length , self.max_len):
lst.insert(0,0) return[ (float)(i) for i in lst]
import tensorflow as tf
import time
import os
from src.bitplus.Process import DataProcess
from src.bitplus.rnn import RNN dp = DataProcess()
tf.app.flags.DEFINE_float("learning_rate",0.01,'learning rate')
tf.app.flags.DEFINE_integer("training_iters",500000,'')
tf.app.flags.DEFINE_integer("batch_size",100,'')
tf.app.flags.DEFINE_integer("display_step",10,'')
tf.app.flags.DEFINE_integer("input_dim" , 2,'')
tf.app.flags.DEFINE_integer("steps_num" , dp.max_len,'')
tf.app.flags.DEFINE_integer("hidden_dim", dp.max_len * 2 ,'')
tf.app.flags.DEFINE_integer("class_num" , dp.max_len,'')
tf.app.flags.DEFINE_string("rnn_model",'rnn.model','')
FLAGS = tf.app.flags.FLAGS logfolder = time.strftime("%Y%m%d_%H%M%S", time.localtime()) def train():
rnn = RNN(FLAGS.input_dim , FLAGS.hidden_dim ,FLAGS.steps_num , FLAGS.class_num , FLAGS.learning_rate)
optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate).minimize(rnn.loss)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
train_writer = tf.summary.FileWriter('../LOGS/'+logfolder,sess.graph)
step = 1
while step * FLAGS.batch_size < FLAGS.training_iters:
batch_x, batch_y ,_= dp.nextBatch(FLAGS.batch_size , False)
batch_x = batch_x.transpose((0, 2, 1))
# Run optimization op (backprop)
sess.run(optimizer, feed_dict={rnn.x: batch_x, rnn.y: batch_y})
if step % FLAGS.display_step == 0:
# Calculate batch loss
merge_all , loss , acc,predict= sess.run([rnn.merge_all ,rnn.loss ,rnn.accuracy,rnn.predict],feed_dict={rnn.x: batch_x, rnn.y: batch_y})
print("Iter " + str(step*FLAGS.batch_size) + ", Minibatch Loss= " + \
"{:.6f}".format(loss) + ", Training Accuracy= " + \
"{:.5f}".format(acc))
train_writer.add_summary(merge_all , step)
step += 1 print("Optimization Finished!")
train_writer.close()
saver = tf.train.Saver(tf.all_variables())
if not os.path.exists('model'):
os.mkdir('model')
saver.save(sess, './model/' + FLAGS.rnn_model)
def tModel(model):
rnn = RNN(FLAGS.input_dim , FLAGS.hidden_dim ,FLAGS.steps_num , FLAGS.class_num , FLAGS.learning_rate)
saver = tf.train.Saver()
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
saver.restore(sess,model)
batch_x , batch_y , batch_test = dp.nextBatch(100 , False,True)
batch_x = batch_x.transpose((0, 2, 1))
predict = sess.run([rnn.predict],feed_dict={rnn.x: batch_x, rnn.y: batch_y}) for ndx in range(len(predict[0])):
print('%s + %s = %d(%d)'%(batch_test[ndx][0] , batch_test[ndx][1] , batch_test[ndx][0] + batch_test[ndx][1],bin2Ten(predict[0][ndx])))
print('pause!')
def bin2Ten(bin):
lst = []
for i in bin[::-1]:
lst.append(str(int(i)))
return int(''.join(lst) , base=2)
def main(_):
print('start...')
print('max len:' , dp.max_len)
# train()
tModel('./model/' + FLAGS.rnn_model)
if __name__ == '__main__':
tf.app.run()

使用tensorflow 构建rnn网络的更多相关文章

  1. TensorFlow之RNN:堆叠RNN、LSTM、GRU及双向LSTM

    RNN(Recurrent Neural Networks,循环神经网络)是一种具有短期记忆能力的神经网络模型,可以处理任意长度的序列,在自然语言处理中的应用非常广泛,比如机器翻译.文本生成.问答系统 ...

  2. TensorFlow 实现 RNN 入门教程

    转子:https://www.leiphone.com/news/201705/zW49Eo8YfYu9K03J.html 最近在看RNN模型,为简单起见,本篇就以简单的二进制序列作为训练数据,而不实 ...

  3. 第二十二节,TensorFlow中RNN实现一些其它知识补充

    一 初始化RNN 上一节中介绍了 通过cell类构建RNN的函数,其中有一个参数initial_state,即cell初始状态参数,TensorFlow中封装了对其初始化的方法. 1.初始化为0 对于 ...

  4. 第二十节,使用RNN网络拟合回声信号序列

    这一节使用TensorFlow中的函数搭建一个简单的RNN网络,使用一串随机的模拟数据作为原始信号,让RNN网络来拟合其对应的回声信号. 样本数据为一串随机的由0,1组成的数字,将其当成发射出去的一串 ...

  5. 深度学习原理与框架-递归神经网络-RNN_exmaple(代码) 1.rnn.BasicLSTMCell(构造基本网络) 2.tf.nn.dynamic_rnn(执行rnn网络) 3.tf.expand_dim(增加输入数据的维度) 4.tf.tile(在某个维度上按照倍数进行平铺迭代) 5.tf.squeeze(去除维度上为1的维度)

    1. rnn.BasicLSTMCell(num_hidden) #  构造单层的lstm网络结构 参数说明:num_hidden表示隐藏层的个数 2.tf.nn.dynamic_rnn(cell, ...

  6. 深度学习原理与框架-递归神经网络-RNN网络基本框架(代码?) 1.rnn.LSTMCell(生成单层LSTM) 2.rnn.DropoutWrapper(对rnn进行dropout操作) 3.tf.contrib.rnn.MultiRNNCell(堆叠多层LSTM) 4.mlstm_cell.zero_state(state初始化) 5.mlstm_cell(进行LSTM求解)

    问题:LSTM的输出值output和state是否是一样的 1. rnn.LSTMCell(num_hidden, reuse=tf.get_variable_scope().reuse)  # 构建 ...

  7. 解读tensorflow之rnn

    from: http://lan2720.github.io/2016/07/16/%E8%A7%A3%E8%AF%BBtensorflow%E4%B9%8Brnn/ 这两天想搞清楚用tensorfl ...

  8. 解读tensorflow之rnn 的示例 ptb_word_lm.py

    这两天想搞清楚用tensorflow来实现rnn/lstm如何做,但是google了半天,发现tf在rnn方面的实现代码或者教程都太少了,仅有的几个教程讲的又过于简单.没办法,只能亲自动手一步步研究官 ...

  9. language model ——tensorflow 之RNN

    代码结构 tf的代码看多了之后就知道其实官方代码的这个结构并不好: graph的构建和训练部分放在了一个文件中,至少也应该分开成model.py和train.py两个文件,model.py中只有一个P ...

随机推荐

  1. Linux - 静默安装oracle数据库总结

    Web服务器上面的Linux一般是不会有图形界面的,所有通过图形界面来安装Linux的方式在没有图形界面的Linux上面是行不通的,我们要使用的安装方式叫做Linux的静默安装.即在没有图形界面的Li ...

  2. EasyUI 表单 tree

    第一步:创建HTML标记 <divid="dlg"style="padding:20px;">     <h2>Account Info ...

  3. 010杰信-创建购销合同Excel报表系列-3-新增合同货物:这里涉及到子表的新增(合同货物表是购销合同表的子表)

    效果说明: 前面分析过购销合同的Excel报表需要四张表,这篇讲的是合同货物表. 这个合同货物表是购销合同的子表,是一个购销合同有多个合同货物的关系.在合同货物表中有购销合同的主键作为外键.所以这张表 ...

  4. lmbench

    lmbench作为性能检测工具的一种,提供内存,网络,内核等多方面的测试工具.是benchmark众多功能测试软件中的一种.几天了解了下,记录于此. 参考链接 http://www.bitmover. ...

  5. 学习shader之前必须知道的东西之计算机图形学(一)渲染管线

    引言 shader到底是干什么用的?shader的工作原理是什么? 其实当我们对这个问题还很懵懂的时候,就已经开始急不可耐的要四处搜寻有关shader的资料,恨不得立刻上手写一个出来.但看了一些资料甚 ...

  6. SSH开发环境整合搭建

    1.建立动态web工程,加入必要的jar包. antlr-2.7.7.jar asm-3.3.jar asm-commons-3.3.jar asm-tree-3.3.jar c3p0-0.9.1.2 ...

  7. 基于JEECG的代码模板自动生成

    1.基于JEECG3.5.2,提供多种数据源的代码生成,目前支持Oracle良好: 2.可动态配置数据源: 可动态配置模板集合,基于freemarker的模板文件: 可选择需要生成的数据表: 可导入一 ...

  8. DBA面试题及解答

    一:SQL tuning 类 1:列举几种表连接方式答:merge join,hash join,nested loop Hash join散列连接是CBO 做大数据集连接时的常用方式,优化器使用两个 ...

  9. AWS系列-EC2默认限制说明

    Amazon EC2 提供您可以使用的不同资源,例如实例和卷. 在您创建 AWS 账户时,AWS 会针对每个区域中的这些资源设置限制.此页面列出您在 亚太区域 (东京) 中的 EC2 服务限制. 1. ...

  10. 自动化测试环境准备robotframework

    (一)针对python2.7版本的自动化环境准备: python 下载地址: https://www.python.org/downloads/ 这里选择Python2.7系列的,后面涉及到wxPyt ...