Accelerating Deep Learning by Focusing on the Biggest Losers

思想很简单, 在训练网络的时候, 每个样本都会产生一个损失\(\mathcal{L}(f(x_i),y_i)\), 训练的模式往往是批训练, 将一个批次\(\sum_i \mathcal{L}(f(x_i),y_i)\)所产生的损失的梯度都传回去, 然后更新参数. 本文认为, 有些样本\((x_i,y_i)\)由于重复度高, 网络很高能够识别, 使得对应的\(\mathcal{L}(f(x_i),y_i)\)相对较小, 所以设计了一种机制, 使得损失较大的样本有大概率被选中, 而不重要的样本不被选中, 以此来降低计算时间. 实验证明, 这种方法能够在保持准确率不变的前提下降低训练时间.

相关工作

作者说这个算法首先是由Are Loss Functions All the Same?提出的, 但是这篇文章只是讲了hinge loss的优势和对其它损失函数的分析.

作者说最相关的文章是Not All Samples Are Created Equal: Deep Learning with Importance Sampling, 这篇文章是从预处理(虽然也是要算loss的)的角度出发的, 理论部分较本文多一些.

主要内容



算法1的思路是很清晰的, 主要困扰在算法2概率的计算上. 假设我们以及计算了\(n\)个样本的损失, 我们将其存储起来, 假设下一个样本的损失是\(\mathcal{L}_c\), 如果这\(n\)个样本中有\(k\)个样本的损失均小于\(\mathcal{L}_c\), 则改样本被选中的概率是:

\[\max \{(k/n)^\beta, s\}
\]

其中\(s\in[0,1]\)是人为设置的, 保证每个样本都有被选中的可能.

我们还可以设置一个最大的长度\(r\), 将以往的损失存储在一个双栈中, 当\(n=r\)的时候,存储下一个损失的同时会抛弃第一个损失, 这么做能在一定程度上减少计算量.

graph LR
A[样本x] --> C(网络f)
C --> D[损失l]
D--更新-->E[损失库]
D-->F[计算概率]
F-->G(形成batch)
G--反向传递-->C
E-->F

从最开始的图中, 第二列就是表示这个算法, 第三列是在此基础上对前向传递进行一些处理. 直接的是, 每隔\(n\)次epoches更新一次损失, 然后中间的n-1次不更新损失, 直接用旧的损失对样本选择(应该是直接在传入网络就将样本选择好否则就不能降低时间了).

在随机算法中, 有单通道选择样本的一个算法, 但是这个算法只用于选择一个. 所以如果选择很多这个算法就没用了, 感觉一次性选择很多个不好弄.

代码

因为条件限制, 代码并没有测试过, 论文也给出了很棒的代码.

"""
OptInput.py
纯粹是为了便于交互一些, 直接用argparse也可以
""" class Unit: def __init__(self, command, type=str,
default=None):
if default is None:
default = type()
self.command = command
self.type = type
self.default = default class Opi:
"""
>>> parser = Opi()
>>> parser.add_opt(command="lr", type=float)
>>> parser.add_opt(command="epochs", type=int)
"""
def __init__(self):
self.store = []
self.infos = {} def add_opt(self, **kwargs):
self.store.append(
Unit(**kwargs)
) def acquire(self):
s = "Acquire args {0.command} [" \
"type:{0.type.__name__} " \
"default:{0.default}] : "
for unit in self.store:
while True:
inp = input(s.format(
unit
))
try:
if inp: #若有输入
inp = unit.type(inp)
else:
inp = unit.default
self.infos.update(
{unit.command:inp}
)
self.__setattr__(unit.command, inp)
break
except:
print("Type {0} should be given".format(
unit.type.__name__
)) if __name__ == "__main__":
parser = Opi()
parser.add_opt(command = "x", type=int)
parser.add_opt(command="y", type=str)
parser.acquire()
print(parser.infos)
print(parser.x)
'''
calcprob.py
计算概率
''' import collections class Calcprob:
def __init__(self, beta, sample_min, max_len=3000):
assert 0. <= sample_min <= 1., "Invalid sample_min"
assert beta > 0, "Invalid beta"
self.beta = beta
self.sample_min = sample_min
self.max_len = max_len
self.history = collections.deque(maxlen=max_len)
self.num_slot = 1000
self.hist = [0] * self.num_slot
self.count = 0 def update_history(self, losses):
"""
BoundedHistogram
:param losses:
:return:
"""
for loss in losses:
assert loss > 0
if self.count is self.max_len:
loss_old = self.history.popleft()
slot_old = int(loss_old * self.num_slot) % self.num_slot
self.hist[slot_old] -= 1
else:
self.count += 1
self.history.append(loss)
slot = int(loss * self.num_slot) % self.num_slot
self.hist[slot] += 1 def get_probability(self, loss):
assert loss > 0
slot = int(loss * self.num_slot) % self.num_slot
prob = sum(self.hist[:slot]) / self.count
assert isinstance(prob, float), "int division error..."
return prob ** self.beta def calc_probability(self, losses):
if isinstance(losses, float):
losses = (losses, )
self.update_history(losses)
probs = (
max(
self.get_probability(loss),
self.sample_min
)
for loss in losses
)
return probs def __call__(self, losses):
return self.calc_probability(losses) if __name__ == "__main__":
pass
'''
selector.py
''' import calcprob
import numpy as np class Selector: def __init__(self, batch_size,
beta, sample_min, max_len=3000):
self.batch_size = batch_size
self.calcprob = calcprob.Calcprob(beta,
sample_min,
max_len)
self.reset() def backward(self):
loss = sum(self.batch)
loss.backward()
self.reset() def reset(self):
self.batch = []
self.length = 0. def select(self, losses):
probs = self.calcprob(losses)
for i, prob in enumerate(probs):
if np.random.rand() < prob:
self.batch.append(losses[i])
self.length += 1
if self.length >= self.batch_size:
self.backward() def __call__(self, losses):
self.select(losses)
'''
main.py
''' import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import os import selector class Train: def __init__(self, model, lossfunc,
bpsize, beta, sample_min, max_len=3000,
lr=0.01, momentum=0.9, weight_decay=0.0001):
self.net = self.choose_net(model)
self.criterion = self.choose_lossfunc(lossfunc)
self.opti = torch.optim.SGD(self.net.parameters(),
lr=lr, momentum=momentum,
weight_decay=weight_decay)
self.selector = selector.Selector(bpsize, beta,
sample_min, max_len)
self.gpu()
self.generate_path()
self.acc_rates = []
self.errors = [] def choose_net(self, model):
net = getattr(
torchvision.models,
model,
None
)
if net is None:
raise ValueError("no such model")
return net() def choose_lossfunc(self, lossfunc):
lossfunc = getattr(
nn,
lossfunc,
None
)
if lossfunc is None:
raise ValueError("no such lossfunc")
return lossfunc def gpu(self):
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
print("Let'us use %d GPUs" % torch.cuda.device_count())
self.net = nn.DataParallel(self.net)
self.net = self.net.to(self.device) def generate_path(self):
"""
生成保存数据的路径
:return:
"""
try:
os.makedirs('./paras')
os.makedirs('./logs')
os.makedirs('./infos')
except FileExistsError as e:
pass
name = self.net.__class__.__name__
paras = os.listdir('./paras')
logs = os.listdir('./logs')
infos = os.listdir('./infos')
number = max((len(paras), len(logs), len(infos)))
self.para_path = "./paras/{0}{1}.pt".format(
name,
number
) self.log_path = "./logs/{0}{1}.txt".format(
name,
number
)
self.info_path = "./infos/{0}{1}.npy".format(
name,
number
) def log(self, strings):
"""
运行日志
:param strings:
:return:
"""
# a 往后添加内容
with open(self.log_path, 'a', encoding='utf8') as f:
f.write(strings) def save(self):
"""
保存网络参数
:return:
"""
torch.save(self.net.state_dict(), self.para_path) def derease_lr(self, multi=0.96):
"""
降低学习率
:param multi:
:return:
"""
self.opti.param_groups[0]['lr'] *= multi def train(self, trainloder, epochs=50):
data_size = len(trainloder) * trainloder.batch_size
part = int(trainloder.batch_size / 2)
for epoch in range(epochs):
running_loss = 0.
total_loss = 0.
acc_count = 0.
if (epoch + 1) % 8 is 0:
self.derease_lr()
self.log(#日志记录
"learning rate change!!!\n"
)
for i, data in enumerate(trainloder):
imgs, labels = data
imgs = imgs.to(self.device)
labels = labels.to(self.device)
out = self.net(imgs)
_, pre = torch.max(out, 1) #判断是否判断正确
acc_count += (pre == labels).sum().item() #加总对的个数 losses = (
self.criterion(out[i], labels[i])
for i in range(len(labels))
) self.opti.zero_grad()
self.selector(losses) #选择
self.opti.step() running_loss += sum(losses).item() if (i+1) % part is 0:
strings = "epoch {0:<3} part {1:<5} loss: {2:<.7f}\n".format(
epoch, i, running_loss / part
)
self.log(strings)#日志记录
total_loss += running_loss
running_loss = 0.
self.acc_rates.append(acc_count / data_size)
self.errors.append(total_loss / data_size)
self.log( #日志记录
"Accuracy of the network on %d train images: %d %%\n" %(
data_size, acc_count / data_size * 100
)
)
self.save() #保存网络参数
#保存一些信息画图用
np.save(self.info_path, {
'acc_rates': np.array(self.acc_rates),
'errors': np.array(self.errors)
}) if __name__ == "__main__": import OptInput
args = OptInput.Opi()
args.add_opt(command="model", default="resnet34")
args.add_opt(command="lossfunc", default="CrossEntropyLoss")
args.add_opt(command="bpsize", default=32)
args.add_opt(command="beta", default=0.9)
args.add_opt(command="sample_min", default=0.3)
args.add_opt(command="max_len", default=3000)
args.add_opt(command="lr", default=0.001)
args.add_opt(command="momentum", default=0.9)
args.add_opt(command="weight_decay", default=0.0001) args.acquire() root = "C:/Users/pkavs/1jupiterdata/data" trainset = torchvision.datasets.CIFAR10(root=root, train=True,
download=False,
transform=transforms.Compose(
[transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)) train_loader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=8,
pin_memory=True) dog = Train(**args.infos)
dog.train(train_loader, epochs=1000)

Accelerating Deep Learning by Focusing on the Biggest Losers的更多相关文章

  1. Deep Learning 27:Batch normalization理解——读论文“Batch normalization: Accelerating deep network training by reducing internal covariate shift ”——ICML 2015

    这篇经典论文,甚至可以说是2015年最牛的一篇论文,早就有很多人解读,不需要自己着摸,但是看了论文原文Batch normalization: Accelerating deep network tr ...

  2. Applied Deep Learning Resources

    Applied Deep Learning Resources A collection of research articles, blog posts, slides and code snipp ...

  3. [C3] Andrew Ng - Neural Networks and Deep Learning

    About this Course If you want to break into cutting-edge AI, this course will help you do so. Deep l ...

  4. 【深度学习Deep Learning】资料大全

    最近在学深度学习相关的东西,在网上搜集到了一些不错的资料,现在汇总一下: Free Online Books  by Yoshua Bengio, Ian Goodfellow and Aaron C ...

  5. (转) Awesome - Most Cited Deep Learning Papers

    转自:https://github.com/terryum/awesome-deep-learning-papers Awesome - Most Cited Deep Learning Papers ...

  6. deep learning 的综述

    从13年11月初开始接触DL,奈何boss忙or 各种问题,对DL理解没有CSDN大神 比如 zouxy09等 深刻,主要是自己觉得没啥进展,感觉荒废时日(丢脸啊,这么久....)开始开文,即为记录自 ...

  7. (转)分布式深度学习系统构建 简介 Distributed Deep Learning

    HOME ABOUT CONTACT SUBSCRIBE VIA RSS   DEEP LEARNING FOR ENTERPRISE Distributed Deep Learning, Part ...

  8. Machine and Deep Learning with Python

    Machine and Deep Learning with Python Education Tutorials and courses Supervised learning superstiti ...

  9. The Brain vs Deep Learning Part I: Computational Complexity — Or Why the Singularity Is Nowhere Near

    The Brain vs Deep Learning Part I: Computational Complexity — Or Why the Singularity Is Nowhere Near ...

随机推荐

  1. C/C++ Qt 数据库与SqlTableModel组件应用

    SqlTableModel 组件可以将数据库中的特定字段动态显示在TableView表格组件中,通常设置QSqlTableModel类的变量作为数据模型后就可以显示数据表内容,界面组件中则通过QDat ...

  2. CSS系列,清除浮动方法总结

    在非IE浏览器(如Firefox)下,当容器的高度为auto,且容器的内容中有浮动(float为left或right)的元素.在这种情况下,容器的高度不能自动伸长以适应内容的高度,使得内容溢出到容器外 ...

  3. HDFS初探之旅(二)

    6.HDFS API详解 Hadoop中关于文件操作类疾病上全部在"org.apache.hadoop.fs"包中,这些API能够支持的操作包含:打开文件.读写文件.删除文件等. ...

  4. 【编程思想】【设计模式】【结构模式Structural】front_controller

    Python版 https://github.com/faif/python-patterns/blob/master/structural/front_controller.py #!/usr/bi ...

  5. java异常处理中throws和throw的使用

    异常介绍: 运行时异常.非运行时异常 在编写可能会抛出异常的方法时,它们都必须声明为有异常. 一.throws关键字 1.声明方法可能抛出的异常: 2.写在方法名后面: 3.可声明抛出多个异常,异常名 ...

  6. minikube metrics-server HPA 自动扩缩容错误

    minikube metrics-server pod 错误 启动 minikube addons enable metrics-server 之后查看 metrics-server pod 会有如下 ...

  7. 安霸pipeline简述之rgb域的处理

    RGB域处理模块的详细介绍: RGB域的处理主要是demosaic,color_correction,tone_curve(类似于gamma曲线).   Demosaic:此模块将bayer Patt ...

  8. minkube在deban10上的安装步骤

    环境准备: 所用机器为4c 16g  i3 4170   1t机械硬盘 系统 debian 10 安装docker 如果已经安装并配置好可直接跳过 安装ssl sudo apt-get install ...

  9. $(document).ready()与window.onload的区别,站在三个维度回答问题

    1.执行时机 window.onload必须等到页面内包括图片的所有元素加载完毕后才能执行.         $(document).ready()是DOM结构绘制完毕后就执行,不必等到加载完毕. 2 ...

  10. Nginx模块之stub_status

    目录 一.介绍 二.使用 三.参数 一.介绍 Nginx中的stub_status模块主要用于查看Nginx的一些状态信息. 当前默认在nginx的源码文件中,不需要单独下载 二.使用 本模块默认是不 ...