import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt # 定义RNN的参数。
HIDDEN_SIZE = 30 # LSTM中隐藏节点的个数。
NUM_LAYERS = 2 # LSTM的层数。
TIMESTEPS = 10 # 循环神经网络的训练序列长度。
TRAINING_STEPS = 10000 # 训练轮数。
BATCH_SIZE = 32 # batch大小。
TRAINING_EXAMPLES = 10000 # 训练数据个数。
TESTING_EXAMPLES = 1000 # 测试数据个数。
SAMPLE_GAP = 0.01 # 采样间隔。
# 产生正弦数据。
def generate_data(seq):
X = []
y = []
# 序列的第i项和后面的TIMESTEPS-1项合在一起作为输入;第i + TIMESTEPS项作为输
# 出。即用sin函数前面的TIMESTEPS个点的信息,预测第i + TIMESTEPS个点的函数值。
for i in range(len(seq) - TIMESTEPS):
X.append([seq[i: i + TIMESTEPS]])
y.append([seq[i + TIMESTEPS]])
return np.array(X, dtype=np.float32), np.array(y, dtype=np.float32) # 用正弦函数生成训练和测试数据集合。
test_start = (TRAINING_EXAMPLES + TIMESTEPS) * SAMPLE_GAP
test_end = test_start + (TESTING_EXAMPLES + TIMESTEPS) * SAMPLE_GAP
train_X, train_y = generate_data(np.sin(np.linspace(0, test_start, TRAINING_EXAMPLES + TIMESTEPS, dtype=np.float32)))
test_X, test_y = generate_data(np.sin(np.linspace(test_start, test_end, TESTING_EXAMPLES + TIMESTEPS, dtype=np.float32)))
#  定义网络结构和优化步骤。
def lstm_model(X, y, is_training):
# 使用多层的LSTM结构。
cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE) for _ in range(NUM_LAYERS)]) # 使用TensorFlow接口将多层的LSTM结构连接成RNN网络并计算其前向传播结果。
outputs, _ = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)
output = outputs[:, -1, :] # 对LSTM网络的输出再做加一层全链接层并计算损失。注意这里默认的损失为平均
# 平方差损失函数。
predictions = tf.contrib.layers.fully_connected(output, 1, activation_fn=None) # 只在训练时计算损失函数和优化步骤。测试时直接返回预测结果。
if not is_training:
return predictions, None, None # 计算损失函数。
loss = tf.losses.mean_squared_error(labels=y, predictions=predictions) # 创建模型优化器并得到优化步骤。
train_op = tf.contrib.layers.optimize_loss(loss, tf.train.get_global_step(),optimizer="Adagrad", learning_rate=0.1)
return predictions, loss, train_op
# 定义测试方法。
def run_eval(sess, test_X, test_y):
# 将测试数据以数据集的方式提供给计算图。
ds = tf.data.Dataset.from_tensor_slices((test_X, test_y))
ds = ds.batch(1)
X, y = ds.make_one_shot_iterator().get_next() # 调用模型得到计算结果。这里不需要输入真实的y值。
with tf.variable_scope("model", reuse=True):
prediction, _, _ = lstm_model(X, [0.0], False) # 将预测结果存入一个数组。
predictions = []
labels = []
for i in range(TESTING_EXAMPLES):
p, l = sess.run([prediction, y])
predictions.append(p)
labels.append(l) # 计算rmse作为评价指标。
predictions = np.array(predictions).squeeze()
labels = np.array(labels).squeeze()
rmse = np.sqrt(((predictions - labels) ** 2).mean(axis=0))
print("Root Mean Square Error is: %f" % rmse) #对预测的sin函数曲线进行绘图。
plt.figure()
plt.plot(predictions, label='predictions')
plt.plot(labels, label='real_sin')
plt.legend()
plt.show()
#  执行训练和测试。
# 将训练数据以数据集的方式提供给计算图。
ds = tf.data.Dataset.from_tensor_slices((train_X, train_y))
ds = ds.repeat().shuffle(1000).batch(BATCH_SIZE)
X, y = ds.make_one_shot_iterator().get_next() # 定义模型,得到预测结果、损失函数,和训练操作。
with tf.variable_scope("model"):
_, loss, train_op = lstm_model(X, y, True) with tf.Session() as sess:
sess.run(tf.global_variables_initializer()) # 测试在训练之前的模型效果。
print("Evaluate model before training.")
run_eval(sess, test_X, test_y) # 训练模型。
for i in range(TRAINING_STEPS):
_, l = sess.run([train_op, loss])
if i % 1000 == 0:
print("train step: " + str(i) + ", loss: " + str(l)) # 使用训练好的模型对测试数据进行预测。
print("Evaluate model after training.")
run_eval(sess, test_X, test_y)

吴裕雄--天生自然 pythonTensorFlow图形数据处理:循环神经网络预测正弦函数的更多相关文章

  1. 吴裕雄--天生自然 pythonTensorFlow图形数据处理:数据集高层操作

    import tempfile import tensorflow as tf # 1. 列举输入文件. # 输入数据生成的训练和测试数据. train_files = tf.train.match_ ...

  2. 吴裕雄--天生自然 pythonTensorFlow图形数据处理:输入数据处理框架

    import tensorflow as tf # 1. 创建文件列表,通过文件列表创建输入文件队列 files = tf.train.match_filenames_once("F:\\o ...

  3. 吴裕雄--天生自然 pythonTensorFlow图形数据处理:数据集基本使用方法

    import tempfile import tensorflow as tf # 1. 从数组创建数据集. input_data = [1, 2, 3, 5, 8] dataset = tf.dat ...

  4. 吴裕雄--天生自然 pythonTensorFlow图形数据处理:输入文件队列

    import tensorflow as tf # 1. 生成文件存储样例数据. def _int64_feature(value): return tf.train.Feature(int64_li ...

  5. 吴裕雄--天生自然 pythonTensorFlow图形数据处理:多线程队列操作

    import tensorflow as tf #1. 定义队列及其操作. queue = tf.FIFOQueue(100,"float") enqueue_op = queue ...

  6. 吴裕雄--天生自然 pythonTensorFlow图形数据处理:队列操作

    import tensorflow as tf #1. 创建队列,并操作里面的元素. q = tf.FIFOQueue(2, "int32") init = q.enqueue_m ...

  7. 吴裕雄--天生自然 pythonTensorFlow图形数据处理:图像预处理完整样例

    import numpy as np import tensorflow as tf import matplotlib.pyplot as plt #随机调整图片的色彩,定义两种顺序. def di ...

  8. 吴裕雄--天生自然 pythonTensorFlow图形数据处理:TensorFlow图像处理函数

    import numpy as np import tensorflow as tf import matplotlib.pyplot as plt #读取图片 image_raw_data = tf ...

  9. 吴裕雄--天生自然 pythonTensorFlow图形数据处理:读取MNIST手写图片数据写入的TFRecord文件

    import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...

随机推荐

  1. 完美的代价(swap成回文串、贪心)

    Description 回文串,是一种特殊的字符串,它从左往右读和从右往左读是一样的.小龙龙认为回文串才是完美的. 现在给你一个串,它不一定是回文的,请你计算最少的交换次数使得该串变成一个完美的回文串 ...

  2. UVALive 4043 转化最佳完美匹配

    首先黑点和白点是组成一个二分图这毫无疑问 关键是题目中要求的所有黑白配的线不能交叉...一开始我也没想到这个怎么转化为二分图里面的算法. 后来看书才知道,如果两两交叉,则可以把两根线当四边形的对角线, ...

  3. python脚本下载 Google Driver 文件

    使用python脚本下载 Google Driver 文件 import yaml import sys import requests import os import re import tarf ...

  4. 一个例子搞清楚Java程序执行顺序

    当我们new一个GirlFriend时,我们都做了什么? 一个例子搞懂Java程序运行顺序 public class Girl { Person person = new Person("G ...

  5. redis(四)----发布订阅

    发布订阅(pub/sub)是一种消息通信模式,主要的目的是解耦消息发布者和消息订阅者之间的耦合.pub /sub不仅仅解决发布者和订阅者直接代码级别耦合,也解决两者在物理部署上的耦合.废话不多说,直接 ...

  6. POJ - 1742 Coins(dp---多重背包)

    题意:给定n种硬币的价值和数量,问能组成1~m中多少种面值. 分析: 1.dp[j]表示当前用了前i种硬币的情况下,可以组成面值j. 2.eg: 3 10 1 3 4 2 3 1 (1)使用第1种硬币 ...

  7. POJ 1325 && 1274:Machine Schedule 匈牙利算法模板题

    Machine Schedule Time Limit: 1000MS   Memory Limit: 10000K Total Submissions: 12976   Accepted: 5529 ...

  8. trove module使用说明

    原文来自:https://github.com/openstack/openstack-manuals/blob/master/doc/user-guide/source/database-modul ...

  9. iOS部分页面横屏显示

    在iOS系统支持横屏顺序默认读取plist里面设置的方向(优先级最高)等同于Xcode Geneal设置里面勾选application window设置的级别次之 然后是UINavigationcon ...

  10. 当切换用户时出现-bash-4.1$

    问题重现 [root@localhost ~]# su - yh -bash-4.1$ -bash-4.1$ -bash-4.1$ -bash-4.1$ -bash-4.1$ cd /home -ba ...