Song Y. and Ermon S. Generative modeling by estimating gradients of the data distribution. In Advances in Neural Information Processing Systems (NIPS), 2019.

当前生成模型, 要么依赖对抗损失(GAN), 要么依赖替代损失(VAE), 本文提出了基于score matching 训练, 以及利用annealed Langevin dynamics推断的模型, 思想非常有趣.

主要内容

Langevin dynamics

对于分布\(p(x)\), 我们可以通过下列方式迭代生成

\[\tilde{x}_t = \tilde{x}_{t-1} + \frac{\epsilon}{2} \nabla_x \log p (\tilde{x}_{t-1}) + \sqrt{\epsilon} z_t,
\]

其中\(\tilde{x}_0 \sim \pi(x)\)来自一个先验分布, \(z_t \sim \mathcal{N}(0, I)\). 当步长\(\epsilon \rightarrow 0\)并且\(T \rightarrow +\infty\)的时候, \(\tilde{x}_T\)可以认为是从\(p(x)\)中采样的样本.

注: 一般的Langevin, dynamics还需要在每一次迭代后计算一个接受概率然后判断是否接受, 不过在实际中这一步往往可以省略.

Score Matching

通过上述的迭代可以发现, 我们只需要获得\(\nabla_x \log p(x)\)即可采样\(x\), 我们可以期望通过下面的方式, 通过一个网络\(s_{\theta}(x)\)来逼近\(\nabla_x \log p_{data}(x)\):

\[\min_{\theta} \: \frac{1}{2} \mathbb{E}_{p_{data}(x)} [\| s_{\theta} (x) - \nabla_x \log p_{data}(x) \|_2^2],
\]

但是在实际中, 先验\(\log p_{data}(x)\)也是未知的, 幸运的是上述公式等价于:

\[\min_{\theta} \: \mathbb{E}_{p_{data}(x)} [\mathrm{tr}(\nabla_x s_{\theta} (x)) + \frac{1}{2} \|s_{\theta}(x)\|_2^2].
\]

注: 见 score matching

Denoising Score Matching

一个共识是, 所获得的数据往往是一个低维流形, 即其内在的维度实际上很低. 所以\(\mathbb{E}_{p_{data}(x)}\)在实际中会出现高密度的区域估计得很好, 但是低密度得区域估计得非常差. Denosing Score Matching提高了一个较为鲁棒的替代方法:

\[\min_{\theta} \: \frac{1}{2} \mathbb{E}_{q_{\sigma}(\tilde{x}|x)p_{data}(x)} [\| s_{\theta} (\tilde{x}) - \nabla_x \log q_{\sigma}(\tilde{x}|x) \|_2^2].
\]

当优化得足够好的时候,

\[s_{\theta^*}(x) = \nabla_x \log q_{\sigma}(x), \: q_{\sigma}(\tilde{x}) := \int q_{\sigma}(\tilde{x}|x) p_{data}(x) \mathrm{d}x.
\]

实际中, 通常取\(q_{\sigma}(\tilde{x}|x) = \mathcal{N}(\tilde{x}|x, \sigma^2 I)\), 相当于在真实数据\(x\)上加了一个扰动, 当扰动足够小(\(\sigma\)足够小)的时候, \(q_{\sigma}(x) \approx p_{data}(x)\), 则\(s_{\theta^*}(x) \approx \nabla_x \log p_{data}(x)\).

注: 为啥期望部分要有\(p_{data}\)? 实际上上述目标和score matching依旧是等价的.

Noise Conditional Score Networks

Slow mixing of Langevin dynamics

假设\(p_{data}(x) = \pi p_1(x) + (1 - \pi)p_2(x)\), 且\(p_1, p_2\)的支撑集合是互斥的, 那么 \(\nabla_{x} \log p_{data}(x)\)要么为\(\nabla_{x} \log p_{1}(x)\)或者\(\nabla_{x} \log p_{2}(x)\), 与\(\pi\)没有丝毫关联, 这会导致训练的结果与\(\pi\)也没有关联. 在实际中, 若\(p_1, p_2\)近似互斥, 也会产生类似的情况:

如上图所示, 通过Langevin dynamics采样的点几乎是1:1的, 这与真实的分布便有了出入.

作者的想法是, 设计一个noise conditional score networks:

\[s_\theta(x, \sigma),
\]

给定不同的\(\sigma\)其拟合不同扰动大小的\(p_{\sigma}\), 在采样中, 首先用大一点的\(\sigma\), 然后再逐步缩小, 这便是一种退火的思想. 显然, 一开始用大一点的\(\sigma\)能够为后面的采样提供更好更鲁棒的初始点.

损失函数

设定\(\sigma_i, i=1,2,\cdots, L\), 且满足:

\[\frac{\sigma_1}{\sigma_2} = \cdots = \frac{\sigma_{L-1}}{\sigma_L} > 1,
\]

即一个等比例(缩小)的数列.

对于每个\(\sigma\)采用如下损失:

\[\ell(\theta; \sigma) =
\frac{1}{2} \mathbb{E}_{p_{data}(x)} \mathbb{E}_{\mathcal{N}(\tilde{x}|x, \sigma I)} [\| s_{\theta} (\tilde{x}, \sigma) + \frac{\tilde{x} - x}{\sigma^2} \|_2^2].
\]

注: \(\nabla_{\tilde{x}} q_{\sigma}(\tilde{x}|x) = -\frac{\tilde{x} - x}{\sigma^2}\).

于是总损失为

\[\mathcal{L}(\theta; \{\sigma_i\}_{i=1}^L) := \frac{1}{L}\sum_{i=1}^L \lambda (\sigma_i)\ell(\theta;\sigma_i),
\]

\(\lambda(\sigma_i)\)为权重系数.

Annealed Langevin dynamics

Input: \(\{\sigma_i\}_{i=1}^L, \epsilon, T\);

  1. 初始化\(x_0\);
  2. For \(i=1,2,\cdots, L\) do:
    • \(\alpha_i \leftarrow \epsilon \cdot \sigma_i^2 / \sigma_L^2\);
    • For \(t=1,2,\cdots, T\) do:
      • 采样\(z_t \sim \mathcal{N}(0, I)\);
      • \(x_t \leftarrow x_{t-1} + \frac{\alpha_i}{2}s_{\theta}(x_{t-1}, \sigma) + \sqrt{\alpha_i} z_t\);
    • \(x_0 \leftarrow x_T\);

Output: \(x_T\).

细节

  1. 关于参数\(\lambda(\sigma)\)的选择:

    作者推荐选择\(\lambda(\sigma) = \sigma^2\), 因为当优化到最优的时候, \(\|s_{\theta}(x, \sigma)\|_2 \propto 1 / \sigma\), 故\(\sigma^2 \ell(\theta;\sigma) = \frac{1}{2}\mathbb{E}[\|\sigma s_{\theta}(x, \sigma) + \frac{\tilde{x} - x}{\sigma} \|_2^2]\), 其中\(\sigma s_{\theta}(x, \sigma) \propto 1, \frac{\tilde{x} - x}{\sigma} \sim \mathcal{N}(0, I)\), 故\(\sigma^2 \ell_{\theta,\sigma}\)与\(\sigma\)无关.

  2. 关于\(\alpha_i \leftarrow \epsilon \cdot \sigma_i^2 / \sigma_L^2\):

对于一次Langevin dynamic, 其获得的信息为: \(\frac{\alpha_i}{2} s_{\theta}(x_{t-1}, \sigma)\), 其噪声为\(\sqrt{\alpha_i}z_t\), 故其信噪比(signal-to-noise)为(应该是element-wise的计算?)

\[\frac{\alpha_i s_{\theta}(x, \sigma_i)}{2 \sqrt{\alpha_i} z},
\]

当我们按照算法中的取法时, 我们有

\[\begin{array}{ll}
\|\frac{\alpha_i s_{\theta}(x, \sigma_i)}{2 \sqrt{\alpha_i} z}\|_2^2
&\approx\frac{\alpha_i \| s_{\theta}(x, \sigma_i)\|_2^2}{4} \\
&\propto\frac{\|\sigma_i s_{\theta}(x, \sigma_i)\|_2^2}{4} \\
&\propto \frac{1}{4}.
\end{array}
\]

故采用此策略能够保证SNR保持一个稳定的值.

代码

原文代码

Generative Modeling by Estimating Gradients of the Data Distribution的更多相关文章

  1. 泡泡一分钟:GEN-SLAM - Generative Modeling for Monocular Simultaneous Localization and Mapping

    张宁  GEN-SLAM - Generative Modeling for Monocular Simultaneous Localization and Mapping GEN-SLAM  - 单 ...

  2. Statistics : Data Distribution

    1.Normal distribution In probability theory, the normal (or Gaussian or Gauss or Laplace–Gauss) dist ...

  3. 论文笔记之:Generative Adversarial Nets

    Generative Adversarial Nets NIPS 2014  摘要:本文通过对抗过程,提出了一种新的框架来预测产生式模型,我们同时训练两个模型:一个产生式模型 G,该模型可以抓住数据分 ...

  4. [GAN] Generative networks

    中文版:https://zhuanlan.zhihu.com/p/27440393 原文版:https://www.oreilly.com/learning/generative-adversaria ...

  5. Generative model 和Discriminative model

    学习音乐自动标注过程中设计了有关分类型模型和生成型模型的东西,特地查了相关资料,在这里汇总. http://blog.sina.com.cn/s/blog_a18c98e50101058u.html ...

  6. 生成模型(Generative)和判别模型(Discriminative)

    生成模型(Generative)和判别模型(Discriminative) 引言    最近看文章<A survey of appearance models in visual object ...

  7. (转)Deep Learning Research Review Week 1: Generative Adversarial Nets

    Adit Deshpande CS Undergrad at UCLA ('19) Blog About Resume Deep Learning Research Review Week 1: Ge ...

  8. (转)Introductory guide to Generative Adversarial Networks (GANs) and their promise!

    Introductory guide to Generative Adversarial Networks (GANs) and their promise! Introduction Neural ...

  9. SalGAN: Visual saliency prediction with generative adversarial networks

    SalGAN: Visual saliency prediction with generative adversarial networks 2017-03-17 摘要:本文引入了对抗网络的对抗训练 ...

随机推荐

  1. 关于stm32不常用的中断,如何添加, 比如timer10 timer11等

    首先可以从keil中找到 比如找到定时器11的溢出中断,如上图是26 然后,配置定时器11 溢出中断的时候,我就在:下面填上这个变量. 之后要写中断服务函数,也就是发生中断后要跳转到的函数. 需要知道 ...

  2. Oracle中的索引

    1.Oracle 索引简介      在Oracle数据库中,存储的每一行数据都有一个rowID来标识.当Oracle中存储着大量的数据时,意味着有大量的rowID,此时想要快速定位指定的rowID, ...

  3. 【Linux】【Shell】【Basic】变量与数据类型

    1. 变量: 1.1. 局部变量:作用域是函数的生命周期:在函数结束时被自动销毁: 定义局部变量的方法:local VARIABLE=VALUE 1.2. 本地变量:作用域是运行脚本的shell进程的 ...

  4. ebs 初始化登陆

    BEGIN fnd_global.APPS_INITIALIZE(user_id => youruesr_id, esp_id => yourresp_id, resp_appl_id = ...

  5. archive后upload to app store时遇到app id不可用的问题

    问题如下图 出现此问题的原因有两种: 1.此app id在AppStore中已经存在,也就是说你使用别人注册的app ID ,  如果是这样,你只能更换app ID 2.此app ID是自己的,突然之 ...

  6. 重新整理 .net core 实践篇——— UseEndpoints中间件[四十八]

    前言 前文已经提及到了endponint 是怎么匹配到的,也就是说在UseRouting 之后的中间件都能获取到endpoint了,如果能够匹配到的话,那么UseEndpoints又做了什么呢?它是如 ...

  7. Pytorch入门下 —— 其他

    本节内容参照小土堆的pytorch入门视频教程. 现有模型使用和修改 pytorch框架提供了很多现有模型,其中torchvision.models包中有很多关于视觉(图像)领域的模型,如下图: 下面 ...

  8. Windows 10 彻底关闭 Antimalware Service Executable 降低内存占用

    概述 最近给内网的一台电脑安装 Windows 10 专业版系统,由于此电脑不会涉及到不安全因素,所以杀毒软件非必须. 以最大限度节省系统资源考虑,默认安装的 Micoroft Defender 占用 ...

  9. 工作簿拆分(Excel代码集团)

    一个工作簿中包括N个工作表,将各个工作表拆分成工作簿. 工作表数量不定,表内内容不限,拆分后保存于当前文件夹内. Sub Sample() Dim MySheetsCount As Long For ...

  10. 建立资源的方法(Project)

    <Project2016 企业项目管理实践>张会斌 董方好 编著 终于,进入第5章资源计划编制了,所以就不能还在任务工作表里厮混了是吧,那就先进入资源工作表吧:[任务]>[甘特图]& ...