深度学习-InfoGAN论文理解笔记
在弄清楚InfoGAN之前,可以先理解一下变分推断目的以及在概率论中的应用与ELBO是什么,以及KL散度 https://blog.csdn.net/qy20115549/article/details/93074519
https://blog.csdn.net/qy20115549/article/details/86644192。
如果理解了变分推断,KL散度,ELBO,对于InfoGAN中的重要方法就可以很容易理解了。
这里首先看一下简单的对数推导为方便对InfoGAN文中的公式的阅读:

下面的笔记参阅:
https://blog.csdn.net/u011699990/article/details/71599067
https://www.cnblogs.com/zzycv/p/9312039.html
先记一下预备知识就当作复习了。
条件熵公式推导:

用另一个变量对原变量分类后, 原变量的不确定性就会减小, 不确定程度减小了就是信息增益。
互信息(Mutual Information, MI)是变量间相互依赖性的度量, 它度量两个事件集合之间的相关性。
两个离散随机变量X和Y的互信息可定义为:

p(x,y)是X和Y的联合概率分布函数, 而p(x)和p(y)分别是X和Y的边缘概率分布函数
在连续随机变量的情形下:

I(X;Y)的一些计算:

GAN
总的来说,这个生成模型就是通过两个神经网络互相之间的竞争对抗来进行训练。这两个网络中有一个是生成器,它需要将随机的噪声分布z映射到我们需要得到的真实分布x,另外一个网络就是判别器,从真实数据和生成的数据中间随机采样,判断这个数据是否是真实数据,在这里判别器D相当于一个二分类器。所以,整个优化问题就转换为一个minmax game,D(x)代表x来自真实数据的概率,在训练过程中,最大化D(x)的值,同时,最小化生成器生成数据的能力,从而达到两个网络互相竞争,互相进步,生成器生成数据的能力越来越强,判别器判别数据的能力也越来越强。

如果从表征学习的角度来看GAN模型,会发现,由于在生成器使用噪声z的时候没有加任何的限制,所以在以一种高度混合的方式使用z,z的任何一个维度都没有明显的表示一个特征,所以在数据生成过程中,我们无法得知什么样的噪声z可以用来生成数字1,什么样的噪声z可以用来生成数字3,我们对这些一无所知,这从一点程度上限制了我们对GAN的使用。

InfoGAN
基于上面的分析,作者就在生成器中除了原先的噪声z还增加了一个隐含编码c,提出了一个新的GAN模型—InfoGAN,其中Info代表互信息,它表示生成数据x与隐藏编码c之间关联程度的大小,为了使得x与c之间关联密切,所以我们需要最大化互信息的值,据此对原始GAN模型的值函数做了一点修改,相当于加了一个互信息的正则化项。是一个超参,通过之后的实验选择了一个最优值1。

上面的I(C|X)中的|应该为";" 即I(C;X), I(C;X=G(Z,C))

在实践的过程中, 发现在计算互信息的值时, 需要用到后验分布(P(C|X),但是这是一个很难采样和估计的一个值。作者想出通过定义一个辅助分布来协助估计后验分布。这就与文中第一次给出的简单的对数推导连上了。


在引入辅助分布之后,通过变分分布最大化来最大化这个互信息的下界,这里是整个推导过程。


使用Q来接近p得到了互信息的变分估计, 在论文中提到, 为了优化方便, 作者固定了隐编码的分布, 因此H(C)是一个常量.通过对G,Q的优化来最大化L(G,Q),当时,不等式取等号,从而得到最大化互信息的效果。所以InfoGAN的最优化问题中的互信息正则项替换为L(G,Q)。但是在L(G,Q)中发现,还是有个后验分布。 作者利用一个定理来消除这个后验分布, 如下:

更详细的推导如下:

定理的证明:

因此就可以得到最小最大值InfoGAN,V(D,G)减去超参数λ与互信息变分正则化项的乘积,

在InfoGAN模型中,GAN的网络结构用的是DCGAN。给生成器输入隐含编码c和噪声z,生成假的数据,从假数据和真实数据中随机采样,输入给定D进行判断,是真还是假。Q通过与D共享卷积层,可以减少计算花销。在这里,Q是一个变分分布,在神经网络中直接最大化,Q也可以视作一个判别器,输出类别c。

实验,主要有两个目标,第一个是验证互信息可以被有效的最大化,第二个是验证InfoGAN是否可以学到可分解的可解释的特征。在第一个实验中,使用MNIST数据集,隐含编码c设为包括10个类别的离散编码,均匀分布,并使用GAN中同样加了一个辅助分布Q做对比试验。实验中发现,InfoGAN中L可以很快收敛到H(c),而GAN中,生成的图片与隐含编码c之间的互信息几乎没有,说明GAN中并没有用到这个隐含编码。

第二个实验来尝试学习数据集中可分解的特征。同样的使用MNIST数据集,在这里使用了三个隐含编码,c1用十个离散数字进行编码,每个类别的概率都是0.1,c2,c3连续编码,是-2到2的均匀分布。通过实验发现,c1可以作为一个分类器,分类的错误率为5%,图片a中第二行将7识别为9,但是不是意味着c1的0-9分别代表着生成数字的0-9,这是为了可视化效果,对数据重新排序的结果。如果在常规的GAN模型中添加c1编码,发现生成的图片与c1没有明显的关联。通过观察发现,c2表示生成数字的旋转的角度,c3表示生成数字的宽度。图c显示,小的c2值表示数字向左偏,大的c2值表示数字向右偏。图d显示,c3的值越大,生成的数字越宽。

除了MNIST数据集外,还在其他数据集上做了实验。比如在一个3D的人脸数据集中,使用多个连续的编码,得到了一些不同的特征。比如图a中,可以得到人脸转向的特征,图b得到了人脸仰角大小的特征,图c得到了图片亮度的特征,图d得到人脸宽窄的特征。

在3d椅子的数据集中,图a表示可以得到不同椅子旋转角度不同的特征,图b表示可以得到不同椅子宽度不同的特征。

在一个街景的楼栋数字数据集中,也得到了不同的特征。如图a所示,可以获取这些数字图片不同的亮度,图b所示,可以区分出图片中不同的数字。

在CelebA数据集中,同样的可以通过不同的编码获取一些特征,比如人脸不同的转向角度,是否带了眼镜,发型的不同,情绪的变化。

最后提出了一些未来工作的设想,将互信息最大化应用到其他的方法,尝试学习层级特征,用来提高半监督学习,发现更高维度的数据。通过这篇paper当中的实验发现,隐含编码c学习到的内容很丰富,可以用来进行半监督学习的分类。(往真实数据中添加噪声,可以使训练更稳定)
深度学习-InfoGAN论文理解笔记的更多相关文章
- 【计算机视觉】【神经网络与深度学习】论文阅读笔记:You Only Look Once: Unified, Real-Time Object Detection
尊重原创,转载请注明:http://blog.csdn.net/tangwei2014 这是继RCNN,fast-RCNN 和 faster-RCNN之后,rbg(Ross Girshick)大神挂名 ...
- 转载-【深度学习】深入理解Batch Normalization批标准化
全文转载于郭耀华-[深度学习]深入理解Batch Normalization批标准化: 文章链接Batch Normalization: Accelerating Deep Network T ...
- 深度学习-Wasserstein GAN论文理解笔记
GAN存在问题 训练困难,G和D多次尝试没有稳定性,Loss无法知道能否优化,生成样本单一,改进方案靠暴力尝试 WGAN GAN的Loss函数选择不合适,使模型容易面临梯度消失,梯度不稳定,优化目标不 ...
- 【深度学习】深入理解Batch Normalization批标准化
这几天面试经常被问到BN层的原理,虽然回答上来了,但还是感觉答得不是很好,今天仔细研究了一下Batch Normalization的原理,以下为参考网上几篇文章总结得出. Batch Normaliz ...
- 【深度学习】深入理解优化器Optimizer算法(BGD、SGD、MBGD、Momentum、NAG、Adagrad、Adadelta、RMSprop、Adam)
在机器学习.深度学习中使用的优化算法除了常见的梯度下降,还有 Adadelta,Adagrad,RMSProp 等几种优化器,都是什么呢,又该怎么选择呢? 在 Sebastian Ruder 的这篇论 ...
- 深度学习—BN的理解(一)
0.问题 机器学习领域有个很重要的假设:IID独立同分布假设,就是假设训练数据和测试数据是满足相同分布的,这是通过训练数据获得的模型能够在测试集获得好的效果的一个基本保障.那BatchNorm的作用是 ...
- 转:深度学习斯坦福cs231n 课程笔记
http://blog.csdn.net/dinosoft/article/details/51813615 前言 对于深度学习,新手我推荐先看UFLDL,不做assignment的话,一两个晚上就可 ...
- 【深度学习】深入理解ReLU(Rectifie Linear Units)激活函数
论文参考:Deep Sparse Rectifier Neural Networks (很有趣的一篇paper) Part 0:传统激活函数.脑神经元激活频率研究.稀疏激活性 0.1 一般激活函数有 ...
- 深度学习-DCGAN论文的理解笔记
训练方法DCGAN 的训练方法跟GAN 是一样的,分为以下三步: (1)for k steps:训练D 让式子[logD(x) + log(1 - D(G(z)) (G keeps still)]的值 ...
随机推荐
- 在IDEA编辑器中建立Spring Cloud的子项目包(构建微服务)
本文介绍在IDEA编辑器中建立Spring Cloud的子项目包 总共分为5个包: 外层使用maven quickstart建立,子modules直接选择了springboot
- SpringBoot 的一些学习资源
很多Java Web开发者目前常用的技术还是SpringBoot,想要工作效率更,在刚入门不久时有必要全面了解一下它的功能特性,而高效学习的方法,除了在官网学习外,还可以看下网上已有的不错的教程.刚看 ...
- 洛谷P2580 于是他错误的点名开始了 题解
qwq!为什么!木有非结构体非指针的题解怎么阔以!所以, 我来辽~咻咻咻~ 题面 来分析, 我们可以先建一棵树,来存储整个名单, 然后再判断 ; i <= n; i++) { root = ; ...
- CCF 201803-3 URL映射
CCF 201803-3 URL映射 试题编号: 201803-3 试题名称: URL映射 时间限制: 1.0s 内存限制: 256.0MB 问题描述: 问题描述 URL 映射是诸如 Django. ...
- Leetcode: 43. 接雨水
题目描述: 给定 n 个非负整数表示每个宽度为 1 的柱子的高度图,计算按此排列的柱子,下雨之后能接多少雨水. 上面是由数组 [0,1,0,2,1,0,1,3,2,1,2,1] 表示的高度图,在这种情 ...
- windows 安装 Composer 报错
错误信息如下: 解决方法: 在 extension = php_snmp.dll 前加上 ";" 然后重启 php 即可安装
- css3特效插件wow.js
在使用css3写特效的时候,会遇到比较麻烦的就是css3代码需要大量的调试,但是现在有了wow.js,让写特效变得简单了很多. wow.js官网 https://www.delac.io/wow/in ...
- rabbitMQ消息队列 – Message方法解析
消息的创建由AMQPMessage对象来创建$message = new AMQPMessage("消息内容");是不是很简单. 后边是一个数组.可以对消息进行一些特殊配置$mes ...
- ubuntu 各压缩文件解压命令大全
.tar 解包:tar xvf xxx.tar 打包:tar cvf xxx.tar DirName (注:tar是打包,不是压缩!) .gz 解压1:gunzip FileName.gz 解压2:g ...
- linux编译qt
1.使用QtCreator新建工程,windows和linux都可以,这样才有.pro文件 2.在linux中进入工程目录,生成makefile: /home/5.9.2/gcc_64/bin/qma ...