import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf #——————————————————导入数据——————————————————————
f=open('./dataset/dataset_1.csv')
df=pd.read_csv(f) #读入股票数据
data=np.array(df['最高价']) #获取最高价序列
data=data[::-1] #反转,使数据按照日期先后顺序排列
#以折线图展示data
# plt.figure()
# plt.plot(data)
# plt.show()
normalize_data=(data-np.mean(data))/np.std(data) #标准化
normalize_data=normalize_data[:,np.newaxis] #增加维度 #生成训练集
#设置常量
time_step=20 #时间步
rnn_unit=10 #hidden layer units
batch_size=60 #每一批次训练多少个样例
input_size=1 #输入层维度
output_size=1 #输出层维度
lr=0.0006 #学习率
train_x,train_y=[],[] #训练集
for i in range(len(normalize_data)-time_step-1):
x=normalize_data[i:i+time_step]
y=normalize_data[i+1:i+time_step+1]
train_x.append(x.tolist())
train_y.append(y.tolist()) #——————————————————定义神经网络变量——————————————————
X=tf.placeholder(tf.float32, [None,time_step,input_size]) #每批次输入网络的tensor
Y=tf.placeholder(tf.float32, [None,time_step,output_size]) #每批次tensor对应的标签
#输入层、输出层权重、偏置
weights={
'in':tf.Variable(tf.random_normal([input_size,rnn_unit])),
'out':tf.Variable(tf.random_normal([rnn_unit,1]))
}
biases={
'in':tf.Variable(tf.constant(0.1,shape=[rnn_unit,])),
'out':tf.Variable(tf.constant(0.1,shape=[1,]))
} #——————————————————定义神经网络变量——————————————————
def lstm(batch): #参数:输入网络批次数目
w_in=weights['in']
b_in=biases['in']
input=tf.reshape(X,[-1,input_size]) #需要将tensor转成2维进行计算,计算后的结果作为隐藏层的输入
input_rnn=tf.matmul(input,w_in)+b_in
input_rnn=tf.reshape(input_rnn,[-1,time_step,rnn_unit]) #将tensor转成3维,作为lstm cell的输入
cell=tf.nn.rnn_cell.BasicLSTMCell(rnn_unit)
init_state=cell.zero_state(batch,dtype=tf.float32)
output_rnn,final_states=tf.nn.dynamic_rnn(cell, input_rnn,initial_state=init_state, dtype=tf.float32) #output_rnn是记录lstm每个输出节点的结果,final_states是最后一个cell的结果
output=tf.reshape(output_rnn,[-1,rnn_unit]) #作为输出层的输入
w_out=weights['out']
b_out=biases['out']
pred=tf.matmul(output,w_out)+b_out
return pred,final_states #——————————————————训练模型——————————————————
def train_lstm():
global batch_size
pred,_=lstm(batch_size)
#损失函数
loss=tf.reduce_mean(tf.square(tf.reshape(pred,[-1])-tf.reshape(Y, [-1])))
train_op=tf.train.AdamOptimizer(lr).minimize(loss)
saver=tf.train.Saver(tf.global_variables())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
#重复训练10000次
for i in range(10000):
step=0
start=0
end=start+batch_size
while(end<len(train_x)):
_,loss_=sess.run([train_op,loss],feed_dict={X:train_x[start:end],Y:train_y[start:end]})
start+=batch_size
end=start+batch_size
#每10步保存一次参数
if step%10==0:
print(i,step,loss_)
print("保存模型:",saver.save(sess,'./module2/stock.model'))
step+=1 #————————————————预测模型————————————————————
def prediction():
pred,_=lstm(1) #预测时只输入[1,time_step,input_size]的测试数据
saver=tf.train.Saver(tf.global_variables())
with tf.Session() as sess:
#参数恢复
module_file = tf.train.latest_checkpoint('./module2/')
saver.restore(sess, module_file) #取训练集最后一行为测试样本。shape=[1,time_step,input_size]
prev_seq=train_x[-1]
predict=[]
#得到之后100个预测结果
for i in range(100):
next_seq=sess.run(pred,feed_dict={X:[prev_seq]})
predict.append(next_seq[-1])
#每次得到最后一个时间步的预测结果,与之前的数据加在一起,形成新的测试样本
prev_seq=np.vstack((prev_seq[1:],next_seq[-1]))
#以折线图表示结果
plt.figure()
plt.plot(list(range(len(normalize_data))), normalize_data, color='b')
plt.plot(list(range(len(normalize_data), len(normalize_data) + len(predict))), predict, color='r')
plt.show() if __name__ == '__main__': # train_lstm()
prediction()
hello = tf.constant('Hello, TensorFlow!')
sess = tf.Session()
print(sess.run(hello))

数据集格式:

时间           最高价
2015/12/11 3455.55
2015/12/10 3503.65
2015/12/9 3495.7
2015/12/8 3518.65
2015/12/7 3543.95
2015/12/4 3568.97
2015/12/3 3591.73
2015/12/2 3538.85
2015/12/1 3483.41
2015/11/30 3470.37
2015/11/27 3621.9
2015/11/26 3668.38
2015/11/25 3648.37
2015/11/24 3616.48
2015/11/23 3654.75
2015/11/20 3640.53
2015/11/19 3618.21
2015/11/18 3617.07
2015/11/17 3678.27

Tensflow预测股票实例的更多相关文章

  1. 基于Spark Streaming预测股票走势的例子(一)

    最近学习Spark Streaming,不知道是不是我搜索的姿势不对,总找不到具体的.完整的例子,一怒之下就决定自己写一个出来.下面以预测股票走势为例,总结了用Spark Streaming开发的具体 ...

  2. 通过机器学习的线性回归算法预测股票走势(用Python实现)

    在本人的新书里,将通过股票案例讲述Python知识点,让大家在学习Python的同时还能掌握相关的股票知识,所谓一举两得.这里给出以线性回归算法预测股票的案例,以此讲述通过Python的sklearn ...

  3. Tensorflow实例:利用LSTM预测股票每日最高价(一)

    RNN与LSTM 这一部分主要涉及循环神经网络的理论,讲的可能会比较简略. 什么是RNN RNN全称循环神经网络(Recurrent Neural Networks),是用来处理序列数据的.在传统的神 ...

  4. 20岁少年小伙利用Python_SVM预测股票趋势月入十万!

      在做数据预处理的时候,超额收益率是股票行业里的一个专有名词,指大于无风险投资的收益率,在我国无风险投资收益率即是银行定期存款. pycharm + anaconda3.6开发,涉及到的第三方库有p ...

  5. 基于Spark Streaming预测股票走势的例子(二)

    上一篇博客中,已经对股票预测的例子做了简单的讲解,下面对其中的几个关键的技术点再作一些总结. 1.updateStateByKey 由于在1.6版本中有一个替代函数,据说效率比较高,所以作者就顺便研究 ...

  6. AI金融:LSTM预测股票

    第一部分:从RNN到LSTM 1.什么是RNN RNN全称循环神经网络(Recurrent Neural Networks),是用来处理序列数据的.在传统的神经网络模型中,从输入层到隐含层再到输出层, ...

  7. AI金融:利用LSTM预测股票每日最高价

    第一部分:从RNN到LSTM 1.什么是RNN RNN全称循环神经网络(Recurrent Neural Networks),是用来处理序列数据的.在传统的神经网络模型中,从输入层到隐含层再到输出层, ...

  8. 如何预测股票分析--先知(Prophet)

    在上一篇中,我们探讨了自动ARIMA,但是好像表现的还是不够完善,接下来看看先知的力量! 先知(Prophet) 有许多时间序列技术可以用在股票预测数据集上,但是大多数技术在拟合模型之前需要大量的数据 ...

  9. 《BI那点儿事》Microsoft 逻辑回归算法——预测股票的涨跌

    数据准备:一组股票历史成交数据(股票代码:601106 中国一重),起止日期:2011-01-04至今,其中变量有“开盘”.“最高”.“最低”.“收盘”.“总手”.“金额”.“涨跌”等 UPDATE ...

随机推荐

  1. swift 数组 的一些快速方法

    1. filter (过滤器):返回符合条件的一个数组 let arr = [1,5,6,7,10,0] //写法1 let arr1 = arr.filter { (item) -> Bool ...

  2. Java_4.1 猜数字游戏

    猜数字游戏: 系统随机产生一个1-100之间的数字,用户输入一个数字,如果用户输入的数字和随机数一样,输出“恭喜,猜对了”:如果猜大了,输出“猜大了”:如果猜小了,就输出“猜小了”. 循环直到用户输入 ...

  3. thinkphp 查找表并返回结果

    使用 M('devices')->where($data)-> execute($resql); 返回的只是int值,表示成功与否,并没有返回search 到的 record . $res ...

  4. android轮播图的实现原理

    1.轮播图的点:RadioGroup,根据网络请求的数据,解析得到的图片的个数,设置RadioGroup的RadioButton的个数. 2.轮播图的核心技术:用Gallery来存放图片,设置适配器. ...

  5. 6-完美解决Error:SSL peer shut down incorrectly

    转载自: 完美解决Error:SSL peer shut down incorrectly 打开gradle文件夹下的gradle-wrapper文件 修改其中的配置文件将红色区域修改为http:// ...

  6. 亚像素Sub Pixel

    亚像素Sub Pixel 评估图像处理算法时,通常会考虑是否具有亚像素精度. 亚像素概念的引出: 图像处理过程中,提高检测方法的精度一般有两种方式:一种是提高图像系统的光学放大倍数和CCD相机的分辨率 ...

  7. C# Convert.ToInt32和int.Parse转换null和空字符串时的不同表现

    Convert.ToInt32最终调用的函数见下图: int.Parse调用的函数见下图: 具体的见https://www.cnblogs.com/leolis/p/3968943.html的博客,说 ...

  8. 学习knockoutjs轻量级的MVVM框架

    教程:knockoutjs介绍 http://www.w3cfuns.com/forum.php?mod=viewthread&tid=5598714 MVVM架构~knockoutjs实现简 ...

  9. samtools

    samtools 用法 samtools <command> [options] command 见以下列表, 每个 command 的 options 也不同 dict faidx in ...

  10. 含有选择器的 bootstrap菜单

    var menu = new BootstrapMenu('#jsmind_container jmnode:not(.root)', { actions: [{ name: '展开节点', onCl ...