文章目录

3.3 通过Sin预测Cos

%matplotlib inline
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import optim
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.animation
import math, random
torch.__version__
'1.3.0'

在介绍循环神经网络时候我们说过,循环神经网络由于其的特殊结构,十分十分擅长处理时间相关的数据,下面我们就来通过输入sin函数,输出cos函数来实际使用。
首先,我们还是定义一些超参数

TIME_STEP = 10 # rnn 时序步长数
INPUT_SIZE = 1 # rnn 的输入维度
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
H_SIZE = 64 # of rnn 隐藏单元个数
EPOCHS=300 # 总共训练次数
h_state = None # 隐藏层状态

由于是使用sin和cos函数,所以这里不需要dataloader,我们直接使用Numpy生成数据,Pytorch没有π这个常量,所以所有操作都是用Numpy完成

steps = np.linspace(0, np.pi*2, 256, dtype=np.float32)
x_np = np.sin(steps)
y_np = np.cos(steps)

生成完后,我们可视化一下数据

plt.figure(1)
plt.suptitle('Sin and Cos',fontsize='18')
plt.plot(steps, y_np, 'r-', label='target (cos)')
plt.plot(steps, x_np, 'b-', label='input (sin)')
plt.legend(loc='best')
plt.show()

下面定义一下我们的网络结构

class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn = nn.RNN(
input_size=INPUT_SIZE,
hidden_size=H_SIZE,
num_layers=1,
batch_first=True,
)
self.out = nn.Linear(H_SIZE, 1)
def forward(self, x, h_state):
# x (batch, time_step, input_size)
# h_state (n_layers, batch, hidden_size)
# r_out (batch, time_step, hidden_size)
r_out, h_state = self.rnn(x, h_state)
outs = [] # 保存所有的预测值
for time_step in range(r_out.size(1)): # 计算每一步长的预测值
outs.append(self.out(r_out[:, time_step, :]))
return torch.stack(outs, dim=1), h_state
# 也可使用以下这样的返回值
# r_out = r_out.view(-1, 32)
# outs = self.out(r_out)
# return outs, h_state

下面我们定义我们的网络

rnn = RNN().to(DEVICE)
optimizer = torch.optim.Adam(rnn.parameters()) # Adam优化,几乎不用调参
criterion = nn.MSELoss() # 因为最终的结果是一个数值,所以损失函数用均方误差

由于没有测试集,所以我们训练和测试写在一起了

rnn.train()
plt.figure(2)
for step in range(EPOCHS):
start, end = step * np.pi, (step+1)*np.pi # 一个时间周期
steps = np.linspace(start, end, TIME_STEP, dtype=np.float32)
x_np = np.sin(steps)
y_np = np.cos(steps)
x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis]) # shape (batch, time_step, input_size)
y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])
x=x.to(DEVICE)
prediction, h_state = rnn(x, h_state) # rnn output
# 这一步非常重要
h_state = h_state.data # 重置隐藏层的状态, 切断和前一次迭代的链接
loss = criterion(prediction.cpu(), y)
# 这三行写在一起就可以
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (step+1)%20==0: #每训练20个批次可视化一下效果,并打印一下loss
print("EPOCHS: {},Loss:{:4f}".format(step,loss))
plt.plot(steps, y_np.flatten(), 'r-')
plt.plot(steps, prediction.cpu().data.numpy().flatten(), 'b-')
plt.draw()
plt.pause(0.01)
EPOCHS: 19,Loss:0.139491

EPOCHS: 39,Loss:0.007957

EPOCHS: 59,Loss:0.025667

EPOCHS: 79,Loss:0.004511

EPOCHS: 99,Loss:0.012425

EPOCHS: 119,Loss:0.006166

EPOCHS: 139,Loss:0.017573

EPOCHS: 159,Loss:0.005687

EPOCHS: 179,Loss:0.008566

EPOCHS: 199,Loss:0.000836

EPOCHS: 219,Loss:0.003727

EPOCHS: 239,Loss:0.005441

EPOCHS: 259,Loss:0.005437

EPOCHS: 279,Loss:0.004994

EPOCHS: 299,Loss:0.004386

蓝色是模型预测的结果,红色是函数的结果,通过300次的训练,已经基本拟合了

[Pytorch框架] 3.3 通过Sin预测Cos的更多相关文章

  1. [再寄小读者之数学篇](2014-11-19 $\sin(x+y)=\sin x\cos y+\cos x\sin y$)

    $$\bex \sin(x+y)=\sin x\cos y+\cos x\sin y. \eex$$ Ref. [Proof Without Words: Sine Sum Identity, The ...

  2. [再寄小读者之数学篇](2014-04-08 from 1297503521@qq.com $\sin x-x\cos x=0$ 的根的估计)

    (2014-04-08 from 1297503521@qq.com) 设方程 $\sin x-x\cos x=0$ 在 $(0,+\infty)$ 中的第 $n$ 个解为 $x_n$. 证明: $$ ...

  3. 单变量微积分笔记20——三角替换1(sin和cos)

    sin和cos的常用公式 基本公式: 半角公式: 微分公式: 积分公式: 三角替换 示例1 根据微分公式,cosxdx = dsinx 示例2 示例3 半角公式 示例1 示例2 解法1: 解法2: 综 ...

  4. 数学中的Sin和Cos是什么意思?(转)

    数学中的Sin和Cos是什么意思? 作者:admin 分类:生活随笔 发表于 2012年03月21日 16:48 问:数学中的Sin和Cos是什么意思? 答:sin, cos, tan 都是三角函数, ...

  5. PyTorch框架+Python 3面向对象编程学习笔记

    一.CNN情感分类中的面向对象部分 sparse.py super(Embedding, self).__init__() 表示需要父类初始化,即要运行父类的_init_(),如果没有这个,则要自定义 ...

  6. 手写数字识别 卷积神经网络 Pytorch框架实现

    MNIST 手写数字识别 卷积神经网络 Pytorch框架 谨此纪念刚入门的我在卷积神经网络上面的摸爬滚打 说明 下面代码是使用pytorch来实现的LeNet,可以正常运行测试,自己添加了一些注释, ...

  7. 带 sin, cos 的线段树 - 牛客

    链接:https://www.nowcoder.com/acm/contest/160/D来源:牛客网 题目描述给出一个长度为n的整数序列a1,a2,...,an,进行m次操作,操作分为两类.操作1: ...

  8. 小白学习之pytorch框架(1)-torch.nn.Module+squeeze(unsqueeze)

    我学习pytorch框架不是从框架开始,从代码中看不懂的pytorch代码开始的 可能由于是小白的原因,个人不喜欢一些一下子粘贴老多行代码的博主或者一些弄了一堆概念,导致我更迷惑还增加了畏惧的情绪(个 ...

  9. 全面解析Pytorch框架下模型存储,加载以及冻结

    最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...

  10. 小白学习之pytorch框架(7)之实战Kaggle比赛:房价预测(K折交叉验证、*args、**kwargs)

    本篇博客代码来自于<动手学深度学习>pytorch版,也是代码较多,解释较少的一篇.不过好多方法在我以前的博客都有提,所以这次没提.还有一个原因是,这篇博客的代码,只要好好看看肯定能看懂( ...

随机推荐

  1. navicat for mysql (本人亲测,真实有效)

    参考: https://www.cnblogs.com/myprogramer/p/10534481.html 第一步:下载软件 先从Xclient上下载下来Navicat Premium 12.0. ...

  2. C++ condition_variable

    一.使用场景 在主线程中创建一个子线程去计数,计数累计100次后认为成功,并告诉主线程:主线程收到计数100次完成的信息后继续往下执行 二.条件变量的成员函数 wait:当前线程调用 wait() 后 ...

  3. Serverless 架构演进与实践

    Serverless 架构演进与实践 1. 介绍 Serverless 并不仅仅是一个概念,很多地方都已经有了它的影子和思想,本文将给大家介绍最近比较火的 Serverless. 首先放出官方对 Se ...

  4. vue 数组修改 页面无法刷新

    saveData: { current: 1, records:[] , total:0}, countSaveMoney:{ bidSuccessMoney:0, saveMoney:0},页面上有 ...

  5. 内网jenkins跨版本升级

    概要: 原来使用的jenkins版本为1.6,现在需要升级为最新版2.3.6 由于在内网,不能使用jenkins自带的在线升级工具 升级思路: 由于版本跨度太大,直接copy jenkins目录,启动 ...

  6. 使用vite创建vue3 遇到 process is not defined

    今天新建项目遇到报错,查资料得出,需要在vite.config.js中添加代码如下 import { defineConfig } from 'vite' import vue from '@vite ...

  7. vulnhub靶场之HARRYPOTTER: ARAGOG (1.0.2)

    准备: 攻击机:虚拟机kali.本机win10. 靶机:HarryPotter: Aragog (1.0.2),下载地址:https://download.vulnhub.com/harrypotte ...

  8. AI智能问答助手 AI智能批量文章生成器 网站优化SEO批量内容生成工具 原创文章生成软件

    <AI智能问答助手>   软件基于当下热门的OpenAI的ChatGPT技术,导入问题列表就可以批量生成对应的内容,内容质量高.原创度高.适合对内容生成需求量大的场景,如网站优化.广告文案 ...

  9. 做bad apple第二步: python如何将视频变成一帧帧的图片,如何将一帧帧的图片转为视频

    直接上代码 """视频转图片""" port cv2def getphoto(video_in, video_save): cap = cv ...

  10. 认识流媒体协议,从 RTSP 协议解析开始!

    RTSP 是 Internet 协议规范,是 TCP/IP 协议体系中的一个应用层协议级网络通信系统.专为娱乐(如音频和视频)和通信系统的使用,以控制流媒体服务器.该协议用于在端点之间建立和控制媒体会 ...