Song Y. and Ermon S. Generative modeling by estimating gradients of the data distribution. In Advances in Neural Information Processing Systems (NIPS), 2019.

当前生成模型, 要么依赖对抗损失(GAN), 要么依赖替代损失(VAE), 本文提出了基于score matching 训练, 以及利用annealed Langevin dynamics推断的模型, 思想非常有趣.

主要内容

Langevin dynamics

对于分布\(p(x)\), 我们可以通过下列方式迭代生成

\[\tilde{x}_t = \tilde{x}_{t-1} + \frac{\epsilon}{2} \nabla_x \log p (\tilde{x}_{t-1}) + \sqrt{\epsilon} z_t,
\]

其中\(\tilde{x}_0 \sim \pi(x)\)来自一个先验分布, \(z_t \sim \mathcal{N}(0, I)\). 当步长\(\epsilon \rightarrow 0\)并且\(T \rightarrow +\infty\)的时候, \(\tilde{x}_T\)可以认为是从\(p(x)\)中采样的样本.

注: 一般的Langevin, dynamics还需要在每一次迭代后计算一个接受概率然后判断是否接受, 不过在实际中这一步往往可以省略.

Score Matching

通过上述的迭代可以发现, 我们只需要获得\(\nabla_x \log p(x)\)即可采样\(x\), 我们可以期望通过下面的方式, 通过一个网络\(s_{\theta}(x)\)来逼近\(\nabla_x \log p_{data}(x)\):

\[\min_{\theta} \: \frac{1}{2} \mathbb{E}_{p_{data}(x)} [\| s_{\theta} (x) - \nabla_x \log p_{data}(x) \|_2^2],
\]

但是在实际中, 先验\(\log p_{data}(x)\)也是未知的, 幸运的是上述公式等价于:

\[\min_{\theta} \: \mathbb{E}_{p_{data}(x)} [\mathrm{tr}(\nabla_x s_{\theta} (x)) + \frac{1}{2} \|s_{\theta}(x)\|_2^2].
\]

注: 见 score matching

Denoising Score Matching

一个共识是, 所获得的数据往往是一个低维流形, 即其内在的维度实际上很低. 所以\(\mathbb{E}_{p_{data}(x)}\)在实际中会出现高密度的区域估计得很好, 但是低密度得区域估计得非常差. Denosing Score Matching提高了一个较为鲁棒的替代方法:

\[\min_{\theta} \: \frac{1}{2} \mathbb{E}_{q_{\sigma}(\tilde{x}|x)p_{data}(x)} [\| s_{\theta} (\tilde{x}) - \nabla_x \log q_{\sigma}(\tilde{x}|x) \|_2^2].
\]

当优化得足够好的时候,

\[s_{\theta^*}(x) = \nabla_x \log q_{\sigma}(x), \: q_{\sigma}(\tilde{x}) := \int q_{\sigma}(\tilde{x}|x) p_{data}(x) \mathrm{d}x.
\]

实际中, 通常取\(q_{\sigma}(\tilde{x}|x) = \mathcal{N}(\tilde{x}|x, \sigma^2 I)\), 相当于在真实数据\(x\)上加了一个扰动, 当扰动足够小(\(\sigma\)足够小)的时候, \(q_{\sigma}(x) \approx p_{data}(x)\), 则\(s_{\theta^*}(x) \approx \nabla_x \log p_{data}(x)\).

注: 为啥期望部分要有\(p_{data}\)? 实际上上述目标和score matching依旧是等价的.

Noise Conditional Score Networks

Slow mixing of Langevin dynamics

假设\(p_{data}(x) = \pi p_1(x) + (1 - \pi)p_2(x)\), 且\(p_1, p_2\)的支撑集合是互斥的, 那么 \(\nabla_{x} \log p_{data}(x)\)要么为\(\nabla_{x} \log p_{1}(x)\)或者\(\nabla_{x} \log p_{2}(x)\), 与\(\pi\)没有丝毫关联, 这会导致训练的结果与\(\pi\)也没有关联. 在实际中, 若\(p_1, p_2\)近似互斥, 也会产生类似的情况:

如上图所示, 通过Langevin dynamics采样的点几乎是1:1的, 这与真实的分布便有了出入.

作者的想法是, 设计一个noise conditional score networks:

\[s_\theta(x, \sigma),
\]

给定不同的\(\sigma\)其拟合不同扰动大小的\(p_{\sigma}\), 在采样中, 首先用大一点的\(\sigma\), 然后再逐步缩小, 这便是一种退火的思想. 显然, 一开始用大一点的\(\sigma\)能够为后面的采样提供更好更鲁棒的初始点.

损失函数

设定\(\sigma_i, i=1,2,\cdots, L\), 且满足:

\[\frac{\sigma_1}{\sigma_2} = \cdots = \frac{\sigma_{L-1}}{\sigma_L} > 1,
\]

即一个等比例(缩小)的数列.

对于每个\(\sigma\)采用如下损失:

\[\ell(\theta; \sigma) =
\frac{1}{2} \mathbb{E}_{p_{data}(x)} \mathbb{E}_{\mathcal{N}(\tilde{x}|x, \sigma I)} [\| s_{\theta} (\tilde{x}, \sigma) + \frac{\tilde{x} - x}{\sigma^2} \|_2^2].
\]

注: \(\nabla_{\tilde{x}} q_{\sigma}(\tilde{x}|x) = -\frac{\tilde{x} - x}{\sigma^2}\).

于是总损失为

\[\mathcal{L}(\theta; \{\sigma_i\}_{i=1}^L) := \frac{1}{L}\sum_{i=1}^L \lambda (\sigma_i)\ell(\theta;\sigma_i),
\]

\(\lambda(\sigma_i)\)为权重系数.

Annealed Langevin dynamics

Input: \(\{\sigma_i\}_{i=1}^L, \epsilon, T\);

  1. 初始化\(x_0\);
  2. For \(i=1,2,\cdots, L\) do:
    • \(\alpha_i \leftarrow \epsilon \cdot \sigma_i^2 / \sigma_L^2\);
    • For \(t=1,2,\cdots, T\) do:
      • 采样\(z_t \sim \mathcal{N}(0, I)\);
      • \(x_t \leftarrow x_{t-1} + \frac{\alpha_i}{2}s_{\theta}(x_{t-1}, \sigma) + \sqrt{\alpha_i} z_t\);
    • \(x_0 \leftarrow x_T\);

Output: \(x_T\).

细节

  1. 关于参数\(\lambda(\sigma)\)的选择:

    作者推荐选择\(\lambda(\sigma) = \sigma^2\), 因为当优化到最优的时候, \(\|s_{\theta}(x, \sigma)\|_2 \propto 1 / \sigma\), 故\(\sigma^2 \ell(\theta;\sigma) = \frac{1}{2}\mathbb{E}[\|\sigma s_{\theta}(x, \sigma) + \frac{\tilde{x} - x}{\sigma} \|_2^2]\), 其中\(\sigma s_{\theta}(x, \sigma) \propto 1, \frac{\tilde{x} - x}{\sigma} \sim \mathcal{N}(0, I)\), 故\(\sigma^2 \ell_{\theta,\sigma}\)与\(\sigma\)无关.

  2. 关于\(\alpha_i \leftarrow \epsilon \cdot \sigma_i^2 / \sigma_L^2\):

对于一次Langevin dynamic, 其获得的信息为: \(\frac{\alpha_i}{2} s_{\theta}(x_{t-1}, \sigma)\), 其噪声为\(\sqrt{\alpha_i}z_t\), 故其信噪比(signal-to-noise)为(应该是element-wise的计算?)

\[\frac{\alpha_i s_{\theta}(x, \sigma_i)}{2 \sqrt{\alpha_i} z},
\]

当我们按照算法中的取法时, 我们有

\[\begin{array}{ll}
\|\frac{\alpha_i s_{\theta}(x, \sigma_i)}{2 \sqrt{\alpha_i} z}\|_2^2
&\approx\frac{\alpha_i \| s_{\theta}(x, \sigma_i)\|_2^2}{4} \\
&\propto\frac{\|\sigma_i s_{\theta}(x, \sigma_i)\|_2^2}{4} \\
&\propto \frac{1}{4}.
\end{array}
\]

故采用此策略能够保证SNR保持一个稳定的值.

代码

原文代码

Generative Modeling by Estimating Gradients of the Data Distribution的更多相关文章

  1. 泡泡一分钟:GEN-SLAM - Generative Modeling for Monocular Simultaneous Localization and Mapping

    张宁  GEN-SLAM - Generative Modeling for Monocular Simultaneous Localization and Mapping GEN-SLAM  - 单 ...

  2. Statistics : Data Distribution

    1.Normal distribution In probability theory, the normal (or Gaussian or Gauss or Laplace–Gauss) dist ...

  3. 论文笔记之:Generative Adversarial Nets

    Generative Adversarial Nets NIPS 2014  摘要:本文通过对抗过程,提出了一种新的框架来预测产生式模型,我们同时训练两个模型:一个产生式模型 G,该模型可以抓住数据分 ...

  4. [GAN] Generative networks

    中文版:https://zhuanlan.zhihu.com/p/27440393 原文版:https://www.oreilly.com/learning/generative-adversaria ...

  5. Generative model 和Discriminative model

    学习音乐自动标注过程中设计了有关分类型模型和生成型模型的东西,特地查了相关资料,在这里汇总. http://blog.sina.com.cn/s/blog_a18c98e50101058u.html ...

  6. 生成模型(Generative)和判别模型(Discriminative)

    生成模型(Generative)和判别模型(Discriminative) 引言    最近看文章<A survey of appearance models in visual object ...

  7. (转)Deep Learning Research Review Week 1: Generative Adversarial Nets

    Adit Deshpande CS Undergrad at UCLA ('19) Blog About Resume Deep Learning Research Review Week 1: Ge ...

  8. (转)Introductory guide to Generative Adversarial Networks (GANs) and their promise!

    Introductory guide to Generative Adversarial Networks (GANs) and their promise! Introduction Neural ...

  9. SalGAN: Visual saliency prediction with generative adversarial networks

    SalGAN: Visual saliency prediction with generative adversarial networks 2017-03-17 摘要:本文引入了对抗网络的对抗训练 ...

随机推荐

  1. A Child's History of England.41

    When intelligence of this new affront [hit in the face, c-o-n-frontation!] was carried to the King i ...

  2. Qt最好用评价最高的是哪个版本?

    来源: http://www.qtcn.org/bbs/read-htm-tid-89455.html /// Qt4:    4.8.7      4.X 系列终结版本 Qt5 :   5.6 LT ...

  3. Netty之ByteBuf

    本文内容主要参考<<Netty In Action>>,偏笔记向. 网络编程中,字节缓冲区是一个比较基本的组件.Java NIO提供了ByteBuffer,但是使用过的都知道B ...

  4. C++ 数组元素循环右移问题

    这道题要求不用另外的数组,并且尽量移动次数少. 算法思想:设计一个结构体存储数组数据和它应在的索引位置,再直接交换,但是这种方法不能一次性就移动完成,因此再加一个判断条件.等这个判断条件满足后就退出循 ...

  5. 解决springboot序列化 json数据到前端中文乱码问题

    前言 关于springboot乱码的问题,之前有文章已经介绍过了,这一篇算是作为补充,重点解决对象在序列化过程中出现的中文乱码的问题,以及后台报500的错误. 问题描述 spring Boot 中文返 ...

  6. Output of C++ Program | Set 11

    Predict the output of following C++ programs. Question 1 1 #include<iostream> 2 using namespac ...

  7. 解决CSV文件用Excel打开乱码问题

    这篇文章适合有一定编码基础的人看,纯手动解决乱码问题请参见: 转码保存后,重新打开即可. 转码操作如下: 编辑器->另存为->ASCII码格式文件/UTF-8含BOM格式->保存. ...

  8. Linux centos7 安装.net 环境

    其实在linux 下安装.net 环境并不复杂,但最近遇到的服务器没有外网,比较坑很多依赖都没有,记录下这次的安装过程. 一开始以为是服务器没有外网,后来发现是服务器没有配置dns,于是配置dns 第 ...

  9. shell脚本 系统状态信息查看

    一.简介 源码地址 日期:2018/6/23 介绍:显示简单的系统信息 效果图: 二.使用 适用:centos6+,ubuntu12+ 语言:中文 注意:无 下载 wget https://raw.g ...

  10. CentOS6设置开机自启动

    1.把开机启动脚本(mysqld)copy到文件夹/etc/init.d 或 /etc/rc.d/init.d 中 2.将启动程序的命令添加到 /etc/rc.d/rc.local 文件中,比如: # ...