如需了解完整代码请跳转到:

https://www.emperinter.info/2020/08/05/change-leaning-rate-by-reducelronplateau-in-pytorch/

缘由

自己之前写过一个Pytorch学习率更新,其中感觉依据是否loss升高或降低的次数来动态更新学习率,感觉是个挺好玩的东西,自己弄了好久都设置错误,今天算是搞出来了!

解析

说明

  • torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)

在发现loss不再降低或者acc不再提高之后,降低学习率。各参数意义如下:

参数 含义
mode 'min'模式检测metric是否不再减小,'max'模式检测metric是否不再增大;
factor 触发条件后lr*=factor;
patience 不再减小(或增大)的累计次数;
verbose 触发条件后print;
threshold 只关注超过阈值的显著变化;
threshold_mode 有rel和abs两种阈值计算模式,rel规则:max模式下如果超过best(1+threshold)为显著,min模式下如果低于best(1-threshold)为显著;abs规则:max模式下如果超过best+threshold为显著,min模式下如果低于best-threshold为显著;
cooldown 触发一次条件后,等待一定epoch再进行检测,避免lr下降过速;
min_lr 最小的允许lr;
eps 如果新旧lr之间的差异小与1e-8,则忽略此次更新。
  • 例子,如图所示的y轴为lr,x为调整的次序,初始的学习率为0.0009575

    则学习率的方程为:lr = 0.0009575 * (0.35)^x

import math
import matplotlib.pyplot as plt
#%matplotlib inline x = 0
o = []
p = []
o.append(0)
p.append(0.0009575)
while(x < 8):
x += 1
y = 0.0009575 * math.pow(0.35,x)
o.append(x)
p.append(y)
print('%d: %.50f' %(x,y)) plt.plot(o,p,c='red',label='test') #分别为x,y轴对应数据,c:color,label
plt.legend(loc='best') # 显示label,loc为显示位置(best为系统认为最好的位置)
plt.show()

难点

我感觉这里面最难的时这几个参数的选择,第一个是初始的学习率(我目前接触的miniest和下面的图像分类貌似都是0.001,我这里训练调整时才发现自己设置的为0.0009575,这个值是上一个实验忘更改了,但发现结果不错,第一次运行该代码接近到0.001这么小的损失值),这里面的乘积系数以及判断说多少次没有减少(增加)后决定变换学习率都是难以估计的。我自己的最好方法是先按默认不变的0.001来训练一下(结合tensoarboard )观察从哪里开始出现问题就可以从这里来确定次数,而乘积系数,个人感觉还是用上面的代码来获取一个较为平滑且变化极小的数字来作为选择。建议在做这种测试时可以把模型先备份一下以免浪费过多的时间!

例子

  • 该例子初始学习率为0.0009575,乘积项系数为:0.35,在我的例子中x变化的条件是:累计125次没有减小则x加1;自己训练在第一次lr变化后(从0.0009575变化到0.00011729)损失值慢慢取向于0.001(如第一张图所示),准确率达到69%;

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from torch.optim import * PATH = './cifar_net_tensorboard_net_width_200_and_chang_lr_by_decrease_0_35^x.pth' # 保存模型地址 transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=0) testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=0) classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck') device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Assuming that we are on a CUDA machine, this should print a CUDA device: print(device) print("获取一些随机训练数据")
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next() # functions to show an image
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show() # show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
print("**********************") # 设置一个tensorborad
# helper function to show an image
# (used in the `plot_classes_preds` function below)
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.cpu().numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0))) # 设置tensorBoard
# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter('runs/train') # get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next() # create grid of images
img_grid = torchvision.utils.make_grid(images) # show images
# matplotlib_imshow(img_grid, one_channel=True)
imshow(img_grid) # write to tensorboard
# writer.add_image('imag_classify', img_grid) # Tracking model training with TensorBoard
# helper functions def images_to_probs(net, images):
'''
Generates predictions and corresponding probabilities from a trained
network and a list of images
'''
output = net(images)
# convert output probabilities to predicted class
_, preds_tensor = torch.max(output, 1)
# preds = np.squeeze(preds_tensor.numpy())
preds = np.squeeze(preds_tensor.cpu().numpy())
return preds, [F.softmax(el, dim=0)[i].item() for i, el in zip(preds, output)] def plot_classes_preds(net, images, labels):
preds, probs = images_to_probs(net, images)
# plot the images in the batch, along with predicted and true labels
fig = plt.figure(figsize=(12, 48))
for idx in np.arange(4):
ax = fig.add_subplot(1, 4, idx+1, xticks=[], yticks=[])
matplotlib_imshow(images[idx], one_channel=True)
ax.set_title("{0}, {1:.1f}%\n(label: {2})".format(
classes[preds[idx]],
probs[idx] * 100.0,
classes[labels[idx]]),
color=("green" if preds[idx]==labels[idx].item() else "red"))
return fig # class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 200, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(200, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x net = Net()
# # 把net结构可视化出来
writer.add_graph(net, images)
net.to(device) ·······
·······
·······

如需了解完整代码请跳转到:

https://www.emperinter.info/2020/08/05/change-leaning-rate-by-reducelronplateau-in-pytorch/

Pytorch使用ReduceLROnPlateau来更新学习率的更多相关文章

  1. 【转载】 Pytorch(0)降低学习率torch.optim.lr_scheduler.ReduceLROnPlateau类

    原文地址: https://blog.csdn.net/weixin_40100431/article/details/84311430 ------------------------------- ...

  2. Pytorch系列:(八)学习率调整方法

    学习率的调整会对网络模型的训练造成巨大的影响,本文总结了pytorch自带的学习率调整函数,以及其使用方法. 设置网络固定学习率 设置固定学习率的方法有两种,第一种是直接设置一些学习率,网络从头到尾都 ...

  3. pytorch更新

    Pytorch如何更新版本与卸载,使用pip,conda更新卸载Pytorch 2018年05月22日 07:33:52 醉雨轩Y 阅读数 19047   今天我们主要汇总如何使用使用ubuntu,C ...

  4. 『PyTorch』屌丝的PyTorch玩法

    1. prefetch_generator 使用 prefetch_generator库 在后台加载下一batch的数据,原本PyTorch默认的DataLoader会创建一些worker线程来预读取 ...

  5. 探索学习率设置技巧以提高Keras中模型性能 | 炼丹技巧

      学习率是一个控制每次更新模型权重时响应估计误差而调整模型程度的超参数.学习率选取是一项具有挑战性的工作,学习率设置的非常小可能导致训练过程过长甚至训练进程被卡住,而设置的非常大可能会导致过快学习到 ...

  6. [PyTorch 学习笔记] 2.1 DataLoader 与 DataSet

    thumbnail: https://image.zhangxiann.com/jeison-higuita-W19AQY42rUk-unsplash.jpg toc: true date: 2020 ...

  7. PyTorch模型读写、参数初始化、Finetune

    使用了一段时间PyTorch,感觉爱不释手(0-0),听说现在已经有C++接口.在应用过程中不可避免需要使用Finetune/参数初始化/模型加载等. 模型保存/加载 1.所有模型参数 训练过程中,有 ...

  8. caffe中的学习率的衰减机制

    版权声明:本文为博主原创文章,未经博主允许不得转载. https://blog.csdn.net/Julialove102123/article/details/79200158 根据  caffe/ ...

  9. 总结笔记 | 深度学习之Pytorch入门教程

    笔记作者:王博Kings 目录 一.整体学习的建议 1.1 如何成为Pytorch大神? 1.2 如何读Github代码? 1.3 代码能力太弱怎么办? 二.Pytorch与TensorFlow概述 ...

  10. [Pytorch框架] PyTorch 中文手册

    PyTorch 中文手册 书籍介绍 这是一本开源的书籍,目标是帮助那些希望和使用PyTorch进行深度学习开发和研究的朋友快速入门. 由于本人水平有限,在写此教程的时候参考了一些网上的资料,在这里对他 ...

随机推荐

  1. postman Could not get any response 无法请求

    外网访问接口地址,刚开始考虑到是阿里云服务器上面的ECS网络安全策略拦截,添加了白名单, 首先在浏览器中回车访问,页面有反应. 但是在postman中请求,仍然返回 Could not get any ...

  2. 【FAQ】HarmonyOS SDK 闭源开放能力 —Ads Kit

    1.问题描述: 开屏广告效果最好的实现方式? 解决方法: 1.动画效果和开发者的实现方式有关,和开屏广告页面本身没什么关系的: 2.示例代码中使用Router跳转的方式展示广告,主要是用于演示广告接口 ...

  3. 我的 ZYNQ 系列总结

    我的 ZYNQ 系列总结 背景 ZYNQ平台是我接触认识比较久的平台,还算不错,是工控.音视频各行业都可以使用中高端平台. 本文以ZYNQ-7000为例,其实更高级的MPSOC也是一样的. 先看看我自 ...

  4. bs4解析-湖南农场品价格行情

    import requests from bs4 import BeautifulSoup import csv url = 'https://price.21food.cn/market/174-p ...

  5. WEB入门 - 文件上传

    WEB入门 - 文件上传 参考文章 https://fushuling.com/index.php/2023/08/20/ctfshow刷题记录持续更新中/ https://www.cnblogs.c ...

  6. SpringBoot配置Mysql连接池

    一.HikariCP连接池 SpringBoot默认使用连接池HikariCP,不需要依赖. spring: datasource: driver-class-name: com.mysql.cj.j ...

  7. Vue 的父组件和子组件生命周期钩子函数执行顺序?

    https://www.cnblogs.com/thinheader/p/9462125.html 参考连接 Vue 的父组件和子组件生命周期钩子函数执行顺序可以归类为以下 4 部分: 加载渲染过程 ...

  8. 新版SpringBoot-Spring-Mybatis事务控制

    快速创建SpringBoot+Spring+Mybatis项目 https://start.spring.io 删除pom中mysql依赖的runtime pom.xml中添加druid依赖 < ...

  9. 前端:如何让background背景图片进行CSS自适应

    在设置login背景时,找到了一张这样的图片: 但是设置成login背景时,如果没有做一些css适应设置,图片就变样了,变成了这样: 严重变形了,这就造成了一种理想与现实的差距. 若想解决这个自适应问 ...

  10. 解决方案 | Adobe Acrobat XI Pro 右键菜单“在Acrobat中合并文件”丢失的最佳修复方法

    1.问题 Adobe Acrobat XI Pro右键菜单"转换为Adobe PDF"与"在Acrobat中合并文件" 不见了. 2.解决方案 桌面左下角搜索& ...