Accelerating Deep Learning by Focusing on the Biggest Losers

思想很简单, 在训练网络的时候, 每个样本都会产生一个损失\(\mathcal{L}(f(x_i),y_i)\), 训练的模式往往是批训练, 将一个批次\(\sum_i \mathcal{L}(f(x_i),y_i)\)所产生的损失的梯度都传回去, 然后更新参数. 本文认为, 有些样本\((x_i,y_i)\)由于重复度高, 网络很高能够识别, 使得对应的\(\mathcal{L}(f(x_i),y_i)\)相对较小, 所以设计了一种机制, 使得损失较大的样本有大概率被选中, 而不重要的样本不被选中, 以此来降低计算时间. 实验证明, 这种方法能够在保持准确率不变的前提下降低训练时间.

相关工作

作者说这个算法首先是由Are Loss Functions All the Same?提出的, 但是这篇文章只是讲了hinge loss的优势和对其它损失函数的分析.

作者说最相关的文章是Not All Samples Are Created Equal: Deep Learning with Importance Sampling, 这篇文章是从预处理(虽然也是要算loss的)的角度出发的, 理论部分较本文多一些.

主要内容



算法1的思路是很清晰的, 主要困扰在算法2概率的计算上. 假设我们以及计算了\(n\)个样本的损失, 我们将其存储起来, 假设下一个样本的损失是\(\mathcal{L}_c\), 如果这\(n\)个样本中有\(k\)个样本的损失均小于\(\mathcal{L}_c\), 则改样本被选中的概率是:

\[\max \{(k/n)^\beta, s\}
\]

其中\(s\in[0,1]\)是人为设置的, 保证每个样本都有被选中的可能.

我们还可以设置一个最大的长度\(r\), 将以往的损失存储在一个双栈中, 当\(n=r\)的时候,存储下一个损失的同时会抛弃第一个损失, 这么做能在一定程度上减少计算量.

graph LR
A[样本x] --> C(网络f)
C --> D[损失l]
D--更新-->E[损失库]
D-->F[计算概率]
F-->G(形成batch)
G--反向传递-->C
E-->F

从最开始的图中, 第二列就是表示这个算法, 第三列是在此基础上对前向传递进行一些处理. 直接的是, 每隔\(n\)次epoches更新一次损失, 然后中间的n-1次不更新损失, 直接用旧的损失对样本选择(应该是直接在传入网络就将样本选择好否则就不能降低时间了).

在随机算法中, 有单通道选择样本的一个算法, 但是这个算法只用于选择一个. 所以如果选择很多这个算法就没用了, 感觉一次性选择很多个不好弄.

代码

因为条件限制, 代码并没有测试过, 论文也给出了很棒的代码.

"""
OptInput.py
纯粹是为了便于交互一些, 直接用argparse也可以
""" class Unit: def __init__(self, command, type=str,
default=None):
if default is None:
default = type()
self.command = command
self.type = type
self.default = default class Opi:
"""
>>> parser = Opi()
>>> parser.add_opt(command="lr", type=float)
>>> parser.add_opt(command="epochs", type=int)
"""
def __init__(self):
self.store = []
self.infos = {} def add_opt(self, **kwargs):
self.store.append(
Unit(**kwargs)
) def acquire(self):
s = "Acquire args {0.command} [" \
"type:{0.type.__name__} " \
"default:{0.default}] : "
for unit in self.store:
while True:
inp = input(s.format(
unit
))
try:
if inp: #若有输入
inp = unit.type(inp)
else:
inp = unit.default
self.infos.update(
{unit.command:inp}
)
self.__setattr__(unit.command, inp)
break
except:
print("Type {0} should be given".format(
unit.type.__name__
)) if __name__ == "__main__":
parser = Opi()
parser.add_opt(command = "x", type=int)
parser.add_opt(command="y", type=str)
parser.acquire()
print(parser.infos)
print(parser.x)
'''
calcprob.py
计算概率
''' import collections class Calcprob:
def __init__(self, beta, sample_min, max_len=3000):
assert 0. <= sample_min <= 1., "Invalid sample_min"
assert beta > 0, "Invalid beta"
self.beta = beta
self.sample_min = sample_min
self.max_len = max_len
self.history = collections.deque(maxlen=max_len)
self.num_slot = 1000
self.hist = [0] * self.num_slot
self.count = 0 def update_history(self, losses):
"""
BoundedHistogram
:param losses:
:return:
"""
for loss in losses:
assert loss > 0
if self.count is self.max_len:
loss_old = self.history.popleft()
slot_old = int(loss_old * self.num_slot) % self.num_slot
self.hist[slot_old] -= 1
else:
self.count += 1
self.history.append(loss)
slot = int(loss * self.num_slot) % self.num_slot
self.hist[slot] += 1 def get_probability(self, loss):
assert loss > 0
slot = int(loss * self.num_slot) % self.num_slot
prob = sum(self.hist[:slot]) / self.count
assert isinstance(prob, float), "int division error..."
return prob ** self.beta def calc_probability(self, losses):
if isinstance(losses, float):
losses = (losses, )
self.update_history(losses)
probs = (
max(
self.get_probability(loss),
self.sample_min
)
for loss in losses
)
return probs def __call__(self, losses):
return self.calc_probability(losses) if __name__ == "__main__":
pass
'''
selector.py
''' import calcprob
import numpy as np class Selector: def __init__(self, batch_size,
beta, sample_min, max_len=3000):
self.batch_size = batch_size
self.calcprob = calcprob.Calcprob(beta,
sample_min,
max_len)
self.reset() def backward(self):
loss = sum(self.batch)
loss.backward()
self.reset() def reset(self):
self.batch = []
self.length = 0. def select(self, losses):
probs = self.calcprob(losses)
for i, prob in enumerate(probs):
if np.random.rand() < prob:
self.batch.append(losses[i])
self.length += 1
if self.length >= self.batch_size:
self.backward() def __call__(self, losses):
self.select(losses)
'''
main.py
''' import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import os import selector class Train: def __init__(self, model, lossfunc,
bpsize, beta, sample_min, max_len=3000,
lr=0.01, momentum=0.9, weight_decay=0.0001):
self.net = self.choose_net(model)
self.criterion = self.choose_lossfunc(lossfunc)
self.opti = torch.optim.SGD(self.net.parameters(),
lr=lr, momentum=momentum,
weight_decay=weight_decay)
self.selector = selector.Selector(bpsize, beta,
sample_min, max_len)
self.gpu()
self.generate_path()
self.acc_rates = []
self.errors = [] def choose_net(self, model):
net = getattr(
torchvision.models,
model,
None
)
if net is None:
raise ValueError("no such model")
return net() def choose_lossfunc(self, lossfunc):
lossfunc = getattr(
nn,
lossfunc,
None
)
if lossfunc is None:
raise ValueError("no such lossfunc")
return lossfunc def gpu(self):
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
print("Let'us use %d GPUs" % torch.cuda.device_count())
self.net = nn.DataParallel(self.net)
self.net = self.net.to(self.device) def generate_path(self):
"""
生成保存数据的路径
:return:
"""
try:
os.makedirs('./paras')
os.makedirs('./logs')
os.makedirs('./infos')
except FileExistsError as e:
pass
name = self.net.__class__.__name__
paras = os.listdir('./paras')
logs = os.listdir('./logs')
infos = os.listdir('./infos')
number = max((len(paras), len(logs), len(infos)))
self.para_path = "./paras/{0}{1}.pt".format(
name,
number
) self.log_path = "./logs/{0}{1}.txt".format(
name,
number
)
self.info_path = "./infos/{0}{1}.npy".format(
name,
number
) def log(self, strings):
"""
运行日志
:param strings:
:return:
"""
# a 往后添加内容
with open(self.log_path, 'a', encoding='utf8') as f:
f.write(strings) def save(self):
"""
保存网络参数
:return:
"""
torch.save(self.net.state_dict(), self.para_path) def derease_lr(self, multi=0.96):
"""
降低学习率
:param multi:
:return:
"""
self.opti.param_groups[0]['lr'] *= multi def train(self, trainloder, epochs=50):
data_size = len(trainloder) * trainloder.batch_size
part = int(trainloder.batch_size / 2)
for epoch in range(epochs):
running_loss = 0.
total_loss = 0.
acc_count = 0.
if (epoch + 1) % 8 is 0:
self.derease_lr()
self.log(#日志记录
"learning rate change!!!\n"
)
for i, data in enumerate(trainloder):
imgs, labels = data
imgs = imgs.to(self.device)
labels = labels.to(self.device)
out = self.net(imgs)
_, pre = torch.max(out, 1) #判断是否判断正确
acc_count += (pre == labels).sum().item() #加总对的个数 losses = (
self.criterion(out[i], labels[i])
for i in range(len(labels))
) self.opti.zero_grad()
self.selector(losses) #选择
self.opti.step() running_loss += sum(losses).item() if (i+1) % part is 0:
strings = "epoch {0:<3} part {1:<5} loss: {2:<.7f}\n".format(
epoch, i, running_loss / part
)
self.log(strings)#日志记录
total_loss += running_loss
running_loss = 0.
self.acc_rates.append(acc_count / data_size)
self.errors.append(total_loss / data_size)
self.log( #日志记录
"Accuracy of the network on %d train images: %d %%\n" %(
data_size, acc_count / data_size * 100
)
)
self.save() #保存网络参数
#保存一些信息画图用
np.save(self.info_path, {
'acc_rates': np.array(self.acc_rates),
'errors': np.array(self.errors)
}) if __name__ == "__main__": import OptInput
args = OptInput.Opi()
args.add_opt(command="model", default="resnet34")
args.add_opt(command="lossfunc", default="CrossEntropyLoss")
args.add_opt(command="bpsize", default=32)
args.add_opt(command="beta", default=0.9)
args.add_opt(command="sample_min", default=0.3)
args.add_opt(command="max_len", default=3000)
args.add_opt(command="lr", default=0.001)
args.add_opt(command="momentum", default=0.9)
args.add_opt(command="weight_decay", default=0.0001) args.acquire() root = "C:/Users/pkavs/1jupiterdata/data" trainset = torchvision.datasets.CIFAR10(root=root, train=True,
download=False,
transform=transforms.Compose(
[transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)) train_loader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=8,
pin_memory=True) dog = Train(**args.infos)
dog.train(train_loader, epochs=1000)

Accelerating Deep Learning by Focusing on the Biggest Losers的更多相关文章

  1. Deep Learning 27:Batch normalization理解——读论文“Batch normalization: Accelerating deep network training by reducing internal covariate shift ”——ICML 2015

    这篇经典论文,甚至可以说是2015年最牛的一篇论文,早就有很多人解读,不需要自己着摸,但是看了论文原文Batch normalization: Accelerating deep network tr ...

  2. Applied Deep Learning Resources

    Applied Deep Learning Resources A collection of research articles, blog posts, slides and code snipp ...

  3. [C3] Andrew Ng - Neural Networks and Deep Learning

    About this Course If you want to break into cutting-edge AI, this course will help you do so. Deep l ...

  4. 【深度学习Deep Learning】资料大全

    最近在学深度学习相关的东西,在网上搜集到了一些不错的资料,现在汇总一下: Free Online Books  by Yoshua Bengio, Ian Goodfellow and Aaron C ...

  5. (转) Awesome - Most Cited Deep Learning Papers

    转自:https://github.com/terryum/awesome-deep-learning-papers Awesome - Most Cited Deep Learning Papers ...

  6. deep learning 的综述

    从13年11月初开始接触DL,奈何boss忙or 各种问题,对DL理解没有CSDN大神 比如 zouxy09等 深刻,主要是自己觉得没啥进展,感觉荒废时日(丢脸啊,这么久....)开始开文,即为记录自 ...

  7. (转)分布式深度学习系统构建 简介 Distributed Deep Learning

    HOME ABOUT CONTACT SUBSCRIBE VIA RSS   DEEP LEARNING FOR ENTERPRISE Distributed Deep Learning, Part ...

  8. Machine and Deep Learning with Python

    Machine and Deep Learning with Python Education Tutorials and courses Supervised learning superstiti ...

  9. The Brain vs Deep Learning Part I: Computational Complexity — Or Why the Singularity Is Nowhere Near

    The Brain vs Deep Learning Part I: Computational Complexity — Or Why the Singularity Is Nowhere Near ...

随机推荐

  1. 学习java 7.27

    学习内容: 创建树 Swing 使用JTree对象来代表一棵树,JTree树中结点可以使用TreePath来标识,该对象封装了当前结点及其所有的父结点. 当一个结点具有子结点时,该结点有两种状态: 展 ...

  2. Android 高级UI组件(一)GridView与ListView

    1.GridView 1.GridView学习 GridView和ListView都是比较常用的多控件布局,而GridView更是实现九宫图的首选 main.xml: <?xml version ...

  3. 查看IP访问量的shell脚本汇总

    第一部分,1,查看TCP连接状态 netstat -nat |awk '{print $6}'|sort|uniq -c|sort -rn netstat -n | awk '/^tcp/ {++S[ ...

  4. 【科研工具】MathType7.2的安装破解与使用

    亲测可用,可以嵌入word. [我们为什么要用MathType] tex不香嘛,但是学校给的模板只有word,word输入公式点起来实在是太麻烦了. 有了这个就可以直接输入公式转换啦. [安装破解教程 ...

  5. vm16虚拟机安装win11

    vm16虚拟机安装win11 参考https://baijiahao.baidu.com/s?id=1712702900207158969&wfr=spider&for=pc win1 ...

  6. Java oop 笔记

    摘要网址:http://note.youdao.com/noteshare?id=bbdc0b970721e40d327db983a2f96371

  7. 从Rest到Graphql

    一.引言 ok,如图所示,我在去年曾经写过一篇文章<闲侃前后端分离的必要性>.嗯,我知道肯定很多人没看过.所以我做一个总结,其实啰里八嗦了一篇文章,就是想说一下现在的大型互联网项目一般是如 ...

  8. C# 扫描识别图片中的文字(.NET Framework)

    环境配置 本文以C#及VB.NET代码为例,介绍如何扫描并读取图片中的文字. 本次程序环境如下: Visual Studio版本要求不低于2017 图片扫描工具:Spire.OCR for .NET ...

  9. TPT Fusion平台升级,AUTOSAR及ViL测试功能重装上线

    TPT简介 TPT是针对嵌入式系统基于模型的测试工具,特别是针对控制系统的软件功能测试.TPT支持众多业内主流的工具平台和测试环境,可以完成V模式要求所有阶段(MiL-SiL-PiL-HiL-ViL) ...

  10. CF475A Bayan Bus 题解

    Update \(\texttt{2020.10.6}\) 修改了一些笔误. Content 模拟一个核载 \(34\) 人的巴士上有 \(k\) 个人时的巴士的状态. 每个人都会优先选择有空位的最后 ...