参考:

https://arxiv.org/pdf/1603.05691.pdf

本文要讨论的是如何构建“集成神经网络”(“Ensemble neural network”),其实“集成神经网络”其本质就是一种“集成学习”的模型,只不过个体学习器不再是传统的机器学习模型而是神经网络模型,不过一般情况下“集成神经网络”的具体形式并不能简单等价于传统的“集成学习”模型。

“集成神经网络”“Ensemble neural network”构建的关键:

The ensemble averages the logits predicted by each model before the softmax layers.

The logits (the scores just prior to the final softmax layer) from each of the CNN in the ensemble model are averaged for each class.

个体神经网络(CNN模型):

CNN层 + 全连接层(输出的是logits) + softmax层 + 交叉熵损失函数

集成神经网络:

训练N个CNN模型,这些CNN模型结构可以是不同的,并且每个CNN模型个体都是使用不同的超参数,不同的训练数据集,这里N一般是一个比较大的数值,如N=200. 需要注意的是集成神经网络往往是模型参数不太大的模型,并且每个模型的各自的训练数据集不是特别大的那种情况,因为一般集成神经网络需要单独训练出比较多的个体,这里假设是200个,如果每个个体模型的训练都很费时的话那么就需要一个极大的训练时长的,也正因此集成神经网络在现在的大数据集和大模型参数的背景下基本是不会采用的。

这里不同的神经网络个体采用不同的训练数据集,可以通过数据增强的方式对某个原有的单一数据集进行扩增,具体构建时每个模型在读入随机抽取的原训练集中数据后进行实时的数据增强,以此实现不同神经网络个体的训练数据的差异性。

假设训练200个模型个体完成后需要选出一个最优的组合,即选出M个个体神经网络,使者M个进行组合后性能最优。在具体选择个体模型的时候,需要使用验证数据集而不是训练数据集。

这个选择过程并没有找到具体的算法描述,因此这里给出一个个人给出的步骤:

对200个模型按照验证数据集的准确度进行从高到低的排序,需要选择出的模型集合为S={},默认将0号模型加入到S集合中(因为0号个体模型的验证数据集准确度最高),从1号模型开始到199号模型,判断这些模型加入到S集合后是否可以提升集成模型的性能,如果能提高则加入,如果不能则跳过,这样我们只需要用验证数据集验证199个集成模型即可。但是该种选择方法其实并不难保证最优的组合,比如该种方法选择出的集合包含1号个体模型,但是我们如果跳过1号模型(将1号模型从集合S中取出),然后从2号模型开始进行选择,那么选出的新的集合 S' 是有可能比S集合性能更优,但是考虑到这样会急剧增加运算时长,因此也不建议采用该种选择方法。

需要注意的是,从200个个体模型中最后选择出的集合往往也是有着较多个体的,比如最终选择出的M值,即S集合的大小为16或19这样的数值,因此最终构建的“集成神经网络”往往在进行预测推理时也是很费时的。

构建“集成神经网络”和传统的“集成学习”之间一个明显的不同是:

  • The ensemble averages the logits predicted by each model before the softmax layers.

  • The logits (the scores just prior to the final softmax layer) from each of the CNN in the ensemble model are averaged for each class.

传统的“集成学习”是将个体模型的预测输出进行融合(比如进行平均),而“集成神经网络”则是将个体模型输出的logits进行平均,而不是预测值p,这个区别是极为重要的。

紧邻softmax层的全连接层输出的是logits,而softmax层输出的是预测值概率P。

个体神经网络(CNN模型):

CNN层 + 全连接层(输出的是logits) + softmax层(输出的是预测值概率P) + 交叉熵损失函数

PS. 虽然集成神经网络基本被淘汰出了历史舞台,但是经典的神经网络理论模型的实证研究都是喜欢用这个“集成神经网络”的,因为虽然“集成神经网络”构建十分费时,但是其往往可以获得比单独个体的神经网络模型更好的performance(一般在一个百分点到1.5个百分点的准确度),这个“集成神经网络”作为知识盲区也是有必要补上的。

补充:

其实“集成网络”是在全连接层(输出的是logits)后进行集成(对logits进行平均),还是在softmax层(输出的是预测值概率P)后进行集成(对p进行平均),其实二者都是有的,至于哪种情况更好这或许是一个实践问题,也就是说要具体的问题具体分析,这个分类问题或许这个集成方式好,而那个数据集或许是另种集成方法好。Maybe it is an empirical question.

如何构建“集成神经网络”“Ensemble neural network”的更多相关文章

  1. 递归神经网络(Recursive Neural Network, RNN)

    信息往往还存在着诸如树结构.图结构等更复杂的结构.这就需要用到递归神经网络 (Recursive Neural Network, RNN),巧合的是递归神经网络的缩写和循环神经网络一样,也是RNN,递 ...

  2. 卷积神经网络(Convolutional Neural Network, CNN)简析

    目录 1 神经网络 2 卷积神经网络 2.1 局部感知 2.2 参数共享 2.3 多卷积核 2.4 Down-pooling 2.5 多层卷积 3 ImageNet-2010网络结构 4 DeepID ...

  3. 深度学习FPGA实现基础知识10(Deep Learning(深度学习)卷积神经网络(Convolutional Neural Network,CNN))

    需求说明:深度学习FPGA实现知识储备 来自:http://blog.csdn.net/stdcoutzyx/article/details/41596663 说明:图文并茂,言简意赅. 自今年七月份 ...

  4. 人工神经网络 Artificial Neural Network

    2017-12-18 23:42:33 一.什么是深度学习 深度学习(deep neural network)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高 ...

  5. [C4] 前馈神经网络(Feedforward Neural Network)

    前馈神经网络(Feedforward Neural Network - BP) 常见的前馈神经网络 感知器网络 感知器(又叫感知机)是最简单的前馈网络,它主要用于模式分类,也可用在基于模式分类的学习控 ...

  6. 详解循环神经网络(Recurrent Neural Network)

    本文结构: 模型 训练算法 基于 RNN 的语言模型例子 代码实现 1. 模型 和全连接网络的区别 更细致到向量级的连接图 为什么循环神经网络可以往前看任意多个输入值 循环神经网络种类繁多,今天只看最 ...

  7. 【原创】深度神经网络(Deep Neural Network, DNN)

    线性模型通过特征间的现行组合来表达“结果-特征集合”之间的对应关系.由于线性模型的表达能力有限,在实践中,只能通过增加“特征计算”的复杂度来优化模型.比如,在广告CTR预估应用中,除了“标题长度.描述 ...

  8. 脉冲神经网络Spiking neural network

    (原文地址:维基百科) 简单介绍: 脉冲神经网络Spiking neuralnetworks (SNNs)是第三代神经网络模型,其模拟神经元更加接近实际,除此之外,把时间信息的影响也考虑当中.思路是这 ...

  9. 吴恩达深度学习第1课第4周-任意层人工神经网络(Artificial Neural Network,即ANN)(向量化)手写推导过程(我觉得已经很详细了)

    学习了吴恩达老师深度学习工程师第一门课,受益匪浅,尤其是吴老师所用的符号系统,准确且易区分. 遵循吴老师的符号系统,我对任意层神经网络模型进行了详细的推导,形成笔记. 有人说推导任意层MLP很容易,我 ...

  10. Bootstrap aggregating Bagging 合奏 Ensemble Neural Network

    zh.wikipedia.org/wiki/Bagging算法 Bagging算法 (英语:Bootstrap aggregating,引导聚集算法),又称装袋算法,是机器学习领域的一种团体学习算法. ...

随机推荐

  1. xtrabackup备份工具

    为什么要学这个工具 背景 一个合格的运维工程师或者dba工程师,如果有从事数据库方面的话,首先需要做的就是备份,如果没有备份,出现问题的话,你的业务就会出问题,你的工作甚至会... 所以备份是重要的, ...

  2. Scrapy框架(三)--全站数据爬取

    scrapy基于Spider类的全站数据爬取 大部分的网站展示的数据都进行了分页操作,那么将所有页码对应的页面数据进行爬取就是爬虫中的全站数据爬取.基于scrapy如何进行全站数据爬取呢?1.将每一个 ...

  3. OAuth + Security - 4 - 客户端信息存储数据库

    PS:此文章为系列文章,建议从第一篇开始阅读. 在之前的所有配置中,我们的客户端信息和授权码模式下的授权码任然还是存储在数据库中的,这样就不利于我们后期的扩展,所以在正式的生成环境中,我们一般将其存储 ...

  4. 硬件开发笔记(二十一):外部搜索不到的元器件封装可尝试使用AD21软件的“ManufacturerPart Search”功能

    前言   这是一个AD的一个强大的新功能,能招到元器件的原理图.3D模型还有价格厂家,但是不一定都有,有了也不一定有其3D模型. ManufacturerPart Search 在设计工具中选择即用型 ...

  5. 记一次反向代理过滤sql注入

    公司有一php系统,由于该系统是购买的,并且没人懂php,无法通过修改代码过滤sql注入问题 代码如下: public class Program { public static void Main( ...

  6. 【动手学深度学习】第五章笔记:层与块、参数管理、自定义层、读写文件、GPU

    为了更好的阅读体验,请点击这里 由于本章内容比较少且以后很显然会经常回来翻,因此会写得比较详细. 5.1 层和块 事实证明,研究讨论"比单个层大"但"比整个模型小&quo ...

  7. hive第一课:Hive3.1.2概述与基本操作

    Hive3.1.2概述与基本操作 1.Hive基本概念 1.1 Hive简介 Hive本质是将SQL转换为MapReduce的任务进行运算,底层由HDFS来提供数据存储,说白了hive可以理解为一个将 ...

  8. socket 端口复用 SO_REUSEPORT 与 SO_REUSEADDR

    背景 在学习 SO_REUSEADDR 地址复用的时候,看到有人提到了 SO_REUSEPORT .于是也了解了一下. SO_REUSEPORT 概述 SO_REUSEPOR这个socket选项可以让 ...

  9. B 站和小红书又又又崩了,罪魁祸首竟然又是他。。。

    大家好,我是凌晨. 今天上午10点左右,我打开B站发现无法刷新视频列表和评论区,收藏夹和弹幕也均不可用. 原以为是手机网络问题,换网络重启手机都还是不行,第一时间打开微博,果然,B站崩了的新闻荣登榜首 ...

  10. python实用总结

    Python3 常用工具 1. 命令行快速搭建本地http服务器 python3 -m http.server 8000 在命令行中输入此命令,就会在当前目录下搭建http服务器,可以通过访问http ...