import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf #——————————————————导入数据——————————————————————
f=open('./dataset/dataset_1.csv')
df=pd.read_csv(f) #读入股票数据
data=np.array(df['最高价']) #获取最高价序列
data=data[::-1] #反转,使数据按照日期先后顺序排列
#以折线图展示data
# plt.figure()
# plt.plot(data)
# plt.show()
normalize_data=(data-np.mean(data))/np.std(data) #标准化
normalize_data=normalize_data[:,np.newaxis] #增加维度 #生成训练集
#设置常量
time_step=20 #时间步
rnn_unit=10 #hidden layer units
batch_size=60 #每一批次训练多少个样例
input_size=1 #输入层维度
output_size=1 #输出层维度
lr=0.0006 #学习率
train_x,train_y=[],[] #训练集
for i in range(len(normalize_data)-time_step-1):
x=normalize_data[i:i+time_step]
y=normalize_data[i+1:i+time_step+1]
train_x.append(x.tolist())
train_y.append(y.tolist()) #——————————————————定义神经网络变量——————————————————
X=tf.placeholder(tf.float32, [None,time_step,input_size]) #每批次输入网络的tensor
Y=tf.placeholder(tf.float32, [None,time_step,output_size]) #每批次tensor对应的标签
#输入层、输出层权重、偏置
weights={
'in':tf.Variable(tf.random_normal([input_size,rnn_unit])),
'out':tf.Variable(tf.random_normal([rnn_unit,1]))
}
biases={
'in':tf.Variable(tf.constant(0.1,shape=[rnn_unit,])),
'out':tf.Variable(tf.constant(0.1,shape=[1,]))
} #——————————————————定义神经网络变量——————————————————
def lstm(batch): #参数:输入网络批次数目
w_in=weights['in']
b_in=biases['in']
input=tf.reshape(X,[-1,input_size]) #需要将tensor转成2维进行计算,计算后的结果作为隐藏层的输入
input_rnn=tf.matmul(input,w_in)+b_in
input_rnn=tf.reshape(input_rnn,[-1,time_step,rnn_unit]) #将tensor转成3维,作为lstm cell的输入
cell=tf.nn.rnn_cell.BasicLSTMCell(rnn_unit)
init_state=cell.zero_state(batch,dtype=tf.float32)
output_rnn,final_states=tf.nn.dynamic_rnn(cell, input_rnn,initial_state=init_state, dtype=tf.float32) #output_rnn是记录lstm每个输出节点的结果,final_states是最后一个cell的结果
output=tf.reshape(output_rnn,[-1,rnn_unit]) #作为输出层的输入
w_out=weights['out']
b_out=biases['out']
pred=tf.matmul(output,w_out)+b_out
return pred,final_states #——————————————————训练模型——————————————————
def train_lstm():
global batch_size
pred,_=lstm(batch_size)
#损失函数
loss=tf.reduce_mean(tf.square(tf.reshape(pred,[-1])-tf.reshape(Y, [-1])))
train_op=tf.train.AdamOptimizer(lr).minimize(loss)
saver=tf.train.Saver(tf.global_variables())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
#重复训练10000次
for i in range(10000):
step=0
start=0
end=start+batch_size
while(end<len(train_x)):
_,loss_=sess.run([train_op,loss],feed_dict={X:train_x[start:end],Y:train_y[start:end]})
start+=batch_size
end=start+batch_size
#每10步保存一次参数
if step%10==0:
print(i,step,loss_)
print("保存模型:",saver.save(sess,'./module2/stock.model'))
step+=1 #————————————————预测模型————————————————————
def prediction():
pred,_=lstm(1) #预测时只输入[1,time_step,input_size]的测试数据
saver=tf.train.Saver(tf.global_variables())
with tf.Session() as sess:
#参数恢复
module_file = tf.train.latest_checkpoint('./module2/')
saver.restore(sess, module_file) #取训练集最后一行为测试样本。shape=[1,time_step,input_size]
prev_seq=train_x[-1]
predict=[]
#得到之后100个预测结果
for i in range(100):
next_seq=sess.run(pred,feed_dict={X:[prev_seq]})
predict.append(next_seq[-1])
#每次得到最后一个时间步的预测结果,与之前的数据加在一起,形成新的测试样本
prev_seq=np.vstack((prev_seq[1:],next_seq[-1]))
#以折线图表示结果
plt.figure()
plt.plot(list(range(len(normalize_data))), normalize_data, color='b')
plt.plot(list(range(len(normalize_data), len(normalize_data) + len(predict))), predict, color='r')
plt.show() if __name__ == '__main__': # train_lstm()
prediction()
hello = tf.constant('Hello, TensorFlow!')
sess = tf.Session()
print(sess.run(hello))

数据集格式:

时间           最高价
2015/12/11 3455.55
2015/12/10 3503.65
2015/12/9 3495.7
2015/12/8 3518.65
2015/12/7 3543.95
2015/12/4 3568.97
2015/12/3 3591.73
2015/12/2 3538.85
2015/12/1 3483.41
2015/11/30 3470.37
2015/11/27 3621.9
2015/11/26 3668.38
2015/11/25 3648.37
2015/11/24 3616.48
2015/11/23 3654.75
2015/11/20 3640.53
2015/11/19 3618.21
2015/11/18 3617.07
2015/11/17 3678.27

Tensflow预测股票实例的更多相关文章

  1. 基于Spark Streaming预测股票走势的例子(一)

    最近学习Spark Streaming,不知道是不是我搜索的姿势不对,总找不到具体的.完整的例子,一怒之下就决定自己写一个出来.下面以预测股票走势为例,总结了用Spark Streaming开发的具体 ...

  2. 通过机器学习的线性回归算法预测股票走势(用Python实现)

    在本人的新书里,将通过股票案例讲述Python知识点,让大家在学习Python的同时还能掌握相关的股票知识,所谓一举两得.这里给出以线性回归算法预测股票的案例,以此讲述通过Python的sklearn ...

  3. Tensorflow实例:利用LSTM预测股票每日最高价(一)

    RNN与LSTM 这一部分主要涉及循环神经网络的理论,讲的可能会比较简略. 什么是RNN RNN全称循环神经网络(Recurrent Neural Networks),是用来处理序列数据的.在传统的神 ...

  4. 20岁少年小伙利用Python_SVM预测股票趋势月入十万!

      在做数据预处理的时候,超额收益率是股票行业里的一个专有名词,指大于无风险投资的收益率,在我国无风险投资收益率即是银行定期存款. pycharm + anaconda3.6开发,涉及到的第三方库有p ...

  5. 基于Spark Streaming预测股票走势的例子(二)

    上一篇博客中,已经对股票预测的例子做了简单的讲解,下面对其中的几个关键的技术点再作一些总结. 1.updateStateByKey 由于在1.6版本中有一个替代函数,据说效率比较高,所以作者就顺便研究 ...

  6. AI金融:LSTM预测股票

    第一部分:从RNN到LSTM 1.什么是RNN RNN全称循环神经网络(Recurrent Neural Networks),是用来处理序列数据的.在传统的神经网络模型中,从输入层到隐含层再到输出层, ...

  7. AI金融:利用LSTM预测股票每日最高价

    第一部分:从RNN到LSTM 1.什么是RNN RNN全称循环神经网络(Recurrent Neural Networks),是用来处理序列数据的.在传统的神经网络模型中,从输入层到隐含层再到输出层, ...

  8. 如何预测股票分析--先知(Prophet)

    在上一篇中,我们探讨了自动ARIMA,但是好像表现的还是不够完善,接下来看看先知的力量! 先知(Prophet) 有许多时间序列技术可以用在股票预测数据集上,但是大多数技术在拟合模型之前需要大量的数据 ...

  9. 《BI那点儿事》Microsoft 逻辑回归算法——预测股票的涨跌

    数据准备:一组股票历史成交数据(股票代码:601106 中国一重),起止日期:2011-01-04至今,其中变量有“开盘”.“最高”.“最低”.“收盘”.“总手”.“金额”.“涨跌”等 UPDATE ...

随机推荐

  1. 手动获取被spring管理的bean对象工具

       在netty handler开发中,我们无法将spring的依赖注入到Handler中,无法进行数据库的操作,这时候我们就需要手动获取被spring管理的bean对象:    创建一个  imp ...

  2. linux查看本服务端口开放情况

    在Linux使用过程中,需要了解当前系统开放了哪些端口,并且要查看开放这些端口的具体进程和用户,可以通过netstat命令进行简单查询. 1.netstat命令各个参数说明如下: -t : 指明显示T ...

  3. how2j网站前端项目——天猫前端(第一次)学习笔记8

    其他页面的学习 这些页面有1.查询结果页 2.支付页面 3.支付成功页面 4.确认收货页面上 5.确认收货页面下 6.收获成功页面 7.评价页面上 8.评价页面下 9.登陆页面 10.注册页面 1.查 ...

  4. vue路由权限之访问权限(meta控制是否有访问权限)

    首先登录那权限表 router.beforeEach((to, from, next) => { if(to.path === '/login') { next(); }else{ if(!st ...

  5. 页面练习my blog day51

    html端: <!DOCTYPE html> <html lang="en"> <head> <meta charset="UT ...

  6. Oracle_SQL(1) 基本查询

    1.oracle的安装与卸载 2.PL/SQL Developer的安装 3.登陆PL/SQL Developer 4.SCOTT用户下表的介绍 5.基本查询语句 查询雇员的所有信息: select ...

  7. accept与epoll惊群 转载

    今天打开 OneNote,发现里面躺着一篇很久以前写的笔记,现在将它贴出来. 1. 什么叫惊群现象 首先,我们看看维基百科对惊群的定义: The thundering herd problem occ ...

  8. VS2010插件 VS.PHP 调试开发php程序

    VS 插件VS.PHP 调试PHP的方法;不得不说vs强大啊,此断点调试功能在zend都做不到 如图: 设置成功之后,就可以像调试 .Net程序一样试调Php程序了! 调试的步骤: 1.在需要调试的地 ...

  9. Spring 如何保证后置处理器的执行顺序 - OrderComparator

    Spring 如何保证后置处理器的执行顺序 - OrderComparator Spring 系列目录(https://www.cnblogs.com/binarylei/p/10198698.htm ...

  10. 使用Simple MvvmToolkit开发Android和iOS程序

    详情见:Android and iOS Development with Simple MVVM Toolkit? Yes you can! :http://blog.tonysneed.com/20 ...