最近在网上找到了一个使用LSTM 网络解决  世界银行中各国 GDP预测的一个问题,感觉比较实用,毕竟这是找到的唯一一个可以正确运行的程序。

#encoding:UTF-8

import pandas as pd
from pandas_datareader import wb import torch
import torch.nn
import torch.optim #读取数据
countries = ['BR', 'CA', 'CN', 'FR', 'DE', 'IN', 'IL', 'JP', 'SA', 'GB', 'US',]
dat = wb.download(indicator='NY.GDP.PCAP.KD',
country=countries, start=1970, end=2016) df = dat.unstack().T
df.index = df.index.droplevel(0).astype(int)
#print(df) #搭建神经网络
class Net(torch.nn.Module): def __init__(self, input_size, hidden_size):
super(Net, self).__init__()
self.rnn = torch.nn.LSTM(input_size, hidden_size)
self.fc = torch.nn.Linear(hidden_size, 1) def forward(
self, x):
x = x[:, :, None]
x, _ = self.rnn(x)
x = self.fc(x)
x = x[:, :, 0]
return x net = Net(input_size=1, hidden_size=5)
#print(net) #训练神经网络
# 数据归一化
df_scaled = df / df.loc[2000] # 确定训练集和测试集
years = df.index
train_seq_len = sum((years >= 1971) & (years <= 2000))
test_seq_len = sum(years > 2000) print ('训练集长度 = {}, 测试集长度 = {}'.format(
train_seq_len, test_seq_len)) # 确定训练使用的特征和标签
inputs = torch.tensor(df_scaled.iloc[:-1].values, dtype=torch.float32)
labels = torch.tensor(df_scaled.iloc[1:].values, dtype=torch.float32) # 训练网络
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters())
for step in range(10001):
if step:
optimizer.zero_grad()
train_loss.backward()
optimizer.step() preds = net(inputs)
train_preds = preds[:train_seq_len]
train_labels = labels[:train_seq_len]
train_loss = criterion(train_preds, train_labels) test_preds = preds[-test_seq_len]
test_labels = labels[-test_seq_len]
test_loss = criterion(test_preds, test_labels) if step % 500 == 0:
print ('第{}次迭代: loss (训练集) = {}, loss (测试集) = {}'.format(
step, train_loss, test_loss)) preds = net(inputs)
df_pred_scaled = pd.DataFrame(preds.detach().numpy(),
index=years[1:], columns=df.columns)
df_pred = df_pred_scaled * df.loc[2000]
df_pred.loc[2001:]

深度学习 循环神经网络 LSTM 示例的更多相关文章

  1. 时间序列深度学习:状态 LSTM 模型预测太阳黑子

    目录 时间序列深度学习:状态 LSTM 模型预测太阳黑子 教程概览 商业应用 长短期记忆(LSTM)模型 太阳黑子数据集 构建 LSTM 模型预测太阳黑子 1 若干相关包 2 数据 3 探索性数据分析 ...

  2. 十 | 门控循环神经网络LSTM与GRU(附python演练)

    欢迎大家关注我们的网站和系列教程:http://panchuang.net/ ,学习更多的机器学习.深度学习的知识! 目录: 门控循环神经网络简介 长短期记忆网络(LSTM) 门控制循环单元(GRU) ...

  3. 循环神经网络LSTM RNN回归:sin曲线预测

    摘要:本篇文章将分享循环神经网络LSTM RNN如何实现回归预测. 本文分享自华为云社区<[Python人工智能] 十四.循环神经网络LSTM RNN回归案例之sin曲线预测 丨[百变AI秀]& ...

  4. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.1

    3.Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.1 http://blog.csdn.net/sunbow0 ...

  5. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.2

    3.Spark MLlib Deep Learning Convolution Neural Network(深度学习-卷积神经网络)3.2 http://blog.csdn.net/sunbow0 ...

  6. 针对深度学习(神经网络)的AI框架调研

    针对深度学习(神经网络)的AI框架调研 在我们的AI安全引擎中未来会使用深度学习(神经网络),后续将引入AI芯片,因此重点看了下业界AI芯片厂商和对应芯片的AI框架,包括Intel(MKL CPU). ...

  7. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.3

    3.Spark MLlib Deep Learning Convolution Neural Network(深度学习-卷积神经网络)3.3 http://blog.csdn.net/sunbow0 ...

  8. 深度学习--RNN,LSTM

    一.RNN 1.定义 递归神经网络(RNN)是两种人工神经网络的总称.一种是时间递归神经网络(recurrent neural network),另一种是结构递归神经网络(recursive neur ...

  9. 【深度学习与神经网络】深度学习的下一个热点——GANs将改变世界

    本文作者 Nikolai Yakovenko 毕业于哥伦比亚大学,目前是 Google 的工程师,致力于构建人工智能系统,专注于语言处理.文本分类.解析与生成. 生成式对抗网络-简称GANs-将成为深 ...

随机推荐

  1. BackgroundWorker+ProgressBar+委托 实现多线程、进度条

    上文在<C# 使用BackgroundWorker实现WinForm异步>介绍了如何通过BackgroundWorker实现winForm异步通信,下面介绍如何通过BackgroundWo ...

  2. (转)SSIS_数据流转换(Union All&合并联接&合并)

    Union All : 与sql语言 Union All 一样,不用排序,上下合并多个表.Union All转换替代合并转换:输入输出无需排序,合并超过两个表 合并联接 : 有左连接.内连接.完全连接 ...

  3. java之简单工厂

    1.使用步骤 创建抽象/接口产品类,定义具体产品的公共接口方法:(产品接口类) 创建具体产品类,是继承抽象产品类的:(产品接口实现类) 创建工厂类,通过创建静态方法根据传入不同参数从而创建不同具体产品 ...

  4. 给iphone配置qq邮箱

    在手机上使用qq邮箱发送和接受邮件,但是又不用qq邮箱,我用的是“网易邮箱大师” ,那么就需要配置服务. 1.在qq邮箱中设置邮箱,开启相关的服务,然后用手机发送短信来生成授权码.最后在手机上设置的密 ...

  5. 生成对抗网络(Generative Adversarial Network)阅读笔记

    笔记持续更新中,请大家耐心等待 首先需要大概了解什么是生成对抗网络,参考维基百科给出的定义(https://zh.wikipedia.org/wiki/生成对抗网络): 生成对抗网络(英语:Gener ...

  6. 什么是ASCII码文本文件

    标准ASCII码方式(也称文本方式)存储的文件,更确切地说,英文.数字等字符存储的是ASCII码.文本文件中除了存储文件有效字符信息(包括能用ASCII码字符表示的回车.换行等信息)外,不能存储其他任 ...

  7. Java Character 类

    Character 类用于对单个字符进行操作. Character 类在对象中包装一个基本类型 char 的值 实例 char ch = 'a'; // Unicode 字符表示形式 char uni ...

  8. showDoc的基本使用方法

    ShowDoc介绍 ShowDoc就是一个非常适合IT团队的在线文档分享工具,它可以加快团队之间沟通的效率. API文档( 查看Demo) 随着移动互联网的发展,BaaS(后端即服务)越来越流行.服务 ...

  9. SQL学习笔记三之MySQL表操作

    阅读目录 一 存储引擎介绍 二 表介绍 三 创建表 四 查看表结构 五 数据类型 六 表完整性约束 七 修改表ALTER TABLE 八 复制表 九 删除表 一 存储引擎介绍 存储引擎即表类型,mys ...

  10. JAVA volatile 解析

    volatile这个关键字可能很多朋友都听说过,或许也都用过.在Java 5之前,它是一个备受争议的关键字,因为在程序中使用它往往会导致出人意料的结果.在Java 5之后,volatile关键字才得以 ...