Pytorch使用ReduceLROnPlateau来更新学习率
如需了解完整代码请跳转到:
https://www.emperinter.info/2020/08/05/change-leaning-rate-by-reducelronplateau-in-pytorch/

缘由
自己之前写过一个Pytorch学习率更新,其中感觉依据是否loss升高或降低的次数来动态更新学习率,感觉是个挺好玩的东西,自己弄了好久都设置错误,今天算是搞出来了!
解析
说明
- torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)
在发现loss不再降低或者acc不再提高之后,降低学习率。各参数意义如下:
| 参数 | 含义 |
|---|---|
| mode | 'min'模式检测metric是否不再减小,'max'模式检测metric是否不再增大; |
| factor | 触发条件后lr*=factor; |
| patience | 不再减小(或增大)的累计次数; |
| verbose | 触发条件后print; |
| threshold | 只关注超过阈值的显著变化; |
| threshold_mode | 有rel和abs两种阈值计算模式,rel规则:max模式下如果超过best(1+threshold)为显著,min模式下如果低于best(1-threshold)为显著;abs规则:max模式下如果超过best+threshold为显著,min模式下如果低于best-threshold为显著; |
| cooldown | 触发一次条件后,等待一定epoch再进行检测,避免lr下降过速; |
| min_lr | 最小的允许lr; |
| eps | 如果新旧lr之间的差异小与1e-8,则忽略此次更新。 |
- 例子,如图所示的y轴为lr,x为调整的次序,初始的学习率为0.0009575
则学习率的方程为:lr = 0.0009575 * (0.35)^x

import math
import matplotlib.pyplot as plt
#%matplotlib inline
x = 0
o = []
p = []
o.append(0)
p.append(0.0009575)
while(x < 8):
x += 1
y = 0.0009575 * math.pow(0.35,x)
o.append(x)
p.append(y)
print('%d: %.50f' %(x,y))
plt.plot(o,p,c='red',label='test') #分别为x,y轴对应数据,c:color,label
plt.legend(loc='best') # 显示label,loc为显示位置(best为系统认为最好的位置)
plt.show()
难点
我感觉这里面最难的时这几个参数的选择,第一个是初始的学习率(我目前接触的miniest和下面的图像分类貌似都是0.001,我这里训练调整时才发现自己设置的为0.0009575,这个值是上一个实验忘更改了,但发现结果不错,第一次运行该代码接近到0.001这么小的损失值),这里面的乘积系数以及判断说多少次没有减少(增加)后决定变换学习率都是难以估计的。我自己的最好方法是先按默认不变的0.001来训练一下(结合tensoarboard )观察从哪里开始出现问题就可以从这里来确定次数,而乘积系数,个人感觉还是用上面的代码来获取一个较为平滑且变化极小的数字来作为选择。建议在做这种测试时可以把模型先备份一下以免浪费过多的时间!
例子
- 该例子初始学习率为0.0009575,乘积项系数为:0.35,在我的例子中x变化的条件是:累计125次没有减小则x加1;自己训练在第一次lr变化后(从0.0009575变化到0.00011729)损失值慢慢取向于0.001(如第一张图所示),准确率达到69%;


import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from torch.optim import *
PATH = './cifar_net_tensorboard_net_width_200_and_chang_lr_by_decrease_0_35^x.pth' # 保存模型地址
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=0)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)
print("获取一些随机训练数据")
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
# functions to show an image
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
print("**********************")
# 设置一个tensorborad
# helper function to show an image
# (used in the `plot_classes_preds` function below)
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.cpu().numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))
# 设置tensorBoard
# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter('runs/train')
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
# create grid of images
img_grid = torchvision.utils.make_grid(images)
# show images
# matplotlib_imshow(img_grid, one_channel=True)
imshow(img_grid)
# write to tensorboard
# writer.add_image('imag_classify', img_grid)
# Tracking model training with TensorBoard
# helper functions
def images_to_probs(net, images):
'''
Generates predictions and corresponding probabilities from a trained
network and a list of images
'''
output = net(images)
# convert output probabilities to predicted class
_, preds_tensor = torch.max(output, 1)
# preds = np.squeeze(preds_tensor.numpy())
preds = np.squeeze(preds_tensor.cpu().numpy())
return preds, [F.softmax(el, dim=0)[i].item() for i, el in zip(preds, output)]
def plot_classes_preds(net, images, labels):
preds, probs = images_to_probs(net, images)
# plot the images in the batch, along with predicted and true labels
fig = plt.figure(figsize=(12, 48))
for idx in np.arange(4):
ax = fig.add_subplot(1, 4, idx+1, xticks=[], yticks=[])
matplotlib_imshow(images[idx], one_channel=True)
ax.set_title("{0}, {1:.1f}%\n(label: {2})".format(
classes[preds[idx]],
probs[idx] * 100.0,
classes[labels[idx]]),
color=("green" if preds[idx]==labels[idx].item() else "red"))
return fig
#
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 200, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(200, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
# # 把net结构可视化出来
writer.add_graph(net, images)
net.to(device)
·······
·······
·······
如需了解完整代码请跳转到:
https://www.emperinter.info/2020/08/05/change-leaning-rate-by-reducelronplateau-in-pytorch/
Pytorch使用ReduceLROnPlateau来更新学习率的更多相关文章
- 【转载】 Pytorch(0)降低学习率torch.optim.lr_scheduler.ReduceLROnPlateau类
原文地址: https://blog.csdn.net/weixin_40100431/article/details/84311430 ------------------------------- ...
- Pytorch系列:(八)学习率调整方法
学习率的调整会对网络模型的训练造成巨大的影响,本文总结了pytorch自带的学习率调整函数,以及其使用方法. 设置网络固定学习率 设置固定学习率的方法有两种,第一种是直接设置一些学习率,网络从头到尾都 ...
- pytorch更新
Pytorch如何更新版本与卸载,使用pip,conda更新卸载Pytorch 2018年05月22日 07:33:52 醉雨轩Y 阅读数 19047 今天我们主要汇总如何使用使用ubuntu,C ...
- 『PyTorch』屌丝的PyTorch玩法
1. prefetch_generator 使用 prefetch_generator库 在后台加载下一batch的数据,原本PyTorch默认的DataLoader会创建一些worker线程来预读取 ...
- 探索学习率设置技巧以提高Keras中模型性能 | 炼丹技巧
学习率是一个控制每次更新模型权重时响应估计误差而调整模型程度的超参数.学习率选取是一项具有挑战性的工作,学习率设置的非常小可能导致训练过程过长甚至训练进程被卡住,而设置的非常大可能会导致过快学习到 ...
- [PyTorch 学习笔记] 2.1 DataLoader 与 DataSet
thumbnail: https://image.zhangxiann.com/jeison-higuita-W19AQY42rUk-unsplash.jpg toc: true date: 2020 ...
- PyTorch模型读写、参数初始化、Finetune
使用了一段时间PyTorch,感觉爱不释手(0-0),听说现在已经有C++接口.在应用过程中不可避免需要使用Finetune/参数初始化/模型加载等. 模型保存/加载 1.所有模型参数 训练过程中,有 ...
- caffe中的学习率的衰减机制
版权声明:本文为博主原创文章,未经博主允许不得转载. https://blog.csdn.net/Julialove102123/article/details/79200158 根据 caffe/ ...
- 总结笔记 | 深度学习之Pytorch入门教程
笔记作者:王博Kings 目录 一.整体学习的建议 1.1 如何成为Pytorch大神? 1.2 如何读Github代码? 1.3 代码能力太弱怎么办? 二.Pytorch与TensorFlow概述 ...
- [Pytorch框架] PyTorch 中文手册
PyTorch 中文手册 书籍介绍 这是一本开源的书籍,目标是帮助那些希望和使用PyTorch进行深度学习开发和研究的朋友快速入门. 由于本人水平有限,在写此教程的时候参考了一些网上的资料,在这里对他 ...
随机推荐
- golang reflect 反射机制的使用场景
Go语言中的 reflect 包提供了运行时反射机制,允许程序在运行时检查和操作任意对象的数据类型和值. 以下是 reflect 包的一些典型使用场景: 1. 动态类型判断与转换:当需要处理多种类型的 ...
- Java正则表达式语法及简单示例
import java.util.regex.Matcher; import java.util.regex.Pattern; public class TestMatcher { public st ...
- C#/.NET/.NET Core拾遗补漏合集(24年6月更新)
前言 在这个快速发展的技术世界中,时常会有一些重要的知识点.信息或细节被忽略或遗漏.<C#/.NET/.NET Core拾遗补漏>专栏我们将探讨一些可能被忽略或遗漏的重要知识点.信息或细节 ...
- C++11智能指针 unique_ptr、shared_ptr、weak_ptr与定制删除器
目录 智能指针 场景引入 - 为什么需要智能指针? 内存泄漏 什么是内存泄漏 内存泄漏的危害 内存泄漏分类 如何避免内存泄漏 智能指针的使用及原理 RAII 简易例程 智能指针的原理 智能指针的拷贝问 ...
- maven项目创建默认目录结构
maven项目创建默认目录结构命令 项目文件夹未创建情况下 mvn \ archetype:generate \ -DgroupId=com.lits.parent \ -DartifactId=my ...
- VIP视频解析
效果图 新建窗口 import tkinter as tk# 创建一个窗口 root = tk.Tk() # 设置窗口大小 root.geometry('700x250+200+200') # 设置标 ...
- 【iOS】自定义CALayer可能会出现没有动画过程但有动画结果的解析
按照计划是要做成这样的动画 可是结果变成了这样 (有时候最重要的不是结果而是过程,日常鸡汤) 结果没有问题说明delegate中- (void)animationDidStop:(CAAnimatio ...
- 【BUG记录】Cause: java.sql.SQLException: Incorrect string value: '\xF0\x9F\x90\xA6' for column 'name' at row 1
大家好呀,我是summo,这次的文章标题是一个Mysql数据库的SQL错误,遇到的同学自然懂,没遇到的同学希望你永远也不要遇到. 一.错误说明 Cause: java.sql.SQLException ...
- 图片接口JWT鉴权实现
图片接口JWT鉴权实现 前言 之前做了个返回图片链接的接口,然后没做授权,然后今天键盘到了,也是用JWT来做接口的权限控制. 然后JTW网上已经有很多文章来说怎么用了,这里就不做多的解释了,如果不懂的 ...
- 算法金 | 没有思考过 Embedding,不足以谈 AI
大侠幸会,在下全网同名「算法金」 0 基础转 AI 上岸,多个算法赛 Top 「日更万日,让更多人享受智能乐趣」 抱个拳,送个礼 在当今的人工智能(AI)领域,Embedding 是一个不可或缺的概念 ...