TF之RNN:实现利用scope.reuse_variables()告诉TF想重复利用RNN的参数的案例—Jason niu
import tensorflow as tf
# 22 scope (name_scope/variable_scope)
from __future__ import print_function class TrainConfig:
batch_size = 20
time_steps = 20
input_size = 10
output_size = 2
cell_size = 11
learning_rate = 0.01 class TestConfig(TrainConfig):
time_steps = 1 class RNN(object): def __init__(self, config):
self._batch_size = config.batch_size
self._time_steps = config.time_steps
self._input_size = config.input_size
self._output_size = config.output_size
self._cell_size = config.cell_size
self._lr = config.learning_rate
self._built_RNN() def _built_RNN(self):
with tf.variable_scope('inputs'):
self._xs = tf.placeholder(tf.float32, [self._batch_size, self._time_steps, self._input_size], name='xs')
self._ys = tf.placeholder(tf.float32, [self._batch_size, self._time_steps, self._output_size], name='ys')
with tf.name_scope('RNN'):
with tf.variable_scope('input_layer'):
l_in_x = tf.reshape(self._xs, [-1, self._input_size], name='2_2D') # (batch*n_step, in_size)
# Ws (in_size, cell_size)
Wi = self._weight_variable([self._input_size, self._cell_size])
print(Wi.name)
# bs (cell_size, )
bi = self._bias_variable([self._cell_size, ])
# l_in_y = (batch * n_steps, cell_size)
with tf.name_scope('Wx_plus_b'):
l_in_y = tf.matmul(l_in_x, Wi) + bi
l_in_y = tf.reshape(l_in_y, [-1, self._time_steps, self._cell_size], name='2_3D') with tf.variable_scope('cell'):
cell = tf.contrib.rnn.BasicLSTMCell(self._cell_size)
with tf.name_scope('initial_state'):
self._cell_initial_state = cell.zero_state(self._batch_size, dtype=tf.float32) self.cell_outputs = []
cell_state = self._cell_initial_state
for t in range(self._time_steps):
if t > 0: tf.get_variable_scope().reuse_variables()
cell_output, cell_state = cell(l_in_y[:, t, :], cell_state)
self.cell_outputs.append(cell_output)
self._cell_final_state = cell_state with tf.variable_scope('output_layer'):
# cell_outputs_reshaped (BATCH*TIME_STEP, CELL_SIZE)
cell_outputs_reshaped = tf.reshape(tf.concat(self.cell_outputs, 1), [-1, self._cell_size])
Wo = self._weight_variable((self._cell_size, self._output_size))
bo = self._bias_variable((self._output_size,))
product = tf.matmul(cell_outputs_reshaped, Wo) + bo
# _pred shape (batch*time_step, output_size)
self._pred = tf.nn.relu(product) # for displacement with tf.name_scope('cost'):
_pred = tf.reshape(self._pred, [self._batch_size, self._time_steps, self._output_size])
mse = self.ms_error(_pred, self._ys)
mse_ave_across_batch = tf.reduce_mean(mse, 0)
mse_sum_across_time = tf.reduce_sum(mse_ave_across_batch, 0)
self._cost = mse_sum_across_time
self._cost_ave_time = self._cost / self._time_steps with tf.variable_scope('trian'):
self._lr = tf.convert_to_tensor(self._lr)
self.train_op = tf.train.AdamOptimizer(self._lr).minimize(self._cost) @staticmethod
def ms_error(y_target, y_pre):
return tf.square(tf.subtract(y_target, y_pre)) @staticmethod
def _weight_variable(shape, name='weights'):
initializer = tf.random_normal_initializer(mean=0., stddev=0.5, )
return tf.get_variable(shape=shape, initializer=initializer, name=name) @staticmethod
def _bias_variable(shape, name='biases'):
initializer = tf.constant_initializer(0.1)
return tf.get_variable(name=name, shape=shape, initializer=initializer) if __name__ == '__main__':
train_config = TrainConfig() #定义train_config
test_config = TestConfig() # # the wrong method to reuse parameters in train rnn
# with tf.variable_scope('train_rnn'):
# train_rnn1 = RNN(train_config)
# with tf.variable_scope('test_rnn'):
# test_rnn1 = RNN(test_config) # the right method to reuse parameters in train rnn
#目的使train的RNN调用参数,然后利用variable_scope方法共享RNN,让test的RNN再次调用一样的参数,
with tf.variable_scope('rnn') as scope:
sess = tf.Session()
train_rnn2 = RNN(train_config)
scope.reuse_variables() #告诉TF想重复利用RNN的参数
test_rnn2 = RNN(test_config)
# tf.initialize_all_variables() no long valid from
# 2017-03-02 if using tensorflow >= 0.12
if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:
init = tf.initialize_all_variables()
else:
init = tf.global_variables_initializer()
sess.run(init)
TF之RNN:实现利用scope.reuse_variables()告诉TF想重复利用RNN的参数的案例—Jason niu的更多相关文章
- TF之RNN:TF的RNN中的常用的两种定义scope的方式get_variable和Variable—Jason niu
# tensorflow中的两种定义scope(命名变量)的方式tf.get_variable和tf.Variable.Tensorflow当中有两种途径生成变量 variable import te ...
- 深度学习原理与框架-递归神经网络-RNN_exmaple(代码) 1.rnn.BasicLSTMCell(构造基本网络) 2.tf.nn.dynamic_rnn(执行rnn网络) 3.tf.expand_dim(增加输入数据的维度) 4.tf.tile(在某个维度上按照倍数进行平铺迭代) 5.tf.squeeze(去除维度上为1的维度)
1. rnn.BasicLSTMCell(num_hidden) # 构造单层的lstm网络结构 参数说明:num_hidden表示隐藏层的个数 2.tf.nn.dynamic_rnn(cell, ...
- TF之RNN:matplotlib动态演示之基于顺序的RNN回归案例实现高效学习逐步逼近余弦曲线—Jason niu
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt BATCH_START = 0 TIME_STEP ...
- TF之RNN:基于顺序的RNN分类案例对手写数字图片mnist数据集实现高精度预测—Jason niu
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_dat ...
- TensorFlow RNN MNIST字符识别演示快速了解TF RNN核心框架
TensorFlow RNN MNIST字符识别演示快速了解TF RNN核心框架 http://blog.sina.com.cn/s/blog_4b0020f30102wv4l.html
- TF之RNN:TensorBoard可视化之基于顺序的RNN回归案例实现蓝色正弦虚线预测红色余弦实线—Jason niu
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt BATCH_START = 0 TIME_STEP ...
- TF:利用sklearn自带数据集使用dropout解决学习中overfitting的问题+Tensorboard显示变化曲线—Jason niu
import tensorflow as tf from sklearn.datasets import load_digits #from sklearn.cross_validation impo ...
- TF:Tensorflow结构简单应用,随机生成100个数,利用Tensorflow训练使其逼近已知线性直线的效率和截距—Jason niu
import os os.environ[' import tensorflow as tf import numpy as np x_data = np.random.rand(100).astyp ...
- 深度学习原理与框架-图像补全(原理与代码) 1.tf.nn.moments(求平均值和标准差) 2.tf.control_dependencies(先执行内部操作) 3.tf.cond(判别执行前或后函数) 4.tf.nn.atrous_conv2d 5.tf.nn.conv2d_transpose(反卷积) 7.tf.train.get_checkpoint_state(判断sess是否存在
1. tf.nn.moments(x, axes=[0, 1, 2]) # 对前三个维度求平均值和标准差,结果为最后一个维度,即对每个feature_map求平均值和标准差 参数说明:x为输入的fe ...
随机推荐
- Go 开源博客平台 Pipe 1.0.0 发布!
这是 Pipe 博客平台的第一个正式版,欢迎大家使用和反馈建议! 简介 Pipe 是一款小而美的开源博客平台,通过黑客派账号登录即可使用. 动机 产品层面: 市面上缺乏支持多独立博客的平台级系统 实现 ...
- Confluence 6 重构索引缓慢
你的索引构建是否需要很长时间?索引构建需要的时间是由下面的一些因素确定的: 你 Confluence 安装实例中的页面数量. 附件的数量,类型和大小. Confluence 安装实例可用的内存大小. ...
- 安装lrzsz 实现windows与linux之间文件互传
环境:CentOS7.4 执行命令安装: [root@linuxhg01 www]# yum install lrzsz rz // Windows 上传到 linux [root@linuxhg01 ...
- 闭包&装饰器
闭包 1.函数引用 def test(): print('--test--') # 调用函数 test() # 引用函数 ret = test print(id(ret)) print(id(test ...
- kali linux宿主机和虚拟机互访实现方案
1.攻防模拟中,将DVWA安装到自己的宿主机中,在kali Linux中通过sqlmap和其他工具启动嗅探攻击,需要配置网络.虚拟机采用桥接方式,并复制Mac地址状况. 2.查看各自系统下的IP地址. ...
- 广工赛-hdu6468构造十叉树
是个以前没见过的模板题.. 我用比较复杂度方式过掉了.. 构造一个十叉树(有点trie的味道)来存数字,然后字典序就是先序遍历的结果 #include<bits/stdc++.h> usi ...
- springboot linux启动方式
手动启动 java -Xms128m -Xmx256m -Xdebug -Xrunjdwp:server=y,transport=dt_socket,address=8081,suspend=n -j ...
- C++ Primer 笔记——union
1.union是一种特殊的类.一个union可以有多个数据成员,但是在任意时刻,只有一个数据成员可以有值.当我们给union的某个成员赋值之后,该union的其他成员就变成未定义的状态了.分配给一个u ...
- 俺也会刷机啦--windows7下刷android
刷机很多人都会,本文只为像我这种入门的朋友而写的. 风险提示: 1. SD卡数据极可能会丢失(我这次就全丢了). 2. 升级失败. (俺的)环境说明: windows7 专业版64位 cmd命令行工具 ...
- WARN util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable Exception in thread "main" java.io.IOException: No FileSystem for sc F
1.执行脚本程序报如下所示的错误: [hadoop@slaver1 script_hadoop]$ hadoop jar web_click_mr_hive.jar com.bie.hive.mr.C ...