[Pytorch框架] 3.3 通过Sin预测Cos
文章目录
3.3 通过Sin预测Cos
%matplotlib inline
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import optim
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.animation
import math, random
torch.__version__
'1.3.0'
在介绍循环神经网络时候我们说过,循环神经网络由于其的特殊结构,十分十分擅长处理时间相关的数据,下面我们就来通过输入sin函数,输出cos函数来实际使用。
首先,我们还是定义一些超参数
TIME_STEP = 10 # rnn 时序步长数
INPUT_SIZE = 1 # rnn 的输入维度
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
H_SIZE = 64 # of rnn 隐藏单元个数
EPOCHS=300 # 总共训练次数
h_state = None # 隐藏层状态
由于是使用sin和cos函数,所以这里不需要dataloader,我们直接使用Numpy生成数据,Pytorch没有π这个常量,所以所有操作都是用Numpy完成
steps = np.linspace(0, np.pi*2, 256, dtype=np.float32)
x_np = np.sin(steps)
y_np = np.cos(steps)
生成完后,我们可视化一下数据
plt.figure(1)
plt.suptitle('Sin and Cos',fontsize='18')
plt.plot(steps, y_np, 'r-', label='target (cos)')
plt.plot(steps, x_np, 'b-', label='input (sin)')
plt.legend(loc='best')
plt.show()

下面定义一下我们的网络结构
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn = nn.RNN(
input_size=INPUT_SIZE,
hidden_size=H_SIZE,
num_layers=1,
batch_first=True,
)
self.out = nn.Linear(H_SIZE, 1)
def forward(self, x, h_state):
# x (batch, time_step, input_size)
# h_state (n_layers, batch, hidden_size)
# r_out (batch, time_step, hidden_size)
r_out, h_state = self.rnn(x, h_state)
outs = [] # 保存所有的预测值
for time_step in range(r_out.size(1)): # 计算每一步长的预测值
outs.append(self.out(r_out[:, time_step, :]))
return torch.stack(outs, dim=1), h_state
# 也可使用以下这样的返回值
# r_out = r_out.view(-1, 32)
# outs = self.out(r_out)
# return outs, h_state
下面我们定义我们的网络
rnn = RNN().to(DEVICE)
optimizer = torch.optim.Adam(rnn.parameters()) # Adam优化,几乎不用调参
criterion = nn.MSELoss() # 因为最终的结果是一个数值,所以损失函数用均方误差
由于没有测试集,所以我们训练和测试写在一起了
rnn.train()
plt.figure(2)
for step in range(EPOCHS):
start, end = step * np.pi, (step+1)*np.pi # 一个时间周期
steps = np.linspace(start, end, TIME_STEP, dtype=np.float32)
x_np = np.sin(steps)
y_np = np.cos(steps)
x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis]) # shape (batch, time_step, input_size)
y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])
x=x.to(DEVICE)
prediction, h_state = rnn(x, h_state) # rnn output
# 这一步非常重要
h_state = h_state.data # 重置隐藏层的状态, 切断和前一次迭代的链接
loss = criterion(prediction.cpu(), y)
# 这三行写在一起就可以
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (step+1)%20==0: #每训练20个批次可视化一下效果,并打印一下loss
print("EPOCHS: {},Loss:{:4f}".format(step,loss))
plt.plot(steps, y_np.flatten(), 'r-')
plt.plot(steps, prediction.cpu().data.numpy().flatten(), 'b-')
plt.draw()
plt.pause(0.01)
EPOCHS: 19,Loss:0.139491

EPOCHS: 39,Loss:0.007957

EPOCHS: 59,Loss:0.025667

EPOCHS: 79,Loss:0.004511

EPOCHS: 99,Loss:0.012425

EPOCHS: 119,Loss:0.006166

EPOCHS: 139,Loss:0.017573

EPOCHS: 159,Loss:0.005687

EPOCHS: 179,Loss:0.008566

EPOCHS: 199,Loss:0.000836

EPOCHS: 219,Loss:0.003727

EPOCHS: 239,Loss:0.005441

EPOCHS: 259,Loss:0.005437

EPOCHS: 279,Loss:0.004994

EPOCHS: 299,Loss:0.004386

蓝色是模型预测的结果,红色是函数的结果,通过300次的训练,已经基本拟合了
[Pytorch框架] 3.3 通过Sin预测Cos的更多相关文章
- [再寄小读者之数学篇](2014-11-19 $\sin(x+y)=\sin x\cos y+\cos x\sin y$)
$$\bex \sin(x+y)=\sin x\cos y+\cos x\sin y. \eex$$ Ref. [Proof Without Words: Sine Sum Identity, The ...
- [再寄小读者之数学篇](2014-04-08 from 1297503521@qq.com $\sin x-x\cos x=0$ 的根的估计)
(2014-04-08 from 1297503521@qq.com) 设方程 $\sin x-x\cos x=0$ 在 $(0,+\infty)$ 中的第 $n$ 个解为 $x_n$. 证明: $$ ...
- 单变量微积分笔记20——三角替换1(sin和cos)
sin和cos的常用公式 基本公式: 半角公式: 微分公式: 积分公式: 三角替换 示例1 根据微分公式,cosxdx = dsinx 示例2 示例3 半角公式 示例1 示例2 解法1: 解法2: 综 ...
- 数学中的Sin和Cos是什么意思?(转)
数学中的Sin和Cos是什么意思? 作者:admin 分类:生活随笔 发表于 2012年03月21日 16:48 问:数学中的Sin和Cos是什么意思? 答:sin, cos, tan 都是三角函数, ...
- PyTorch框架+Python 3面向对象编程学习笔记
一.CNN情感分类中的面向对象部分 sparse.py super(Embedding, self).__init__() 表示需要父类初始化,即要运行父类的_init_(),如果没有这个,则要自定义 ...
- 手写数字识别 卷积神经网络 Pytorch框架实现
MNIST 手写数字识别 卷积神经网络 Pytorch框架 谨此纪念刚入门的我在卷积神经网络上面的摸爬滚打 说明 下面代码是使用pytorch来实现的LeNet,可以正常运行测试,自己添加了一些注释, ...
- 带 sin, cos 的线段树 - 牛客
链接:https://www.nowcoder.com/acm/contest/160/D来源:牛客网 题目描述给出一个长度为n的整数序列a1,a2,...,an,进行m次操作,操作分为两类.操作1: ...
- 小白学习之pytorch框架(1)-torch.nn.Module+squeeze(unsqueeze)
我学习pytorch框架不是从框架开始,从代码中看不懂的pytorch代码开始的 可能由于是小白的原因,个人不喜欢一些一下子粘贴老多行代码的博主或者一些弄了一堆概念,导致我更迷惑还增加了畏惧的情绪(个 ...
- 全面解析Pytorch框架下模型存储,加载以及冻结
最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...
- 小白学习之pytorch框架(7)之实战Kaggle比赛:房价预测(K折交叉验证、*args、**kwargs)
本篇博客代码来自于<动手学深度学习>pytorch版,也是代码较多,解释较少的一篇.不过好多方法在我以前的博客都有提,所以这次没提.还有一个原因是,这篇博客的代码,只要好好看看肯定能看懂( ...
随机推荐
- unittest框架基本使用
1.简介 unittest是python内置的单元测试框架,具备编写用例.组织用例.执行用例.输出报告等自动化框架的条件.使用unittest前需要了解该框架的五个概念: 即test case,tes ...
- java 进程排查
[admin@New-OperSys-01 ~]$ jstack $pid | grep -A 50 55e7 "GC task thread#1 (ParallelGC)" os ...
- Arrays.asList()需要注意的点
千万不要这样使用Arrays.asList ! 测试的几种情况及原因: public static void main(String[] args) { //第一种基本类型数组 int[] arr = ...
- System.IO.IOException:“找不到资源“views.buttonstylepage.xaml”。”
初学作为记录(事发场景): WPFDemo的程序集中,定义了一个Views文件夹,该文件夹放一些页面Page.UI层面的东西.用Frame空间做导航的时候,始终报一个错误 // System. ...
- 1--我们写了一个java类,那么生成一个对象占用多大的内存?
public class Student { private long id; private long userId; private byte state; private long create ...
- Java 比较两个对象的不同之处(old, new) 包含 bean 对象下的 list, Map , bean 的细节
Java 比较两个对象的不同之处(old, new) 包含 bean 对象下的 list, Map , bean 的细节 package com.icil.pinpal.test1; impor ...
- getopts解析shell脚本命令行参数
getopts命令格式 getopts optstring name [arg] optstring为所有可匹配选项组成的字符串,每个字母代表一个选项.如果字母后有冒号:,表明该选项需要选择参数.比如 ...
- MS12-020 拒绝服务 蓝屏攻击
漏洞概要 MS12-020是一个3389远程桌面rdp协议的一个漏洞 攻击者通过特意构造的rdp数据包发送给靶机3389端口,造成系统崩溃,蓝屏重启 影响范围:windows xp .2003.200 ...
- windows2003 DHCP服务器配置
一.导入光驱 二.安装可选的windows组件 三.双击打开网路服务,安装DHCP/DNS服务器. 注:服务器地址要固定,因此安装时要规划好网络. 四.ip地址范围规划时要预留i出一些p地址.排除ip ...
- MySQL学习(七)varchar和char区别
varchar:用于存储可变长字符串,是最常见的字符串数据类型.比定长类型更节省空间,因为它仅使用必要的空间.varchar需要使用1或2个额外字节记录字符串的长度:如果列的最大长度小于或等于255字 ...