Song Y. and Ermon S. Generative modeling by estimating gradients of the data distribution. In Advances in Neural Information Processing Systems (NIPS), 2019.

当前生成模型, 要么依赖对抗损失(GAN), 要么依赖替代损失(VAE), 本文提出了基于score matching 训练, 以及利用annealed Langevin dynamics推断的模型, 思想非常有趣.

主要内容

Langevin dynamics

对于分布\(p(x)\), 我们可以通过下列方式迭代生成

\[\tilde{x}_t = \tilde{x}_{t-1} + \frac{\epsilon}{2} \nabla_x \log p (\tilde{x}_{t-1}) + \sqrt{\epsilon} z_t,
\]

其中\(\tilde{x}_0 \sim \pi(x)\)来自一个先验分布, \(z_t \sim \mathcal{N}(0, I)\). 当步长\(\epsilon \rightarrow 0\)并且\(T \rightarrow +\infty\)的时候, \(\tilde{x}_T\)可以认为是从\(p(x)\)中采样的样本.

注: 一般的Langevin, dynamics还需要在每一次迭代后计算一个接受概率然后判断是否接受, 不过在实际中这一步往往可以省略.

Score Matching

通过上述的迭代可以发现, 我们只需要获得\(\nabla_x \log p(x)\)即可采样\(x\), 我们可以期望通过下面的方式, 通过一个网络\(s_{\theta}(x)\)来逼近\(\nabla_x \log p_{data}(x)\):

\[\min_{\theta} \: \frac{1}{2} \mathbb{E}_{p_{data}(x)} [\| s_{\theta} (x) - \nabla_x \log p_{data}(x) \|_2^2],
\]

但是在实际中, 先验\(\log p_{data}(x)\)也是未知的, 幸运的是上述公式等价于:

\[\min_{\theta} \: \mathbb{E}_{p_{data}(x)} [\mathrm{tr}(\nabla_x s_{\theta} (x)) + \frac{1}{2} \|s_{\theta}(x)\|_2^2].
\]

注: 见 score matching

Denoising Score Matching

一个共识是, 所获得的数据往往是一个低维流形, 即其内在的维度实际上很低. 所以\(\mathbb{E}_{p_{data}(x)}\)在实际中会出现高密度的区域估计得很好, 但是低密度得区域估计得非常差. Denosing Score Matching提高了一个较为鲁棒的替代方法:

\[\min_{\theta} \: \frac{1}{2} \mathbb{E}_{q_{\sigma}(\tilde{x}|x)p_{data}(x)} [\| s_{\theta} (\tilde{x}) - \nabla_x \log q_{\sigma}(\tilde{x}|x) \|_2^2].
\]

当优化得足够好的时候,

\[s_{\theta^*}(x) = \nabla_x \log q_{\sigma}(x), \: q_{\sigma}(\tilde{x}) := \int q_{\sigma}(\tilde{x}|x) p_{data}(x) \mathrm{d}x.
\]

实际中, 通常取\(q_{\sigma}(\tilde{x}|x) = \mathcal{N}(\tilde{x}|x, \sigma^2 I)\), 相当于在真实数据\(x\)上加了一个扰动, 当扰动足够小(\(\sigma\)足够小)的时候, \(q_{\sigma}(x) \approx p_{data}(x)\), 则\(s_{\theta^*}(x) \approx \nabla_x \log p_{data}(x)\).

注: 为啥期望部分要有\(p_{data}\)? 实际上上述目标和score matching依旧是等价的.

Noise Conditional Score Networks

Slow mixing of Langevin dynamics

假设\(p_{data}(x) = \pi p_1(x) + (1 - \pi)p_2(x)\), 且\(p_1, p_2\)的支撑集合是互斥的, 那么 \(\nabla_{x} \log p_{data}(x)\)要么为\(\nabla_{x} \log p_{1}(x)\)或者\(\nabla_{x} \log p_{2}(x)\), 与\(\pi\)没有丝毫关联, 这会导致训练的结果与\(\pi\)也没有关联. 在实际中, 若\(p_1, p_2\)近似互斥, 也会产生类似的情况:

如上图所示, 通过Langevin dynamics采样的点几乎是1:1的, 这与真实的分布便有了出入.

作者的想法是, 设计一个noise conditional score networks:

\[s_\theta(x, \sigma),
\]

给定不同的\(\sigma\)其拟合不同扰动大小的\(p_{\sigma}\), 在采样中, 首先用大一点的\(\sigma\), 然后再逐步缩小, 这便是一种退火的思想. 显然, 一开始用大一点的\(\sigma\)能够为后面的采样提供更好更鲁棒的初始点.

损失函数

设定\(\sigma_i, i=1,2,\cdots, L\), 且满足:

\[\frac{\sigma_1}{\sigma_2} = \cdots = \frac{\sigma_{L-1}}{\sigma_L} > 1,
\]

即一个等比例(缩小)的数列.

对于每个\(\sigma\)采用如下损失:

\[\ell(\theta; \sigma) =
\frac{1}{2} \mathbb{E}_{p_{data}(x)} \mathbb{E}_{\mathcal{N}(\tilde{x}|x, \sigma I)} [\| s_{\theta} (\tilde{x}, \sigma) + \frac{\tilde{x} - x}{\sigma^2} \|_2^2].
\]

注: \(\nabla_{\tilde{x}} q_{\sigma}(\tilde{x}|x) = -\frac{\tilde{x} - x}{\sigma^2}\).

于是总损失为

\[\mathcal{L}(\theta; \{\sigma_i\}_{i=1}^L) := \frac{1}{L}\sum_{i=1}^L \lambda (\sigma_i)\ell(\theta;\sigma_i),
\]

\(\lambda(\sigma_i)\)为权重系数.

Annealed Langevin dynamics

Input: \(\{\sigma_i\}_{i=1}^L, \epsilon, T\);

  1. 初始化\(x_0\);
  2. For \(i=1,2,\cdots, L\) do:
    • \(\alpha_i \leftarrow \epsilon \cdot \sigma_i^2 / \sigma_L^2\);
    • For \(t=1,2,\cdots, T\) do:
      • 采样\(z_t \sim \mathcal{N}(0, I)\);
      • \(x_t \leftarrow x_{t-1} + \frac{\alpha_i}{2}s_{\theta}(x_{t-1}, \sigma) + \sqrt{\alpha_i} z_t\);
    • \(x_0 \leftarrow x_T\);

Output: \(x_T\).

细节

  1. 关于参数\(\lambda(\sigma)\)的选择:

    作者推荐选择\(\lambda(\sigma) = \sigma^2\), 因为当优化到最优的时候, \(\|s_{\theta}(x, \sigma)\|_2 \propto 1 / \sigma\), 故\(\sigma^2 \ell(\theta;\sigma) = \frac{1}{2}\mathbb{E}[\|\sigma s_{\theta}(x, \sigma) + \frac{\tilde{x} - x}{\sigma} \|_2^2]\), 其中\(\sigma s_{\theta}(x, \sigma) \propto 1, \frac{\tilde{x} - x}{\sigma} \sim \mathcal{N}(0, I)\), 故\(\sigma^2 \ell_{\theta,\sigma}\)与\(\sigma\)无关.

  2. 关于\(\alpha_i \leftarrow \epsilon \cdot \sigma_i^2 / \sigma_L^2\):

对于一次Langevin dynamic, 其获得的信息为: \(\frac{\alpha_i}{2} s_{\theta}(x_{t-1}, \sigma)\), 其噪声为\(\sqrt{\alpha_i}z_t\), 故其信噪比(signal-to-noise)为(应该是element-wise的计算?)

\[\frac{\alpha_i s_{\theta}(x, \sigma_i)}{2 \sqrt{\alpha_i} z},
\]

当我们按照算法中的取法时, 我们有

\[\begin{array}{ll}
\|\frac{\alpha_i s_{\theta}(x, \sigma_i)}{2 \sqrt{\alpha_i} z}\|_2^2
&\approx\frac{\alpha_i \| s_{\theta}(x, \sigma_i)\|_2^2}{4} \\
&\propto\frac{\|\sigma_i s_{\theta}(x, \sigma_i)\|_2^2}{4} \\
&\propto \frac{1}{4}.
\end{array}
\]

故采用此策略能够保证SNR保持一个稳定的值.

代码

原文代码

Generative Modeling by Estimating Gradients of the Data Distribution的更多相关文章

  1. 泡泡一分钟:GEN-SLAM - Generative Modeling for Monocular Simultaneous Localization and Mapping

    张宁  GEN-SLAM - Generative Modeling for Monocular Simultaneous Localization and Mapping GEN-SLAM  - 单 ...

  2. Statistics : Data Distribution

    1.Normal distribution In probability theory, the normal (or Gaussian or Gauss or Laplace–Gauss) dist ...

  3. 论文笔记之:Generative Adversarial Nets

    Generative Adversarial Nets NIPS 2014  摘要:本文通过对抗过程,提出了一种新的框架来预测产生式模型,我们同时训练两个模型:一个产生式模型 G,该模型可以抓住数据分 ...

  4. [GAN] Generative networks

    中文版:https://zhuanlan.zhihu.com/p/27440393 原文版:https://www.oreilly.com/learning/generative-adversaria ...

  5. Generative model 和Discriminative model

    学习音乐自动标注过程中设计了有关分类型模型和生成型模型的东西,特地查了相关资料,在这里汇总. http://blog.sina.com.cn/s/blog_a18c98e50101058u.html ...

  6. 生成模型(Generative)和判别模型(Discriminative)

    生成模型(Generative)和判别模型(Discriminative) 引言    最近看文章<A survey of appearance models in visual object ...

  7. (转)Deep Learning Research Review Week 1: Generative Adversarial Nets

    Adit Deshpande CS Undergrad at UCLA ('19) Blog About Resume Deep Learning Research Review Week 1: Ge ...

  8. (转)Introductory guide to Generative Adversarial Networks (GANs) and their promise!

    Introductory guide to Generative Adversarial Networks (GANs) and their promise! Introduction Neural ...

  9. SalGAN: Visual saliency prediction with generative adversarial networks

    SalGAN: Visual saliency prediction with generative adversarial networks 2017-03-17 摘要:本文引入了对抗网络的对抗训练 ...

随机推荐

  1. Learning Spark中文版--第三章--RDD编程(2)

    Common Transformations and Actions   本章中,我们浏览了Spark中大多数常见的transformation(转换)和action(开工).在包含特定数据类型的RD ...

  2. Spring Cloud中使用Eureka

    一.创建00-eurekaserver-8000 (1)创建工程 创建一个Spring Initializr工程,命名为00-eurekaserver-8000,仅导入Eureka Server依赖即 ...

  3. @PropertySource配置的用法

    功能 加载指定的属性文件(*.properties)到 Spring 的 Environment 中.可以配合 @Value 和@ConfigurationProperties 使用. @Proper ...

  4. maven高级学习

    上一篇<maven是什么>介绍了最初级的maven学习,今天就趁着周末的大好时光一起学习下maven的高级知识吧. 1.maven工程要导入jar包的坐标,就必须要考虑解决jar冲突 1) ...

  5. 网络通信引擎ICE的使用

    ICE是一种网络通信引擎,在javaWeb的开发中可以用于解决局域网内部服务器端与客户端之间的网络通信问题.即可以在 1.在服务器和客户端都安装好ICE 2.服务器端(java)在java项目中引入I ...

  6. shell脚本 监控网卡信息

    一.简介 源码地址 日期:2018/6/22 介绍:显示实时输入输出流量 效果图: 二.使用 适用:centos6+ 语言:英文 注意:无 下载 wget https://raw.githubuser ...

  7. [BUUCTF]PWN——xdctf2015_pwn200

    xdctf2015_pwn200 附件 步骤 例行检查,32位程序,开启了nx保护 本地试运行一下程序,看看大概的情况 32位ida载入,习惯性的检索程序里的字符串,没有发现什么铭感的地方,直接看ma ...

  8. Linux驱动实践:一起来梳理中断的前世今生(附代码)

    作 者:道哥,10+年嵌入式开发老兵,专注于:C/C++.嵌入式.Linux. 关注下方公众号,回复[书籍],获取 Linux.嵌入式领域经典书籍:回复[PDF],获取所有原创文章( PDF 格式). ...

  9. 日历的种类(Project)

    <Project2016 企业项目管理实践>张会斌 董方好 编著   日历有三种:标准日历.24小时日历和夜班日历. 设置的位置在[项目]>[属性]>[更改工作时间]>[ ...

  10. EhCache简单入门

    一 介绍 EhCache 是一个纯Java的进程内缓存框架,具有快速.精干等特点,是Hibernate中默认CacheProvider.Ehcache是一种广泛使用的开源Java分布式缓存.主要面向通 ...