[Pytorch框架] 3.3 通过Sin预测Cos
文章目录
3.3 通过Sin预测Cos
%matplotlib inline
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import optim
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.animation
import math, random
torch.__version__
'1.3.0'
在介绍循环神经网络时候我们说过,循环神经网络由于其的特殊结构,十分十分擅长处理时间相关的数据,下面我们就来通过输入sin函数,输出cos函数来实际使用。
首先,我们还是定义一些超参数
TIME_STEP = 10 # rnn 时序步长数
INPUT_SIZE = 1 # rnn 的输入维度
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
H_SIZE = 64 # of rnn 隐藏单元个数
EPOCHS=300 # 总共训练次数
h_state = None # 隐藏层状态
由于是使用sin和cos函数,所以这里不需要dataloader,我们直接使用Numpy生成数据,Pytorch没有π这个常量,所以所有操作都是用Numpy完成
steps = np.linspace(0, np.pi*2, 256, dtype=np.float32)
x_np = np.sin(steps)
y_np = np.cos(steps)
生成完后,我们可视化一下数据
plt.figure(1)
plt.suptitle('Sin and Cos',fontsize='18')
plt.plot(steps, y_np, 'r-', label='target (cos)')
plt.plot(steps, x_np, 'b-', label='input (sin)')
plt.legend(loc='best')
plt.show()

下面定义一下我们的网络结构
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn = nn.RNN(
input_size=INPUT_SIZE,
hidden_size=H_SIZE,
num_layers=1,
batch_first=True,
)
self.out = nn.Linear(H_SIZE, 1)
def forward(self, x, h_state):
# x (batch, time_step, input_size)
# h_state (n_layers, batch, hidden_size)
# r_out (batch, time_step, hidden_size)
r_out, h_state = self.rnn(x, h_state)
outs = [] # 保存所有的预测值
for time_step in range(r_out.size(1)): # 计算每一步长的预测值
outs.append(self.out(r_out[:, time_step, :]))
return torch.stack(outs, dim=1), h_state
# 也可使用以下这样的返回值
# r_out = r_out.view(-1, 32)
# outs = self.out(r_out)
# return outs, h_state
下面我们定义我们的网络
rnn = RNN().to(DEVICE)
optimizer = torch.optim.Adam(rnn.parameters()) # Adam优化,几乎不用调参
criterion = nn.MSELoss() # 因为最终的结果是一个数值,所以损失函数用均方误差
由于没有测试集,所以我们训练和测试写在一起了
rnn.train()
plt.figure(2)
for step in range(EPOCHS):
start, end = step * np.pi, (step+1)*np.pi # 一个时间周期
steps = np.linspace(start, end, TIME_STEP, dtype=np.float32)
x_np = np.sin(steps)
y_np = np.cos(steps)
x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis]) # shape (batch, time_step, input_size)
y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])
x=x.to(DEVICE)
prediction, h_state = rnn(x, h_state) # rnn output
# 这一步非常重要
h_state = h_state.data # 重置隐藏层的状态, 切断和前一次迭代的链接
loss = criterion(prediction.cpu(), y)
# 这三行写在一起就可以
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (step+1)%20==0: #每训练20个批次可视化一下效果,并打印一下loss
print("EPOCHS: {},Loss:{:4f}".format(step,loss))
plt.plot(steps, y_np.flatten(), 'r-')
plt.plot(steps, prediction.cpu().data.numpy().flatten(), 'b-')
plt.draw()
plt.pause(0.01)
EPOCHS: 19,Loss:0.139491

EPOCHS: 39,Loss:0.007957

EPOCHS: 59,Loss:0.025667

EPOCHS: 79,Loss:0.004511

EPOCHS: 99,Loss:0.012425

EPOCHS: 119,Loss:0.006166

EPOCHS: 139,Loss:0.017573

EPOCHS: 159,Loss:0.005687

EPOCHS: 179,Loss:0.008566

EPOCHS: 199,Loss:0.000836

EPOCHS: 219,Loss:0.003727

EPOCHS: 239,Loss:0.005441

EPOCHS: 259,Loss:0.005437

EPOCHS: 279,Loss:0.004994

EPOCHS: 299,Loss:0.004386

蓝色是模型预测的结果,红色是函数的结果,通过300次的训练,已经基本拟合了
[Pytorch框架] 3.3 通过Sin预测Cos的更多相关文章
- [再寄小读者之数学篇](2014-11-19 $\sin(x+y)=\sin x\cos y+\cos x\sin y$)
$$\bex \sin(x+y)=\sin x\cos y+\cos x\sin y. \eex$$ Ref. [Proof Without Words: Sine Sum Identity, The ...
- [再寄小读者之数学篇](2014-04-08 from 1297503521@qq.com $\sin x-x\cos x=0$ 的根的估计)
(2014-04-08 from 1297503521@qq.com) 设方程 $\sin x-x\cos x=0$ 在 $(0,+\infty)$ 中的第 $n$ 个解为 $x_n$. 证明: $$ ...
- 单变量微积分笔记20——三角替换1(sin和cos)
sin和cos的常用公式 基本公式: 半角公式: 微分公式: 积分公式: 三角替换 示例1 根据微分公式,cosxdx = dsinx 示例2 示例3 半角公式 示例1 示例2 解法1: 解法2: 综 ...
- 数学中的Sin和Cos是什么意思?(转)
数学中的Sin和Cos是什么意思? 作者:admin 分类:生活随笔 发表于 2012年03月21日 16:48 问:数学中的Sin和Cos是什么意思? 答:sin, cos, tan 都是三角函数, ...
- PyTorch框架+Python 3面向对象编程学习笔记
一.CNN情感分类中的面向对象部分 sparse.py super(Embedding, self).__init__() 表示需要父类初始化,即要运行父类的_init_(),如果没有这个,则要自定义 ...
- 手写数字识别 卷积神经网络 Pytorch框架实现
MNIST 手写数字识别 卷积神经网络 Pytorch框架 谨此纪念刚入门的我在卷积神经网络上面的摸爬滚打 说明 下面代码是使用pytorch来实现的LeNet,可以正常运行测试,自己添加了一些注释, ...
- 带 sin, cos 的线段树 - 牛客
链接:https://www.nowcoder.com/acm/contest/160/D来源:牛客网 题目描述给出一个长度为n的整数序列a1,a2,...,an,进行m次操作,操作分为两类.操作1: ...
- 小白学习之pytorch框架(1)-torch.nn.Module+squeeze(unsqueeze)
我学习pytorch框架不是从框架开始,从代码中看不懂的pytorch代码开始的 可能由于是小白的原因,个人不喜欢一些一下子粘贴老多行代码的博主或者一些弄了一堆概念,导致我更迷惑还增加了畏惧的情绪(个 ...
- 全面解析Pytorch框架下模型存储,加载以及冻结
最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...
- 小白学习之pytorch框架(7)之实战Kaggle比赛:房价预测(K折交叉验证、*args、**kwargs)
本篇博客代码来自于<动手学深度学习>pytorch版,也是代码较多,解释较少的一篇.不过好多方法在我以前的博客都有提,所以这次没提.还有一个原因是,这篇博客的代码,只要好好看看肯定能看懂( ...
随机推荐
- 基于AI边缘智能网关的工业质检应用
成品质量检验是工业生产最后必不可少的环节,随着我国工业化的蓬勃发展,工业产品日益迈向高端化.精密化,对于工业产品的质量检验要求和投入成本也在不断提高,产品质检涉及到比以往更多维度.更多零部件.更高精度 ...
- Jackson工具类及其配置
1 package com.ruoyi.common.core.utils.json; 2 3 import com.fasterxml.jackson.annotation.JsonAutoDete ...
- Linux(CentOS8) 安装 Docker
查询当前系统的相关信息 cat /etc/os-release 输入内容如下 校验当前CentOS内核版本 说明:Docker 要求 CentOS 的内核版本,至少高于 3.10 .低于 3.10 的 ...
- 关于一维数组传入函数的使用 //西电oj214题字符统计
#include<stdio.h> void count(char str[],int num[]){//形参用[],传递数组首地址后可以直接正常用数组str[i] int i; for( ...
- 30天帮你一步步学会Python的开源项目
最近发现一个不错的免费开源学习项目:30天学会Python 如果您最近有学习Python的打算,不妨看看这个是否适合你? 项目地址:https://github.com/Asabeneh/30-Day ...
- Insecure Randomness 不安全的随机数
Insecure Randomness Abstract 标准的伪随机数生成器不能抵挡各种加密攻击. Explanation 在对安全性要求较高的环境中,使用一个能产生可预测数值的函数作为随机数据源, ...
- 【读书笔记】组合计数-Tilings-正文 学一半的笔记
Tilings-正文部分 目录 9.2 转移函数方法 例子 补充 9.3 其余的方法 9.3.1 the path method 9.3.2 The permanent-determinant and ...
- Apinto Dashboad V2.0 发布:可视化控制台让配置更轻松!
大家好, Eolink 旗下开源网关 Apinto 本次带来了 Apinto Dashboad V2.0 的版本发布. Dashboad 需要与 Apinto 主版本一起使用,目前 Dashboad ...
- ACM-NEFUOJ-汉诺塔问题
P200汉诺塔 #include<bits/stdc++.h> using namespace std; int main() { int n,i; long long s[40]; s[ ...
- Linux服务器MySQL操作总结
目录 1. Navicat连接服务器MySQL 2. 如何查看MySQL用户名和密码 3. 修改MySQL的登录密码 4. 安装MySQL开发包(Centos7版) 错误:error 1045 (28 ...