『PyTorch』屌丝的PyTorch玩法
1. prefetch_generator
使用 prefetch_generator库 在后台加载下一batch的数据,原本PyTorch默认的DataLoader会创建一些worker线程来预读取新的数据,但是除非这些线程的数据全部都被清空,这些线程才会读下一批数据。使用prefetch_generator,我们可以保证线程不会等待,每个线程都总有至少一个数据在加载。
安装
pip install prefetch_generator
使用
之前加载数据集的正确方式是使用torch.utils.data.DataLoader,现在我们只要利用这个库,新建个DataLoaderX类继承DataLoader并重写__iter__方法即可from torch.utils.data import DataLoader
from prefetch_generator import BackgroundGenerator class DataLoaderX(DataLoader): def __iter__(self):
return BackgroundGenerator(super().__iter__())之后这样用:
train_dataset = MyDataset(".........")
train_loader = DataLoaderX(dataset=train_dataset,
batch_size=batch_size, num_workers=4, shuffle=shuffle)
2. Apex
2.1 安装
- 克隆源代码
git clone https://github.com/NVIDIA/apex
可以先下载到码云,再下载到本地
- 安装apex
cd apex
python setup.py install
最好打开PyCharm的终端进行安装,这样实在Anaconda的环境里安装了
- 删除刚刚clone下来的apex文件夹,然后重启PyCharm
【注意】安装PyTorch和cuda时注意版本对应,要按照正确流程安装
- 测试安装成功
from apex import amp
如果导入不报错说明安装成功
2.2 使用
from apex import amp # 这个必须的,其他的导包省略了
train_dataset = MyDataset("......")
train_loader = DataLoader(dataset=train_dataset, batch_size=2, num_workers=4, shuffle=True)
model = MyNet().to(device) # 创建模型
criterion = nn.MSELoss() # 定义损失函数
optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=0.00001) # 优化器
net, optimizer = amp.initialize(net, optimizer, opt_level="O1") # 这一步很重要
# 学习率衰减
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer=optimizer, mode="min",factor=0.1, patience=3,
verbose=False,cooldown=0, min_lr=0.0, eps=1e-7)
for epoch in range(epochs):
net.train() # 训练模式
train_loss_epoch = [] # 记录一个epoch内的训练集每个batch的loss
test_loss_epoch = [] # 记录一个epoch内测试集的每个batch的loss
for i, data in enumerate(train_loader):
# forward
x, y = data
x = x.to(device)
y = y.to(device)
outputs = net(x)
# backward
optimizer.zero_grad()
loss = criterion(outputs, labels)
# 这一步也很重要
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
# 更新权重
optimizer.step()
scheduler.step(1) # 更新学习率。每1步更新一次
- 主要是添加了三行代码
- scaled_loss 是将原loss放大了,所以要保存loss应该保存之前的值,这种放大防止梯度消失
考察amp.initialize(net, optimizer, opt_level="O1")的opt_level参数
opt_level=O0(base)
表示的是当前执行FP32训练,即正常的训练opt_level=O1(推荐)
表示的是当前使用部分FP16混合训练opt_level=O2表示的是除了BN层的权重外,其他层的权重都使用FP16执行训练
opt_level=O3
表示的是默认所有的层都使用FP16执行计算,当keep_batch norm_fp32=True,则会使用cudnn执行BN层的计算,该优化等级能够获得最快的速度,但是精度可能会有一些较大的损失
一般我们用
O1级别就行,最多O2,注意,是欧不是零
『PyTorch』屌丝的PyTorch玩法的更多相关文章
- 『PyTorch』第十二弹_nn.Module和nn.functional
大部分nn中的层class都有nn.function对应,其区别是: nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Para ...
- 『PyTorch』第九弹_前馈网络简化写法
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下 在前面的例子中,基本上都是将每一层的输出直接作为下一层的 ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...
- 『PyTorch』第三弹重置_Variable对象
『PyTorch』第三弹_自动求导 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Varibale包含三个属性: data ...
- 『PyTorch』第二弹重置_Tensor对象
『PyTorch』第二弹_张量 Tensor基础操作 简单的初始化 import torch as t Tensor基础操作 # 构建张量空间,不初始化 x = t.Tensor(5,3) x -2. ...
- 『PyTorch』第十弹_循环神经网络
RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...
- 『TensorFlow』专题汇总
TensorFlow:官方文档 TensorFlow:项目地址 本篇列出文章对于全零新手不太合适,可以尝试TensorFlow入门系列博客,搭配其他资料进行学习. Keras使用tf.Session训 ...
- 『Python』__getattr__()特殊方法
self的认识 & __getattr__()特殊方法 将字典调用方式改为通过属性查询的一个小class, class Dict(dict): def __init__(self, **kw) ...
- 『TensorFlow』流程控制
『PyTorch』第六弹_最小二乘法对比PyTorch和TensorFlow TensorFlow 控制流程操作 TensorFlow 提供了几个操作和类,您可以使用它们来控制操作的执行并向图中添加条 ...
随机推荐
- Mysql 主从同步原理简析
在开始讲述原理的情况下,我们先来做个知识汇总,究竟什么是主从,为什么要搞主从,可以怎么实现主从,mysql主从同步的原理1.什么是主从其实主从这个概念非常简单主机就是我们平常主要用来读写的服务,我们称 ...
- Synchronized和ReentranLock的区别
1.底层实现上来说? Synchronized是JVM层面的锁,是Java关键字,通过monitor对象来完成. ReentranLock是API层面的锁底层使用AQS. 2.是否可手动释放锁? sy ...
- swiper在一个页面多个轮播图
<script> var swiper = new Swiper('.swiper-container1', { spaceBetween: 30, centeredSlides: tru ...
- ASP net core面试题汇总及答案
在dot net core中,我们不需要关心如何释放这些服务, 因为系统会帮我们释放掉.有三种服务的生命周期. 单实例服务, 通过add singleton方法来添加.在注册时即创建服务, 在随后的请 ...
- UWP AppConnection.
https://www.cnblogs.com/manupstairs/p/14582794.html
- jQuery中的筛选(六):first()、last()、has()、is()、find()、siblings()等
<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN"> <html> <hea ...
- 浅谈Java和Go的程序退出
前言 今天在开发中对Java程序的退出产生了困惑,因为题主之前写过一段时间Go,这两者的程序退出逻辑是不同的,下面首先给出结论,再通过简单的例子来介绍. 对于Java程序,Main线程退出,如果当前存 ...
- linux centos7 tail
2021-08-30 # 不指定行数,默认显示 10 行 # 显示 /var/log/crond 后100行 taile -100 /var/log/crond # 动态显示 /var/log/cro ...
- vue 引入 echarts 图表 并且展示柱状图
npm i echarts -S 下载 echarts 图表 mian.js 文件 引入图表并且全局挂载 //echarts 图表 import echarts from 'echarts' Vue. ...
- go语言学习代码
1.day01 package main //声明文件所在的包,每个go文件必须有归属包 import "fmt" //引入程序中需要用的包,为了使用包下的函数 比如函数:Prin ...