Pytorch使用ReduceLROnPlateau来更新学习率
如需了解完整代码请跳转到:
https://www.emperinter.info/2020/08/05/change-leaning-rate-by-reducelronplateau-in-pytorch/

缘由
自己之前写过一个Pytorch学习率更新,其中感觉依据是否loss升高或降低的次数来动态更新学习率,感觉是个挺好玩的东西,自己弄了好久都设置错误,今天算是搞出来了!
解析
说明
- torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)
在发现loss不再降低或者acc不再提高之后,降低学习率。各参数意义如下:
| 参数 | 含义 |
|---|---|
| mode | 'min'模式检测metric是否不再减小,'max'模式检测metric是否不再增大; |
| factor | 触发条件后lr*=factor; |
| patience | 不再减小(或增大)的累计次数; |
| verbose | 触发条件后print; |
| threshold | 只关注超过阈值的显著变化; |
| threshold_mode | 有rel和abs两种阈值计算模式,rel规则:max模式下如果超过best(1+threshold)为显著,min模式下如果低于best(1-threshold)为显著;abs规则:max模式下如果超过best+threshold为显著,min模式下如果低于best-threshold为显著; |
| cooldown | 触发一次条件后,等待一定epoch再进行检测,避免lr下降过速; |
| min_lr | 最小的允许lr; |
| eps | 如果新旧lr之间的差异小与1e-8,则忽略此次更新。 |
- 例子,如图所示的y轴为lr,x为调整的次序,初始的学习率为0.0009575
则学习率的方程为:lr = 0.0009575 * (0.35)^x

import math
import matplotlib.pyplot as plt
#%matplotlib inline
x = 0
o = []
p = []
o.append(0)
p.append(0.0009575)
while(x < 8):
x += 1
y = 0.0009575 * math.pow(0.35,x)
o.append(x)
p.append(y)
print('%d: %.50f' %(x,y))
plt.plot(o,p,c='red',label='test') #分别为x,y轴对应数据,c:color,label
plt.legend(loc='best') # 显示label,loc为显示位置(best为系统认为最好的位置)
plt.show()
难点
我感觉这里面最难的时这几个参数的选择,第一个是初始的学习率(我目前接触的miniest和下面的图像分类貌似都是0.001,我这里训练调整时才发现自己设置的为0.0009575,这个值是上一个实验忘更改了,但发现结果不错,第一次运行该代码接近到0.001这么小的损失值),这里面的乘积系数以及判断说多少次没有减少(增加)后决定变换学习率都是难以估计的。我自己的最好方法是先按默认不变的0.001来训练一下(结合tensoarboard )观察从哪里开始出现问题就可以从这里来确定次数,而乘积系数,个人感觉还是用上面的代码来获取一个较为平滑且变化极小的数字来作为选择。建议在做这种测试时可以把模型先备份一下以免浪费过多的时间!
例子
- 该例子初始学习率为0.0009575,乘积项系数为:0.35,在我的例子中x变化的条件是:累计125次没有减小则x加1;自己训练在第一次lr变化后(从0.0009575变化到0.00011729)损失值慢慢取向于0.001(如第一张图所示),准确率达到69%;


import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from torch.optim import *
PATH = './cifar_net_tensorboard_net_width_200_and_chang_lr_by_decrease_0_35^x.pth' # 保存模型地址
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=0)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)
print("获取一些随机训练数据")
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
# functions to show an image
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
print("**********************")
# 设置一个tensorborad
# helper function to show an image
# (used in the `plot_classes_preds` function below)
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.cpu().numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))
# 设置tensorBoard
# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter('runs/train')
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
# create grid of images
img_grid = torchvision.utils.make_grid(images)
# show images
# matplotlib_imshow(img_grid, one_channel=True)
imshow(img_grid)
# write to tensorboard
# writer.add_image('imag_classify', img_grid)
# Tracking model training with TensorBoard
# helper functions
def images_to_probs(net, images):
'''
Generates predictions and corresponding probabilities from a trained
network and a list of images
'''
output = net(images)
# convert output probabilities to predicted class
_, preds_tensor = torch.max(output, 1)
# preds = np.squeeze(preds_tensor.numpy())
preds = np.squeeze(preds_tensor.cpu().numpy())
return preds, [F.softmax(el, dim=0)[i].item() for i, el in zip(preds, output)]
def plot_classes_preds(net, images, labels):
preds, probs = images_to_probs(net, images)
# plot the images in the batch, along with predicted and true labels
fig = plt.figure(figsize=(12, 48))
for idx in np.arange(4):
ax = fig.add_subplot(1, 4, idx+1, xticks=[], yticks=[])
matplotlib_imshow(images[idx], one_channel=True)
ax.set_title("{0}, {1:.1f}%\n(label: {2})".format(
classes[preds[idx]],
probs[idx] * 100.0,
classes[labels[idx]]),
color=("green" if preds[idx]==labels[idx].item() else "red"))
return fig
#
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 200, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(200, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
# # 把net结构可视化出来
writer.add_graph(net, images)
net.to(device)
·······
·······
·······
如需了解完整代码请跳转到:
https://www.emperinter.info/2020/08/05/change-leaning-rate-by-reducelronplateau-in-pytorch/
Pytorch使用ReduceLROnPlateau来更新学习率的更多相关文章
- 【转载】 Pytorch(0)降低学习率torch.optim.lr_scheduler.ReduceLROnPlateau类
原文地址: https://blog.csdn.net/weixin_40100431/article/details/84311430 ------------------------------- ...
- Pytorch系列:(八)学习率调整方法
学习率的调整会对网络模型的训练造成巨大的影响,本文总结了pytorch自带的学习率调整函数,以及其使用方法. 设置网络固定学习率 设置固定学习率的方法有两种,第一种是直接设置一些学习率,网络从头到尾都 ...
- pytorch更新
Pytorch如何更新版本与卸载,使用pip,conda更新卸载Pytorch 2018年05月22日 07:33:52 醉雨轩Y 阅读数 19047 今天我们主要汇总如何使用使用ubuntu,C ...
- 『PyTorch』屌丝的PyTorch玩法
1. prefetch_generator 使用 prefetch_generator库 在后台加载下一batch的数据,原本PyTorch默认的DataLoader会创建一些worker线程来预读取 ...
- 探索学习率设置技巧以提高Keras中模型性能 | 炼丹技巧
学习率是一个控制每次更新模型权重时响应估计误差而调整模型程度的超参数.学习率选取是一项具有挑战性的工作,学习率设置的非常小可能导致训练过程过长甚至训练进程被卡住,而设置的非常大可能会导致过快学习到 ...
- [PyTorch 学习笔记] 2.1 DataLoader 与 DataSet
thumbnail: https://image.zhangxiann.com/jeison-higuita-W19AQY42rUk-unsplash.jpg toc: true date: 2020 ...
- PyTorch模型读写、参数初始化、Finetune
使用了一段时间PyTorch,感觉爱不释手(0-0),听说现在已经有C++接口.在应用过程中不可避免需要使用Finetune/参数初始化/模型加载等. 模型保存/加载 1.所有模型参数 训练过程中,有 ...
- caffe中的学习率的衰减机制
版权声明:本文为博主原创文章,未经博主允许不得转载. https://blog.csdn.net/Julialove102123/article/details/79200158 根据 caffe/ ...
- 总结笔记 | 深度学习之Pytorch入门教程
笔记作者:王博Kings 目录 一.整体学习的建议 1.1 如何成为Pytorch大神? 1.2 如何读Github代码? 1.3 代码能力太弱怎么办? 二.Pytorch与TensorFlow概述 ...
- [Pytorch框架] PyTorch 中文手册
PyTorch 中文手册 书籍介绍 这是一本开源的书籍,目标是帮助那些希望和使用PyTorch进行深度学习开发和研究的朋友快速入门. 由于本人水平有限,在写此教程的时候参考了一些网上的资料,在这里对他 ...
随机推荐
- QQ、支付宝、微信收款码三合一
Tips:当你看到这个提示的时候,说明当前的文章是由原emlog博客系统搬迁至此的,文章发布时间已过于久远,编排和内容不一定完整,还请谅解` QQ.支付宝.微信收款码三合一 日期:2018-8-24 ...
- 解决Vue中使用history路由模式出现404的问题
背景 vue中默认的路由模式是hash,会出现烦人的符号#,如http://127.0.0.1/#/. 改为history模式可以解决这个问题,但是有一个坑是:强刷新.回退等操作会出现404. Vue ...
- 开发板测试手册——USB 4G 模块、GPS 定位功能操作步骤详解(3)
目录 4 USB 4G 模块测试 41 4.1 网络功能测试 42 4.2 短信功能测试 43 4.3 GPS 定位功能测试 44 4.4 通话功能测试 45 4.5 测试程序编译 46 5 USB ...
- NKCTF 2023 Misc
NKCTF 2023 Misc hard-misc base32 --> N0wayBack公众号回复:NKCTF2023我来了! 得到flag:NKCTF{wtk2023Oo0oImcoM1N ...
- 新知识get,vue3是如何实现在style中使用响应式变量?
前言 vue2的时候想必大家有遇到需要在style模块中访问script模块中的响应式变量,为此我们不得不使用css变量去实现.现在vue3已经内置了这个功能啦,可以在style中使用v-bind指令 ...
- c 语言学习第五天
break 语句 在循环体中使用 break,可以跳出循环 打印 10 以内的数. #include<stdio.h> int main(){ int i,j = 20; for(i = ...
- CF372C
思路 根据题意可以得到dp转移方程是 \(f_{i,j}=\max\{f_{i-1,k}+b_i-|a_i-j|\}\) 而且 \(j-(t_{i}-t_{i-1})\times d\le k\le ...
- ES6拼接数组与小程序本地存储
拼接数组 ES6扩展运算符[三个点(...)将一个数组转为用逗号分隔的参数序列] goodsList: [...goodsList, ...goods] 本地存储 // 把接口数据存入本地存储中 wx ...
- 简单万能队列状态机——WTYKAMC@2023
WTYKAMC@2023框架 [简介] 这是一个基于队列的灵活状态机,可以满足队列元素先进先出,先进后出,后进后出,后进先出,可以清除队列中未执行完的状态,且有一个默认超时执行状态:通过超时时间可以改 ...
- oeasy 教您玩转linux 010304 图形界面 xfce
我们来回顾一下 上一部分我们都讲了什么? 讲了文件管理器和命令行终端互相交互 用命令nautilus在文件管理器打开某路径 这次我们来看看 图形用户界面(GUI)的情况 图形界面和发行版的关系 一个发 ...