参考:

https://arxiv.org/pdf/1603.05691.pdf

本文要讨论的是如何构建“集成神经网络”(“Ensemble neural network”),其实“集成神经网络”其本质就是一种“集成学习”的模型,只不过个体学习器不再是传统的机器学习模型而是神经网络模型,不过一般情况下“集成神经网络”的具体形式并不能简单等价于传统的“集成学习”模型。

“集成神经网络”“Ensemble neural network”构建的关键:

The ensemble averages the logits predicted by each model before the softmax layers.

The logits (the scores just prior to the final softmax layer) from each of the CNN in the ensemble model are averaged for each class.

个体神经网络(CNN模型):

CNN层 + 全连接层(输出的是logits) + softmax层 + 交叉熵损失函数

集成神经网络:

训练N个CNN模型,这些CNN模型结构可以是不同的,并且每个CNN模型个体都是使用不同的超参数,不同的训练数据集,这里N一般是一个比较大的数值,如N=200. 需要注意的是集成神经网络往往是模型参数不太大的模型,并且每个模型的各自的训练数据集不是特别大的那种情况,因为一般集成神经网络需要单独训练出比较多的个体,这里假设是200个,如果每个个体模型的训练都很费时的话那么就需要一个极大的训练时长的,也正因此集成神经网络在现在的大数据集和大模型参数的背景下基本是不会采用的。

这里不同的神经网络个体采用不同的训练数据集,可以通过数据增强的方式对某个原有的单一数据集进行扩增,具体构建时每个模型在读入随机抽取的原训练集中数据后进行实时的数据增强,以此实现不同神经网络个体的训练数据的差异性。

假设训练200个模型个体完成后需要选出一个最优的组合,即选出M个个体神经网络,使者M个进行组合后性能最优。在具体选择个体模型的时候,需要使用验证数据集而不是训练数据集。

这个选择过程并没有找到具体的算法描述,因此这里给出一个个人给出的步骤:

对200个模型按照验证数据集的准确度进行从高到低的排序,需要选择出的模型集合为S={},默认将0号模型加入到S集合中(因为0号个体模型的验证数据集准确度最高),从1号模型开始到199号模型,判断这些模型加入到S集合后是否可以提升集成模型的性能,如果能提高则加入,如果不能则跳过,这样我们只需要用验证数据集验证199个集成模型即可。但是该种选择方法其实并不难保证最优的组合,比如该种方法选择出的集合包含1号个体模型,但是我们如果跳过1号模型(将1号模型从集合S中取出),然后从2号模型开始进行选择,那么选出的新的集合 S' 是有可能比S集合性能更优,但是考虑到这样会急剧增加运算时长,因此也不建议采用该种选择方法。

需要注意的是,从200个个体模型中最后选择出的集合往往也是有着较多个体的,比如最终选择出的M值,即S集合的大小为16或19这样的数值,因此最终构建的“集成神经网络”往往在进行预测推理时也是很费时的。

构建“集成神经网络”和传统的“集成学习”之间一个明显的不同是:

  • The ensemble averages the logits predicted by each model before the softmax layers.

  • The logits (the scores just prior to the final softmax layer) from each of the CNN in the ensemble model are averaged for each class.

传统的“集成学习”是将个体模型的预测输出进行融合(比如进行平均),而“集成神经网络”则是将个体模型输出的logits进行平均,而不是预测值p,这个区别是极为重要的。

紧邻softmax层的全连接层输出的是logits,而softmax层输出的是预测值概率P。

个体神经网络(CNN模型):

CNN层 + 全连接层(输出的是logits) + softmax层(输出的是预测值概率P) + 交叉熵损失函数

PS. 虽然集成神经网络基本被淘汰出了历史舞台,但是经典的神经网络理论模型的实证研究都是喜欢用这个“集成神经网络”的,因为虽然“集成神经网络”构建十分费时,但是其往往可以获得比单独个体的神经网络模型更好的performance(一般在一个百分点到1.5个百分点的准确度),这个“集成神经网络”作为知识盲区也是有必要补上的。

补充:

其实“集成网络”是在全连接层(输出的是logits)后进行集成(对logits进行平均),还是在softmax层(输出的是预测值概率P)后进行集成(对p进行平均),其实二者都是有的,至于哪种情况更好这或许是一个实践问题,也就是说要具体的问题具体分析,这个分类问题或许这个集成方式好,而那个数据集或许是另种集成方法好。Maybe it is an empirical question.

如何构建“集成神经网络”“Ensemble neural network”的更多相关文章

  1. 递归神经网络(Recursive Neural Network, RNN)

    信息往往还存在着诸如树结构.图结构等更复杂的结构.这就需要用到递归神经网络 (Recursive Neural Network, RNN),巧合的是递归神经网络的缩写和循环神经网络一样,也是RNN,递 ...

  2. 卷积神经网络(Convolutional Neural Network, CNN)简析

    目录 1 神经网络 2 卷积神经网络 2.1 局部感知 2.2 参数共享 2.3 多卷积核 2.4 Down-pooling 2.5 多层卷积 3 ImageNet-2010网络结构 4 DeepID ...

  3. 深度学习FPGA实现基础知识10(Deep Learning(深度学习)卷积神经网络(Convolutional Neural Network,CNN))

    需求说明:深度学习FPGA实现知识储备 来自:http://blog.csdn.net/stdcoutzyx/article/details/41596663 说明:图文并茂,言简意赅. 自今年七月份 ...

  4. 人工神经网络 Artificial Neural Network

    2017-12-18 23:42:33 一.什么是深度学习 深度学习(deep neural network)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高 ...

  5. [C4] 前馈神经网络(Feedforward Neural Network)

    前馈神经网络(Feedforward Neural Network - BP) 常见的前馈神经网络 感知器网络 感知器(又叫感知机)是最简单的前馈网络,它主要用于模式分类,也可用在基于模式分类的学习控 ...

  6. 详解循环神经网络(Recurrent Neural Network)

    本文结构: 模型 训练算法 基于 RNN 的语言模型例子 代码实现 1. 模型 和全连接网络的区别 更细致到向量级的连接图 为什么循环神经网络可以往前看任意多个输入值 循环神经网络种类繁多,今天只看最 ...

  7. 【原创】深度神经网络(Deep Neural Network, DNN)

    线性模型通过特征间的现行组合来表达“结果-特征集合”之间的对应关系.由于线性模型的表达能力有限,在实践中,只能通过增加“特征计算”的复杂度来优化模型.比如,在广告CTR预估应用中,除了“标题长度.描述 ...

  8. 脉冲神经网络Spiking neural network

    (原文地址:维基百科) 简单介绍: 脉冲神经网络Spiking neuralnetworks (SNNs)是第三代神经网络模型,其模拟神经元更加接近实际,除此之外,把时间信息的影响也考虑当中.思路是这 ...

  9. 吴恩达深度学习第1课第4周-任意层人工神经网络(Artificial Neural Network,即ANN)(向量化)手写推导过程(我觉得已经很详细了)

    学习了吴恩达老师深度学习工程师第一门课,受益匪浅,尤其是吴老师所用的符号系统,准确且易区分. 遵循吴老师的符号系统,我对任意层神经网络模型进行了详细的推导,形成笔记. 有人说推导任意层MLP很容易,我 ...

  10. Bootstrap aggregating Bagging 合奏 Ensemble Neural Network

    zh.wikipedia.org/wiki/Bagging算法 Bagging算法 (英语:Bootstrap aggregating,引导聚集算法),又称装袋算法,是机器学习领域的一种团体学习算法. ...

随机推荐

  1. 使用Jsoup和htmlunit爬取动态网页

    在对http://zkgg.tjtalents.com.cn/newzxxx.jsp这个网页爬取内容时,如果只使用Jsoup进行解析的话,起内部的a href标签内容无法获取到. 但是实际上通过 Do ...

  2. 算法金 | AI 基石,无处不在的朴素贝叶斯算法

    大侠幸会,在下全网同名「算法金」 0 基础转 AI 上岸,多个算法赛 Top 「日更万日,让更多人享受智能乐趣」 历史上,许多杰出人才在他们有生之年默默无闻, 却在逝世后被人们广泛追忆和崇拜. 18世 ...

  3. Selenium模块的使用(一)

    简介 selenium最初是一个自动化测试工具,而爬虫中使用它主要是为了解决requests无法直接执行JavaScript代码的问题 selenium本质是通过驱动浏览器, 完全模拟浏览器的操作,比 ...

  4. java多线程编程:你真的了解线程中断吗?

    java.lang.Thread类有一个 interrupt 方法,该方法直接对线程调用.当被interrupt的线程正在sleep或wait时,会抛出 InterruptedException 异常 ...

  5. idea编译报错 Lombok运行测试类报错 jar依赖冲突解决

    idea编译报错 Lombok运行测试类报错 jar依赖冲突解决 1.现象是idea编译,运行项目的时候是没有问题,可以正常跑起来.2.运行junit测试类的时候,报错提示 lombok找不到类,解决 ...

  6. Linux实时查看Java接口数据

    1.Linux实时查看Java接口数据的方法 在Linux系统中实时查看Java接口数据通常涉及几个步骤: (1)编写Java应用程序:首先,我们需要有一个Java应用程序,它暴露了一个或多个HTTP ...

  7. JSP四个作用域和九个对象

    一.四个作用域 (1)Requset 请求作用域,就是客户端的一次请求 (2)Session 会话作用域,当用户首次访问时,产生一个新的会话,以后服务器就可以记住这个会话状态.生命周期:会话超时,或者 ...

  8. Spark Structured Streaming(二)实战

    5. 实战Structured Streaming 5.1. Static版本 先读一份static 数据: val static = spark.read.json("s3://xxx/d ...

  9. Linux 内核:设备树中的特殊节点

    Linux 内核:设备树中的特殊节点 背景 在解析设备树dtb格式的时候,发现了这个,学习一下. 参考: https://blog.csdn.net/weixin_45309916/article/d ...

  10. Linux内核中的static-key机制

    # Linux内核中的static-key机制 背景 在移植某个TP时,发现频繁操作屏幕会导致i2c总线死掉.在跟踪代码的时候,我发现了这个static-key. 因此,学习一下这块的知识. refe ...