Accelerating Deep Learning by Focusing on the Biggest Losers

思想很简单, 在训练网络的时候, 每个样本都会产生一个损失\(\mathcal{L}(f(x_i),y_i)\), 训练的模式往往是批训练, 将一个批次\(\sum_i \mathcal{L}(f(x_i),y_i)\)所产生的损失的梯度都传回去, 然后更新参数. 本文认为, 有些样本\((x_i,y_i)\)由于重复度高, 网络很高能够识别, 使得对应的\(\mathcal{L}(f(x_i),y_i)\)相对较小, 所以设计了一种机制, 使得损失较大的样本有大概率被选中, 而不重要的样本不被选中, 以此来降低计算时间. 实验证明, 这种方法能够在保持准确率不变的前提下降低训练时间.

相关工作

作者说这个算法首先是由Are Loss Functions All the Same?提出的, 但是这篇文章只是讲了hinge loss的优势和对其它损失函数的分析.

作者说最相关的文章是Not All Samples Are Created Equal: Deep Learning with Importance Sampling, 这篇文章是从预处理(虽然也是要算loss的)的角度出发的, 理论部分较本文多一些.

主要内容



算法1的思路是很清晰的, 主要困扰在算法2概率的计算上. 假设我们以及计算了\(n\)个样本的损失, 我们将其存储起来, 假设下一个样本的损失是\(\mathcal{L}_c\), 如果这\(n\)个样本中有\(k\)个样本的损失均小于\(\mathcal{L}_c\), 则改样本被选中的概率是:

\[\max \{(k/n)^\beta, s\}
\]

其中\(s\in[0,1]\)是人为设置的, 保证每个样本都有被选中的可能.

我们还可以设置一个最大的长度\(r\), 将以往的损失存储在一个双栈中, 当\(n=r\)的时候,存储下一个损失的同时会抛弃第一个损失, 这么做能在一定程度上减少计算量.

graph LR
A[样本x] --> C(网络f)
C --> D[损失l]
D--更新-->E[损失库]
D-->F[计算概率]
F-->G(形成batch)
G--反向传递-->C
E-->F

从最开始的图中, 第二列就是表示这个算法, 第三列是在此基础上对前向传递进行一些处理. 直接的是, 每隔\(n\)次epoches更新一次损失, 然后中间的n-1次不更新损失, 直接用旧的损失对样本选择(应该是直接在传入网络就将样本选择好否则就不能降低时间了).

在随机算法中, 有单通道选择样本的一个算法, 但是这个算法只用于选择一个. 所以如果选择很多这个算法就没用了, 感觉一次性选择很多个不好弄.

代码

因为条件限制, 代码并没有测试过, 论文也给出了很棒的代码.

"""
OptInput.py
纯粹是为了便于交互一些, 直接用argparse也可以
""" class Unit: def __init__(self, command, type=str,
default=None):
if default is None:
default = type()
self.command = command
self.type = type
self.default = default class Opi:
"""
>>> parser = Opi()
>>> parser.add_opt(command="lr", type=float)
>>> parser.add_opt(command="epochs", type=int)
"""
def __init__(self):
self.store = []
self.infos = {} def add_opt(self, **kwargs):
self.store.append(
Unit(**kwargs)
) def acquire(self):
s = "Acquire args {0.command} [" \
"type:{0.type.__name__} " \
"default:{0.default}] : "
for unit in self.store:
while True:
inp = input(s.format(
unit
))
try:
if inp: #若有输入
inp = unit.type(inp)
else:
inp = unit.default
self.infos.update(
{unit.command:inp}
)
self.__setattr__(unit.command, inp)
break
except:
print("Type {0} should be given".format(
unit.type.__name__
)) if __name__ == "__main__":
parser = Opi()
parser.add_opt(command = "x", type=int)
parser.add_opt(command="y", type=str)
parser.acquire()
print(parser.infos)
print(parser.x)
'''
calcprob.py
计算概率
''' import collections class Calcprob:
def __init__(self, beta, sample_min, max_len=3000):
assert 0. <= sample_min <= 1., "Invalid sample_min"
assert beta > 0, "Invalid beta"
self.beta = beta
self.sample_min = sample_min
self.max_len = max_len
self.history = collections.deque(maxlen=max_len)
self.num_slot = 1000
self.hist = [0] * self.num_slot
self.count = 0 def update_history(self, losses):
"""
BoundedHistogram
:param losses:
:return:
"""
for loss in losses:
assert loss > 0
if self.count is self.max_len:
loss_old = self.history.popleft()
slot_old = int(loss_old * self.num_slot) % self.num_slot
self.hist[slot_old] -= 1
else:
self.count += 1
self.history.append(loss)
slot = int(loss * self.num_slot) % self.num_slot
self.hist[slot] += 1 def get_probability(self, loss):
assert loss > 0
slot = int(loss * self.num_slot) % self.num_slot
prob = sum(self.hist[:slot]) / self.count
assert isinstance(prob, float), "int division error..."
return prob ** self.beta def calc_probability(self, losses):
if isinstance(losses, float):
losses = (losses, )
self.update_history(losses)
probs = (
max(
self.get_probability(loss),
self.sample_min
)
for loss in losses
)
return probs def __call__(self, losses):
return self.calc_probability(losses) if __name__ == "__main__":
pass
'''
selector.py
''' import calcprob
import numpy as np class Selector: def __init__(self, batch_size,
beta, sample_min, max_len=3000):
self.batch_size = batch_size
self.calcprob = calcprob.Calcprob(beta,
sample_min,
max_len)
self.reset() def backward(self):
loss = sum(self.batch)
loss.backward()
self.reset() def reset(self):
self.batch = []
self.length = 0. def select(self, losses):
probs = self.calcprob(losses)
for i, prob in enumerate(probs):
if np.random.rand() < prob:
self.batch.append(losses[i])
self.length += 1
if self.length >= self.batch_size:
self.backward() def __call__(self, losses):
self.select(losses)
'''
main.py
''' import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import os import selector class Train: def __init__(self, model, lossfunc,
bpsize, beta, sample_min, max_len=3000,
lr=0.01, momentum=0.9, weight_decay=0.0001):
self.net = self.choose_net(model)
self.criterion = self.choose_lossfunc(lossfunc)
self.opti = torch.optim.SGD(self.net.parameters(),
lr=lr, momentum=momentum,
weight_decay=weight_decay)
self.selector = selector.Selector(bpsize, beta,
sample_min, max_len)
self.gpu()
self.generate_path()
self.acc_rates = []
self.errors = [] def choose_net(self, model):
net = getattr(
torchvision.models,
model,
None
)
if net is None:
raise ValueError("no such model")
return net() def choose_lossfunc(self, lossfunc):
lossfunc = getattr(
nn,
lossfunc,
None
)
if lossfunc is None:
raise ValueError("no such lossfunc")
return lossfunc def gpu(self):
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
print("Let'us use %d GPUs" % torch.cuda.device_count())
self.net = nn.DataParallel(self.net)
self.net = self.net.to(self.device) def generate_path(self):
"""
生成保存数据的路径
:return:
"""
try:
os.makedirs('./paras')
os.makedirs('./logs')
os.makedirs('./infos')
except FileExistsError as e:
pass
name = self.net.__class__.__name__
paras = os.listdir('./paras')
logs = os.listdir('./logs')
infos = os.listdir('./infos')
number = max((len(paras), len(logs), len(infos)))
self.para_path = "./paras/{0}{1}.pt".format(
name,
number
) self.log_path = "./logs/{0}{1}.txt".format(
name,
number
)
self.info_path = "./infos/{0}{1}.npy".format(
name,
number
) def log(self, strings):
"""
运行日志
:param strings:
:return:
"""
# a 往后添加内容
with open(self.log_path, 'a', encoding='utf8') as f:
f.write(strings) def save(self):
"""
保存网络参数
:return:
"""
torch.save(self.net.state_dict(), self.para_path) def derease_lr(self, multi=0.96):
"""
降低学习率
:param multi:
:return:
"""
self.opti.param_groups[0]['lr'] *= multi def train(self, trainloder, epochs=50):
data_size = len(trainloder) * trainloder.batch_size
part = int(trainloder.batch_size / 2)
for epoch in range(epochs):
running_loss = 0.
total_loss = 0.
acc_count = 0.
if (epoch + 1) % 8 is 0:
self.derease_lr()
self.log(#日志记录
"learning rate change!!!\n"
)
for i, data in enumerate(trainloder):
imgs, labels = data
imgs = imgs.to(self.device)
labels = labels.to(self.device)
out = self.net(imgs)
_, pre = torch.max(out, 1) #判断是否判断正确
acc_count += (pre == labels).sum().item() #加总对的个数 losses = (
self.criterion(out[i], labels[i])
for i in range(len(labels))
) self.opti.zero_grad()
self.selector(losses) #选择
self.opti.step() running_loss += sum(losses).item() if (i+1) % part is 0:
strings = "epoch {0:<3} part {1:<5} loss: {2:<.7f}\n".format(
epoch, i, running_loss / part
)
self.log(strings)#日志记录
total_loss += running_loss
running_loss = 0.
self.acc_rates.append(acc_count / data_size)
self.errors.append(total_loss / data_size)
self.log( #日志记录
"Accuracy of the network on %d train images: %d %%\n" %(
data_size, acc_count / data_size * 100
)
)
self.save() #保存网络参数
#保存一些信息画图用
np.save(self.info_path, {
'acc_rates': np.array(self.acc_rates),
'errors': np.array(self.errors)
}) if __name__ == "__main__": import OptInput
args = OptInput.Opi()
args.add_opt(command="model", default="resnet34")
args.add_opt(command="lossfunc", default="CrossEntropyLoss")
args.add_opt(command="bpsize", default=32)
args.add_opt(command="beta", default=0.9)
args.add_opt(command="sample_min", default=0.3)
args.add_opt(command="max_len", default=3000)
args.add_opt(command="lr", default=0.001)
args.add_opt(command="momentum", default=0.9)
args.add_opt(command="weight_decay", default=0.0001) args.acquire() root = "C:/Users/pkavs/1jupiterdata/data" trainset = torchvision.datasets.CIFAR10(root=root, train=True,
download=False,
transform=transforms.Compose(
[transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)) train_loader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=8,
pin_memory=True) dog = Train(**args.infos)
dog.train(train_loader, epochs=1000)

Accelerating Deep Learning by Focusing on the Biggest Losers的更多相关文章

  1. Deep Learning 27:Batch normalization理解——读论文“Batch normalization: Accelerating deep network training by reducing internal covariate shift ”——ICML 2015

    这篇经典论文,甚至可以说是2015年最牛的一篇论文,早就有很多人解读,不需要自己着摸,但是看了论文原文Batch normalization: Accelerating deep network tr ...

  2. Applied Deep Learning Resources

    Applied Deep Learning Resources A collection of research articles, blog posts, slides and code snipp ...

  3. [C3] Andrew Ng - Neural Networks and Deep Learning

    About this Course If you want to break into cutting-edge AI, this course will help you do so. Deep l ...

  4. 【深度学习Deep Learning】资料大全

    最近在学深度学习相关的东西,在网上搜集到了一些不错的资料,现在汇总一下: Free Online Books  by Yoshua Bengio, Ian Goodfellow and Aaron C ...

  5. (转) Awesome - Most Cited Deep Learning Papers

    转自:https://github.com/terryum/awesome-deep-learning-papers Awesome - Most Cited Deep Learning Papers ...

  6. deep learning 的综述

    从13年11月初开始接触DL,奈何boss忙or 各种问题,对DL理解没有CSDN大神 比如 zouxy09等 深刻,主要是自己觉得没啥进展,感觉荒废时日(丢脸啊,这么久....)开始开文,即为记录自 ...

  7. (转)分布式深度学习系统构建 简介 Distributed Deep Learning

    HOME ABOUT CONTACT SUBSCRIBE VIA RSS   DEEP LEARNING FOR ENTERPRISE Distributed Deep Learning, Part ...

  8. Machine and Deep Learning with Python

    Machine and Deep Learning with Python Education Tutorials and courses Supervised learning superstiti ...

  9. The Brain vs Deep Learning Part I: Computational Complexity — Or Why the Singularity Is Nowhere Near

    The Brain vs Deep Learning Part I: Computational Complexity — Or Why the Singularity Is Nowhere Near ...

随机推荐

  1. A Child's History of England.6

    It was a British Prince named Vortigern who took this resolution, and who made a treaty of friendshi ...

  2. SQLite is 35% Faster Than The Filesystem

    比方说你要在C++/PHP里实现一个函数Image get_image(string id),不同的图片有1万张(用户头像),你可以把它们存在一个目录/文件夹里,然后fopen()再fread. 你也 ...

  3. day08 索引的创建与慢查询优化

    day08 索引的创建与慢查询优化 昨日内容回顾 视图 视图:将SQL语句查询结果实体化保存起来,方便下次查询使用. 视图里面的数据来源于原表,视图只有表结构 # 创建视图 create view 视 ...

  4. mybatis缓存+aop出现的问题

    在对某些特殊数据进行转换时,getOne方法后执行fieleInfoHandle进行转换,如果直接使用fixedTableData进行操作,没有后续的二次调用这样是没问题的,但是在后面当执行完upda ...

  5. Linux:ps -ef命令

    ps命令将某个进程显示出来 grep命令是查找 中间的|是管道命令 是指ps命令与grep同时执行 PS是LINUX下最常用的也是非常强大的进程查看命令 检查java 进程是否存在:ps -ef |g ...

  6. 【Java 8】Stream中的Pipeline理解

    基于下面一段代码: public static void main(String[] args) { List<String> list = Arrays.asList("123 ...

  7. 【Java 基础】Java动态代理

    Java动态代理InvocationHandler和Proxy java动态代理机制中有两个重要的类和接口InvocationHandler(接口)和Proxy(类),这一个类Proxy和接口Invo ...

  8. 什么是maven(一)

    转自博主--一杯凉茶 我记得在搞懂maven之前看了几次重复的maven的教学视频.不知道是自己悟性太低还是怎么滴,就是搞不清楚,现在弄清楚了,基本上入门了.写该篇博文,就是为了帮助那些和我一样对于m ...

  9. RabbitMQ,RocketMQ,Kafka 几种消息队列的对比

    常用的几款消息队列的对比 前言 RabbitMQ 优点 缺点 RocketMQ 优点 缺点 Kafka 优点 缺点 如何选择合适的消息队列 参考 常用的几款消息队列的对比 前言 消息队列的作用: 1. ...

  10. 【HarmonyOS】【xml】初学XML布局作业

    首先要明确,有两种布局方式 线性布局:DirectionalLayout 依赖布局:DependentLayout 好,接下来看一看下面的例子 页面案例1 代码如下: <?xml version ...