[Pytorch框架] 3.3 通过Sin预测Cos
文章目录
3.3 通过Sin预测Cos
%matplotlib inline
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import optim
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.animation
import math, random
torch.__version__
'1.3.0'
在介绍循环神经网络时候我们说过,循环神经网络由于其的特殊结构,十分十分擅长处理时间相关的数据,下面我们就来通过输入sin函数,输出cos函数来实际使用。
首先,我们还是定义一些超参数
TIME_STEP = 10 # rnn 时序步长数
INPUT_SIZE = 1 # rnn 的输入维度
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
H_SIZE = 64 # of rnn 隐藏单元个数
EPOCHS=300 # 总共训练次数
h_state = None # 隐藏层状态
由于是使用sin和cos函数,所以这里不需要dataloader,我们直接使用Numpy生成数据,Pytorch没有π这个常量,所以所有操作都是用Numpy完成
steps = np.linspace(0, np.pi*2, 256, dtype=np.float32)
x_np = np.sin(steps)
y_np = np.cos(steps)
生成完后,我们可视化一下数据
plt.figure(1)
plt.suptitle('Sin and Cos',fontsize='18')
plt.plot(steps, y_np, 'r-', label='target (cos)')
plt.plot(steps, x_np, 'b-', label='input (sin)')
plt.legend(loc='best')
plt.show()

下面定义一下我们的网络结构
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn = nn.RNN(
input_size=INPUT_SIZE,
hidden_size=H_SIZE,
num_layers=1,
batch_first=True,
)
self.out = nn.Linear(H_SIZE, 1)
def forward(self, x, h_state):
# x (batch, time_step, input_size)
# h_state (n_layers, batch, hidden_size)
# r_out (batch, time_step, hidden_size)
r_out, h_state = self.rnn(x, h_state)
outs = [] # 保存所有的预测值
for time_step in range(r_out.size(1)): # 计算每一步长的预测值
outs.append(self.out(r_out[:, time_step, :]))
return torch.stack(outs, dim=1), h_state
# 也可使用以下这样的返回值
# r_out = r_out.view(-1, 32)
# outs = self.out(r_out)
# return outs, h_state
下面我们定义我们的网络
rnn = RNN().to(DEVICE)
optimizer = torch.optim.Adam(rnn.parameters()) # Adam优化,几乎不用调参
criterion = nn.MSELoss() # 因为最终的结果是一个数值,所以损失函数用均方误差
由于没有测试集,所以我们训练和测试写在一起了
rnn.train()
plt.figure(2)
for step in range(EPOCHS):
start, end = step * np.pi, (step+1)*np.pi # 一个时间周期
steps = np.linspace(start, end, TIME_STEP, dtype=np.float32)
x_np = np.sin(steps)
y_np = np.cos(steps)
x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis]) # shape (batch, time_step, input_size)
y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])
x=x.to(DEVICE)
prediction, h_state = rnn(x, h_state) # rnn output
# 这一步非常重要
h_state = h_state.data # 重置隐藏层的状态, 切断和前一次迭代的链接
loss = criterion(prediction.cpu(), y)
# 这三行写在一起就可以
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (step+1)%20==0: #每训练20个批次可视化一下效果,并打印一下loss
print("EPOCHS: {},Loss:{:4f}".format(step,loss))
plt.plot(steps, y_np.flatten(), 'r-')
plt.plot(steps, prediction.cpu().data.numpy().flatten(), 'b-')
plt.draw()
plt.pause(0.01)
EPOCHS: 19,Loss:0.139491

EPOCHS: 39,Loss:0.007957

EPOCHS: 59,Loss:0.025667

EPOCHS: 79,Loss:0.004511

EPOCHS: 99,Loss:0.012425

EPOCHS: 119,Loss:0.006166

EPOCHS: 139,Loss:0.017573

EPOCHS: 159,Loss:0.005687

EPOCHS: 179,Loss:0.008566

EPOCHS: 199,Loss:0.000836

EPOCHS: 219,Loss:0.003727

EPOCHS: 239,Loss:0.005441

EPOCHS: 259,Loss:0.005437

EPOCHS: 279,Loss:0.004994

EPOCHS: 299,Loss:0.004386

蓝色是模型预测的结果,红色是函数的结果,通过300次的训练,已经基本拟合了
[Pytorch框架] 3.3 通过Sin预测Cos的更多相关文章
- [再寄小读者之数学篇](2014-11-19 $\sin(x+y)=\sin x\cos y+\cos x\sin y$)
$$\bex \sin(x+y)=\sin x\cos y+\cos x\sin y. \eex$$ Ref. [Proof Without Words: Sine Sum Identity, The ...
- [再寄小读者之数学篇](2014-04-08 from 1297503521@qq.com $\sin x-x\cos x=0$ 的根的估计)
(2014-04-08 from 1297503521@qq.com) 设方程 $\sin x-x\cos x=0$ 在 $(0,+\infty)$ 中的第 $n$ 个解为 $x_n$. 证明: $$ ...
- 单变量微积分笔记20——三角替换1(sin和cos)
sin和cos的常用公式 基本公式: 半角公式: 微分公式: 积分公式: 三角替换 示例1 根据微分公式,cosxdx = dsinx 示例2 示例3 半角公式 示例1 示例2 解法1: 解法2: 综 ...
- 数学中的Sin和Cos是什么意思?(转)
数学中的Sin和Cos是什么意思? 作者:admin 分类:生活随笔 发表于 2012年03月21日 16:48 问:数学中的Sin和Cos是什么意思? 答:sin, cos, tan 都是三角函数, ...
- PyTorch框架+Python 3面向对象编程学习笔记
一.CNN情感分类中的面向对象部分 sparse.py super(Embedding, self).__init__() 表示需要父类初始化,即要运行父类的_init_(),如果没有这个,则要自定义 ...
- 手写数字识别 卷积神经网络 Pytorch框架实现
MNIST 手写数字识别 卷积神经网络 Pytorch框架 谨此纪念刚入门的我在卷积神经网络上面的摸爬滚打 说明 下面代码是使用pytorch来实现的LeNet,可以正常运行测试,自己添加了一些注释, ...
- 带 sin, cos 的线段树 - 牛客
链接:https://www.nowcoder.com/acm/contest/160/D来源:牛客网 题目描述给出一个长度为n的整数序列a1,a2,...,an,进行m次操作,操作分为两类.操作1: ...
- 小白学习之pytorch框架(1)-torch.nn.Module+squeeze(unsqueeze)
我学习pytorch框架不是从框架开始,从代码中看不懂的pytorch代码开始的 可能由于是小白的原因,个人不喜欢一些一下子粘贴老多行代码的博主或者一些弄了一堆概念,导致我更迷惑还增加了畏惧的情绪(个 ...
- 全面解析Pytorch框架下模型存储,加载以及冻结
最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...
- 小白学习之pytorch框架(7)之实战Kaggle比赛:房价预测(K折交叉验证、*args、**kwargs)
本篇博客代码来自于<动手学深度学习>pytorch版,也是代码较多,解释较少的一篇.不过好多方法在我以前的博客都有提,所以这次没提.还有一个原因是,这篇博客的代码,只要好好看看肯定能看懂( ...
随机推荐
- Cplex解决JSP问题
我的上一篇博客Cplex解决FSP问题 - 加油,陌生人! - 博客园 (cnblogs.com)用Cplex完成了FSP的建模,这篇文章主要是解决JSP问题(车间调度问题). JSP问题:n个工件, ...
- .Net5.0 上传图片、文件到服务器
今天来看看.net上传图片到服务器的方式 public class ControlPresetUploadInput { /// <summary> /// 通道编号 /// </s ...
- 微信小程序中如何识别银行卡和身份证
识别银行卡云函数card2/index.js: const cloud = require('wx-server-sdk') cloud.init({ env: cloud.DYNAMIC_CURRE ...
- [BUUCTF]HCTF 2018WarmUp1 write up
ctrl+U查看源代码, 如下: 访问提示中的source.php文件 发现显示了源码,且存在另一个PHP文件hint.php(提示.php),先查看文件内是否有信息 用file来传参,并且要绕过wh ...
- C#开发微信
C#开发微信门户及应用教程 C#开发微信门户及应用(1)--开始使用微信接口... 6 1.微信账号... 6 2.微信菜单定义... 7 3.接入微信的链接处理... 8 4.使用开发方式创建菜 ...
- 看了还不懂b+tree的本质就来打我
看了还不懂b+tree的本质就来打我 大家好,我是蓝胖子. 今天我们来看看b+tree这种数据结构,我们知道数据库的索引就是由b+tree实现,那么这种结构究竟为什么适合磁盘呢,它又有哪些缺点呢? 我 ...
- PicList 现已上架Mac App Store 分享下整个上架过程和遇到的问题
PicList 是一款云存储/图床平台管理和文件上传工具,基于 PicGo 进行了深度二次开发,保留了 PicGo 的所有功能的同时,为相册添加了同步云端删除功能,同时增加了完整的云存储管理功能,包括 ...
- 第一章C语言概述
1.1程序实例 //first.程序 #include <stdio.h> int main() { int num; num = 1; printf("I am a simpl ...
- Windows 系统下怎么获取 UDP 本机地址
Windows 系统下怎么获取 UDP 本机地址 我们知道 UDP 获取远端地址非常简单,通常接口 recvfrom 就可以直接获取到远端的地址和端口:如果获取 UDP 的本机地址就需要点特殊处理了, ...
- AIGC时代:未来已来
摘要:人工智能的快速发展使得我们进入了AIGC时代.AIGC时代的到来,将会带来巨大的机遇和挑战. 本文分享自华为云社区<GPT-4发布,AIGC时代的多模态还能走多远?系列之一: AIGC时代 ...