如需了解示例完整代码及其后续内容请访问: https://www.emperinter.info/2020/08/01/learning-rate-in-pytorch/

缘由

自己在尝试了官方的代码后就想提高训练的精度就想到了调整学习率,但固定的学习率肯定不适合训练就尝试了几个更改学习率的方法,但没想到居然更差!可能有几个学习率没怎么尝试吧!

更新方法

直接修改optimizer中的lr参数;

  • 定义一个简单的神经网络模型:y=Wx+b
import torch
import matplotlib.pyplot as plt
%matplotlib inline
from torch.optim import *
import torch.nn as nn class net(nn.Module):
def __init__(self):
super(net,self).__init__()
self.fc = nn.Linear(1,10)
def forward(self,x):
return self.fc(x)
  • 直接更改lr的值
model = net()
LR = 0.01
optimizer = Adam(model.parameters(),lr = LR)
lr_list = []
for epoch in range(100):
if epoch % 5 == 0:
for p in optimizer.param_groups:
p['lr'] *= 0.9
lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
plt.plot(range(100),lr_list,color = 'r')

关键是如下两行能达到手动阶梯式更改,自己也可按需求来更改变换函数

for p in optimizer.param_groups:
p['lr'] *= 0.9

利用lr_scheduler()提供的几种衰减函数

  • torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)
参数 含义
lr_lambda 会接收到一个int参数:epoch,然后根据epoch计算出对应的lr。如果设置多个lambda函数的话,会分别作用于Optimizer中的不同的params_group
import numpy as np
lr_list = []
model = net()
LR = 0.01
optimizer = Adam(model.parameters(),lr = LR)
lambda1 = lambda epoch:np.sin(epoch) / epoch
scheduler = lr_scheduler.LambdaLR(optimizer,lr_lambda = lambda1)
for epoch in range(100):
scheduler.step()
lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
plt.plot(range(100),lr_list,color = 'r')

  • torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=0, last_epoch=-1)
参数 含义
T_max 对应1/2个cos周期所对应的epoch数值
eta_min 最小的lr值,默认为0
lr_list = []
model = net()
LR = 0.01
optimizer = Adam(model.parameters(),lr = LR)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max = 20)
for epoch in range(100):
scheduler.step()
lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
plt.plot(range(100),lr_list,color = 'r')

  • torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)

在发现loss不再降低或者acc不再提高之后,降低学习率。各参数意义如下:

参数 含义
mode 'min'模式检测metric是否不再减小,'max'模式检测metric是否不再增大;
factor 触发条件后lr*=factor;
patience 不再减小(或增大)的累计次数;
verbose 触发条件后print;
threshold 只关注超过阈值的显著变化;
threshold_mode 有rel和abs两种阈值计算模式,rel规则:max模式下如果超过best(1+threshold)为显著,min模式下如果低于best(1-threshold)为显著;abs规则:max模式下如果超过best+threshold为显著,min模式下如果低于best-threshold为显著;
cooldown 触发一次条件后,等待一定epoch再进行检测,避免lr下降过速;
min_lr 最小的允许lr;
eps 如果新旧lr之间的差异小与1e-8,则忽略此次更新。

如需了解其它学习率更新方法请访问: https://www.emperinter.info/2020/08/01/learning-rate-in-pytorch/

示例

使用的更新方法

  • 代码中可选的选项有:余弦方式(默认方式,其他两种注释了)、e^-x的方式以及按loss是否不在降低来判断的三种方式,其他就自己测试吧!

  • 训练截图(第一个图为trainingg_loss,第二个为学习率变化曲线)

完整代码

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from torch.optim import * transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=0) testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=0) classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
#如需了解示例完整代码及其后续内容请访问: [https://www.emperinter.info/2020/08/01/learning-rate-in-pytorch/](https://www.emperinter.info/2020/08/01/learning-rate-in-pytorch/)

如需了解示例完整代码及其后续内容请访问: https://www.emperinter.info/2020/08/01/learning-rate-in-pytorch/

Pytorch学习率更新的更多相关文章

  1. PyTorch大更新!谷歌出手帮助开发,正式支持TensorBoard | 附5大开源项目

    大家又少了一个用TensorFlow的理由. 在一年一度的开发者大会F8上,Facebook放出PyTorch的1.1版本,直指TensorFlow"腹地". 不仅宣布支持Tens ...

  2. tensorflow中常用学习率更新策略

    神经网络训练过程中,根据每batch训练数据前向传播的结果,计算损失函数,再由损失函数根据梯度下降法更新每一个网络参数,在参数更新过程中使用到一个学习率(learning rate),用来定义每次参数 ...

  3. 【转载】 PyTorch学习之六个学习率调整策略

    原文地址: https://blog.csdn.net/shanglianlm/article/details/85143614 ----------------------------------- ...

  4. PyTorch学习之六个学习率调整策略

    PyTorch学习率调整策略通过torch.optim.lr_scheduler接口实现.PyTorch提供的学习率调整策略分为三大类,分别是 有序调整:等间隔调整(Step),按需调整学习率(Mul ...

  5. Pytorch系列:(八)学习率调整方法

    学习率的调整会对网络模型的训练造成巨大的影响,本文总结了pytorch自带的学习率调整函数,以及其使用方法. 设置网络固定学习率 设置固定学习率的方法有两种,第一种是直接设置一些学习率,网络从头到尾都 ...

  6. pytorch更新

    Pytorch如何更新版本与卸载,使用pip,conda更新卸载Pytorch 2018年05月22日 07:33:52 醉雨轩Y 阅读数 19047   今天我们主要汇总如何使用使用ubuntu,C ...

  7. 基于卷积神经网络的面部表情识别(Pytorch实现)----台大李宏毅机器学习作业3(HW3)

    一.项目说明 给定数据集train.csv,要求使用卷积神经网络CNN,根据每个样本的面部图片判断出其表情.在本项目中,表情共分7类,分别为:(0)生气,(1)厌恶,(2)恐惧,(3)高兴,(4)难过 ...

  8. 手写数字识别 卷积神经网络 Pytorch框架实现

    MNIST 手写数字识别 卷积神经网络 Pytorch框架 谨此纪念刚入门的我在卷积神经网络上面的摸爬滚打 说明 下面代码是使用pytorch来实现的LeNet,可以正常运行测试,自己添加了一些注释, ...

  9. Pytorch源码与运行原理浅析--网络篇(一)

    前言 申请的专栏开通了,刚好最近闲下来了,就打算开这个坑了hhhhh 第一篇就先讲一讲pytorch的运行机制好了... 记得当时刚刚接触的时候一直搞不明白,为什么自己只是定义了几个网络,就可以完整的 ...

  10. 全面解析Pytorch框架下模型存储,加载以及冻结

    最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...

随机推荐

  1. spring boot jpa 进行通用多条件动态查询和更新 接口

    原因: jpa 没有类似于mybatis的那种 拼接sql的方式 想动态更新 需要使用 CriteriaUpdate的方式 去一直拼接,其实大多数场景只要传入一个非空实体类,去动态拼接sql 1.定义 ...

  2. Masonry的进阶使用技巧

    Masonry是iOS开发中常见的视图约束框架,但是有人对他的使用还是浅尝辄止,接下来会提出几点比较少见但是又十分便捷的使用技巧. mas_greaterThanOrEqualTo mas_great ...

  3. linux下后台运行程序

    文章目录 背景 nohup命令 setsid命令 pm2 背景 后台运行程序的时候,如果退出当前的终端(session),你运行的所有程序(包括后台程序),都将被关闭. 原因是:你运行的程序都是你的终 ...

  4. mysql多表删除指定记录

    在Mysql4.0之后,mysql开始支持跨表delete. Mysql可以在一个sql语句中同时删除多表记录,也可以根据多个表之间的关系来删除某一个表中的记录. 假定我们有两张表:Product表和 ...

  5. PAT-甲级-1007

    一.看题,https://www.patest.cn/contests/pat-a-practise/1007 其实,也是一顿暴力,但是最后一个测试点会运行超时,最开始,计算一段区间的值的总和的时候, ...

  6. 常见 i2c设备地址

    背景 朋友分享的一份i2c器件地址清单,我觉得还不错. reference:https://learn.adafruit.com/i2c-addresses/the-list Special case ...

  7. ETL服务器连接GaussDB(DWS)集群客户端配置

    问题描述:给ETL的服务器上安装gsql的工具,用来连接GaussDB(DWS)集群,做数据抽取用 DWS:GaussDB(DWS) 8.2.1-ESL 1.获取软件包 登录FusionInsight ...

  8. selenium的各种操作

    import time from selenium.webdriver import Edge from selenium.webdriver.common.by import By from sel ...

  9. KES数据库实践指南:探索KES数据库的事务隔离级别

    引言 前两篇文章我们详细讲解了如何安装KES金仓数据库,并提供了快速查询和搭建基于coze平台的智能体的解决方案.今天,我们的焦点将放在并发控制机制和事务隔离级别上. 本文将通过一系列实验操作,深入探 ...

  10. TIOBE 7月编程排行榜出炉!Python再次出圈

    又到了周三,本周有过半了,大家好呀 ~~ 每月的TIOBE编程排行榜都是技术社区关注的焦点,作为编程语言流行度的晴雨表,它反映了行业趋势和 技术走向.2024年7月的榜单揭晓了一个重要变化:Pytho ...