Wasserstein GAN最新进展:从weight clipping到gradient penalty,更加先进的Lipschitz限制手法
前段时间,Wasserstein GAN以其精巧的理论分析、简单至极的算法实现、出色的实验效果,在GAN研究圈内掀起了一阵热潮(对WGAN不熟悉的读者,可以参考我之前写的介绍文章:令人拍案叫绝的Wasserstein GAN - 知乎专栏)。但是很多人(包括我们实验室的同学)到了上手跑实验的时候,却发现WGAN实际上没那么完美,反而存在着训练困难、收敛速度慢等问题。其实,WGAN的作者Martin Arjovsky不久后就在reddit上表示他也意识到了这个问题,认为关键在于原设计中Lipschitz限制的施加方式不对,并在新论文中提出了相应的改进方案:
首先回顾一下WGAN的关键部分——Lipschitz限制是什么。WGAN中,判别器D和生成器G的loss函数分别是:
(公式1)
(公式2)
公式1表示判别器希望尽可能拉高真样本的分数,拉低假样本的分数,公式2表示生成器希望尽可能拉高假样本的分数。
Lipschitz限制则体现为,在整个样本空间
上,要求判别器函数D(x)梯度的Lp-norm不大于一个有限的常数K:
(公式3)
直观上解释,就是当输入的样本稍微变化后,判别器给出的分数不能发生太过剧烈的变化。在原来的论文中,这个限制具体是通过weight clipping的方式实现的:每当更新完一次判别器的参数之后,就检查判别器的所有参数的绝对值有没有超过一个阈值,比如0.01,有的话就把这些参数clip回 [-0.01, 0.01] 范围内。通过在训练过程中保证判别器的所有参数有界,就保证了判别器不能对两个略微不同的样本给出天差地别的分数值,从而间接实现了Lipschitz限制。
然而weight clipping的实现方式存在两个严重问题:
第一,如公式1所言,判别器loss希望尽可能拉大真假样本的分数差,然而weight clipping独立地限制每一个网络参数的取值范围,在这种情况下我们可以想象,最优的策略就是尽可能让所有参数走极端,要么取最大值(如0.01)要么取最小值(如-0.01)!为了验证这一点,作者统计了经过充分训练的判别器中所有网络参数的数值分布,发现真的集中在最大和最小两个极端上:

这样带来的结果就是,判别器会非常倾向于学习一个简单的映射函数(想想看,几乎所有参数都是正负0.01,都已经可以直接视为一个二值神经网络了,太简单了)。而作为一个深层神经网络来说,这实在是对自身强大拟合能力的巨大浪费!判别器没能充分利用自身的模型能力,经过它回传给生成器的梯度也会跟着变差。
在正式介绍gradient penalty之前,我们可以先看看在它的指导下,同样充分训练判别器之后,参数的数值分布就合理得多了,判别器也能够充分利用自身模型的拟合能力:

第二个问题,weight clipping会导致很容易一不小心就梯度消失或者梯度爆炸。原因是判别器是一个多层网络,如果我们把clipping threshold设得稍微小了一点,每经过一层网络,梯度就变小一点点,多层之后就会指数衰减;反之,如果设得稍微大了一点,每经过一层网络,梯度变大一点点,多层之后就会指数爆炸。只有设得不大不小,才能让生成器获得恰到好处的回传梯度,然而在实际应用中这个平衡区域可能很狭窄,就会给调参工作带来麻烦。相比之下,gradient penalty就可以让梯度在后向传播的过程中保持平稳。论文通过下图体现了这一点,其中横轴代表判别器从低到高第几层,纵轴代表梯度回传到这一层之后的尺度大小(注意纵轴是对数刻度),c是clipping threshold:

说了这么多,gradient penalty到底是什么?
前面提到,Lipschitz限制是要求判别器的梯度不超过K,那我们何不直接设置一个额外的loss项来体现这一点呢?比如说:
(公式4)
不过,既然判别器希望尽可能拉大真假样本的分数差距,那自然是希望梯度越大越好,变化幅度越大越好,所以判别器在充分训练之后,其梯度norm其实就会是在K附近。知道了这一点,我们可以把上面的loss改成要求梯度norm离K越近越好,效果是类似的:
(公式5)
究竟是公式4好还是公式5好,我看不出来,可能需要实验验证,反正论文作者选的是公式5。接着我们简单地把K定为1,再跟WGAN原来的判别器loss加权合并,就得到新的判别器loss:
(公式6)
这就是所谓的gradient penalty了吗?还没完。公式6有两个问题,首先是loss函数中存在梯度项,那么优化这个loss岂不是要算梯度的梯度?一些读者可能对此存在疑惑,不过这属于实现上的问题,放到后面说。
其次,3个loss项都是期望的形式,落到实现上肯定得变成采样的形式。前面两个期望的采样我们都熟悉,第一个期望是从真样本集里面采,第二个期望是从生成器的噪声输入分布采样后,再由生成器映射到样本空间。可是第三个分布要求我们在整个样本空间
上采样,这完全不科学!由于所谓的维度灾难问题,如果要通过采样的方式在图片或自然语言这样的高维样本空间中估计期望值,所需样本量是指数级的,实际上没法做到。
所以,论文作者就非常机智地提出,我们其实没必要在整个样本空间上施加Lipschitz限制,只要重点抓住生成样本集中区域、真实样本集中区域以及夹在它们中间的区域就行了。具体来说,我们先随机采一对真假样本,还有一个0-1的随机数:
(公式7)
然后在
和
的连线上随机插值采样:
(公式8)
把按照上述流程采样得到的
所满足的分布记为
,就得到最终版本的判别器loss:
(公式9)
这就是新论文所采用的gradient penalty方法,相应的新WGAN模型简称为WGAN-GP。我们可以做一个对比:
- weight clipping是对样本空间全局生效,但因为是间接限制判别器的梯度norm,会导致一不小心就梯度消失或者梯度爆炸;
- gradient penalty只对真假样本集中区域、及其中间的过渡地带生效,但因为是直接把判别器的梯度norm限制在1附近,所以梯度可控性非常强,容易调整到合适的尺度大小。
论文还讲了一些使用gradient penalty时需要注意的配套事项,这里只提一点:由于我们是对每个样本独立地施加梯度惩罚,所以判别器的模型架构中不能使用Batch Normalization,因为它会引入同个batch中不同样本的相互依赖关系。如果需要的话,可以选择其他normalization方法,如Layer Normalization、Weight Normalization和Instance Normalization,这些方法就不会引入样本之间的依赖。论文推荐的是Layer Normalization。
实验表明,gradient penalty能够显著提高训练速度,解决了原始WGAN收敛缓慢的问题:

虽然还是比不过DCGAN,但是因为WGAN不存在平衡判别器与生成器的问题,所以会比DCGAN更稳定,还是很有优势的。不过,作者凭什么能这么说?因为下面的实验体现出,在各种不同的网络架构下,其他GAN变种能不能训练好,可以说是一件相当看人品的事情,但是WGAN-GP全都能够训练好,尤其是最下面一行所对应的101层残差神经网络:

剩下的实验结果中,比较厉害的是第一次成功做到了“纯粹的”的文本GAN训练!我们知道在图像上训练GAN是不需要额外的有监督信息的,但是之前就没有人能够像训练图像GAN一样训练好一个文本GAN,要么依赖于预训练一个语言模型,要么就是利用已有的有监督ground truth提供指导信息。而现在WGAN-GP终于在无需任何有监督信息的情况下,生成出下图所示的英文字符序列:

它是怎么做到的呢?我认为关键之处是对样本形式的更改。以前我们一般会把文本这样的离散序列样本表示为sequence of index,但是它把文本表示成sequence of probability vector。对于生成样本来说,我们可以取网络softmax层输出的词典概率分布向量,作为序列中每一个位置的内容;而对于真实样本来说,每个probability vector实际上就蜕化为我们熟悉的onehot vector。
但是如果按照传统GAN的思路来分析,这不是作死吗?一边是hard onehot vector,另一边是soft probability vector,判别器一下子就能够区分它们,生成器还怎么学习?没关系,对于WGAN来说,真假样本好不好区分并不是问题,WGAN只是拉近两个分布之间的Wasserstein距离,就算是一边是hard onehot另一边是soft probability也可以拉近,在训练过程中,概率向量中的有些项可能会慢慢变成0.8、0.9到接近1,整个向量也会接近onehot,最后我们要真正输出sequence of index形式的样本时,只需要对这些概率向量取argmax得到最大概率的index就行了。
新的样本表示形式+WGAN的分布拉近能力是一个“黄金组合”,但除此之外,还有其他因素帮助论文作者跑出上图的效果,包括:
- 文本粒度为英文字符,而非英文单词,所以字典大小才二三十,大大减小了搜索空间
- 文本长度也才32
- 生成器用的不是常见的LSTM架构,而是多层反卷积网络,输入一个高斯噪声向量,直接一次性转换出所有32个字符
上面第三点非常有趣,因为它让我联想到前段时间挺火的语言学科幻电影《降临》:

里面的外星人“七肢怪”所使用的语言跟人类不同,人类使用的是线性的、串行的语言,而“七肢怪”使用的是非线性的、并行的语言。“七肢怪”在跟主角交流的时候,都是一次性同时给出所有的语义单元的,所以说它们其实是一些多层反卷积网络进化出来的人工智能生命吗?


开完脑洞,我们回过头看,不得不承认这个实验的setup实在过于简化了,能否扩展到更加实际的复杂场景,也会是一个问题。但是不管怎样,生成出来的结果仍然是突破性的。
最后说回gradient penalty的实现问题。loss中本身包含梯度,优化loss就需要求梯度的梯度,这个功能并不是现在所有深度学习框架的标配功能,不过好在Tensorflow就有提供这个接口——tf.gradients。开头链接的GitHub源码中就是这么写的:
# interpolates就是随机插值采样得到的图像,gradients就是loss中的梯度惩罚项
gradients = tf.gradients(Discriminator(interpolates), [interpolates])[0]
对于我这样的PyTorch党就非常不幸了,高阶梯度的功能还在开发,感兴趣的PyTorch党可以订阅这个GitHub的pull request:Autograd refactor,如果它被merged了话就可以在最新版中使用高阶梯度的功能实现gradient penalty了。
但是除了等待我们就没有别的办法了吗?其实可能是有的,我想到了一种近似方法来实现gradient penalty,只需要把微分换成差分:
(公式10)
也就是说,我们仍然是在分布
上随机采样,但是一次采两个,然后要求它们的连线斜率要接近1,这样理论上也可以起到跟公式9一样的效果,我自己在MNIST+MLP上简单验证过有作用,PyTorch党甚至Tensorflow党都可以尝试用一下。
作者:郑华滨
链接:https://www.zhihu.com/question/52602529/answer/158727900
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
Wasserstein GAN最新进展:从weight clipping到gradient penalty,更加先进的Lipschitz限制手法的更多相关文章
- W-GAN系 (Wasserstein GAN、 Improved WGAN)
学习总结于国立台湾大学 :李宏毅老师 WGAN前作:Towards Principled Methods for Training Generative Adversarial Networks W ...
- (转) Read-through: Wasserstein GAN
Sorta Insightful Reviews Projects Archive Research About In a world where everyone has opinions, on ...
- 关于Wasserstein GAN的一些笔记
这篇笔记基于上一篇<关于GAN的一些笔记>. 1 GAN的缺陷 由于 $P_G$ 和 $P_{data}$ 它们实际上是 high-dim space 中的 low-dim manifol ...
- paper 91:边缘检测近期最新进展的讨论
VALSE QQ群对边缘检测近期最新进展的讨论,内容整理如下: 1)推荐一篇deep learning的文章,该文章大幅度提高了edge detection的精度,在bsds上,将edge detec ...
- paper 90:人脸检测研究2015最新进展
搜集整理了2004~2015性能最好的人脸检测的部分资料,欢迎交流和补充相关资料. 1:人脸检测性能 1.1 人脸检测测评 目前有两个比较大的人脸测评网站: 1:Face Detection Data ...
- Generative Adversarial Nets[Wasserstein GAN]
本文来自<Wasserstein GAN>,时间线为2017年1月,本文可以算得上是GAN发展的一个里程碑文献了,其解决了以往GAN训练困难,结果不稳定等问题. 1 引言 本文主要思考的是 ...
- Graph 卷积神经网络:概述、样例及最新进展
http://www.52ml.net/20031.html [新智元导读]Graph Convolutional Network(GCN)是直接作用于图的卷积神经网络,GCN 允许对结构化数据进行端 ...
- Wasserstein GAN
在GAN的相关研究如火如荼甚至可以说是泛滥的今天,一篇新鲜出炉的arXiv论文<Wasserstein GAN>却在Reddit的Machine Learning频道火了,连Goodfel ...
- SQL on Hadoop系统的最新进展(1)
转自:http://blog.jobbole.com/47892/ 为什么非要把SQL放到Hadoop上? SQL易于使用.那为什么非得基于Hadoop呢?the robust and scalabl ...
随机推荐
- Python数据类型之数字
数字(数值) 整数 :123 (int型) 浮点数: 0.25(带小数点的数字即为浮点数,Float型) 布尔值:False,True(即0和1,bool型) 复数 (暂无资料,complex型) 整 ...
- Python文件读取和数据处理
一.python文件读取 1.基本操作 读取文件信息时要注意文件编码,文件编码有UFT-8.ASCII或UTF-16等. 不过在python中最为常用的是UTF-8,所以如果不特别说明就默认UTF-8 ...
- dubbo学习笔记:快速搭建
搭建一个简单的dubbo服务 参考地址: dubbo官网:http://dubbo.apache.org/zh-cn/docs/user/references/registry/zookeeper.h ...
- iddler抓包过程以及fiddler抓包手机添加代理后连不上网解决办法
转载自:https://blog.csdn.net/m0_37554415/article/details/80434477,感谢博主的热心分享 1.(1)电脑端打开安装好的的fiddler,打开To ...
- html个人简历
https://gitee.com/aijiawei3344/codes/g8piyjc3kb7nav4whqd2r79 <!DOCTYPE html> <html> < ...
- (转)mysql主从切换步骤
原文:http://6226001001.blog.51cto.com/9243584/1723273 1> 正常切换 1)从服务器检查SHOW PROCESSLIST语句的输出,直到你看到Ha ...
- springmvc执行原理及自定义mvc框架
springmvc是spring的一部分,也是一个优秀的mvc框架,其执行原理如下: (1)浏览器提交请求经web容器(比如tomcat)转发到中央调度器dispatcherServlet. (2)中 ...
- 如何为 Go 设计一个通用的日志包
需求 一个通用的日志包,应该满足以下几个需求: 兼容 log.Logger,标准库大量使用了 log.Logger 作为其错误内容的输出通道,比如 net/http.Server.ErrorLog,所 ...
- package-info类解读
类不能带有public.private访问权限.package-info.java再怎么特殊,也是一个类文件,也会被编译成package-info.class,但是在package-info.java ...
- ubuntu工具安装
smplayer sudo add-apt-repository ppa:rvm/smplayer sudo apt-get update sudo apt-get install smplayer ...