深度学习-InfoGAN论文理解笔记
在弄清楚InfoGAN之前,可以先理解一下变分推断目的以及在概率论中的应用与ELBO是什么,以及KL散度 https://blog.csdn.net/qy20115549/article/details/93074519
https://blog.csdn.net/qy20115549/article/details/86644192。
如果理解了变分推断,KL散度,ELBO,对于InfoGAN中的重要方法就可以很容易理解了。
这里首先看一下简单的对数推导为方便对InfoGAN文中的公式的阅读:
下面的笔记参阅:
https://blog.csdn.net/u011699990/article/details/71599067
https://www.cnblogs.com/zzycv/p/9312039.html
先记一下预备知识就当作复习了。
条件熵公式推导:
用另一个变量对原变量分类后, 原变量的不确定性就会减小, 不确定程度减小了就是信息增益。
互信息(Mutual Information, MI)是变量间相互依赖性的度量, 它度量两个事件集合之间的相关性。
两个离散随机变量X和Y的互信息可定义为:
p(x,y)是X和Y的联合概率分布函数, 而p(x)和p(y)分别是X和Y的边缘概率分布函数
在连续随机变量的情形下:
I(X;Y)的一些计算:
GAN
总的来说,这个生成模型就是通过两个神经网络互相之间的竞争对抗来进行训练。这两个网络中有一个是生成器,它需要将随机的噪声分布z映射到我们需要得到的真实分布x,另外一个网络就是判别器,从真实数据和生成的数据中间随机采样,判断这个数据是否是真实数据,在这里判别器D相当于一个二分类器。所以,整个优化问题就转换为一个minmax game,D(x)代表x来自真实数据的概率,在训练过程中,最大化D(x)的值,同时,最小化生成器生成数据的能力,从而达到两个网络互相竞争,互相进步,生成器生成数据的能力越来越强,判别器判别数据的能力也越来越强。
如果从表征学习的角度来看GAN模型,会发现,由于在生成器使用噪声z的时候没有加任何的限制,所以在以一种高度混合的方式使用z,z的任何一个维度都没有明显的表示一个特征,所以在数据生成过程中,我们无法得知什么样的噪声z可以用来生成数字1,什么样的噪声z可以用来生成数字3,我们对这些一无所知,这从一点程度上限制了我们对GAN的使用。
InfoGAN
基于上面的分析,作者就在生成器中除了原先的噪声z还增加了一个隐含编码c,提出了一个新的GAN模型—InfoGAN,其中Info代表互信息,它表示生成数据x与隐藏编码c之间关联程度的大小,为了使得x与c之间关联密切,所以我们需要最大化互信息的值,据此对原始GAN模型的值函数做了一点修改,相当于加了一个互信息的正则化项。是一个超参,通过之后的实验选择了一个最优值1。
上面的I(C|X)中的|应该为";" 即I(C;X), I(C;X=G(Z,C))
在实践的过程中, 发现在计算互信息的值时, 需要用到后验分布(P(C|X),但是这是一个很难采样和估计的一个值。作者想出通过定义一个辅助分布来协助估计后验分布。这就与文中第一次给出的简单的对数推导连上了。
在引入辅助分布之后,通过变分分布最大化来最大化这个互信息的下界,这里是整个推导过程。
使用Q来接近p得到了互信息的变分估计, 在论文中提到, 为了优化方便, 作者固定了隐编码的分布, 因此H(C)是一个常量.通过对G,Q的优化来最大化L(G,Q),当时,不等式取等号,从而得到最大化互信息的效果。所以InfoGAN的最优化问题中的互信息正则项替换为L(G,Q)。但是在L(G,Q)中发现,还是有个后验分布。 作者利用一个定理来消除这个后验分布, 如下:
更详细的推导如下:
定理的证明:
因此就可以得到最小最大值InfoGAN,V(D,G)减去超参数λ与互信息变分正则化项的乘积,
在InfoGAN模型中,GAN的网络结构用的是DCGAN。给生成器输入隐含编码c和噪声z,生成假的数据,从假数据和真实数据中随机采样,输入给定D进行判断,是真还是假。Q通过与D共享卷积层,可以减少计算花销。在这里,Q是一个变分分布,在神经网络中直接最大化,Q也可以视作一个判别器,输出类别c。
实验,主要有两个目标,第一个是验证互信息可以被有效的最大化,第二个是验证InfoGAN是否可以学到可分解的可解释的特征。在第一个实验中,使用MNIST数据集,隐含编码c设为包括10个类别的离散编码,均匀分布,并使用GAN中同样加了一个辅助分布Q做对比试验。实验中发现,InfoGAN中L可以很快收敛到H(c),而GAN中,生成的图片与隐含编码c之间的互信息几乎没有,说明GAN中并没有用到这个隐含编码。
第二个实验来尝试学习数据集中可分解的特征。同样的使用MNIST数据集,在这里使用了三个隐含编码,c1用十个离散数字进行编码,每个类别的概率都是0.1,c2,c3连续编码,是-2到2的均匀分布。通过实验发现,c1可以作为一个分类器,分类的错误率为5%,图片a中第二行将7识别为9,但是不是意味着c1的0-9分别代表着生成数字的0-9,这是为了可视化效果,对数据重新排序的结果。如果在常规的GAN模型中添加c1编码,发现生成的图片与c1没有明显的关联。通过观察发现,c2表示生成数字的旋转的角度,c3表示生成数字的宽度。图c显示,小的c2值表示数字向左偏,大的c2值表示数字向右偏。图d显示,c3的值越大,生成的数字越宽。
除了MNIST数据集外,还在其他数据集上做了实验。比如在一个3D的人脸数据集中,使用多个连续的编码,得到了一些不同的特征。比如图a中,可以得到人脸转向的特征,图b得到了人脸仰角大小的特征,图c得到了图片亮度的特征,图d得到人脸宽窄的特征。
在3d椅子的数据集中,图a表示可以得到不同椅子旋转角度不同的特征,图b表示可以得到不同椅子宽度不同的特征。
在一个街景的楼栋数字数据集中,也得到了不同的特征。如图a所示,可以获取这些数字图片不同的亮度,图b所示,可以区分出图片中不同的数字。
在CelebA数据集中,同样的可以通过不同的编码获取一些特征,比如人脸不同的转向角度,是否带了眼镜,发型的不同,情绪的变化。
最后提出了一些未来工作的设想,将互信息最大化应用到其他的方法,尝试学习层级特征,用来提高半监督学习,发现更高维度的数据。通过这篇paper当中的实验发现,隐含编码c学习到的内容很丰富,可以用来进行半监督学习的分类。(往真实数据中添加噪声,可以使训练更稳定)
深度学习-InfoGAN论文理解笔记的更多相关文章
- 【计算机视觉】【神经网络与深度学习】论文阅读笔记:You Only Look Once: Unified, Real-Time Object Detection
尊重原创,转载请注明:http://blog.csdn.net/tangwei2014 这是继RCNN,fast-RCNN 和 faster-RCNN之后,rbg(Ross Girshick)大神挂名 ...
- 转载-【深度学习】深入理解Batch Normalization批标准化
全文转载于郭耀华-[深度学习]深入理解Batch Normalization批标准化: 文章链接Batch Normalization: Accelerating Deep Network T ...
- 深度学习-Wasserstein GAN论文理解笔记
GAN存在问题 训练困难,G和D多次尝试没有稳定性,Loss无法知道能否优化,生成样本单一,改进方案靠暴力尝试 WGAN GAN的Loss函数选择不合适,使模型容易面临梯度消失,梯度不稳定,优化目标不 ...
- 【深度学习】深入理解Batch Normalization批标准化
这几天面试经常被问到BN层的原理,虽然回答上来了,但还是感觉答得不是很好,今天仔细研究了一下Batch Normalization的原理,以下为参考网上几篇文章总结得出. Batch Normaliz ...
- 【深度学习】深入理解优化器Optimizer算法(BGD、SGD、MBGD、Momentum、NAG、Adagrad、Adadelta、RMSprop、Adam)
在机器学习.深度学习中使用的优化算法除了常见的梯度下降,还有 Adadelta,Adagrad,RMSProp 等几种优化器,都是什么呢,又该怎么选择呢? 在 Sebastian Ruder 的这篇论 ...
- 深度学习—BN的理解(一)
0.问题 机器学习领域有个很重要的假设:IID独立同分布假设,就是假设训练数据和测试数据是满足相同分布的,这是通过训练数据获得的模型能够在测试集获得好的效果的一个基本保障.那BatchNorm的作用是 ...
- 转:深度学习斯坦福cs231n 课程笔记
http://blog.csdn.net/dinosoft/article/details/51813615 前言 对于深度学习,新手我推荐先看UFLDL,不做assignment的话,一两个晚上就可 ...
- 【深度学习】深入理解ReLU(Rectifie Linear Units)激活函数
论文参考:Deep Sparse Rectifier Neural Networks (很有趣的一篇paper) Part 0:传统激活函数.脑神经元激活频率研究.稀疏激活性 0.1 一般激活函数有 ...
- 深度学习-DCGAN论文的理解笔记
训练方法DCGAN 的训练方法跟GAN 是一样的,分为以下三步: (1)for k steps:训练D 让式子[logD(x) + log(1 - D(G(z)) (G keeps still)]的值 ...
随机推荐
- P1501 [国家集训队]Tree II LCT
链接 luogu 思路 简单题 代码 #include <bits/stdc++.h> #define ls c[x][0] #define rs c[x][1] using namesp ...
- WDM驱动改可手动加卸载的NT驱动
WDM驱动改可手动加卸载的NT驱动 测试工具:osrloader 把一个WDM类型的驱动改成可动态加载/卸载,需要做以下2个修改: 1. 把SOURCES文件夹中的DRIVERTYPE=WDM去掉 2 ...
- VLC搭建RTSP服务器
实时流协议 RTSP 是在实时传输协议的基础上工作的,主要实现对多媒体播放的控制.用户对多媒体信息的播放.暂停.前进和后退等功能就是通过对实时数据流的控制来实现的. 而这些播放控制功能的实现不仅需要多 ...
- 【Swoole】php7.1安装swoole扩展
参照:https://zixuephp.net/article-430.html 1.源码编译安装,PHP版本7.1.33 2.在已经编译好安装的php7.1中安装swoole扩展. 一.下载swoo ...
- 关于uboot中的属性"u-boot,dm-pre-reloc”的意义
"u-boot,dm-pre-reloc”属性:当设置了这个属性时,则表示这个设备在重定向之前就需要使用. 当dm_init_and_scan的参数为true时,只会对带有“u-boot,d ...
- Mysql创建数据库以及用户分配权限
一.创建mysql数据库 1.创建数据库语法 --创建名称为“testdb”数据库,并设定编码集为utf8 CREATE DATABASE IF NOT EXISTS testdb DEFAULT C ...
- 接口项目servlet的一种处理方式,将异常返回给调用者【我】
接口项目servlet的一种处理方式,其他层有异常全部网上抛,抛到servlet层,将异常返回给调用者: Servlet层: private void processRequest(HttpServl ...
- 阿里云环境安装K8S步骤
1. 安装docker yum install -y docker 2. 修改 /etc/docker/daemon.json 文件并添加上 registry-mirrors 键值 $ vim /et ...
- error: RPC failed; curl 56 GnuTLS recv error (-54): Error in the pull function.
. . . . . 今天从 github 上 clone 代码的时候,出现了一个错误,重试多次后仍然出现,错误如下: >$ git clone https://github.com/BOINC/ ...
- ABAP基础篇2 数据类型
基本数据类型列表: 1.长度可变的内置类型(String.XString)1)string类型 在ABAP程序中,string类型是长度无限的字符型字段,可以和CHAR ,D,T ,I,N (F和P未 ...