lstm的前向结构,不迭代

最基本的lstm结构。不涉及损失值和bp过程

import tensorflow as tf
import numpy as np inputs = tf.placeholder(np.float32, shape=(32,40,5)) # 32 是 batch_size
lstm_cell_1 = tf.nn.rnn_cell.LSTMCell(num_units=128) #实例话一个lstm单元,输出是128单元 print("output_size:",lstm_cell_1.output_size)
print("state_size:",lstm_cell_1.state_size)
print(lstm_cell_1.state_size.h)
print(lstm_cell_1.state_size.c) output,state=tf.nn.dynamic_rnn(
cell=lstm_cell_1,
inputs=inputs,
dtype=tf.float32
)
# 根据inputs输入的维度迭代rnn,并将输出和隐层态,push进output和state里面。
(inputs是三个维度,第一维,是batch_size,第二维:数据切片为面,第三维:切片面的具体数据) print("第一个输入的最后一个序列的预测输出:",output[1,-1,:])
print("output.shape:",output.shape)
print("len of state tuple",len(state))
print("state.h.shape:",state.h.shape)
print("state.c.shape:",state.c.shape) #>>>
output_size: 128
state_size: LSTMStateTuple(c=128, h=128)
128
128
第一个输入的最后一个序列的预测输出: Tensor("strided_slice:0", shape=(128,), dtype=float32)
output.shape: (32, 40, 128)
len of state tuple 2
state.h.shape: (32, 128)
state.c.shape: (32, 128)

用lstm对mnist数据分类

#引包和加载mnist数据

import tensorflow as tf
import input_data
import numpy as np
import matplotlib.pyplot as plt mnist = input_data.read_data_sets("data/", one_hot=True)
trainimgs, trainlabels, testimgs, testlabels \
= mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
ntrain, ntest, dim, nclasses \
= trainimgs.shape[0], testimgs.shape[0], trainimgs.shape[1], trainlabels.shape[1]
print ("MNIST loaded")
diminput=28
dimhidden=128
dimoutput=nclasses
nsteps=28
weights={
'hidden':tf.Variable(tf.random_normal([diminput,dimhidden])),
'out':tf.Variable(tf.random_normal([dimhidden,dimoutput]))
}
biases={
'hidden':tf.Variable(tf.random_normal([dimhidden])),
'out':tf.Variable(tf.random_normal([dimoutput]))
}
def RNN(X,W,B,nsteps,name):
print(X.shape,'---')
X=tf.reshape(X,[-1,diminput])
X = tf.matmul(X, W['hidden']) + B['hidden']
X=tf.reshape(X,[-1,diminput,dimhidden])
print(X.shape)
with tf.variable_scope(name) as scope:
#scope.reuse_variables()
lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(dimhidden,forget_bias=1.0)
lstm_o,lstm_s=tf.nn.dynamic_rnn(cell=lstm_cell,inputs=X,dtype=tf.float32)
resultOut=tf.matmul(lstm_o[:,-1,:],W['out'])+B['out']
return {
'X':X,
'lstm_o':lstm_o,'lstm_s':lstm_s,'resultOut':resultOut
}
learning_rate=0.001
x=tf.placeholder('float',[None,nsteps,diminput]) y=tf.placeholder('float',[None,dimoutput]) myrnn=RNN(x,weights,biases,nsteps,'basic')
pred=myrnn['resultOut']
cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optm=tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
accr=tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred,1),tf.argmax(y,1)),tf.float32))
init=tf.global_variables_initializer()
training_epochs=33
batch_size=16
display_step=1
sess=tf.Session()
sess.run(init) for epoch in range(training_epochs):
avg_cost=100
total_batch=100
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
batch_xs = batch_xs.reshape((batch_size, nsteps, diminput))
feeds = {x: batch_xs, y: batch_ys}
sess.run(optm, feed_dict=feeds)
# Compute average loss
avg_cost += sess.run(cost, feed_dict=feeds)/total_batch
if epoch % display_step == 0:
print ("Epoch: %03d/%03d cost: %.9f" % (epoch, training_epochs, avg_cost))
feeds = {x: batch_xs, y: batch_ys}
train_acc = sess.run(accr, feed_dict=feeds)
print (" Training accuracy: %.3f" % (train_acc))
testimgs = testimgs.reshape((ntest, nsteps, diminput))
feeds = {x: testimgs, y: testlabels}
test_acc = sess.run(accr, feed_dict=feeds)
print (" Test accuracy: %.3f" % (test_acc))
Epoch: 000/033 cost: 101.797383542
Training accuracy: 0.688
Test accuracy: 0.461
Epoch: 001/033 cost: 101.269138204
Training accuracy: 0.438
Test accuracy: 0.549
Epoch: 002/033 cost: 101.139203327
Training accuracy: 0.688
Test accuracy: 0.614
Epoch: 003/033 cost: 100.965362185
Training accuracy: 0.938
Test accuracy: 0.619
Epoch: 004/033 cost: 100.914383653
Training accuracy: 0.875
Test accuracy: 0.648
Epoch: 005/033 cost: 100.813317066
Training accuracy: 0.625
Test accuracy: 0.656
Epoch: 006/033 cost: 100.781623098
Training accuracy: 0.875
Test accuracy: 0.708
Epoch: 007/033 cost: 100.710710035
Training accuracy: 1.000
Test accuracy: 0.716
Epoch: 008/033 cost: 100.684573339
Training accuracy: 1.000
Test accuracy: 0.745
Epoch: 009/033 cost: 100.635698693
Training accuracy: 0.875
Test accuracy: 0.751
Epoch: 010/033 cost: 100.622099145
Training accuracy: 0.938
Test accuracy: 0.763
Epoch: 011/033 cost: 100.562925613
Training accuracy: 0.750
Test accuracy: 0.763
Epoch: 012/033 cost: 100.592214927
Training accuracy: 0.812
Test accuracy: 0.771
Epoch: 013/033 cost: 100.544024273
Training accuracy: 0.938
Test accuracy: 0.769
Epoch: 014/033 cost: 100.516522627
Training accuracy: 0.875
Test accuracy: 0.791
Epoch: 015/033 cost: 100.479632292
Training accuracy: 0.938
Test accuracy: 0.801
Epoch: 016/033 cost: 100.471150137
Training accuracy: 0.938
Test accuracy: 0.816
Epoch: 017/033 cost: 100.431061392
Training accuracy: 0.875
Test accuracy: 0.807
Epoch: 018/033 cost: 100.464853102
Training accuracy: 0.812
Test accuracy: 0.798
Epoch: 019/033 cost: 100.445183915
Training accuracy: 0.750
Test accuracy: 0.828
Epoch: 020/033 cost: 100.399013084
Training accuracy: 1.000
Test accuracy: 0.804
Epoch: 021/033 cost: 100.393008129
Training accuracy: 0.938
Test accuracy: 0.833
Epoch: 022/033 cost: 100.413909222
Training accuracy: 0.812
Test accuracy: 0.815

RNN(二)——基于tensorflow的LSTM的实现的更多相关文章

  1. 如何基于TensorFlow使用LSTM和CNN实现时序分类任务

    https://www.jiqizhixin.com/articles/2017-09-12-5 By 蒋思源2017年9月12日 09:54 时序数据经常出现在很多领域中,如金融.信号处理.语音识别 ...

  2. 学习Tensorflow的LSTM的RNN例子

    学习Tensorflow的LSTM的RNN例子 基于TensorFlow一次简单的RNN实现 极客学院-递归神经网络 如何使用TensorFlow构建.训练和改进循环神经网络

  3. TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人

    简介 TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人. 文章包括一下几个部分: 1.为什么要尝试做这个项目? 2.为 ...

  4. 个基于TensorFlow的简单故事生成案例:带你了解LSTM

    https://medium.com/towards-data-science/lstm-by-example-using-tensorflow-feb0c1968537 在深度学习中,循环神经网络( ...

  5. TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人。

    简介 TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人. 文章包括一下几个部分: 1.为什么要尝试做这个项目? 2.为 ...

  6. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  7. 两种开源聊天机器人的性能测试(二)——基于tensorflow的chatbot

    http://blog.csdn.net/hfutdog/article/details/78155676 开源项目链接:https://github.com/dennybritz/chatbot-r ...

  8. 基于TensorFlow Object Detection API进行迁移学习训练自己的人脸检测模型(二)

    前言 已完成数据预处理工作,具体参照: 基于TensorFlow Object Detection API进行迁移学习训练自己的人脸检测模型(一) 设置配置文件 新建目录face_faster_rcn ...

  9. 基于Tensorflow + Opencv 实现CNN自定义图像分类

    摘要:本篇文章主要通过Tensorflow+Opencv实现CNN自定义图像分类案例,它能解决我们现实论文或实践中的图像分类问题,并与机器学习的图像分类算法进行对比实验. 本文分享自华为云社区< ...

随机推荐

  1. Linux 编译kernel有关Kconfig文件详解

    ref : https://blog.csdn.net/Ultraman_hs/article/details/52984929 Kconfig的格式 下面截取/drivers/net下的Kconfi ...

  2. Destination高级特性

    一.组合队列 Composite Destinations 组合队列允许用一个虚拟的destination代表多个destinations.这样就可以通过composite destinations在 ...

  3. java 框架-缓冲-Redis 1概述

    https://www.jianshu.com/p/56999f2b8e3b Redis 概述 在我们日常的Java Web开发中,无不都是使用数据库来进行数据的存储,由于一般的系统任务中通常不会存在 ...

  4. SpringMVC的理论

    围绕Handler开发 数据Model 页面View SpringMVC的运行流程: 1.用户发送一个请求,所有的请求都会映射到DispatcherServlet(中央控制器的servlet,该ser ...

  5. 使用了frame的页面如何整体进行跳转,而不是仅frame跳转

    使用了frame的页面如何整体进行跳转,而不是仅frame跳转 js window.parent.location.href="你的地址"; php echo "&quo ...

  6. 关闭mysql严格模式

    配置文件my.ini sql-mode="STRICT_TRANS_TABLES,NO_AUTO_CREATE_USER,NO_ENGINE_SUBSTITUTION" 修改为 s ...

  7. IPC——管道

    概述 管道通信分为无名管道.有名管道 管道通信的本质 不管是有名管道,还是无名管道,它们的本质其实都是一样的,它们都是内核所开辟的一段缓存空间.进程间通过管道通信时,本质上就是通过共享操作这段缓存来实 ...

  8. doesn't declare an explicit app_label and isn't in an application in INSTALLED_APPS.

    在settings.py中增加 INSTALLED_APPS = [ ... 'django.contrib.sites', ] 问题就解决了.什么原因.——不知道.. 具体请看: https://s ...

  9. VUE 单选下拉框Select中动态加载 默认选中第一个

    <lable>分类情况</lable> <select v-model="content.tid"> <option v-for=&quo ...

  10. ping加上时间信息

    一.linux系统ping加时间戳信息 1.ping 加时间信息,然后还要实时保存到一个文件中,那么就与awk结合 ping 115.239.211.112 -c 10 | awk '{ print ...