莫烦pytorch学习笔记(七)——Optimizer优化器
 import torch
 import torch.utils.data as Data
 import torch.nn.functional as F
 from torch.autograd import Variable
 import matplotlib.pyplot as plt
 # 超参数
 LR = 0.01
 BATCH_SIZE = 
 EPOCH = 
 # 生成假数据
 # torch.unsqueeze() 的作用是将一维变二维,torch只能处理二维的数据
 x = torch.unsqueeze(torch.linspace(-, , ), dim=)  # x data (tensor), shape(, )
 # 0.2 * torch.rand(x.size())增加噪点
 y = x.pow() + 0.1 * torch.normal(torch.zeros(*x.size()))
 # 输出数据图
 # plt.scatter(x.numpy(), y.numpy())
 # plt.show()
 torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
 loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=)
 class Net(torch.nn.Module):
     # 初始化
     def __init__(self):
         super(Net, self).__init__()
         self.hidden = torch.nn.Linear(, )
         self.predict = torch.nn.Linear(, )
     # 前向传递
     def forward(self, x):
         x = F.relu(self.hidden(x))
         x = self.predict(x)
         return x
 net_SGD = Net()
 net_Momentum = Net()
 net_RMSProp = Net()
 net_Adam = Net()
 nets = [net_SGD, net_Momentum, net_RMSProp, net_Adam]
 opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)
 opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)
 opt_RMSProp = torch.optim.RMSprop(net_RMSProp.parameters(), lr=LR, alpha=0.9)
 opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))
 optimizers = [opt_SGD, opt_Momentum, opt_RMSProp, opt_Adam]
 loss_func = torch.nn.MSELoss()
 loss_his = [[], [], [], []]  # 记录损失
 for epoch in range(EPOCH):
     print(epoch)
     for step, (batch_x, batch_y) in enumerate(loader):
         b_x = Variable(batch_x)
         b_y = Variable(batch_y)
         for net, opt,l_his in zip(nets, optimizers, loss_his):
             output = net(b_x)  # get output for every net
             loss = loss_func(output, b_y)  # compute loss for every net
             opt.zero_grad()  # clear gradients for next train
             loss.backward()  # backpropagation, compute gradients
             opt.step()  # apply gradients
             l_his.append(loss.data.numpy())  # loss recoder
 labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']
 for i, l_his in enumerate(loss_his):
     plt.plot(l_his, label=labels[i])
 plt.legend(loc='best')
 plt.xlabel('Steps')
 plt.ylabel('Loss')
 plt.ylim((, 0.2))
 plt.show()

莫烦pytorch学习笔记(七)——Optimizer优化器的更多相关文章
- 莫烦 - Pytorch学习笔记 [ 一 ]
		
1. Numpy VS Torch #相互转换 np_data = torch_data.numpy() torch_data = torch.from_numpy(np_data) #abs dat ...
 - 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)
		
莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...
 - 莫烦PyTorch学习笔记(五)——模型的存取
		
import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed() ...
 - [PyTorch 学习笔记] 4.3 优化器
		
本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson4/optimizer_methods.py https: ...
 - 莫烦PyTorch学习笔记(五)——分类
		
import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...
 - 莫烦PyTorch学习笔记(四)——回归
		
下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...
 - 莫烦PyTorch学习笔记(六)——批处理
		
1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...
 - 莫烦PyTorch学习笔记(三)——激励函数
		
1. sigmod函数 函数公式和图表如下图 在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...
 - 莫烦pytorch学习笔记(二)——variable
		
.简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...
 
随机推荐
- linux就该这么学--资料整理--持续更新
			
基础命令 服务管理 systemctl redhat7 systemctl start foo.service 启动服务 systemctl restart foo.service 重启服务 syst ...
 - spring MVC <mvc:annotation-driven>
			
研究SpringMvc 3.2的<mvc:annotation-driven>默认干了什么 如果不配置其他参数,大致相当于以下的配置文件(参考自org.springframework.we ...
 - django 项目分析
			
项目要点 一.功能制定 1.用户功能 #.登陆 #.权限组功能 2.数据展示功能 #.列表展示 #.详细信息展示 #.图标展示 3.资源管理功能 #远程管理 #对远程服务器上的进程具有 #开启 #关闭 ...
 - Algo: maxSubArray vs. maxProduct
			
这两个问题类似,都可利用动态规划思想求解. 一.最大连续子序列和 https://leetcode.com/problems/maximum-subarray/description/ https:/ ...
 - 从别人git下载项目下来然后运行
			
1点击clone or download 2.自由选择 3.拉到你想放的位置,我是放到桌面上的 4. cmd 打开,进入你 的下载到桌面的项目 5. # install dependencies n ...
 - 阿里云应用上边缘云解决方案助力互联网All in Cloud
			
九月末的杭州因为一场云栖大会变得格外火热. 9月25日,吸引全球目光的2019杭州云栖大会如期开幕.20000平米的展区集结数百家企业,为数万名开发者带来了一场前沿科技的饕餮盛宴. 如同往年一样,位于 ...
 - thinkphp 链接数据库
			
ThinkPHP内置了抽象数据库访问层,把不同的数据库操作封装起来,我们只需要使用公共的Db类进行操作,而无需针对不同的数据库写不同的代码和底层实现,Db类会自动调用相应的数据库驱动来处理.目前的数据 ...
 - (转)Android开发把项目打包成apk
			
转:http://blog.csdn.net/luoyin22/article/details/7862742 做完一个Android项目之后,如何才能把项目发布到Internet上供别人使用呢?我们 ...
 - 什么是存根类 Stub
			
转:http://www.cnblogs.com/cy163/archive/2009/08/04/1539077.html 存根类是一个类,它实现了一个接口,但是实现后的每个方法都是空的. ...
 - Vuex听说很难?
			
Vuex 是什么? Vuex 是一个专为 Vue.js 应用程序开发的状态管理模式.它采用集中式存储管理应用的所有组件状态,并以相应的规则保证状态以一种可预测的方式发生变化. 什么鬼东西 看完这段 ...