import tensorflow as tf
# 22 scope (name_scope/variable_scope)
from __future__ import print_function class TrainConfig:
batch_size = 20
time_steps = 20
input_size = 10
output_size = 2
cell_size = 11
learning_rate = 0.01 class TestConfig(TrainConfig):
time_steps = 1 class RNN(object): def __init__(self, config):
self._batch_size = config.batch_size
self._time_steps = config.time_steps
self._input_size = config.input_size
self._output_size = config.output_size
self._cell_size = config.cell_size
self._lr = config.learning_rate
self._built_RNN() def _built_RNN(self):
with tf.variable_scope('inputs'):
self._xs = tf.placeholder(tf.float32, [self._batch_size, self._time_steps, self._input_size], name='xs')
self._ys = tf.placeholder(tf.float32, [self._batch_size, self._time_steps, self._output_size], name='ys')
with tf.name_scope('RNN'):
with tf.variable_scope('input_layer'):
l_in_x = tf.reshape(self._xs, [-1, self._input_size], name='2_2D') # (batch*n_step, in_size)
# Ws (in_size, cell_size)
Wi = self._weight_variable([self._input_size, self._cell_size])
print(Wi.name)
# bs (cell_size, )
bi = self._bias_variable([self._cell_size, ])
# l_in_y = (batch * n_steps, cell_size)
with tf.name_scope('Wx_plus_b'):
l_in_y = tf.matmul(l_in_x, Wi) + bi
l_in_y = tf.reshape(l_in_y, [-1, self._time_steps, self._cell_size], name='2_3D') with tf.variable_scope('cell'):
cell = tf.contrib.rnn.BasicLSTMCell(self._cell_size)
with tf.name_scope('initial_state'):
self._cell_initial_state = cell.zero_state(self._batch_size, dtype=tf.float32) self.cell_outputs = []
cell_state = self._cell_initial_state
for t in range(self._time_steps):
if t > 0: tf.get_variable_scope().reuse_variables()
cell_output, cell_state = cell(l_in_y[:, t, :], cell_state)
self.cell_outputs.append(cell_output)
self._cell_final_state = cell_state with tf.variable_scope('output_layer'):
# cell_outputs_reshaped (BATCH*TIME_STEP, CELL_SIZE)
cell_outputs_reshaped = tf.reshape(tf.concat(self.cell_outputs, 1), [-1, self._cell_size])
Wo = self._weight_variable((self._cell_size, self._output_size))
bo = self._bias_variable((self._output_size,))
product = tf.matmul(cell_outputs_reshaped, Wo) + bo
# _pred shape (batch*time_step, output_size)
self._pred = tf.nn.relu(product) # for displacement with tf.name_scope('cost'):
_pred = tf.reshape(self._pred, [self._batch_size, self._time_steps, self._output_size])
mse = self.ms_error(_pred, self._ys)
mse_ave_across_batch = tf.reduce_mean(mse, 0)
mse_sum_across_time = tf.reduce_sum(mse_ave_across_batch, 0)
self._cost = mse_sum_across_time
self._cost_ave_time = self._cost / self._time_steps with tf.variable_scope('trian'):
self._lr = tf.convert_to_tensor(self._lr)
self.train_op = tf.train.AdamOptimizer(self._lr).minimize(self._cost) @staticmethod
def ms_error(y_target, y_pre):
return tf.square(tf.subtract(y_target, y_pre)) @staticmethod
def _weight_variable(shape, name='weights'):
initializer = tf.random_normal_initializer(mean=0., stddev=0.5, )
return tf.get_variable(shape=shape, initializer=initializer, name=name) @staticmethod
def _bias_variable(shape, name='biases'):
initializer = tf.constant_initializer(0.1)
return tf.get_variable(name=name, shape=shape, initializer=initializer) if __name__ == '__main__':
train_config = TrainConfig() #定义train_config
test_config = TestConfig() # # the wrong method to reuse parameters in train rnn
# with tf.variable_scope('train_rnn'):
# train_rnn1 = RNN(train_config)
# with tf.variable_scope('test_rnn'):
# test_rnn1 = RNN(test_config) # the right method to reuse parameters in train rnn
#目的使train的RNN调用参数,然后利用variable_scope方法共享RNN,让test的RNN再次调用一样的参数,
with tf.variable_scope('rnn') as scope:
sess = tf.Session()
train_rnn2 = RNN(train_config)
scope.reuse_variables() #告诉TF想重复利用RNN的参数
test_rnn2 = RNN(test_config)
# tf.initialize_all_variables() no long valid from
# 2017-03-02 if using tensorflow >= 0.12
if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:
init = tf.initialize_all_variables()
else:
init = tf.global_variables_initializer()
sess.run(init)

  

TF之RNN:实现利用scope.reuse_variables()告诉TF想重复利用RNN的参数的案例—Jason niu的更多相关文章

  1. TF之RNN:TF的RNN中的常用的两种定义scope的方式get_variable和Variable—Jason niu

    # tensorflow中的两种定义scope(命名变量)的方式tf.get_variable和tf.Variable.Tensorflow当中有两种途径生成变量 variable import te ...

  2. 深度学习原理与框架-递归神经网络-RNN_exmaple(代码) 1.rnn.BasicLSTMCell(构造基本网络) 2.tf.nn.dynamic_rnn(执行rnn网络) 3.tf.expand_dim(增加输入数据的维度) 4.tf.tile(在某个维度上按照倍数进行平铺迭代) 5.tf.squeeze(去除维度上为1的维度)

    1. rnn.BasicLSTMCell(num_hidden) #  构造单层的lstm网络结构 参数说明:num_hidden表示隐藏层的个数 2.tf.nn.dynamic_rnn(cell, ...

  3. TF之RNN:matplotlib动态演示之基于顺序的RNN回归案例实现高效学习逐步逼近余弦曲线—Jason niu

    import tensorflow as tf import numpy as np import matplotlib.pyplot as plt BATCH_START = 0 TIME_STEP ...

  4. TF之RNN:基于顺序的RNN分类案例对手写数字图片mnist数据集实现高精度预测—Jason niu

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_dat ...

  5. TensorFlow RNN MNIST字符识别演示快速了解TF RNN核心框架

    TensorFlow RNN MNIST字符识别演示快速了解TF RNN核心框架 http://blog.sina.com.cn/s/blog_4b0020f30102wv4l.html

  6. TF之RNN:TensorBoard可视化之基于顺序的RNN回归案例实现蓝色正弦虚线预测红色余弦实线—Jason niu

    import tensorflow as tf import numpy as np import matplotlib.pyplot as plt BATCH_START = 0 TIME_STEP ...

  7. TF:利用sklearn自带数据集使用dropout解决学习中overfitting的问题+Tensorboard显示变化曲线—Jason niu

    import tensorflow as tf from sklearn.datasets import load_digits #from sklearn.cross_validation impo ...

  8. TF:Tensorflow结构简单应用,随机生成100个数,利用Tensorflow训练使其逼近已知线性直线的效率和截距—Jason niu

    import os os.environ[' import tensorflow as tf import numpy as np x_data = np.random.rand(100).astyp ...

  9. 深度学习原理与框架-图像补全(原理与代码) 1.tf.nn.moments(求平均值和标准差) 2.tf.control_dependencies(先执行内部操作) 3.tf.cond(判别执行前或后函数) 4.tf.nn.atrous_conv2d 5.tf.nn.conv2d_transpose(反卷积) 7.tf.train.get_checkpoint_state(判断sess是否存在

    1. tf.nn.moments(x, axes=[0, 1, 2])  # 对前三个维度求平均值和标准差,结果为最后一个维度,即对每个feature_map求平均值和标准差 参数说明:x为输入的fe ...

随机推荐

  1. 注册InstallShield

    安装InstallShield 下载installshield limitededition版本,这个版本是免费的 注册 安装打开后会给一个网址要求进行注册 其中,国籍是必填项但是下拉菜单中没有内容, ...

  2. C语言学习及应用笔记之二:C语言static关键字及其使用

    C语言有很多关键字,大多关键字使用起来是很明确的,但有一些关键字却要相对复杂一些.我们这里要说明的static关键字就是如此,它的功能很强大,相应的使用也就更复杂. 一般来说static关键字的常见用 ...

  3. checkstyle.xml Code Style for Eclipse

    1. Code Templates [下载 Code Templates] 打开 Eclipse -> Window -> Preferences -> Java -> Cod ...

  4. Zookeeper安装(本地,伪分布式,集群)

    概述 ZooKeeper是一个分布式开源框架,提供了协调分布式应用的基本服务,它向外部应用暴露一组通用服务——分布式同步(Distributed Synchronization).命名服务(Namin ...

  5. LeetCode(87):扰乱字符串

    Hard! 题目描述: 给定一个字符串 s1,我们可以把它递归地分割成两个非空子字符串,从而将其表示为二叉树. 下图是字符串 s1 = "great" 的一种可能的表示形式. gr ...

  6. 性能测试五十:Jmeter+Influxdb+Grafana实时数据展示系统搭建

    如果用生成jtl文件再分析结果的方式的话,每一次请求就会往jtl里面写一条数据,在进行长时间的稳定性测试的时候,特别是当TPS很高的时候,写入的数据会非常的大,这个时候等稳定性测试完成,再对jtl进行 ...

  7. 性能测试四十二:sql案例之联合索引最左前缀

    联合索引:一个索引同时作用于多个字段 联合索引的最左前缀: A.B.C3个字段--联合索引 这个时候,可以使用的查询条件有:A.A+B.A+C.A+B+C,唯独不能使用B+C,即最左侧那个字段必须匹配 ...

  8. mysql下载源码方法

    方法一 进入mysql官网:http://dev.mysql.com/downloads/mysql/ 选择相关的平台下载: 3.选择Source Code 选型后,拉倒网页下方,选择要下载的源码包 ...

  9. Java接口自动化测试之TestNG测试报告ExtentReports的应用(三)

    pom.xml导入包 <?xml version="1.0" encoding="UTF-8"?> <project xmlns=" ...

  10. Selenium+PhantomJS使用时报错原因及解决方案

    问题 今天在使用selenium+PhantomJS动态抓取网页时,出现如下报错信息: UserWarning: Selenium support for PhantomJS has been dep ...