Tensorflow实例:利用LSTM预测股票每日最高价(一)
RNN与LSTM
这一部分主要涉及循环神经网络的理论,讲的可能会比较简略。
什么是RNN
RNN全称循环神经网络(Recurrent Neural Networks),是用来处理序列数据的。在传统的神经网络模型中,从输入层到隐含层再到输出层,层与层之间是全连接的,每层之间的节点是无连接的。但是这种普通的神经网络对于很多关于时间序列的问题却无能无力。例如,你要预测句子的下一个单词是什么,一般需要用到前面的单词,因为一个句子中前后单词并不是独立的。RNN之所以称为循环神经网路,即一个序列当前的输出与前面的输出也有关。具体的表现形式为网络会对前面时刻的信息进行记忆并应用于当前输出的计算中,即隐藏层之间的节点不再无连接而是有连接的,并且隐藏层的输入不仅包括输入层的输出还包括上一时刻隐藏层的输出。
说了这么多,用一张图表示,就是这个样子。
传统的神经网络中,数据从输入层输入,在隐藏层加工,从输出层输出。RNN不同的就是在隐藏层的加工方法不一样,后一个节点不仅受输入层输入的影响,还包受上一个节点的影响。
展开来就是这个样子:
图中的xt−1 ,xt , xt+1就是不同时刻的输入,每个x都具有input layer的n维特征,依次进入循环神经网络以后,隐藏层输出st受到上一时刻st−1的隐藏层输出以及此刻输入层输入xt 的两方影响。
如果要更详细地了解tensorflow对RNN的解释,清戳官方tensorflow.RNN
另外推荐的学习资料:WildML
什么是LSTM
LSTM全称长短期记忆人工神经网络(Long-Short Term Memory),是对RNN的变种。举个例子,假设我们试着去预测“I grew up in France… 中间隔了好多好多字……I speak fluent __”下划线的词。我们拍脑瓜子想这个词应该是French。对于循环神经网络来说,当前的信息建议下一个词可能是一种语言的名字,但是如果需要弄清楚是什么语言,我们是需要离当前下划线位置很远的“France” 这个词信息。相关信息和当前预测位置之间的间隔变得相当的大,在这个间隔不断增大时,RNN 会丧失学习到连接如此远的信息的能力。
这个时候就需要LSTM登场了。在LSTM中,我们可以控制丢弃什么信息,存放什么信息。
具体的理论这里就不多说了,推荐一篇博文Understanding LSTM Networks,里面有对LSTM详细的介绍,有网友作出的翻译请戳[译] 理解 LSTM 网络
股票预测
在对理论有理解的基础上,我们使用LSTM对股票每日最高价进行预测。在本例中,仅使用一维特征。
数据格式如下:
本例取每日最高价作为输入特征[x],后一天的最高价最为标签[y]
获取数据,请戳stock_dataset.csv,密码:md9l
导入数据:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow
f=open('stock_dataset.csv')
df=pd.read_csv(f) #读入股票数据
data=np.array(df['最高价']) #获取最高价序列
data=data[::-1] #反转,使数据按照日期先后顺序排列
#以折线图展示data
plt.figure()
plt.plot(data)
plt.show()
normalize_data=(data-np.mean(data))/np.std(data) #标准化
normalize_data=normalize_data[:,np.newaxis] #增加维度
#———————————————————形成训练集—————————————————————
#设置常量
time_step=20 #时间步
rnn_unit=10 #hidden layer units
batch_size=60 #每一批次训练多少个样例
input_size=1 #输入层维度
output_size=1 #输出层维度
lr=0.0006 #学习率
train_x,train_y=[],[] #训练集
for i in range(len(normalize_data)-time_step-1):
x=normalize_data[i:i+time_step]
y=normalize_data[i+1:i+time_step+1]
train_x.append(x.tolist())
train_y.append(y.tolist())
出来的train_x就是像这个样子:
[[[-1.59618],……中间还有18个……, [-1.56340]]
……
[[-1.59202] [-1.58244]]]
是一个shape为[-1,time_step,input__size]的矩阵
定义神经网络变量
X=tf.placeholder(tf.float32, [None,time_step,input_size]) #每批次输入网络的tensor
Y=tf.placeholder(tf.float32, [None,time_step,output_size]) #每批次tensor对应的标签 #输入层、输出层权重、偏置
weights={
'in':tf.Variable(tf.random_normal([input_size,rnn_unit])),
'out':tf.Variable(tf.random_normal([rnn_unit,1]))
}
biases={
'in':tf.Variable(tf.constant(0.1,shape=[rnn_unit,])),
'out':tf.Variable(tf.constant(0.1,shape=[1,]))
}
定义lstm网络
def lstm(batch): #参数:输入网络批次数目
w_in=weights['in']
b_in=biases['in']
input=tf.reshape(X,[-1,input_size]) #需要将tensor转成2维进行计算,计算后的结果作为隐藏层的输入
input_rnn=tf.matmul(input,w_in)+b_in
input_rnn=tf.reshape(input_rnn,[-1,time_step,rnn_unit]) #将tensor转成3维,作为lstm cell的输入
cell=tf.nn.rnn_cell.BasicLSTMCell(rnn_unit)
init_state=cell.zero_state(batch,dtype=tf.float32)
output_rnn,final_states=tf.nn.dynamic_rnn(cell, input_rnn,initial_state=init_state, dtype=tf.float32) #output_rnn是记录lstm每个输出节点的结果,final_states是最后一个cell的结果
output=tf.reshape(output_rnn,[-1,rnn_unit]) #作为输出层的输入
w_out=weights['out']
b_out=biases['out']
pred=tf.matmul(output,w_out)+b_out
return pred,final_states
训练模型
def train_lstm():
global batch_size
pred,_=rnn(batch_size)
#损失函数
loss=tf.reduce_mean(tf.square(tf.reshape(pred,[-1])-tf.reshape(Y, [-1])))
train_op=tf.train.AdamOptimizer(lr).minimize(loss)
saver=tf.train.Saver(tf.global_variables())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
#重复训练10000次
for i in range(10000):
step=0
start=0
end=start+batch_size
while(end<len(train_x)):
_,loss_=sess.run([train_op,loss],feed_dict={X:train_x[start:end],Y:train_y[start:end]})
start+=batch_size
end=start+batch_size
#每10步保存一次参数
if step%10==0:
print(i,step,loss_)
print("保存模型:",saver.save(sess,'stock.model'))
step+=1
预测模型
def prediction():
pred,_=lstm(1) #预测时只输入[1,time_step,input_size]的测试数据
saver=tf.train.Saver(tf.global_variables())
with tf.Session() as sess:
#参数恢复
module_file = tf.train.latest_checkpoint(base_path+'module2/')
saver.restore(sess, module_file)
#取训练集最后一行为测试样本。shape=[1,time_step,input_size]
prev_seq=train_x[-1]
predict=[]
#得到之后100个预测结果
for i in range(100):
next_seq=sess.run(pred,feed_dict={X:[prev_seq]})
predict.append(next_seq[-1])
#每次得到最后一个时间步的预测结果,与之前的数据加在一起,形成新的测试样本
prev_seq=np.vstack((prev_seq[1:],next_seq[-1]))
#以折线图表示结果
plt.figure()
plt.plot(list(range(len(normalize_data))), normalize_data, color='b')
plt.plot(list(range(len(normalize_data), len(normalize_data) + len(predict))), predict, color='r')
plt.show()
代码
这一讲只有把最高价作为特征,去预测之后的最高价趋势,下一讲会增加输入的特征维度,把最低价、开盘价、收盘价、交易额等作为输入的特征对之后的最高价进行预测。
注:本文在介绍RNN和LSTM的部分,出处若涉及版权问题或原文链接错误,请指正,必会马上修改。
Tensorflow实例:利用LSTM预测股票每日最高价(一)的更多相关文章
- AI金融:利用LSTM预测股票每日最高价
第一部分:从RNN到LSTM 1.什么是RNN RNN全称循环神经网络(Recurrent Neural Networks),是用来处理序列数据的.在传统的神经网络模型中,从输入层到隐含层再到输出层, ...
- AI金融:LSTM预测股票
第一部分:从RNN到LSTM 1.什么是RNN RNN全称循环神经网络(Recurrent Neural Networks),是用来处理序列数据的.在传统的神经网络模型中,从输入层到隐含层再到输出层, ...
- 20岁少年小伙利用Python_SVM预测股票趋势月入十万!
在做数据预处理的时候,超额收益率是股票行业里的一个专有名词,指大于无风险投资的收益率,在我国无风险投资收益率即是银行定期存款. pycharm + anaconda3.6开发,涉及到的第三方库有p ...
- 矩池云 | 利用LSTM框架实时预测比特币价格
温馨提示:本案例只作为学习研究用途,不构成投资建议. 比特币的价格数据是基于时间序列的,因此比特币的价格预测大多采用LSTM模型来实现. 长期短期记忆(LSTM)是一种特别适用于时间序列数据(或具有时 ...
- 作为深度学习最强框架的TensorFlow如何进行时序预测!(转)
作为深度学习最强框架的TensorFlow如何进行时序预测! BigQuant 2 个月前 摘要: 2017年深度学习框架关注度排名tensorflow以绝对的优势占领榜首,本文通过一个小例子介绍了T ...
- 利用JFreeChart绘制股票K线图完整解决方案
http://blog.sina.com.cn/s/blog_4ad042e50100q7d9.html 利用JFreeChart绘制股票K线图完整解决方案 (2011-04-30 13:27:17) ...
- 深度神经网络在量化交易里的应用 之二 -- 用深度网络(LSTM)预测5日收盘价格
距离上一篇文章,正好两个星期. 这边文章9月15日 16:30 开始写. 可能几个小时后就写完了.用一句粗俗的话说, "当你怀孕的时候,别人都知道你怀孕了, 但不知道你被日了多少回 ...
- tensorflow实现基于LSTM的文本分类方法
tensorflow实现基于LSTM的文本分类方法 作者:u010223750 引言 学习一段时间的tensor flow之后,想找个项目试试手,然后想起了之前在看Theano教程中的一个文本分类的实 ...
- (转)干货|这篇TensorFlow实例教程文章告诉你GANs为何引爆机器学习?(附源码)
干货|这篇TensorFlow实例教程文章告诉你GANs为何引爆机器学习?(附源码) 该博客来源自:https://mp.weixin.qq.com/s?__biz=MzA4NzE1NzYyMw==& ...
随机推荐
- Angular基础(四) 创建Angular应用
应用(Application)是由组件构成的树.树的根部是最顶层的组件即应用本身,启动的时候,浏览器会最先渲染顶层组件,然后根据树形结构,迭代渲染子组件.组件是可装配的,可以互相组合以构成更大的组件. ...
- 喜闻乐见-Activity生命周期
Activity的生命周期,对于Android开发者来说,再熟悉不过了.但是我们接触到的资料,绝大部分都只是谈了一些表面上的东西,例如各个回调的顺序等等.本文试图换个角度来讲解,也希望对各位读者有所帮 ...
- 为什么 APM 能提升 IT 团队工作质量?
“有必要吗?”这是很多 IT 专业人员在尝试向团队内部推荐应用程序性能管理价值时所面临的问题.APM(应用程序性能管理)能为公司节约成本,提高内部工作效率,并真实了解用户对公司的系统和产品是否满意.除 ...
- sysfs_create_group创建sysfs接口
在调试驱动,可能需要对驱动里的某些变量进行读写,或函数调用.可通过sysfs接口创建驱动对应的属性,使得可以在用户空间通过sysfs接口的show和store函数与硬件交互: Syss接口可通过sys ...
- phoneGap使用 (MAC)
一.安装 ①先安装NodeJS(如果有的就不用安装了) http://nodejs.org/ ②.sudo npm install -g phonegap 需要等待安装完成 ③.检测是否安装成功 no ...
- Codeforces gym 101343 A. On The Way to Lucky Plaza【概率+逆元+精度问题】
2017 JUST Programming Contest 2.0 题目链接:http://codeforces.com/gym/101343/problem/A A. On The Way to ...
- 【TJOJI\HEOI2016】求和
[TJOI/HEOI2016]求和 这题好难啊!! 斯特林数+NTT. 首先我们将第二类斯特林数用容斥展开,具体原理不解释了. \(\displaystyle S(i,j)=\frac{1}{j!}\ ...
- 解决MySQL Workbench导出乱码问题
1.导出数据 2.默认CSV格式 3.乱码 4.解决 文件->另存为,会发现编码为UTF-8,正是MySQL表的编码方式.我们选择编码方式为ANSI,保存类型为所有,覆盖源文件
- javascript中DOM0,DOM2,DOM3级事件模型解析
DOM 即 文档对象模型. 文档对象模型是一种与编程语言及平台无关的API(Application programming Interface),借助于它,程序能够动态地访问和修改文档内容.结构或显示 ...
- CF1045G:AI robots(CDQ分治)
Description 火星上有$n$个机器人排成一行,第$i$个机器人的位置为$x_i$,视野为$r_i$,智商为$q_i$.我们认为第$i$个机器人可以看到的位置是$[x_i−r_i,x_i+ ...