如需了解完整代码请跳转到:

https://www.emperinter.info/2020/08/05/change-leaning-rate-by-reducelronplateau-in-pytorch/

缘由

自己之前写过一个Pytorch学习率更新,其中感觉依据是否loss升高或降低的次数来动态更新学习率,感觉是个挺好玩的东西,自己弄了好久都设置错误,今天算是搞出来了!

解析

说明

  • torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)

在发现loss不再降低或者acc不再提高之后,降低学习率。各参数意义如下:

参数 含义
mode 'min'模式检测metric是否不再减小,'max'模式检测metric是否不再增大;
factor 触发条件后lr*=factor;
patience 不再减小(或增大)的累计次数;
verbose 触发条件后print;
threshold 只关注超过阈值的显著变化;
threshold_mode 有rel和abs两种阈值计算模式,rel规则:max模式下如果超过best(1+threshold)为显著,min模式下如果低于best(1-threshold)为显著;abs规则:max模式下如果超过best+threshold为显著,min模式下如果低于best-threshold为显著;
cooldown 触发一次条件后,等待一定epoch再进行检测,避免lr下降过速;
min_lr 最小的允许lr;
eps 如果新旧lr之间的差异小与1e-8,则忽略此次更新。
  • 例子,如图所示的y轴为lr,x为调整的次序,初始的学习率为0.0009575

    则学习率的方程为:lr = 0.0009575 * (0.35)^x

import math
import matplotlib.pyplot as plt
#%matplotlib inline x = 0
o = []
p = []
o.append(0)
p.append(0.0009575)
while(x < 8):
x += 1
y = 0.0009575 * math.pow(0.35,x)
o.append(x)
p.append(y)
print('%d: %.50f' %(x,y)) plt.plot(o,p,c='red',label='test') #分别为x,y轴对应数据,c:color,label
plt.legend(loc='best') # 显示label,loc为显示位置(best为系统认为最好的位置)
plt.show()

难点

我感觉这里面最难的时这几个参数的选择,第一个是初始的学习率(我目前接触的miniest和下面的图像分类貌似都是0.001,我这里训练调整时才发现自己设置的为0.0009575,这个值是上一个实验忘更改了,但发现结果不错,第一次运行该代码接近到0.001这么小的损失值),这里面的乘积系数以及判断说多少次没有减少(增加)后决定变换学习率都是难以估计的。我自己的最好方法是先按默认不变的0.001来训练一下(结合tensoarboard )观察从哪里开始出现问题就可以从这里来确定次数,而乘积系数,个人感觉还是用上面的代码来获取一个较为平滑且变化极小的数字来作为选择。建议在做这种测试时可以把模型先备份一下以免浪费过多的时间!

例子

  • 该例子初始学习率为0.0009575,乘积项系数为:0.35,在我的例子中x变化的条件是:累计125次没有减小则x加1;自己训练在第一次lr变化后(从0.0009575变化到0.00011729)损失值慢慢取向于0.001(如第一张图所示),准确率达到69%;

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from torch.optim import * PATH = './cifar_net_tensorboard_net_width_200_and_chang_lr_by_decrease_0_35^x.pth' # 保存模型地址 transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=0) testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=0) classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck') device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Assuming that we are on a CUDA machine, this should print a CUDA device: print(device) print("获取一些随机训练数据")
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next() # functions to show an image
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show() # show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
print("**********************") # 设置一个tensorborad
# helper function to show an image
# (used in the `plot_classes_preds` function below)
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.cpu().numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0))) # 设置tensorBoard
# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter('runs/train') # get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next() # create grid of images
img_grid = torchvision.utils.make_grid(images) # show images
# matplotlib_imshow(img_grid, one_channel=True)
imshow(img_grid) # write to tensorboard
# writer.add_image('imag_classify', img_grid) # Tracking model training with TensorBoard
# helper functions def images_to_probs(net, images):
'''
Generates predictions and corresponding probabilities from a trained
network and a list of images
'''
output = net(images)
# convert output probabilities to predicted class
_, preds_tensor = torch.max(output, 1)
# preds = np.squeeze(preds_tensor.numpy())
preds = np.squeeze(preds_tensor.cpu().numpy())
return preds, [F.softmax(el, dim=0)[i].item() for i, el in zip(preds, output)] def plot_classes_preds(net, images, labels):
preds, probs = images_to_probs(net, images)
# plot the images in the batch, along with predicted and true labels
fig = plt.figure(figsize=(12, 48))
for idx in np.arange(4):
ax = fig.add_subplot(1, 4, idx+1, xticks=[], yticks=[])
matplotlib_imshow(images[idx], one_channel=True)
ax.set_title("{0}, {1:.1f}%\n(label: {2})".format(
classes[preds[idx]],
probs[idx] * 100.0,
classes[labels[idx]]),
color=("green" if preds[idx]==labels[idx].item() else "red"))
return fig # class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 200, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(200, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x net = Net()
# # 把net结构可视化出来
writer.add_graph(net, images)
net.to(device) ·······
·······
·······

如需了解完整代码请跳转到:

https://www.emperinter.info/2020/08/05/change-leaning-rate-by-reducelronplateau-in-pytorch/

Pytorch使用ReduceLROnPlateau来更新学习率的更多相关文章

  1. 【转载】 Pytorch(0)降低学习率torch.optim.lr_scheduler.ReduceLROnPlateau类

    原文地址: https://blog.csdn.net/weixin_40100431/article/details/84311430 ------------------------------- ...

  2. Pytorch系列:(八)学习率调整方法

    学习率的调整会对网络模型的训练造成巨大的影响,本文总结了pytorch自带的学习率调整函数,以及其使用方法. 设置网络固定学习率 设置固定学习率的方法有两种,第一种是直接设置一些学习率,网络从头到尾都 ...

  3. pytorch更新

    Pytorch如何更新版本与卸载,使用pip,conda更新卸载Pytorch 2018年05月22日 07:33:52 醉雨轩Y 阅读数 19047   今天我们主要汇总如何使用使用ubuntu,C ...

  4. 『PyTorch』屌丝的PyTorch玩法

    1. prefetch_generator 使用 prefetch_generator库 在后台加载下一batch的数据,原本PyTorch默认的DataLoader会创建一些worker线程来预读取 ...

  5. 探索学习率设置技巧以提高Keras中模型性能 | 炼丹技巧

      学习率是一个控制每次更新模型权重时响应估计误差而调整模型程度的超参数.学习率选取是一项具有挑战性的工作,学习率设置的非常小可能导致训练过程过长甚至训练进程被卡住,而设置的非常大可能会导致过快学习到 ...

  6. [PyTorch 学习笔记] 2.1 DataLoader 与 DataSet

    thumbnail: https://image.zhangxiann.com/jeison-higuita-W19AQY42rUk-unsplash.jpg toc: true date: 2020 ...

  7. PyTorch模型读写、参数初始化、Finetune

    使用了一段时间PyTorch,感觉爱不释手(0-0),听说现在已经有C++接口.在应用过程中不可避免需要使用Finetune/参数初始化/模型加载等. 模型保存/加载 1.所有模型参数 训练过程中,有 ...

  8. caffe中的学习率的衰减机制

    版权声明:本文为博主原创文章,未经博主允许不得转载. https://blog.csdn.net/Julialove102123/article/details/79200158 根据  caffe/ ...

  9. 总结笔记 | 深度学习之Pytorch入门教程

    笔记作者:王博Kings 目录 一.整体学习的建议 1.1 如何成为Pytorch大神? 1.2 如何读Github代码? 1.3 代码能力太弱怎么办? 二.Pytorch与TensorFlow概述 ...

  10. [Pytorch框架] PyTorch 中文手册

    PyTorch 中文手册 书籍介绍 这是一本开源的书籍,目标是帮助那些希望和使用PyTorch进行深度学习开发和研究的朋友快速入门. 由于本人水平有限,在写此教程的时候参考了一些网上的资料,在这里对他 ...

随机推荐

  1. OAuth + Security - 2 - 资源服务器配置

    PS:此文章为系列文章,建议从第一篇开始阅读. 资源服务器配置 @EnableResourceServer 注解到一个@Configuration配置类上,并且必须使用ResourceServerCo ...

  2. rabbitMq实现系统内的短信发送设计&动态获取BEAN

    rabbitMq实现系统内的短信发送设计&动态获取BEAN 1.短信非系统的重要节点操作,可以在任务完成之后,比如下单成功,发送下单成功的mq消息,短信服务接收到mq消息,动态的判断该短信的c ...

  3. Pycharm创建的虚拟环境,使用命令行指定库的版本进行安装

    Pycharm创建的项目,使用了虚拟环境,对库的版本进行管理:有些项目的对第三方库的版本 要求不同,可使用虚拟环境进行管理 直接想通过pip命令安装,直接看第3点 操作步骤: 1.找到当前项目的虚拟环 ...

  4. VUE中watch的详细使用教程

      1.watch是什么? watch:是vue中常用的侦听器(监听器),用来监听数据的变化 2.watch的使用方式如下 watch: { 这里写你在data中定义的变量名或别处方法名: { han ...

  5. IgH EtherCAT主站开发案例分享——基于NXP i.MX 8M Mini

    前  言 本文档主要演示NXP i.MX 8M Mini工业开发板基于IgH EtherCAT控制伺服电机.   演示板卡是创龙科技的TLIMX8-EVM工业开发板,它是基于NXP i.MX 8M M ...

  6. 2.SpringBoot快速上手

    2.SpringBoot快速上手 SpringBoot介绍 javaEE的开发经常会涉及到3个框架Spring ,SpringMVC,MyBatis.但是这三个框架配置极其繁琐,有大量的xml文件,s ...

  7. 使用sqlcel导入数据时出现“a column named '***' already belongs to this datatable”问题的解决办法

    我修改编码为GBK之后,选择导入部分字段,如下: 这样就不会出现之前的问题了,完美 ----------------------------------------------- 但是出现一个问题,我 ...

  8. mobaXterm 查看密码

    参考:MobaXterm中密码的查看方式 检查是否把密码保存到了注册表中 然后从https://github.com/HyperSine/how-does-MobaXterm-encrypt-pass ...

  9. sqlite相关

    前言 本文记录一些sqlite相关笔记,随时更新. 正文 时间函数 datetime() -- 当前时间 2022-03-24 17:32:43 select datetime('now'); --2 ...

  10. VulnHub_DC-4渗透流程

    VulnHub_DC-4 DC-4 is another purposely built vulnerable lab with the intent of gaining experience in ...