[Pytorch框架] 3.3 通过Sin预测Cos
文章目录
3.3 通过Sin预测Cos
%matplotlib inline
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import optim
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.animation
import math, random
torch.__version__
'1.3.0'
在介绍循环神经网络时候我们说过,循环神经网络由于其的特殊结构,十分十分擅长处理时间相关的数据,下面我们就来通过输入sin函数,输出cos函数来实际使用。
首先,我们还是定义一些超参数
TIME_STEP = 10 # rnn 时序步长数
INPUT_SIZE = 1 # rnn 的输入维度
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
H_SIZE = 64 # of rnn 隐藏单元个数
EPOCHS=300 # 总共训练次数
h_state = None # 隐藏层状态
由于是使用sin和cos函数,所以这里不需要dataloader,我们直接使用Numpy生成数据,Pytorch没有π这个常量,所以所有操作都是用Numpy完成
steps = np.linspace(0, np.pi*2, 256, dtype=np.float32)
x_np = np.sin(steps)
y_np = np.cos(steps)
生成完后,我们可视化一下数据
plt.figure(1)
plt.suptitle('Sin and Cos',fontsize='18')
plt.plot(steps, y_np, 'r-', label='target (cos)')
plt.plot(steps, x_np, 'b-', label='input (sin)')
plt.legend(loc='best')
plt.show()

下面定义一下我们的网络结构
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn = nn.RNN(
input_size=INPUT_SIZE,
hidden_size=H_SIZE,
num_layers=1,
batch_first=True,
)
self.out = nn.Linear(H_SIZE, 1)
def forward(self, x, h_state):
# x (batch, time_step, input_size)
# h_state (n_layers, batch, hidden_size)
# r_out (batch, time_step, hidden_size)
r_out, h_state = self.rnn(x, h_state)
outs = [] # 保存所有的预测值
for time_step in range(r_out.size(1)): # 计算每一步长的预测值
outs.append(self.out(r_out[:, time_step, :]))
return torch.stack(outs, dim=1), h_state
# 也可使用以下这样的返回值
# r_out = r_out.view(-1, 32)
# outs = self.out(r_out)
# return outs, h_state
下面我们定义我们的网络
rnn = RNN().to(DEVICE)
optimizer = torch.optim.Adam(rnn.parameters()) # Adam优化,几乎不用调参
criterion = nn.MSELoss() # 因为最终的结果是一个数值,所以损失函数用均方误差
由于没有测试集,所以我们训练和测试写在一起了
rnn.train()
plt.figure(2)
for step in range(EPOCHS):
start, end = step * np.pi, (step+1)*np.pi # 一个时间周期
steps = np.linspace(start, end, TIME_STEP, dtype=np.float32)
x_np = np.sin(steps)
y_np = np.cos(steps)
x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis]) # shape (batch, time_step, input_size)
y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])
x=x.to(DEVICE)
prediction, h_state = rnn(x, h_state) # rnn output
# 这一步非常重要
h_state = h_state.data # 重置隐藏层的状态, 切断和前一次迭代的链接
loss = criterion(prediction.cpu(), y)
# 这三行写在一起就可以
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (step+1)%20==0: #每训练20个批次可视化一下效果,并打印一下loss
print("EPOCHS: {},Loss:{:4f}".format(step,loss))
plt.plot(steps, y_np.flatten(), 'r-')
plt.plot(steps, prediction.cpu().data.numpy().flatten(), 'b-')
plt.draw()
plt.pause(0.01)
EPOCHS: 19,Loss:0.139491

EPOCHS: 39,Loss:0.007957

EPOCHS: 59,Loss:0.025667

EPOCHS: 79,Loss:0.004511

EPOCHS: 99,Loss:0.012425

EPOCHS: 119,Loss:0.006166

EPOCHS: 139,Loss:0.017573

EPOCHS: 159,Loss:0.005687

EPOCHS: 179,Loss:0.008566

EPOCHS: 199,Loss:0.000836

EPOCHS: 219,Loss:0.003727

EPOCHS: 239,Loss:0.005441

EPOCHS: 259,Loss:0.005437

EPOCHS: 279,Loss:0.004994

EPOCHS: 299,Loss:0.004386

蓝色是模型预测的结果,红色是函数的结果,通过300次的训练,已经基本拟合了
[Pytorch框架] 3.3 通过Sin预测Cos的更多相关文章
- [再寄小读者之数学篇](2014-11-19 $\sin(x+y)=\sin x\cos y+\cos x\sin y$)
$$\bex \sin(x+y)=\sin x\cos y+\cos x\sin y. \eex$$ Ref. [Proof Without Words: Sine Sum Identity, The ...
- [再寄小读者之数学篇](2014-04-08 from 1297503521@qq.com $\sin x-x\cos x=0$ 的根的估计)
(2014-04-08 from 1297503521@qq.com) 设方程 $\sin x-x\cos x=0$ 在 $(0,+\infty)$ 中的第 $n$ 个解为 $x_n$. 证明: $$ ...
- 单变量微积分笔记20——三角替换1(sin和cos)
sin和cos的常用公式 基本公式: 半角公式: 微分公式: 积分公式: 三角替换 示例1 根据微分公式,cosxdx = dsinx 示例2 示例3 半角公式 示例1 示例2 解法1: 解法2: 综 ...
- 数学中的Sin和Cos是什么意思?(转)
数学中的Sin和Cos是什么意思? 作者:admin 分类:生活随笔 发表于 2012年03月21日 16:48 问:数学中的Sin和Cos是什么意思? 答:sin, cos, tan 都是三角函数, ...
- PyTorch框架+Python 3面向对象编程学习笔记
一.CNN情感分类中的面向对象部分 sparse.py super(Embedding, self).__init__() 表示需要父类初始化,即要运行父类的_init_(),如果没有这个,则要自定义 ...
- 手写数字识别 卷积神经网络 Pytorch框架实现
MNIST 手写数字识别 卷积神经网络 Pytorch框架 谨此纪念刚入门的我在卷积神经网络上面的摸爬滚打 说明 下面代码是使用pytorch来实现的LeNet,可以正常运行测试,自己添加了一些注释, ...
- 带 sin, cos 的线段树 - 牛客
链接:https://www.nowcoder.com/acm/contest/160/D来源:牛客网 题目描述给出一个长度为n的整数序列a1,a2,...,an,进行m次操作,操作分为两类.操作1: ...
- 小白学习之pytorch框架(1)-torch.nn.Module+squeeze(unsqueeze)
我学习pytorch框架不是从框架开始,从代码中看不懂的pytorch代码开始的 可能由于是小白的原因,个人不喜欢一些一下子粘贴老多行代码的博主或者一些弄了一堆概念,导致我更迷惑还增加了畏惧的情绪(个 ...
- 全面解析Pytorch框架下模型存储,加载以及冻结
最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...
- 小白学习之pytorch框架(7)之实战Kaggle比赛:房价预测(K折交叉验证、*args、**kwargs)
本篇博客代码来自于<动手学深度学习>pytorch版,也是代码较多,解释较少的一篇.不过好多方法在我以前的博客都有提,所以这次没提.还有一个原因是,这篇博客的代码,只要好好看看肯定能看懂( ...
随机推荐
- LINUX下的VSCODE-C/C++配置
LINUX下的VSCODE-C/C++配置 1.生成默认的任务文件 2.lunch.json,调整"configurations"里的成员,如下 ①添加 "preLaun ...
- python调用java&反编译地址
反编译工具地址: https://github.com/java-decompiler/jd-gui/releases 你想知道的JPype全在这里∞ 先总结自己趟的坑 1. python进程是6 ...
- C/C++ 数据结构循环队列的实现
#include <iostream> #include <Windows.h> using namespace std; #define MAXSIZE 6 typedef ...
- 部门mysql操作
use test_db; -- 删除表 drop table if exists t1_profit; drop table if exists t1_salgrade; drop table i ...
- 博弈论练习8 Northcott Game(取石子问题)
题目链接在这里:I-Northcott Game_牛客竞赛博弈专题班组合游戏基本概念.对抗搜索.Bash游戏.Nim游戏习题 (nowcoder.com) 这题是一个伪装的很好的取石子问题,可以发现, ...
- IDEA如何使用Maven不通过模板创建javaWeb项目
IDEA如何使用Maven不通过模板创建javaWeb项目 1.创建项目 进入IDEA,点击"项目">"新建项目",填写项目信息,最后点击"创建 ...
- Java-01enum常量特定方法
OnJava8-Enum-常量特定方法 用枚举实现责任链模式 责任链(Chain Of Responsibility)设计模式先创建了一批用于解决目标问题的不同方法,然后将它们连成一条"链& ...
- Mathematica制作和使用程序包
步骤 这里拿你制作并且使用一个程序包lost为例子 新建一个空白.wl文档,输入代码如下 BeginPackage[ "MyPkg`"] MainFunction::usage = ...
- 全网最详细中英文ChatGPT-GPT-4示例文档-从0到1快速入门计算时间复杂度应用——官网推荐的48种最佳应用场景(附python/node.js/curl命令源代码,小白也能学)
目录 Introduce 简介 setting 设置 Prompt 提示 Sample response 回复样本 API request 接口请求 python接口请求示例 node.js接口请求示 ...
- Django笔记十一之外键查询优化select_related和prefetch_related
本篇笔记目录如下: select_related prefetch_related 在介绍 select_related 和 prefetch_related 这两个函数前,我们先来看一个例子. 对于 ...