深度变分信息瓶颈——Deep Variational Information Bottleneck
Deep Variational Information Bottleneck (VIB) 变分信息瓶颈 论文阅读笔记。本文利用变分推断将信息瓶颈框架适应到深度学习模型中,可视为一种正则化方法。
变分信息瓶颈
假设数据输入输出对为$(X,Y)$,假设判别模型$f_\theta(\cdot)$有关于$X$的中间表示$Z$,本文旨在优化$\theta$以最小化互信息$I(Z;X)$ ,同时最大化互信息$I(Z;Y)$,即:
$\max\limits_{\theta}I(Z;Y|\theta)-\beta I(Z;X|\theta)$
其中$\beta>0$为平衡系数。那么如何构造这个模型以及相应的优化方案?下面推导上式的下界,使其下界变大,上式即可变大。为了简化,下面去掉$\theta$进行推导。
下界1
首先,$I(Z;Y)$展开为:
$\displaystyle I(Z; Y) = \int \int p(y, z) \log \frac{p(y|z)}{p(y)} \, dy \, dz$
其中$p(y)$是数据的标签分布,已知。未知而需要进行处理的是其中的$p(y,z)$和$p(y|z)$,也就是模型需要拟合的分布。对于$p(y|z)$,可以用一个解码器$q(y|z)$来拟合,即文中所谓的变分估计。利用KL散度的大于零性质,有以下不等式:
\begin{align*} &\text{KL}(p(Y|Z),q(Y|Z))\geq 0\\ \implies &\int \, p(y|z) \log \frac{p(y|z)}{q(y|z)} dy\geq 0\\ \implies &\int \, \frac{p(y,z)}{p(z)} \log \frac{p(y|z)}{q(y|z)} dy\geq 0\\ \implies &\int \, p(y,z) \log p(y|z) dy\geq \int \, p(y,z) \log q(y|z)dy\\ \end{align*}
则有
\begin{align*} \displaystyle I(Z; Y) &= \int \int p(y, z) \log p(y|z) - p(y, z) \log p(y) \, dy \, dz\\ &\geq \int \int p(y, z) \log q(y|z) - p(y, z) \log p(y) \, dy \, dz\\ &= \int \int \, p(y, z) \log q(y|z) dy \, dz - \int \, p(y) \log p(y) dy \\ &= \int \int \, p(y, z) \log q(y|z)dy \, dz + H(Y) \end{align*}
对于其中的$p(y,z)$,本文基于马尔科夫假设:$Y\leftrightarrow X\leftrightarrow Z$。这个假设表明,$Y$和$Z$在$X$的条件下独立(那在优化时呢?$Z$是关于$X$和$Y$的联合分布进行更新的)。有:
$\displaystyle p(y,z)=\int p(x,y,z)dx=\int p(x,y)p(z|x,y)dx= \int p(x,y)p(z|x)dx$
此外,由于$H(Y)$已知且固定,可忽略。则有
$\displaystyle I(Z; Y) \geq \int \int \int \, p(x,y)p(z|x) \log q(y|z)dx \, dy \, dz$
其中,$p(x,y)$是真实数据分布,$p(z|x)$是原始模型关于$x$对中间表示$z$的推理分布。
上界2
$I(Z;X)$展开为:
$\displaystyle I(Z;X)=\int \int p(x,z)\log \frac{p(z|x)}{p(z)}dx\,dz$
对于其中的$p(z)$,作者用另一个变分估计$r(z)$来拟合。由于有
\begin{align*} &\text{KL}(p(Z),r(Z))\geq 0\\ \implies&\int p(z)\log p(z)dz\geq\int p(z)\log r(z)dz\\ \implies&\int\int p(x,z)\log p(z)dx\,dz\geq\int \int p(x,z)\log r(z)dx\,dz \end{align*}
则有
\begin{align*} I(Z;X)\leq \int\int p(x)p(z|x)\log \frac{p(z|x)}{r(z)}dx\,dz \end{align*}
总体下界和优化
结合下界1和上界2,有:
\begin{align*} &I(Z; Y) - \beta I(Z; X) \\ \geq&\int \int \int \, p(x,y)p(z|x) \log q(y|z)dx \, dy \, dz - \beta \int\int p(x)p(z|x)\log \frac{p(z|x)}{r(z)}dx\,dz = L \end{align*}
针对上式,用经验分布来代替真实分布。即用$\frac{1}{N}\sum_{n=1}^N\delta_{x_n}(x)$代替$p(x)$,用$\frac{1}{N}\sum_{n=1}^N\delta_{y_n}(y)$代替$p(y)$,用$\frac{1}{N}\sum_{n=1}^N\delta_{(x_n,y_n)}(x,y)$代替$p(x,y)$。其中$\delta_{x_n}(x)$表示狄拉克函数,其空间内积分为1,且仅在$x_n$上非零。假设经验分布有$N$各样本$\{(x_n,y_n)\}_{n=1}^N$。文中额外引入所谓狄拉克函数让人看不懂,实际上直接把概率积分改成离散样本的求和取平均即可。则上式可被估计为:
$\displaystyle L\approx \frac{1}{N}\sum\limits_{n=1}^N\int p(z|x_n)\log q(y_n|z)-\beta\, p(z|x_n)\log\frac{p(z|x_n)}{r(z)}\, dz$
文中将$z$视为隐变量,利用VAE的重参数技巧将$p(z|x_n)$实现为一个关于$x_n$的正态分布$\mathcal{N}(f_e^\mu(x_n),f_e^{\Sigma}(x_n))$,其中$f_e^\mu(x_n),f_e^\Sigma(x_n)$分别为基于$x_n$生成的均值和协方差矩阵。将$z$抽样表示为$f(x_n,\epsilon)= f_e^{\Sigma}(x_n))\epsilon + f_e^\mu(x_n) $,其中$\epsilon\sim \mathcal{N}(0,1)$。则最大化$L$可表示为最小化:
$\displaystyle J_{IB} = \frac{1}{N}\sum\limits_{n=1}^N \mathbb{E}_{\epsilon\sim \mathcal{N}(0,1)}\left[-\log q(y_n|f(x_n,\epsilon))\right] + \beta\,\text{KL}\left[p(Z|x_n);r(z)\right]$
其中$r(z)$利用某一特定分布实现,文中使用标准正态分布实现。
直觉理解
直觉上理解:模型要把每个$x_n$分别映射到特定的分布,这些分布既不能偏离标准正态分布太远,又需要让模型后续能根据这些分布的抽样来预测$x_n$的标签。那么这种做法为什么能从$x_n$中抽取对预测$y_n$有效的关键信息而忽略无关信息呢(即信息瓶颈)?我的理解是,模型被惩罚以使不同$x_n$得到的$z_n$分布靠近同一分布,但为了有效预测$y_n$,又必须产生一定的不一致。不同$x_n$对应的$z$分布越一致,通过$z$而流向$y$的差异性信息将越少,导致$q$更难利用采样的$z$预测$y$,从而促使模型忽略$x$中的冗余信息而保留预测$y$所需的关键信息。$\beta$则用于控制$z$保留$x$信息的程度,越大保留信息越少。
相较于一般的判别模型:当不把$z $视为隐变量,而变成关于$x$唯一确定的中间表示时,就是一般的判别模型。这种方式隐式地假定了表示的连续性,然而无法确保所有$z$都不是被离散地分散在表示空间中。最坏的过拟合情况下,每个$(x_n,y_n)$都孤立地确定了一个中间表示$z_n$来实现一一映射,导致无泛化。而对于使用了信息瓶颈$z$的判别模型,由于$x$仅仅确定$z$的生成分布,不同的$x_i,x_j$依然可能抽样出同一个$z$,这种模式强制这个抽样出的$z $必须共享这两个样本的相似信息并忽略不同的信息,从而表示语义的相似性被强制由线性距离控制,实现表示语义的连续性,从而显式地确定了模型的泛化。
实验
表1:信息瓶颈加成的模型和各种正则化后模型的对比。
图1:不同$\beta$、$z$维度$K$下VIB模型在MNIST上的错误率,以及两个互信息的平衡。
图2:$z$维度$K=2$时,1000张图片的$z$分布的可视化。
后续是一些对抗鲁棒的实验,不记录
深度变分信息瓶颈——Deep Variational Information Bottleneck的更多相关文章
- 变分自编码器(Variational auto-encoder,VAE)
参考: https://www.cnblogs.com/huangshiyu13/p/6209016.html https://zhuanlan.zhihu.com/p/25401928 https: ...
- 贝叶斯深度学习(bayesian deep learning)
本文简单介绍什么是贝叶斯深度学习(bayesian deep learning),贝叶斯深度学习如何用来预测,贝叶斯深度学习和深度学习有什么区别.对于贝叶斯深度学习如何训练,本文只能大致给个介绍. ...
- 深度学习概述教程--Deep Learning Overview
引言 深度学习,即Deep Learning,是一种学习算法(Learning algorithm),亦是人工智能领域的一个重要分支.从快速发展到实际应用,短短几年时间里, ...
- (转) 变分自编码器(Variational Autoencoder, VAE)通俗教程
变分自编码器(Variational Autoencoder, VAE)通俗教程 转载自: http://www.dengfanxin.cn/?p=334&sukey=72885186ae5c ...
- 深度学习加速器堆栈Deep Learning Accelerator Stack
深度学习加速器堆栈Deep Learning Accelerator Stack 通用张量加速器(VTA)是一种开放的.通用的.可定制的深度学习加速器,具有完整的基于TVM的编译器堆栈.设计了VTA来 ...
- 关于深度残差网络(Deep residual network, ResNet)
题外话: From <白话深度学习与TensorFlow> 深度残差网络: 深度残差网络的设计就是为了克服这种由于网络深度加深而产生的学习效率变低,准确率无法有效提升的问题(也称为网络退化 ...
- 深度强化学习(Deep Reinforcement Learning)入门:RL base & DQN-DDPG-A3C introduction
转自https://zhuanlan.zhihu.com/p/25239682 过去的一段时间在深度强化学习领域投入了不少精力,工作中也在应用DRL解决业务问题.子曰:温故而知新,在进一步深入研究和应 ...
- 用信息值进行特征选择(Information Value)
Posted by c cm on January 3, 2014 特征选择(feature selection)或者变量选择(variable selection)是在建模之前的重要一步.数据接口越 ...
- 深度学习论文笔记-Deep Learning Face Representation from Predicting 10,000 Classes
来自:CVPR 2014 作者:Yi Sun ,Xiaogang Wang,Xiaoao Tang 题目:Deep Learning Face Representation from Predic ...
- 最实用的深度学习教程 Practical Deep Learning For Coders (Kaggle 冠军 Jeremy Howard 亲授)
Jeremy Howard 在业界可谓大名鼎鼎.他是大数据竞赛平台 Kaggle 的前主席和首席科学家.他本人还是 Kaggle 的冠军选手.他是美国奇点大学(Singularity Universi ...
随机推荐
- ST-SSL: 用于交通流量预测的时空自监督学习《Spatio-Temporal Self-Supervised Learning for Traffic Flow Prediction》(交通流量预测、时空异质性、自监督、数据增强)
2023年10月23日,继续论文,好困,想发疯. 论文:Spatio-Temporal Self-Supervised Learning for Traffic Flow Prediction Git ...
- SQL Server 语句日期格式查找方法
1. SQL Server中,处理日期格式和查找特定日期格式方法示例 在SQL Server中,处理日期格式和查找特定日期格式的记录是一个常见的需求.SQL Server提供了多种函数和格式选项来处理 ...
- java基础 -IO流笔记
610,文件的基础知识 文件流 输入流和输出流都是相对 java程序内存 而言 611,创建文件 在D盘下创建文件. package com.hspedu.file; import org.junit ...
- Spring —— bean生命周期
bean生命周期 生命周期:从创建到消亡的完整过程 bean生命周期:bean从创建到销毁的整体过程 bean生命周期控制:在bean创建后到销毁前做一些事情 方式一:配置控制生命周期 <b ...
- 【QT性能优化】QT性能优化之QT性能优化实战 QML优化 QT高性能 QT6系列视频课程 QT6 性能优化实战 QT高性能 QT原理源码 QML优化 GUI绘图原理源码
QT性能优化实战视频课程 QT6 Widgets高性能应用编程 1.课前考试 2.字符串优化(上) 3.字符串优化(下) 4.绘图优化(上) 5.绘图优化(下) 6.QT界面优化(上) 7.QT界面 ...
- 1 月 25 日见|Flutter Forward 活动日程表正式发布!
2023 年 1 月 25 日 (正月初四),我们将在肯尼亚首都内罗毕举办 Flutter Forward 大会,并同时开启线上直播.本次活动将展示最新的 Flutter 技术更新,包括一个主题演讲以 ...
- Kubernetes集群证书过期解决办法
问题现象 K8S集群证书过期后,会导无法创建Pod,通过kubectl get nodes也无法获取信息,甚至dashboard也无法访问. 一.确认K8S证书过期时间 查看k8s某一证书过期时间: ...
- oh-my-zsh nvm command not found
oh-my-zsh nvm command not found 如果你在使用 oh-my-zsh 并且在终端输入 nvm 命令时提示 "command not found",这可能 ...
- USB总线-Linux内核USB3.0 Hub驱动分析(十四)
1.概述 USB Hub提供了连接USB主机和USB设备的电气接口.USB Hub拥有一个上行口,至少一个下行口,上行口连接上一级的Hub的下行口或者USB主机,连接主机的为Root Hub,下行口连 ...
- USB协议详解第7讲(补充-USB帧和微帧剖析)
1.概念 (1)USB2.0帧和微帧属于物理层时间基准的概念,低速和全速下每个帧时长为1ms,高速下每个帧又分为8个微帧,即每个微帧时长为125us. (2)USB主机和设备控制器同步后,每个微帧起始 ...