Accelerating Deep Learning by Focusing on the Biggest Losers
Accelerating Deep Learning by Focusing on the Biggest Losers
概
思想很简单, 在训练网络的时候, 每个样本都会产生一个损失\(\mathcal{L}(f(x_i),y_i)\), 训练的模式往往是批训练, 将一个批次\(\sum_i \mathcal{L}(f(x_i),y_i)\)所产生的损失的梯度都传回去, 然后更新参数. 本文认为, 有些样本\((x_i,y_i)\)由于重复度高, 网络很高能够识别, 使得对应的\(\mathcal{L}(f(x_i),y_i)\)相对较小, 所以设计了一种机制, 使得损失较大的样本有大概率被选中, 而不重要的样本不被选中, 以此来降低计算时间. 实验证明, 这种方法能够在保持准确率不变的前提下降低训练时间.

相关工作
作者说这个算法首先是由Are Loss Functions All the Same?提出的, 但是这篇文章只是讲了hinge loss的优势和对其它损失函数的分析.
作者说最相关的文章是Not All Samples Are Created Equal: Deep Learning with Importance Sampling, 这篇文章是从预处理(虽然也是要算loss的)的角度出发的, 理论部分较本文多一些.
主要内容


算法1的思路是很清晰的, 主要困扰在算法2概率的计算上. 假设我们以及计算了\(n\)个样本的损失, 我们将其存储起来, 假设下一个样本的损失是\(\mathcal{L}_c\), 如果这\(n\)个样本中有\(k\)个样本的损失均小于\(\mathcal{L}_c\), 则改样本被选中的概率是:
\]
其中\(s\in[0,1]\)是人为设置的, 保证每个样本都有被选中的可能.
我们还可以设置一个最大的长度\(r\), 将以往的损失存储在一个双栈中, 当\(n=r\)的时候,存储下一个损失的同时会抛弃第一个损失, 这么做能在一定程度上减少计算量.
A[样本x] --> C(网络f)
C --> D[损失l]
D--更新-->E[损失库]
D-->F[计算概率]
F-->G(形成batch)
G--反向传递-->C
E-->F
从最开始的图中, 第二列就是表示这个算法, 第三列是在此基础上对前向传递进行一些处理. 直接的是, 每隔\(n\)次epoches更新一次损失, 然后中间的n-1次不更新损失, 直接用旧的损失对样本选择(应该是直接在传入网络就将样本选择好否则就不能降低时间了).
在随机算法中, 有单通道选择样本的一个算法, 但是这个算法只用于选择一个. 所以如果选择很多这个算法就没用了, 感觉一次性选择很多个不好弄.
代码
因为条件限制, 代码并没有测试过, 论文也给出了很棒的代码.
"""
OptInput.py
纯粹是为了便于交互一些, 直接用argparse也可以
"""
class Unit:
    def __init__(self, command, type=str,
                    default=None):
        if default is None:
            default = type()
        self.command = command
        self.type = type
        self.default = default
class Opi:
    """
    >>> parser = Opi()
    >>> parser.add_opt(command="lr", type=float)
    >>> parser.add_opt(command="epochs", type=int)
    """
    def __init__(self):
        self.store = []
        self.infos = {}
    def add_opt(self, **kwargs):
        self.store.append(
            Unit(**kwargs)
        )
    def acquire(self):
        s = "Acquire args {0.command} [" \
            "type:{0.type.__name__} " \
            "default:{0.default}] : "
        for unit in self.store:
            while True:
                inp = input(s.format(
                    unit
                ))
                try:
                    if inp: #若有输入
                        inp = unit.type(inp)
                    else:
                        inp = unit.default
                    self.infos.update(
                        {unit.command:inp}
                    )
                    self.__setattr__(unit.command, inp)
                    break
                except:
                    print("Type {0} should be given".format(
                        unit.type.__name__
                    ))
if __name__ == "__main__":
    parser = Opi()
    parser.add_opt(command = "x", type=int)
    parser.add_opt(command="y", type=str)
    parser.acquire()
    print(parser.infos)
    print(parser.x)
'''
calcprob.py
计算概率
'''
import collections
class Calcprob:
    def __init__(self, beta, sample_min, max_len=3000):
        assert 0. <= sample_min <= 1., "Invalid sample_min"
        assert beta > 0, "Invalid beta"
        self.beta = beta
        self.sample_min = sample_min
        self.max_len = max_len
        self.history = collections.deque(maxlen=max_len)
        self.num_slot = 1000
        self.hist = [0] * self.num_slot
        self.count = 0
    def update_history(self, losses):
        """
        BoundedHistogram
        :param losses:
        :return:
        """
        for loss in losses:
            assert loss > 0
            if self.count is self.max_len:
                loss_old = self.history.popleft()
                slot_old = int(loss_old * self.num_slot) % self.num_slot
                self.hist[slot_old] -= 1
            else:
                self.count += 1
                self.history.append(loss)
            slot = int(loss * self.num_slot) % self.num_slot
            self.hist[slot] += 1
    def get_probability(self, loss):
        assert loss > 0
        slot = int(loss * self.num_slot) % self.num_slot
        prob = sum(self.hist[:slot]) / self.count
        assert isinstance(prob, float), "int division error..."
        return prob ** self.beta
    def calc_probability(self, losses):
        if isinstance(losses, float):
            losses =  (losses, )
        self.update_history(losses)
        probs = (
            max(
                self.get_probability(loss),
                self.sample_min
            )
            for loss in losses
        )
        return probs
    def __call__(self, losses):
        return self.calc_probability(losses)
if __name__ == "__main__":
    pass
'''
selector.py
'''
import calcprob
import numpy as np
class Selector:
    def __init__(self, batch_size,
                 beta, sample_min, max_len=3000):
        self.batch_size = batch_size
        self.calcprob = calcprob.Calcprob(beta,
                                          sample_min,
                                          max_len)
        self.reset()
    def backward(self):
        loss = sum(self.batch)
        loss.backward()
        self.reset()
    def reset(self):
        self.batch = []
        self.length = 0.
    def select(self, losses):
        probs = self.calcprob(losses)
        for i, prob in enumerate(probs):
            if np.random.rand() < prob:
                self.batch.append(losses[i])
                self.length += 1
                if self.length >= self.batch_size:
                    self.backward()
    def __call__(self, losses):
        self.select(losses)
'''
main.py
'''
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import os
import selector
class Train:
    def __init__(self, model, lossfunc,
                 bpsize, beta, sample_min, max_len=3000,
                 lr=0.01, momentum=0.9, weight_decay=0.0001):
        self.net = self.choose_net(model)
        self.criterion = self.choose_lossfunc(lossfunc)
        self.opti = torch.optim.SGD(self.net.parameters(),
                                    lr=lr, momentum=momentum,
                                    weight_decay=weight_decay)
        self.selector = selector.Selector(bpsize, beta,
                                          sample_min, max_len)
        self.gpu()
        self.generate_path()
        self.acc_rates = []
        self.errors = []
    def choose_net(self, model):
        net = getattr(
            torchvision.models,
            model,
            None
        )
        if net is None:
            raise ValueError("no such model")
        return net()
    def choose_lossfunc(self, lossfunc):
        lossfunc = getattr(
            nn,
            lossfunc,
            None
        )
        if lossfunc is None:
            raise ValueError("no such lossfunc")
        return lossfunc
    def gpu(self):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        if torch.cuda.device_count() > 1:
            print("Let'us use %d GPUs" % torch.cuda.device_count())
            self.net = nn.DataParallel(self.net)
        self.net = self.net.to(self.device)
    def generate_path(self):
        """
        生成保存数据的路径
        :return:
        """
        try:
            os.makedirs('./paras')
            os.makedirs('./logs')
            os.makedirs('./infos')
        except FileExistsError as e:
            pass
        name = self.net.__class__.__name__
        paras = os.listdir('./paras')
        logs = os.listdir('./logs')
        infos = os.listdir('./infos')
        number = max((len(paras), len(logs), len(infos)))
        self.para_path = "./paras/{0}{1}.pt".format(
            name,
            number
        )
        self.log_path = "./logs/{0}{1}.txt".format(
            name,
            number
        )
        self.info_path = "./infos/{0}{1}.npy".format(
            name,
            number
        )
    def log(self, strings):
        """
        运行日志
        :param strings:
        :return:
        """
        # a 往后添加内容
        with open(self.log_path, 'a', encoding='utf8') as f:
            f.write(strings)
    def save(self):
        """
        保存网络参数
        :return:
        """
        torch.save(self.net.state_dict(), self.para_path)
    def derease_lr(self, multi=0.96):
        """
        降低学习率
        :param multi:
        :return:
        """
        self.opti.param_groups[0]['lr'] *= multi
    def train(self, trainloder, epochs=50):
        data_size = len(trainloder) * trainloder.batch_size
        part = int(trainloder.batch_size / 2)
        for epoch in range(epochs):
            running_loss = 0.
            total_loss = 0.
            acc_count = 0.
            if (epoch + 1) % 8 is 0:
                self.derease_lr()
                self.log(#日志记录
                    "learning rate change!!!\n"
                )
            for i, data in enumerate(trainloder):
                imgs, labels = data
                imgs = imgs.to(self.device)
                labels = labels.to(self.device)
                out = self.net(imgs)
                _, pre = torch.max(out, 1)  #判断是否判断正确
                acc_count += (pre == labels).sum().item() #加总对的个数
                losses = (
                    self.criterion(out[i], labels[i])
                    for i in range(len(labels))
                )
                self.opti.zero_grad()
                self.selector(losses) #选择
                self.opti.step()
                running_loss += sum(losses).item()
                if (i+1) % part is 0:
                    strings = "epoch {0:<3} part {1:<5} loss: {2:<.7f}\n".format(
                        epoch, i, running_loss / part
                    )
                    self.log(strings)#日志记录
                    total_loss += running_loss
                    running_loss = 0.
            self.acc_rates.append(acc_count / data_size)
            self.errors.append(total_loss / data_size)
            self.log( #日志记录
                "Accuracy of the network on %d train images: %d %%\n" %(
                    data_size, acc_count / data_size * 100
                )
            )
            self.save() #保存网络参数
        #保存一些信息画图用
        np.save(self.info_path, {
            'acc_rates': np.array(self.acc_rates),
            'errors': np.array(self.errors)
        })
if __name__ == "__main__":
    import OptInput
    args = OptInput.Opi()
    args.add_opt(command="model", default="resnet34")
    args.add_opt(command="lossfunc", default="CrossEntropyLoss")
    args.add_opt(command="bpsize", default=32)
    args.add_opt(command="beta", default=0.9)
    args.add_opt(command="sample_min", default=0.3)
    args.add_opt(command="max_len", default=3000)
    args.add_opt(command="lr", default=0.001)
    args.add_opt(command="momentum", default=0.9)
    args.add_opt(command="weight_decay", default=0.0001)
    args.acquire()
    root = "C:/Users/pkavs/1jupiterdata/data"
    trainset = torchvision.datasets.CIFAR10(root=root, train=True,
                                          download=False,
                                          transform=transforms.Compose(
                                              [transforms.Resize(224),
                                               transforms.ToTensor(),
                                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
                                          ))
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                              shuffle=True, num_workers=8,
                                               pin_memory=True)
    dog = Train(**args.infos)
    dog.train(train_loader, epochs=1000)
												
											Accelerating Deep Learning by Focusing on the Biggest Losers的更多相关文章
- Deep Learning 27:Batch normalization理解——读论文“Batch normalization: Accelerating deep network training by reducing internal covariate shift ”——ICML 2015
		
这篇经典论文,甚至可以说是2015年最牛的一篇论文,早就有很多人解读,不需要自己着摸,但是看了论文原文Batch normalization: Accelerating deep network tr ...
 - Applied Deep Learning Resources
		
Applied Deep Learning Resources A collection of research articles, blog posts, slides and code snipp ...
 - [C3] Andrew Ng - Neural Networks and Deep Learning
		
About this Course If you want to break into cutting-edge AI, this course will help you do so. Deep l ...
 - 【深度学习Deep Learning】资料大全
		
最近在学深度学习相关的东西,在网上搜集到了一些不错的资料,现在汇总一下: Free Online Books by Yoshua Bengio, Ian Goodfellow and Aaron C ...
 - (转) Awesome - Most Cited Deep Learning Papers
		
转自:https://github.com/terryum/awesome-deep-learning-papers Awesome - Most Cited Deep Learning Papers ...
 - deep learning 的综述
		
从13年11月初开始接触DL,奈何boss忙or 各种问题,对DL理解没有CSDN大神 比如 zouxy09等 深刻,主要是自己觉得没啥进展,感觉荒废时日(丢脸啊,这么久....)开始开文,即为记录自 ...
 - (转)分布式深度学习系统构建 简介 Distributed Deep Learning
		
HOME ABOUT CONTACT SUBSCRIBE VIA RSS DEEP LEARNING FOR ENTERPRISE Distributed Deep Learning, Part ...
 - Machine and Deep Learning with Python
		
Machine and Deep Learning with Python Education Tutorials and courses Supervised learning superstiti ...
 - The Brain vs Deep Learning Part I: Computational Complexity — Or Why the Singularity Is Nowhere Near
		
The Brain vs Deep Learning Part I: Computational Complexity — Or Why the Singularity Is Nowhere Near ...
 
随机推荐
- C/C++ Qt 数据库与SqlTableModel组件应用
			
SqlTableModel 组件可以将数据库中的特定字段动态显示在TableView表格组件中,通常设置QSqlTableModel类的变量作为数据模型后就可以显示数据表内容,界面组件中则通过QDat ...
 - CSS系列,清除浮动方法总结
			
在非IE浏览器(如Firefox)下,当容器的高度为auto,且容器的内容中有浮动(float为left或right)的元素.在这种情况下,容器的高度不能自动伸长以适应内容的高度,使得内容溢出到容器外 ...
 - HDFS初探之旅(二)
			
6.HDFS API详解 Hadoop中关于文件操作类疾病上全部在"org.apache.hadoop.fs"包中,这些API能够支持的操作包含:打开文件.读写文件.删除文件等. ...
 - 【编程思想】【设计模式】【结构模式Structural】front_controller
			
Python版 https://github.com/faif/python-patterns/blob/master/structural/front_controller.py #!/usr/bi ...
 - java异常处理中throws和throw的使用
			
异常介绍: 运行时异常.非运行时异常 在编写可能会抛出异常的方法时,它们都必须声明为有异常. 一.throws关键字 1.声明方法可能抛出的异常: 2.写在方法名后面: 3.可声明抛出多个异常,异常名 ...
 - minikube metrics-server HPA 自动扩缩容错误
			
minikube metrics-server pod 错误 启动 minikube addons enable metrics-server 之后查看 metrics-server pod 会有如下 ...
 - 安霸pipeline简述之rgb域的处理
			
RGB域处理模块的详细介绍: RGB域的处理主要是demosaic,color_correction,tone_curve(类似于gamma曲线). Demosaic:此模块将bayer Patt ...
 - minkube在deban10上的安装步骤
			
环境准备: 所用机器为4c 16g i3 4170 1t机械硬盘 系统 debian 10 安装docker 如果已经安装并配置好可直接跳过 安装ssl sudo apt-get install ...
 - $(document).ready()与window.onload的区别,站在三个维度回答问题
			
1.执行时机 window.onload必须等到页面内包括图片的所有元素加载完毕后才能执行. $(document).ready()是DOM结构绘制完毕后就执行,不必等到加载完毕. 2 ...
 - Nginx模块之stub_status
			
目录 一.介绍 二.使用 三.参数 一.介绍 Nginx中的stub_status模块主要用于查看Nginx的一些状态信息. 当前默认在nginx的源码文件中,不需要单独下载 二.使用 本模块默认是不 ...