莫烦pytorch学习笔记(七)——Optimizer优化器
 import torch
 import torch.utils.data as Data
 import torch.nn.functional as F
 from torch.autograd import Variable
 import matplotlib.pyplot as plt
 # 超参数
 LR = 0.01
 BATCH_SIZE = 
 EPOCH = 
 # 生成假数据
 # torch.unsqueeze() 的作用是将一维变二维,torch只能处理二维的数据
 x = torch.unsqueeze(torch.linspace(-, , ), dim=)  # x data (tensor), shape(, )
 # 0.2 * torch.rand(x.size())增加噪点
 y = x.pow() + 0.1 * torch.normal(torch.zeros(*x.size()))
 # 输出数据图
 # plt.scatter(x.numpy(), y.numpy())
 # plt.show()
 torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
 loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=)
 class Net(torch.nn.Module):
     # 初始化
     def __init__(self):
         super(Net, self).__init__()
         self.hidden = torch.nn.Linear(, )
         self.predict = torch.nn.Linear(, )
     # 前向传递
     def forward(self, x):
         x = F.relu(self.hidden(x))
         x = self.predict(x)
         return x
 net_SGD = Net()
 net_Momentum = Net()
 net_RMSProp = Net()
 net_Adam = Net()
 nets = [net_SGD, net_Momentum, net_RMSProp, net_Adam]
 opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)
 opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)
 opt_RMSProp = torch.optim.RMSprop(net_RMSProp.parameters(), lr=LR, alpha=0.9)
 opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))
 optimizers = [opt_SGD, opt_Momentum, opt_RMSProp, opt_Adam]
 loss_func = torch.nn.MSELoss()
 loss_his = [[], [], [], []]  # 记录损失
 for epoch in range(EPOCH):
     print(epoch)
     for step, (batch_x, batch_y) in enumerate(loader):
         b_x = Variable(batch_x)
         b_y = Variable(batch_y)
         for net, opt,l_his in zip(nets, optimizers, loss_his):
             output = net(b_x)  # get output for every net
             loss = loss_func(output, b_y)  # compute loss for every net
             opt.zero_grad()  # clear gradients for next train
             loss.backward()  # backpropagation, compute gradients
             opt.step()  # apply gradients
             l_his.append(loss.data.numpy())  # loss recoder
 labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']
 for i, l_his in enumerate(loss_his):
     plt.plot(l_his, label=labels[i])
 plt.legend(loc='best')
 plt.xlabel('Steps')
 plt.ylabel('Loss')
 plt.ylim((, 0.2))
 plt.show()

莫烦pytorch学习笔记(七)——Optimizer优化器的更多相关文章
- 莫烦 - Pytorch学习笔记 [ 一 ]
		1. Numpy VS Torch #相互转换 np_data = torch_data.numpy() torch_data = torch.from_numpy(np_data) #abs dat ... 
- 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)
		莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ... 
- 莫烦PyTorch学习笔记(五)——模型的存取
		import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed() ... 
- [PyTorch 学习笔记] 4.3 优化器
		本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson4/optimizer_methods.py https: ... 
- 莫烦PyTorch学习笔记(五)——分类
		import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ... 
- 莫烦PyTorch学习笔记(四)——回归
		下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ... 
- 莫烦PyTorch学习笔记(六)——批处理
		1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ... 
- 莫烦PyTorch学习笔记(三)——激励函数
		1. sigmod函数 函数公式和图表如下图 在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ... 
- 莫烦pytorch学习笔记(二)——variable
		.简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ... 
随机推荐
- 转: div:给div加滚动条 div的滚动条设置
			div 的滚动条问题: 两种方法: 一. <div style=" overflow:scroll; width:400px; height:400px;”></div&g ... 
- 绿色版mysql 免安装使用(转载)
			MySQL绿色版的安装(mysql-5.6.22-win32.zip) Posted on 2015-01-31 23:21 卒子 阅读(10739) 评论(2) 编辑 收藏 由于工作需要最近要开始研 ... 
- Git的故事
			目录 Git Git的概念 Git的安装 Git的配置 Git的指令 Git Git的概念 首先我们要知道git是什么,最根本的概念是版本控制,顾名思义,就是git可以帮助我们控制自己写的代码或者文档 ... 
- 面试39  MySQL读写分离
			(1)如何实现mysql的读写分离? 其实很简单,就是基于主从复制架构,简单来说,就搞一个主库,挂多个从库,然后我们就单单只是写主库,然后主库会自动把数据给同步到从库上去. (2)MySQL主从复制原 ... 
- heartbeat 高可用
			转载来自 http://www.cnblogs.com/liwei0526vip/p/6391833.html 使用HeartBeat实现高可用HA的配置过程详解 一.写在前面 HA即(high av ... 
- Android开发 View_自定义圆环进度条View
			前言 一个实现,空心圆环的自定义View,已经封装完好,可以直接使用. 效果图 代码 import android.content.Context; import android.graphics.C ... 
- Android开发 View的UI刷新Invalidate和postInvalidate
			Invalidate 正常刷新 /** * 使整个视图无效.如果视图可见, * {@link #onDraw(android.graphics.Canvas)} 调用此方法后将在后续的UI刷新里调用o ... 
- @Value的使用
			<Spring源码解析>笔记 使用@Value赋值:1.基本数值2.可以写SpEL: #{}3.可以写${}:取出配置文件[properties]中的值(在运行环境变量里面的值) 1.创建 ... 
- List、Map、Set 三个接口,存取元素时,各有什么特点
			List与Set都是单列元素的集合,它们有一个功共同的父接口Collection. Set里面不允许有重复的元素, 存元素:add方法有一个boolean的返回值,当集合中没有某个元素,此时add方法 ... 
- Windows exit
			退出 CMD.EXE 程序(命令解释器)或当前批处理脚本. EXIT [/B] [exitCode] /B 指定要退出当前批处理脚本而不是 CMD.EXE.如果从一个 ... 
