Tensflow预测股票实例
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf #——————————————————导入数据——————————————————————
f=open('./dataset/dataset_1.csv')
df=pd.read_csv(f) #读入股票数据
data=np.array(df['最高价']) #获取最高价序列
data=data[::-1] #反转,使数据按照日期先后顺序排列
#以折线图展示data
# plt.figure()
# plt.plot(data)
# plt.show()
normalize_data=(data-np.mean(data))/np.std(data) #标准化
normalize_data=normalize_data[:,np.newaxis] #增加维度 #生成训练集
#设置常量
time_step=20 #时间步
rnn_unit=10 #hidden layer units
batch_size=60 #每一批次训练多少个样例
input_size=1 #输入层维度
output_size=1 #输出层维度
lr=0.0006 #学习率
train_x,train_y=[],[] #训练集
for i in range(len(normalize_data)-time_step-1):
x=normalize_data[i:i+time_step]
y=normalize_data[i+1:i+time_step+1]
train_x.append(x.tolist())
train_y.append(y.tolist()) #——————————————————定义神经网络变量——————————————————
X=tf.placeholder(tf.float32, [None,time_step,input_size]) #每批次输入网络的tensor
Y=tf.placeholder(tf.float32, [None,time_step,output_size]) #每批次tensor对应的标签
#输入层、输出层权重、偏置
weights={
'in':tf.Variable(tf.random_normal([input_size,rnn_unit])),
'out':tf.Variable(tf.random_normal([rnn_unit,1]))
}
biases={
'in':tf.Variable(tf.constant(0.1,shape=[rnn_unit,])),
'out':tf.Variable(tf.constant(0.1,shape=[1,]))
} #——————————————————定义神经网络变量——————————————————
def lstm(batch): #参数:输入网络批次数目
w_in=weights['in']
b_in=biases['in']
input=tf.reshape(X,[-1,input_size]) #需要将tensor转成2维进行计算,计算后的结果作为隐藏层的输入
input_rnn=tf.matmul(input,w_in)+b_in
input_rnn=tf.reshape(input_rnn,[-1,time_step,rnn_unit]) #将tensor转成3维,作为lstm cell的输入
cell=tf.nn.rnn_cell.BasicLSTMCell(rnn_unit)
init_state=cell.zero_state(batch,dtype=tf.float32)
output_rnn,final_states=tf.nn.dynamic_rnn(cell, input_rnn,initial_state=init_state, dtype=tf.float32) #output_rnn是记录lstm每个输出节点的结果,final_states是最后一个cell的结果
output=tf.reshape(output_rnn,[-1,rnn_unit]) #作为输出层的输入
w_out=weights['out']
b_out=biases['out']
pred=tf.matmul(output,w_out)+b_out
return pred,final_states #——————————————————训练模型——————————————————
def train_lstm():
global batch_size
pred,_=lstm(batch_size)
#损失函数
loss=tf.reduce_mean(tf.square(tf.reshape(pred,[-1])-tf.reshape(Y, [-1])))
train_op=tf.train.AdamOptimizer(lr).minimize(loss)
saver=tf.train.Saver(tf.global_variables())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
#重复训练10000次
for i in range(10000):
step=0
start=0
end=start+batch_size
while(end<len(train_x)):
_,loss_=sess.run([train_op,loss],feed_dict={X:train_x[start:end],Y:train_y[start:end]})
start+=batch_size
end=start+batch_size
#每10步保存一次参数
if step%10==0:
print(i,step,loss_)
print("保存模型:",saver.save(sess,'./module2/stock.model'))
step+=1 #————————————————预测模型————————————————————
def prediction():
pred,_=lstm(1) #预测时只输入[1,time_step,input_size]的测试数据
saver=tf.train.Saver(tf.global_variables())
with tf.Session() as sess:
#参数恢复
module_file = tf.train.latest_checkpoint('./module2/')
saver.restore(sess, module_file) #取训练集最后一行为测试样本。shape=[1,time_step,input_size]
prev_seq=train_x[-1]
predict=[]
#得到之后100个预测结果
for i in range(100):
next_seq=sess.run(pred,feed_dict={X:[prev_seq]})
predict.append(next_seq[-1])
#每次得到最后一个时间步的预测结果,与之前的数据加在一起,形成新的测试样本
prev_seq=np.vstack((prev_seq[1:],next_seq[-1]))
#以折线图表示结果
plt.figure()
plt.plot(list(range(len(normalize_data))), normalize_data, color='b')
plt.plot(list(range(len(normalize_data), len(normalize_data) + len(predict))), predict, color='r')
plt.show() if __name__ == '__main__': # train_lstm()
prediction()
hello = tf.constant('Hello, TensorFlow!')
sess = tf.Session()
print(sess.run(hello))
数据集格式:
时间 最高价
2015/12/11 3455.55
2015/12/10 3503.65
2015/12/9 3495.7
2015/12/8 3518.65
2015/12/7 3543.95
2015/12/4 3568.97
2015/12/3 3591.73
2015/12/2 3538.85
2015/12/1 3483.41
2015/11/30 3470.37
2015/11/27 3621.9
2015/11/26 3668.38
2015/11/25 3648.37
2015/11/24 3616.48
2015/11/23 3654.75
2015/11/20 3640.53
2015/11/19 3618.21
2015/11/18 3617.07
2015/11/17 3678.27
Tensflow预测股票实例的更多相关文章
- 基于Spark Streaming预测股票走势的例子(一)
最近学习Spark Streaming,不知道是不是我搜索的姿势不对,总找不到具体的.完整的例子,一怒之下就决定自己写一个出来.下面以预测股票走势为例,总结了用Spark Streaming开发的具体 ...
- 通过机器学习的线性回归算法预测股票走势(用Python实现)
在本人的新书里,将通过股票案例讲述Python知识点,让大家在学习Python的同时还能掌握相关的股票知识,所谓一举两得.这里给出以线性回归算法预测股票的案例,以此讲述通过Python的sklearn ...
- Tensorflow实例:利用LSTM预测股票每日最高价(一)
RNN与LSTM 这一部分主要涉及循环神经网络的理论,讲的可能会比较简略. 什么是RNN RNN全称循环神经网络(Recurrent Neural Networks),是用来处理序列数据的.在传统的神 ...
- 20岁少年小伙利用Python_SVM预测股票趋势月入十万!
在做数据预处理的时候,超额收益率是股票行业里的一个专有名词,指大于无风险投资的收益率,在我国无风险投资收益率即是银行定期存款. pycharm + anaconda3.6开发,涉及到的第三方库有p ...
- 基于Spark Streaming预测股票走势的例子(二)
上一篇博客中,已经对股票预测的例子做了简单的讲解,下面对其中的几个关键的技术点再作一些总结. 1.updateStateByKey 由于在1.6版本中有一个替代函数,据说效率比较高,所以作者就顺便研究 ...
- AI金融:LSTM预测股票
第一部分:从RNN到LSTM 1.什么是RNN RNN全称循环神经网络(Recurrent Neural Networks),是用来处理序列数据的.在传统的神经网络模型中,从输入层到隐含层再到输出层, ...
- AI金融:利用LSTM预测股票每日最高价
第一部分:从RNN到LSTM 1.什么是RNN RNN全称循环神经网络(Recurrent Neural Networks),是用来处理序列数据的.在传统的神经网络模型中,从输入层到隐含层再到输出层, ...
- 如何预测股票分析--先知(Prophet)
在上一篇中,我们探讨了自动ARIMA,但是好像表现的还是不够完善,接下来看看先知的力量! 先知(Prophet) 有许多时间序列技术可以用在股票预测数据集上,但是大多数技术在拟合模型之前需要大量的数据 ...
- 《BI那点儿事》Microsoft 逻辑回归算法——预测股票的涨跌
数据准备:一组股票历史成交数据(股票代码:601106 中国一重),起止日期:2011-01-04至今,其中变量有“开盘”.“最高”.“最低”.“收盘”.“总手”.“金额”.“涨跌”等 UPDATE ...
随机推荐
- swift - 接入听云监测 - 问题
1. 正常下载 探针SDK:https://report.tingyun.com/mobile-web/#/onlyHeader/sdkDownload 2.按步骤接入 ,添加库啊,什么的URLSc ...
- java图片操作--生成与原图对称的图片
java图片操作--生成与原图对称的图片 package com.pay.common.util; import java.awt.image.BufferedImage; import java.i ...
- Java04-Java语法基础(三)流程控制
Java04-Java语法基础(三)流程控制 一.数据类型的转换 1.自动转换:在赋值运算中,占字节数大的类型会自动向字节小的类型转换 double d1 = 3.14; int t1 = d1; 2 ...
- Python 安装pyautogui
在Python中使用PyAutoGui模拟键盘和鼠标操作 一.系统环境 操作系统:win10 64位 Python版本:Python 3.7.0 二.安装参考 1.使用pip进行安装,pip inst ...
- 解决jenkins的内存溢出问题
在jenkins的控制台会看到如下信息: FATAL: Remote call on ime_checkcode failed java.io.IOException: Remote call on ...
- mysql的事务,隔离级别和锁
事务就是一组一起成功或一起失败的sql语句.事务还应该具备,原子性,一致性,隔离性和持久性. 一.事务的基本要素 (ACID) 1.原子性:事务开始后,所有的操作,要么全部成功,要么全部失败,不可能处 ...
- js封装插件
js方式: (function(){ var demo = function(options){ this.options = $.extend({ "x" : "1&q ...
- MFC 一个无参线程的CreateThread 使用
最近想把c#的一个工作中用到的软件用MFC 实现出来, 刚下手 要了解的东西挺多,不但要对c++的语法,大体看一遍. 还要看MFC 内一些窗体,之类的相关的定义,比如cpp ,.h 内的类的定义方式等 ...
- stl之容器、迭代器、算法几者之间的关系
转自:https://blog.csdn.net/bobodem/article/details/49386131 stl包括容器.迭代器和算法: 容器 用于管理一些相关的数据类型.每种容器都有它的优 ...
- JS 实现 jQuery的$(function(){});
1.浏览器渲染引擎的HTML解析流程 何谓“渲染”,其实就是浏览器把请求到的HTML内容显示出来的过程.渲染引擎首先通过网络获得所请求文档的内容,通常以8K分块的方式完成.下面是渲染引擎在取得内容之后 ...