Generative Modeling by Estimating Gradients of the Data Distribution
概
当前生成模型, 要么依赖对抗损失(GAN), 要么依赖替代损失(VAE), 本文提出了基于score matching 训练, 以及利用annealed Langevin dynamics推断的模型, 思想非常有趣.
主要内容
Langevin dynamics
对于分布\(p(x)\), 我们可以通过下列方式迭代生成
\]
其中\(\tilde{x}_0 \sim \pi(x)\)来自一个先验分布, \(z_t \sim \mathcal{N}(0, I)\). 当步长\(\epsilon \rightarrow 0\)并且\(T \rightarrow +\infty\)的时候, \(\tilde{x}_T\)可以认为是从\(p(x)\)中采样的样本.
注: 一般的Langevin, dynamics还需要在每一次迭代后计算一个接受概率然后判断是否接受, 不过在实际中这一步往往可以省略.
Score Matching
通过上述的迭代可以发现, 我们只需要获得\(\nabla_x \log p(x)\)即可采样\(x\), 我们可以期望通过下面的方式, 通过一个网络\(s_{\theta}(x)\)来逼近\(\nabla_x \log p_{data}(x)\):
\]
但是在实际中, 先验\(\log p_{data}(x)\)也是未知的, 幸运的是上述公式等价于:
\]
注: 见 score matching
Denoising Score Matching
一个共识是, 所获得的数据往往是一个低维流形, 即其内在的维度实际上很低. 所以\(\mathbb{E}_{p_{data}(x)}\)在实际中会出现高密度的区域估计得很好, 但是低密度得区域估计得非常差. Denosing Score Matching提高了一个较为鲁棒的替代方法:
\]
当优化得足够好的时候,
\]
实际中, 通常取\(q_{\sigma}(\tilde{x}|x) = \mathcal{N}(\tilde{x}|x, \sigma^2 I)\), 相当于在真实数据\(x\)上加了一个扰动, 当扰动足够小(\(\sigma\)足够小)的时候, \(q_{\sigma}(x) \approx p_{data}(x)\), 则\(s_{\theta^*}(x) \approx \nabla_x \log p_{data}(x)\).
注: 为啥期望部分要有\(p_{data}\)? 实际上上述目标和score matching依旧是等价的.
Noise Conditional Score Networks
Slow mixing of Langevin dynamics
假设\(p_{data}(x) = \pi p_1(x) + (1 - \pi)p_2(x)\), 且\(p_1, p_2\)的支撑集合是互斥的, 那么 \(\nabla_{x} \log p_{data}(x)\)要么为\(\nabla_{x} \log p_{1}(x)\)或者\(\nabla_{x} \log p_{2}(x)\), 与\(\pi\)没有丝毫关联, 这会导致训练的结果与\(\pi\)也没有关联. 在实际中, 若\(p_1, p_2\)近似互斥, 也会产生类似的情况:

如上图所示, 通过Langevin dynamics采样的点几乎是1:1的, 这与真实的分布便有了出入.
作者的想法是, 设计一个noise conditional score networks:
\]
给定不同的\(\sigma\)其拟合不同扰动大小的\(p_{\sigma}\), 在采样中, 首先用大一点的\(\sigma\), 然后再逐步缩小, 这便是一种退火的思想. 显然, 一开始用大一点的\(\sigma\)能够为后面的采样提供更好更鲁棒的初始点.
损失函数
设定\(\sigma_i, i=1,2,\cdots, L\), 且满足:
\]
即一个等比例(缩小)的数列.
对于每个\(\sigma\)采用如下损失:
\frac{1}{2} \mathbb{E}_{p_{data}(x)} \mathbb{E}_{\mathcal{N}(\tilde{x}|x, \sigma I)} [\| s_{\theta} (\tilde{x}, \sigma) + \frac{\tilde{x} - x}{\sigma^2} \|_2^2].
\]
注: \(\nabla_{\tilde{x}} q_{\sigma}(\tilde{x}|x) = -\frac{\tilde{x} - x}{\sigma^2}\).
于是总损失为
\]
\(\lambda(\sigma_i)\)为权重系数.
Annealed Langevin dynamics
Input: \(\{\sigma_i\}_{i=1}^L, \epsilon, T\);
- 初始化\(x_0\);
- For \(i=1,2,\cdots, L\) do:
- \(\alpha_i \leftarrow \epsilon \cdot \sigma_i^2 / \sigma_L^2\);
- For \(t=1,2,\cdots, T\) do:
- 采样\(z_t \sim \mathcal{N}(0, I)\);
- \(x_t \leftarrow x_{t-1} + \frac{\alpha_i}{2}s_{\theta}(x_{t-1}, \sigma) + \sqrt{\alpha_i} z_t\);
- \(x_0 \leftarrow x_T\);
Output: \(x_T\).
细节
关于参数\(\lambda(\sigma)\)的选择:
作者推荐选择\(\lambda(\sigma) = \sigma^2\), 因为当优化到最优的时候, \(\|s_{\theta}(x, \sigma)\|_2 \propto 1 / \sigma\), 故\(\sigma^2 \ell(\theta;\sigma) = \frac{1}{2}\mathbb{E}[\|\sigma s_{\theta}(x, \sigma) + \frac{\tilde{x} - x}{\sigma} \|_2^2]\), 其中\(\sigma s_{\theta}(x, \sigma) \propto 1, \frac{\tilde{x} - x}{\sigma} \sim \mathcal{N}(0, I)\), 故\(\sigma^2 \ell_{\theta,\sigma}\)与\(\sigma\)无关.关于\(\alpha_i \leftarrow \epsilon \cdot \sigma_i^2 / \sigma_L^2\):
对于一次Langevin dynamic, 其获得的信息为: \(\frac{\alpha_i}{2} s_{\theta}(x_{t-1}, \sigma)\), 其噪声为\(\sqrt{\alpha_i}z_t\), 故其信噪比(signal-to-noise)为(应该是element-wise的计算?)
\]
当我们按照算法中的取法时, 我们有
\|\frac{\alpha_i s_{\theta}(x, \sigma_i)}{2 \sqrt{\alpha_i} z}\|_2^2
&\approx\frac{\alpha_i \| s_{\theta}(x, \sigma_i)\|_2^2}{4} \\
&\propto\frac{\|\sigma_i s_{\theta}(x, \sigma_i)\|_2^2}{4} \\
&\propto \frac{1}{4}.
\end{array}
\]
故采用此策略能够保证SNR保持一个稳定的值.
代码
Generative Modeling by Estimating Gradients of the Data Distribution的更多相关文章
- 泡泡一分钟:GEN-SLAM - Generative Modeling for Monocular Simultaneous Localization and Mapping
张宁 GEN-SLAM - Generative Modeling for Monocular Simultaneous Localization and Mapping GEN-SLAM - 单 ...
- Statistics : Data Distribution
1.Normal distribution In probability theory, the normal (or Gaussian or Gauss or Laplace–Gauss) dist ...
- 论文笔记之:Generative Adversarial Nets
Generative Adversarial Nets NIPS 2014 摘要:本文通过对抗过程,提出了一种新的框架来预测产生式模型,我们同时训练两个模型:一个产生式模型 G,该模型可以抓住数据分 ...
- [GAN] Generative networks
中文版:https://zhuanlan.zhihu.com/p/27440393 原文版:https://www.oreilly.com/learning/generative-adversaria ...
- Generative model 和Discriminative model
学习音乐自动标注过程中设计了有关分类型模型和生成型模型的东西,特地查了相关资料,在这里汇总. http://blog.sina.com.cn/s/blog_a18c98e50101058u.html ...
- 生成模型(Generative)和判别模型(Discriminative)
生成模型(Generative)和判别模型(Discriminative) 引言 最近看文章<A survey of appearance models in visual object ...
- (转)Deep Learning Research Review Week 1: Generative Adversarial Nets
Adit Deshpande CS Undergrad at UCLA ('19) Blog About Resume Deep Learning Research Review Week 1: Ge ...
- (转)Introductory guide to Generative Adversarial Networks (GANs) and their promise!
Introductory guide to Generative Adversarial Networks (GANs) and their promise! Introduction Neural ...
- SalGAN: Visual saliency prediction with generative adversarial networks
SalGAN: Visual saliency prediction with generative adversarial networks 2017-03-17 摘要:本文引入了对抗网络的对抗训练 ...
随机推荐
- python写的多项式符号乘法
D:\>poly.py(x - 1) * (x^2 + x + 1) = x^3 - 1 1 import ply.lex as lex # pip install ply 2 import p ...
- hadoop-uber作业模式
如果作业很小,就选择和自己在同一个JVM上运行任务,与在一个节点上顺序运行这些任务相比,当application master 判断在新的容器中的分配和运行任务的开销大于并行运行它们的开销时,就会发生 ...
- 爬虫系列:使用 MySQL 存储数据
上一篇文章我们讲解了爬虫如何存储 CSV 文件,这篇文章,我们讲解如何将采集到的数据保存到 MySQL 数据库中. MySQL 是目前最受欢迎的开源关系型数据库管理系统.一个开源项目具有如此之竞争力实 ...
- 容器之分类与各种测试(四)——unordered-multiset
unordered-multiset是不定序关联式容器,其底部是通过哈希表实现功能. (ps:黑色框就是bucket,白色框即为bucket上挂载的元素) 为了提高查找效率,bucket(篮子)的数量 ...
- AI常用环境安装
torch环境 conda create --name py37 python=3.7 conda activate py37 pip install jieba==0.42.1pip install ...
- ActiveRecord教程
(一.ActiveRecord基础) ActiveRecord是Rails提供的一个对象关系映射(ORM)层,从这篇开始,我们来了解Active Record的一些基础内容,连接数据库,映射表,访问数 ...
- Linux:expr、let、for、while、until、shift、if、case、break、continue、函数、select
1.expr计算整数变量值 格式 :expr arg 例子:计算(2+3)×4的值 1.分步计算,即先计算2+3,再对其和乘4 s=`expr 2 + 3` expr $s \* 4 2.一步完成计算 ...
- matplotlib画直线图的基本用法
一 figure使用 1 import numpy as np 2 import matplotlib.pyplot as plt 3 4 # 从-3到中取50个数 5 x = np.linspac ...
- 【力扣】剑指 Offer 25. 合并两个排序的链表
输入两个递增排序的链表,合并这两个链表并使新链表中的节点仍然是递增排序的. 示例1: 输入:1->2->4, 1->3->4输出:1->1->2->3-> ...
- 为什么Redis集群有16384个槽
一.前言 我在<那些年用过的Redis集群架构(含面试解析)>一文里提到过,现在redis集群架构,redis cluster用的会比较多. 如下图所示 对于客户端请求的key,根据公式H ...