import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf #——————————————————导入数据——————————————————————
f=open('./dataset/dataset_1.csv')
df=pd.read_csv(f) #读入股票数据
data=np.array(df['最高价']) #获取最高价序列
data=data[::-1] #反转,使数据按照日期先后顺序排列
#以折线图展示data
# plt.figure()
# plt.plot(data)
# plt.show()
normalize_data=(data-np.mean(data))/np.std(data) #标准化
normalize_data=normalize_data[:,np.newaxis] #增加维度 #生成训练集
#设置常量
time_step=20 #时间步
rnn_unit=10 #hidden layer units
batch_size=60 #每一批次训练多少个样例
input_size=1 #输入层维度
output_size=1 #输出层维度
lr=0.0006 #学习率
train_x,train_y=[],[] #训练集
for i in range(len(normalize_data)-time_step-1):
x=normalize_data[i:i+time_step]
y=normalize_data[i+1:i+time_step+1]
train_x.append(x.tolist())
train_y.append(y.tolist()) #——————————————————定义神经网络变量——————————————————
X=tf.placeholder(tf.float32, [None,time_step,input_size]) #每批次输入网络的tensor
Y=tf.placeholder(tf.float32, [None,time_step,output_size]) #每批次tensor对应的标签
#输入层、输出层权重、偏置
weights={
'in':tf.Variable(tf.random_normal([input_size,rnn_unit])),
'out':tf.Variable(tf.random_normal([rnn_unit,1]))
}
biases={
'in':tf.Variable(tf.constant(0.1,shape=[rnn_unit,])),
'out':tf.Variable(tf.constant(0.1,shape=[1,]))
} #——————————————————定义神经网络变量——————————————————
def lstm(batch): #参数:输入网络批次数目
w_in=weights['in']
b_in=biases['in']
input=tf.reshape(X,[-1,input_size]) #需要将tensor转成2维进行计算,计算后的结果作为隐藏层的输入
input_rnn=tf.matmul(input,w_in)+b_in
input_rnn=tf.reshape(input_rnn,[-1,time_step,rnn_unit]) #将tensor转成3维,作为lstm cell的输入
cell=tf.nn.rnn_cell.BasicLSTMCell(rnn_unit)
init_state=cell.zero_state(batch,dtype=tf.float32)
output_rnn,final_states=tf.nn.dynamic_rnn(cell, input_rnn,initial_state=init_state, dtype=tf.float32) #output_rnn是记录lstm每个输出节点的结果,final_states是最后一个cell的结果
output=tf.reshape(output_rnn,[-1,rnn_unit]) #作为输出层的输入
w_out=weights['out']
b_out=biases['out']
pred=tf.matmul(output,w_out)+b_out
return pred,final_states #——————————————————训练模型——————————————————
def train_lstm():
global batch_size
pred,_=lstm(batch_size)
#损失函数
loss=tf.reduce_mean(tf.square(tf.reshape(pred,[-1])-tf.reshape(Y, [-1])))
train_op=tf.train.AdamOptimizer(lr).minimize(loss)
saver=tf.train.Saver(tf.global_variables())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
#重复训练10000次
for i in range(10000):
step=0
start=0
end=start+batch_size
while(end<len(train_x)):
_,loss_=sess.run([train_op,loss],feed_dict={X:train_x[start:end],Y:train_y[start:end]})
start+=batch_size
end=start+batch_size
#每10步保存一次参数
if step%10==0:
print(i,step,loss_)
print("保存模型:",saver.save(sess,'./module2/stock.model'))
step+=1 #————————————————预测模型————————————————————
def prediction():
pred,_=lstm(1) #预测时只输入[1,time_step,input_size]的测试数据
saver=tf.train.Saver(tf.global_variables())
with tf.Session() as sess:
#参数恢复
module_file = tf.train.latest_checkpoint('./module2/')
saver.restore(sess, module_file) #取训练集最后一行为测试样本。shape=[1,time_step,input_size]
prev_seq=train_x[-1]
predict=[]
#得到之后100个预测结果
for i in range(100):
next_seq=sess.run(pred,feed_dict={X:[prev_seq]})
predict.append(next_seq[-1])
#每次得到最后一个时间步的预测结果,与之前的数据加在一起,形成新的测试样本
prev_seq=np.vstack((prev_seq[1:],next_seq[-1]))
#以折线图表示结果
plt.figure()
plt.plot(list(range(len(normalize_data))), normalize_data, color='b')
plt.plot(list(range(len(normalize_data), len(normalize_data) + len(predict))), predict, color='r')
plt.show() if __name__ == '__main__': # train_lstm()
prediction()
hello = tf.constant('Hello, TensorFlow!')
sess = tf.Session()
print(sess.run(hello))

数据集格式:

时间           最高价
2015/12/11 3455.55
2015/12/10 3503.65
2015/12/9 3495.7
2015/12/8 3518.65
2015/12/7 3543.95
2015/12/4 3568.97
2015/12/3 3591.73
2015/12/2 3538.85
2015/12/1 3483.41
2015/11/30 3470.37
2015/11/27 3621.9
2015/11/26 3668.38
2015/11/25 3648.37
2015/11/24 3616.48
2015/11/23 3654.75
2015/11/20 3640.53
2015/11/19 3618.21
2015/11/18 3617.07
2015/11/17 3678.27

Tensflow预测股票实例的更多相关文章

  1. 基于Spark Streaming预测股票走势的例子(一)

    最近学习Spark Streaming,不知道是不是我搜索的姿势不对,总找不到具体的.完整的例子,一怒之下就决定自己写一个出来.下面以预测股票走势为例,总结了用Spark Streaming开发的具体 ...

  2. 通过机器学习的线性回归算法预测股票走势(用Python实现)

    在本人的新书里,将通过股票案例讲述Python知识点,让大家在学习Python的同时还能掌握相关的股票知识,所谓一举两得.这里给出以线性回归算法预测股票的案例,以此讲述通过Python的sklearn ...

  3. Tensorflow实例:利用LSTM预测股票每日最高价(一)

    RNN与LSTM 这一部分主要涉及循环神经网络的理论,讲的可能会比较简略. 什么是RNN RNN全称循环神经网络(Recurrent Neural Networks),是用来处理序列数据的.在传统的神 ...

  4. 20岁少年小伙利用Python_SVM预测股票趋势月入十万!

      在做数据预处理的时候,超额收益率是股票行业里的一个专有名词,指大于无风险投资的收益率,在我国无风险投资收益率即是银行定期存款. pycharm + anaconda3.6开发,涉及到的第三方库有p ...

  5. 基于Spark Streaming预测股票走势的例子(二)

    上一篇博客中,已经对股票预测的例子做了简单的讲解,下面对其中的几个关键的技术点再作一些总结. 1.updateStateByKey 由于在1.6版本中有一个替代函数,据说效率比较高,所以作者就顺便研究 ...

  6. AI金融:LSTM预测股票

    第一部分:从RNN到LSTM 1.什么是RNN RNN全称循环神经网络(Recurrent Neural Networks),是用来处理序列数据的.在传统的神经网络模型中,从输入层到隐含层再到输出层, ...

  7. AI金融:利用LSTM预测股票每日最高价

    第一部分:从RNN到LSTM 1.什么是RNN RNN全称循环神经网络(Recurrent Neural Networks),是用来处理序列数据的.在传统的神经网络模型中,从输入层到隐含层再到输出层, ...

  8. 如何预测股票分析--先知(Prophet)

    在上一篇中,我们探讨了自动ARIMA,但是好像表现的还是不够完善,接下来看看先知的力量! 先知(Prophet) 有许多时间序列技术可以用在股票预测数据集上,但是大多数技术在拟合模型之前需要大量的数据 ...

  9. 《BI那点儿事》Microsoft 逻辑回归算法——预测股票的涨跌

    数据准备:一组股票历史成交数据(股票代码:601106 中国一重),起止日期:2011-01-04至今,其中变量有“开盘”.“最高”.“最低”.“收盘”.“总手”.“金额”.“涨跌”等 UPDATE ...

随机推荐

  1. SQL SERVER数据库性能优化之SQL语句篇

    (引用自重明鸟的博客,方便学习和查看) 1. 按需索取字段,跟“SELECT *”说拜拜 字段的提取一定要按照“用多少提多少”的原则,避免使用“SELECT *”这样的操作.做了这样一个实验,表tbl ...

  2. rbac之 权限粒度控制到按钮级别

    rbac之 权限粒度控制到按钮级别:  这里的意思就是 如果当前用户,没有这个权限. 那么这个相对应的这个按钮的权限, 就不应该展示.看都不能给看到. 思路: 为每一个权限,设置一个别名.  这里是这 ...

  3. springboot报错Unable to start EmbeddedWebApplicationContext due to missing EmbeddedServletContainerFactory bean

    springboot项目在启动时需要把servlet容器编译进去,这时候如果你的maven依赖里面没有配置jetty或者tomcat相关依赖就会报错. 解决方法: jetty添加 <depend ...

  4. python string tuple list dict 相互转换的方法

    dict = {'name': 'Zara', 'age': 7, 'class': 'First'}# 字典转为字符串,返回:<type 'str'> {'age': 7, 'name' ...

  5. jquery 进阶 bootstrap

    . 样式操作 . 操作class . 操作CSS属性的 .css("color") .css("color", "green") .css( ...

  6. git 分支强制删除

    添加一个新功能时,你肯定不希望因为一些实验性质的代码,把主分支搞乱了,所以,每添加一个新功能,最好新建一个feature分支,在上面开发,完成后,合并,最后,删除该feature分支. 现在,你终于接 ...

  7. 为什么CPU的主频止步于4GHz?

    你对CPU的认识大概还停留在奔腾4年代吧……奔腾4最终止步于3.8GHz,原计划推出的4GHz奔腾4处理器也被胎死腹中.英特尔意识到处理器研发道路上走入了“唯主频论”的误区,2004年10月,英特尔总 ...

  8. 去掉tableView空白区域的分割线

    //把多余的分割线去掉 UIView * footerView = [[UIView alloc] initWithFrame:CGRectZero]; self.tableView.tableFoo ...

  9. 测试 Open Live Writer

    我要试试. 看看图片如何: 这是从电脑端上传的一个例子,如果编辑器里可以支持复制粘贴图片就好了. Open Live Writer 发布以后,还可在保存在本地,想起来的时候就修改一下. 再美化一下. ...

  10. Spring 系列教程之默认标签的解析

    Spring 系列教程之默认标签的解析 之前提到过 Spring 中的标签包括默认标签和自定义标签两种,而两种标签的用法以及解析方式存在着很大的不同,本章节重点带领读者详细分析默认标签的解析过程. 默 ...