这里的num_units参数并不是指这一层油多少个相互独立的时序lstm,而是lstm单元内部的几个门的参数,这几个门其实内部是一个神经网络,答案来自知乎:

class TRNNConfig(object):
"""RNN配置参数""" # 模型参数
embedding_dim = 100 # 词向量维度
seq_length = 100 # 序列长度
num_classes = 2 # 类别数
vocab_size = 10000 # 词汇表达小 num_layers= 2 # 隐藏层层数
hidden_dim = 128 # 隐藏层神经元
rnn = 'lstm' # lstm 或 gru dropout_keep_prob = 0.8 # dropout保留比例
learning_rate = 1e-3 # 学习率 batch_size = 128 # 每批训练大小
num_epochs = 5 # 总迭代轮次 print_per_batch = 20 # 每多少轮输出一次结果
save_per_batch = 10 # 每多少轮存入tensorboard class TextRNN(object):
"""文本分类,RNN模型"""
def __init__(self, config):
self.config = config # 三个待输入的数据
self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x')
self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y')
self.keep_prob = tf.placeholder(tf.float32, name='keep_prob') self.rnn() def rnn(self):
"""rnn模型""" def lstm_cell(): # lstm核
return tf.contrib.rnn.BasicLSTMCell(self.config.hidden_dim, state_is_tuple=True) def gru_cell(): # gru核
return tf.contrib.rnn.GRUCell(self.config.hidden_dim) def dropout(): # 为每一个rnn核后面加一个dropout层
if (self.config.rnn == 'lstm'):
cell = lstm_cell()
else:
cell = gru_cell()
return tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=self.keep_prob) # 词向量映射
with tf.device('/cpu:0'):
embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])
embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x) with tf.name_scope("rnn"):
# 多层rnn网络
cells = [dropout() for _ in range(self.config.num_layers)]
rnn_cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True) _outputs, _ = tf.nn.dynamic_rnn(cell=rnn_cell, inputs=embedding_inputs, dtype=tf.float32)
last = _outputs[:, -1, :] # 取最后一个时序输出作为结果
# last = _outputs # 取最后一个时序输出作为结果 with tf.name_scope("score"):
# 全连接层,后面接dropout以及relu激活
fc = tf.layers.dense(last, self.config.hidden_dim, name='fc1')
fc = tf.contrib.layers.dropout(fc, self.keep_prob)
fc = tf.nn.relu(fc) # 分类器
self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2')
self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1) # 预测类别 with tf.name_scope("optimize"):
# 损失函数,交叉熵
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)
self.loss = tf.reduce_mean(cross_entropy)
# 优化器
self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss) with tf.name_scope("accuracy"):
# 准确率
correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls)
self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

关于tensorflow里面的tf.contrib.rnn.BasicLSTMCell 中num_units参数问题的更多相关文章

  1. tf.contrib.rnn.core_rnn_cell.BasicLSTMCell should be replaced by tf.contrib.rnn.BasicLSTMCell.

    For Tensorflow 1.2 and Keras 2.0, the line tf.contrib.rnn.core_rnn_cell.BasicLSTMCell should be repl ...

  2. tf.contrib.rnn.static_rnn与tf.nn.dynamic_rnn区别

    tf.contrib.rnn.static_rnn与tf.nn.dynamic_rnn区别 https://blog.csdn.net/u014365862/article/details/78238 ...

  3. tensorflow教程:tf.contrib.rnn.DropoutWrapper

    tf.contrib.rnn.DropoutWrapper Defined in tensorflow/python/ops/rnn_cell_impl.py. def __init__(self, ...

  4. 深度学习原理与框架-递归神经网络-RNN网络基本框架(代码?) 1.rnn.LSTMCell(生成单层LSTM) 2.rnn.DropoutWrapper(对rnn进行dropout操作) 3.tf.contrib.rnn.MultiRNNCell(堆叠多层LSTM) 4.mlstm_cell.zero_state(state初始化) 5.mlstm_cell(进行LSTM求解)

    问题:LSTM的输出值output和state是否是一样的 1. rnn.LSTMCell(num_hidden, reuse=tf.get_variable_scope().reuse)  # 构建 ...

  5. TensorFlow高级API(tf.contrib.learn)及可视化工具TensorBoard的使用

    一.TensorFlow高层次机器学习API (tf.contrib.learn) 1.tf.contrib.learn.datasets.base.load_csv_with_header 加载cs ...

  6. TensorFlow——tf.contrib.layers库中的相关API

    在TensorFlow中封装好了一个高级库,tf.contrib.layers库封装了很多的函数,使用这个高级库来开发将会提高效率,卷积函数使用tf.contrib.layers.conv2d,池化函 ...

  7. tf.contrib.rnn.LSTMCell 里面参数的意义

    num_units:LSTM cell中的单元数量,即隐藏层神经元数量.use_peepholes:布尔类型,设置为True则能够使用peephole连接cell_clip:可选参数,float类型, ...

  8. module 'tensorflow.contrib.rnn' has no attribute 'core_rnn_cell'

    #tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(lstm_size) tf.contrib.rnn.BasicLSTMCell(lstm_size)

  9. TF之RNN:实现利用scope.reuse_variables()告诉TF想重复利用RNN的参数的案例—Jason niu

    import tensorflow as tf # 22 scope (name_scope/variable_scope) from __future__ import print_function ...

随机推荐

  1. Java学习总结一 数据类型

    @Java300 学习总结 一.Java 基本数据类型分类如下: 整型变量:byte.short.int.long 浮点型变量:float.double 字符型变量:char 布尔型变量:boolea ...

  2. 少儿编程Scratch第三讲:宇宙大战.枪战游戏

    小朋友这周的表现还算不错.周末多数时间都由我陪(bi)着(zhe)做课本上的数学题,后来还学了英语.任重道远啊,语数外都还得加强,还远不到自己就能取得好成绩的阶段. 上周说好这周要做一个发射炮弹的游戏 ...

  3. configure,make和make install关系

    linux编译安装中configure.make和make install各自的作用 ./configure是用来检测你的安装平台的目标特征的.configure根据给定的参数和系统环境会生成Make ...

  4. BFS以及hash表判重的应用~

    主要还是讲下hash判重的问题吧 这道题目用的是除法求余散列方式 前几天看了下算法导论 由于我们用的是线性再寻址的方式来解决冲突问题 所以hash表的大小(余数的范围)要包含我们要求的范围 对mod的 ...

  5. 轻松搭建CAS 5.x系列(1)-使用cas overlay搭建SSO SERVER服务端

    概要说明 cas的服务端搭建有两种常用的方式:   1. 基于源码的基础上构建出来的   2. 使用WAR overlay的方式来安装 官方推荐使用第二种,配置管理方便,以后升级也容易.本文就是使用第 ...

  6. mac 在finder上面显示完成路径

    打开终端,输入以下命令并回车: defaults write com.apple.finder _FXShowPosixPathInTitle -bool YES 然后再把finder关了再打开,你会 ...

  7. docker第一篇 容器技术入门

    Container 容器是一种基础工具,泛指任何可以容纳其它物品的工具. Linux Namespaces (docker容器技术主要是通过6个隔离技术来实现) namespace    系统调用参数 ...

  8. mybatis generator代码生成器的使用

    一.有关mybatis generator的使用可以查看如下网址:http://www.mybatis.org/generator/index.html 二.如下是我自己整理的学习步骤: <1& ...

  9. QT版本下载链接

    http://download.qt.io/archive/qt/

  10. 【Struts2】 国际化

    一.概述 二.Struts2中国际化: 2.1 问题1 全局 局部 2.2 问题2 2.3 问题3 2.4 问题4 在Action中怎样使用 在JSP页面上怎样使用 一.概述 同一款软件 可以为不同用 ...