LSTM_Model
#!/usr/bin/python
# -*- coding: utf-8 -*-
import tensorflow as tf
class TRNNConfig(object):
"""RNN配置参数"""
# 模型参数
embedding_dim = 64 # 词向量维度
seq_length = 600 # 序列长度
num_classes = 10 # 类别数
vocab_size = 5000 # 词汇表达小
num_layers= 2 # 隐藏层层数
hidden_dim = 128 # 隐藏层神经元
rnn = 'gru' # lstm 或 gru
dropout_keep_prob = 0.8 # dropout保留比例
learning_rate = 1e-3 # 学习率
batch_size = 128 # 每批训练大小
num_epochs = 10 # 总迭代轮次
print_per_batch = 100 # 每多少轮输出一次结果
save_per_batch = 10 # 每多少轮存入tensorboard
class TextRNN(object):
"""文本分类,RNN模型"""
def __init__(self, config):
self.config = config
# 三个待输入的数据
self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x')
self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y')
self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
self.rnn()
def rnn(self):
"""rnn模型"""
def lstm_cell(): # lstm核
return tf.contrib.rnn.BasicLSTMCell(self.config.hidden_dim, state_is_tuple=True)
def gru_cell(): # gru核
return tf.contrib.rnn.GRUCell(self.config.hidden_dim)
def dropout(): # 为每一个rnn核后面加一个dropout层
if (self.config.rnn == 'lstm'):
cell = lstm_cell()
else:
cell = gru_cell()
return tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=self.keep_prob)
# 动作映射
with tf.device('/cpu:0'):
embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])
embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x)
with tf.name_scope("rnn"):
# 多层rnn网络
cells = [dropout() for _ in range(self.config.num_layers)]
rnn_cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True)
_outputs, _ = tf.nn.dynamic_rnn(cell=rnn_cell, inputs=embedding_inputs, dtype=tf.float32)
last = _outputs[:, -1, :] # 取最后一个时序输出作为结果
with tf.name_scope("score"):
# 全连接层,后面接dropout以及relu激活
fc = tf.layers.dense(last, self.config.hidden_dim, name='fc1')
fc = tf.contrib.layers.dropout(fc, self.keep_prob)
fc = tf.nn.relu(fc)
# 分类器
self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2')
# 预测类别
self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1)
with tf.name_scope("optimize"):
# 损失函数,交叉熵
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)
#求输入的所有行的预测值的均值
self.loss = tf.reduce_mean(cross_entropy)
# 优化器
self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)
with tf.name_scope("accuracy"):
# 准确率 其中 self.y_pred_cls为预测的类别
correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls)
#
self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
LSTM_Model的更多相关文章
- 基于双向BiLstm神经网络的中文分词详解及源码
基于双向BiLstm神经网络的中文分词详解及源码 基于双向BiLstm神经网络的中文分词详解及源码 1 标注序列 2 训练网络 3 Viterbi算法求解最优路径 4 keras代码讲解 最后 源代码 ...
- (转) Using the latest advancements in AI to predict stock market movements
Using the latest advancements in AI to predict stock market movements 2019-01-13 21:31:18 This blog ...
- NLP入门(五)用深度学习实现命名实体识别(NER)
前言 在文章:NLP入门(四)命名实体识别(NER)中,笔者介绍了两个实现命名实体识别的工具--NLTK和Stanford NLP.在本文中,我们将会学习到如何使用深度学习工具来自己一步步地实现N ...
- Reading | 《TensorFlow:实战Google深度学习框架》
目录 三.TensorFlow入门 1. TensorFlow计算模型--计算图 I. 计算图的概念 II. 计算图的使用 2.TensorFlow数据类型--张量 I. 张量的概念 II. 张量的使 ...
- Tensorflow[LSTM]
0.背景 通过对<tensorflow machine learning cookbook>第9章第3节"implementing_lstm"进行阅读,发现如下形式可以 ...
- 使用TensorFlow的递归神经网络(LSTM)进行序列预测
本篇文章介绍使用TensorFlow的递归神经网络(LSTM)进行序列预测.作者在网上找到的使用LSTM模型的案例都是解决自然语言处理的问题,而没有一个是来预测连续值的. 所以呢,这里是基于历史观察数 ...
- Tensorflow LSTM实现
Tensorflow[LSTM] 0.背景 通过对<tensorflow machine learning cookbook>第9章第3节"implementing_lstm ...
- 在TensorFlow中基于lstm构建分词系统笔记
在TensorFlow中基于lstm构建分词系统笔记(一) https://www.jianshu.com/p/ccb805b9f014 前言 我打算基于lstm构建一个分词系统,通过这个例子来学习下 ...
- 『TensotFlow』RNN/LSTM古诗生成
往期RNN相关工程实践文章 『TensotFlow』基础RNN网络分类问题 『TensotFlow』RNN中文文本_上 『TensotFlow』基础RNN网络回归问题 『TensotFlow』RNN中 ...
随机推荐
- 怎么制作电脑系统安装U盘?
现如今U盘安装电脑系统已经是非常普遍的一种方式,这种方式简单好用,能应对大多数情况,受到很多用户的欢迎. 雨后清风U盘启动是一款可将普通U盘制作为系统引导启动工具的软件,其制作的U盘启动盘融合了雨后清 ...
- Linux行编辑器——ed
实验文件test.txt内容 root:x:::root:/root:/bin/bash bin:x:::bin:/bin:/sbin/nologin daemon:x:::daemon:/sbin: ...
- linux基础--命令使用
rpm命令 rpm -qa 包 查看包是否安装 rpm qa 列出系统安装的所有包 rpm -ql 包 查看软件包安装的位置及配置的目录 rpm -ivh 包 安装rpm包或强制安装包 rpm -Uv ...
- #Python语言程序设计Demo - 七段数码管绘制
Python设计七段数码管绘制 单个数码管效果: 设计总数码管效果: Pyhton 编程: #七段数码管绘制 import turtle as t import time as T def drawG ...
- java应用本地缓存
在java应用中,对于访问频率比较高,又不怎么变化的数据,常用的解决方案是把这些数据加入缓存.相比DB,缓存的读取效率快好不少.java应用缓存一般分两种,一是进程内缓存,就是使用ja ...
- linux实操_进程管理
1.显示系统执行的进程 说明:查看进程使用的的指令时 ps ,一般来说使用的参数时ps -aux ps -a:显示当前终端的所有进程信息 ps -u:以用户的格式显示进程星系 ps -x:显示后台进程 ...
- 场效应管种类-场效应管N、P沟道与增强、耗尽型工作原理等知识详解 如何选用晶体三极管与场效应管的技巧
http://www.kiaic.com/article/detail/1308.html 场效应管种类场效应管 场效应晶体管(Field Effect Transistor缩写(FET))简称场效应 ...
- PageHelper使用中出现的问题
PageHelper在分页查询的时候功能强大,内部使用拦截器实现.这边文章做了详细的介绍. https://www.cnblogs.com/ljdblog/p/6725094.html https:/ ...
- redis在应用中使用连接不释放问题解决
今天测试,发现redis使用的时候,调用的链接一直不释放.后查阅蛮多资料,才发现一个配置导致的.并不是他们说的服务没有启动导致的. 1)配置文件 #redis连接配置================= ...
- 014_linux驱动之_信号符号名、描述和它们的信号值
符号名 信号值 描述 是否符合POSIX SIGHUP 1 在控制终端上检测到挂断或控制线程死 亡 是 SIGINT 2 交互注意信号 是 SIGQUIT 3 交 互中止信号 是 SIGILL 4 检 ...