Part2: DDPM as Example of Variational Inference
很多次翻看DDPM,始终不太能理解论文中提到的\(\text{Variational Inference}\)到底是如何在这个工作中起到作用。五一假期在家,无意间又刷到徐亦达老师早些年录制的理论视频,没想到其中也有介绍这部分的内容。老师的上课方式总是娓娓道来,把每一步都讲解得很仔细。本文记录一下个人对开头问题的思考。
Background
如果需要简略地介绍一下DDPM这个工作,可能会用以下几句话简单地描述:DDPM以Markov的形式对数据(图片)“扩散过程”建模,使用神经网络进行训练拟合,学习数据的概率分布。
所以对于生成任务来说,希望从给定数据中学习到的是数据的潜在信息。比如图片生成,在给定一些图片后,模型学习到的是“正常图片长什么样子”,如:
- 一张包含手机正面的图片会有【手机屏幕】;
- 一张包含猫咪的图片会有人们观察到的猫咪模样;
- ...
对于图片中每个像素点和附近的像素点,进行“合理”布局,才能生成“符合人们认知的图片”。
图片生成能像常见的机器学习任务如分类任务、回归任务,能基于maximize likelihood的形式来训练么?
结论是很难,先回顾如何做maximum likelihood。给定一批数据,首先需要假定数据服从的分布,接着写出似然函数,之后直接通过解析解的形式或是梯度下降的形式,求出分布。
问题就出在假定分布这一步,没有人知道图片客观上服从什么分布。那如果使用神经网络直接拟合可以么?这好像也不现实,拿一张512*512*3的图片来说,网络输出层共有约75w的数值。
对于图片生成还有另外一个问题,世界上的图片太多了,目之所及稍做处理,皆为图片。即便使用神经网络能拟合,最后生成的图片很难存在多样性。
那目前图片生成模型都是怎么做的,比如VAE或是本文即将要介绍的Diffusion Model,它们学习的都是数据分布\(p(x)\),但直接求\(p(x)\)这么麻烦,需要怎么做?这其实也是\(\text{Variational Inference}\)的核心思想,“曲线救国”,通过引入其它分布,将原本难以优化的问题转变为可优化问题。
ELOB
先把上述提到的所有背景先抛开,研究一下\(p(x)\),看看能得到什么有意思的结论。
a. 基于条件概率分布,引入新的随机变量\(z\):\(p(x) = \frac{p(x, z)}{p(z\mid x)}\);
b. 对于两边同时取\(\ln\),等式依然成立,因此有:\(\ln{p(x)} = \ln{\frac{p(x, z)}{p(z \mid x)}}\);
c. 右边分子分母同乘以\(q(z)\):\(\ln{p(x)} = \ln{\frac{p(x, z) * q(z)}{p(z \mid x) * q(z)}} = \ln{\left(\frac{p(x, z)}{q(z)} * \frac{q(z)}{p(z \mid x)}\right)} = \ln{\frac{p(x, z)}{q(z)}} + \ln{\frac{q(z)}{p(z \mid x)}}\)
d. 再次,对于上式左右两边求关于\(q(z)\)的期望,等式依然成立:
&\mathbb{E}_{z\sim q(z)}{[\ln{p(x)}]} = \mathbb{E}_{z\sim q(z)}{(\ln{\frac{p(x, z)}{q(z)}} + \ln{\frac{q(z)}{p(z \mid x)}})} \\
\iff & \int_z q(z)\ln{p(x)}dz = \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz + \int_z q(z)\ln{\frac{q(z)}{p(z \mid x)}}dz \\
\iff & \ln{p(x)} = \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz + \int_z q(z)\ln{\frac{q(z)}{p(z \mid x)}}dz
\end{aligned}
\tag{1}
\]
一系列变换后,\((1)\)式是最后的推导结果,等式右边由两个项组成。第二个项\(\int_z q(z)\ln{\frac{q(z)}{p(z \mid x)}}dz\),叫做KL散度,它被用来衡量两个分布之间的“距离”,性质是值不小于0。
这样一来,通过\((1)\)可以得到不等式\((2)\):
\ln{p(x)} \geq \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz
\end{equation*}
\tag{2}
\]
\((1)\)式右边的第一项,同时也是\((2)\)式的右边项,被学者们叫做\(\text{ELBO(Evidence Lower Bound)}\)。
Objective Function
上述推导的\((2)\)式可以被视作“定理”一般的存在,即对于某个分布的对数形式,总可以找到它的下界。
那\((2)\)式可以用来做什么?在Background中提到,图片生成任务中的\(p(x)\)想要对它做maximum likelihood根本无法做起。目标依然是最大化\(p(x)\),但有了\((2)\)式,求解的目标可以转移到最大化它的下界\(\text{ELBO}\)。
这也是论文中提到的:
This paper presents progress in diffusion probabilistic models. A diffusion probabilistic model (which we will call a “diffusion model” for brevity) is a parameterized Markov chain trained using variational inference to produce samples matching the data after finite time.
接下来,回到论文中,看看是如何一步步推导出DDPM的优化目标。\((3)\)式直接摘录于论文:
\ln{p(x)} \geq \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz = \mathbb{E}_{z \sim q(z)}\left[\ln{\frac{p(x,z)}{q(z)}}\right]
\end{equation*}
\tag{2}
\]
\mathbb{E}\left[-\log p_\theta\left(\mathbf{x}_0\right)\right] \leq \mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right]=\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t \geq 1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right]=: L
\end{equation*}
\tag{3}
\]
下面一项项地对\((3)\) 进行拆解,并且将它与\((2)\)比对,能帮助更好地理解:
\((3)\)不等号左边的\(\mathbb{E}\left[-\log p_\theta\left(\mathbf{x}_0\right)\right]\)进一步化简就是\(-\log p_\theta\left(\mathbf{x}_0\right)\)。其中,\(p_\theta\left(\mathbf{x}_0\right)\)便是模型要学习的最终目标:图像的分布,\(\theta\)是模型的参数,\(\mathbf{x}_0\)是图片;
\((2)\)式的左右两边同时加上符号,\(\geq\)变为\(\leq\);
看\((3)\)不等式右边部分,\(\mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right]\)
很明显,\(q(\mathbf{x}_{1:T} \mid \mathbf{x}_0)\)相当于\((2)\)中引入的额外分布\(q(z)\)。对于\(z\),在生成模型中会给它一个称呼:隐变量\((\text{latent})\)。实际上,在
diffusion models里,对\(\mathbf{x}_0\)加噪后的\(\mathbf{x}_1,\mathbf{x}_2,\ldots, \mathbf{x}_T\)就可以看作隐变量,那不妨记作\(z := \{\mathbf{x}_1,\mathbf{x}_2,\ldots, \mathbf{x}_T\}\);\(p_\theta\left(\mathbf{x}_{0: T}\right) = p_\theta\left(\mathbf{x}_{0}, \mathbf{x}_{1}, \ldots, \mathbf{x}_{T}\right)\),是关于\(\mathbf{x}_0, z\)的联合概率分布,因为选用马尔代夫链建模,那么依据马尔可夫链的性质,论文定义:
\begin{aligned}
q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)&:=\prod_{t=1}^T q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right) \\
p_\theta\left(\mathbf{x}_{0: T}\right)&:=p\left(\mathbf{x}_T\right) \prod_{t=1}^T p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)
\end{aligned}
\end{equation*}
\tag{4}
\]
- 将\((4)\)带入\((3)\)不等式右边的第一项,得到\(L\):
\begin{aligned}
&\mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right] \\
=&\mathbb{E}_q\left[-\log \frac{p\left(\mathbf{x}_T\right) \prod_{t=1}^T p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{\prod_{t=1}^T q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right] \\
=&\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t \geq 1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right] := L
\end{aligned}
\end{equation*}
\]
到目前为止,经过了很多轮的变换以及数学公式,先捋一遍,再往下。\(L\)是一个替代的优化目标,
\]
接下来,论文中对\(L\)进行了重写,以下步骤直接摘录自论文\(\text{Appendix A}\)
\begin{aligned}
L & =\mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right] \\ & =\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t \geq 1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right] \\ & =\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}-\log \frac{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}\right] \\
&=\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t>1} \log \left[\frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)} \cdot \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}\right]-\log \frac{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}\right]
\end{aligned}
\end{equation*}
\tag{5}
\]
倒数两步的变换发生在第二项,具体依据为:
q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)
=& \frac{q\left(\mathbf{x}_t, \mathbf{x}_{t-1}\right)}{q\left(\mathbf{x}_{t-1}\right)} \\
=& \frac{q\left(\mathbf{x}_t, \mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) *q(\mathbf{x}_{0})}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) * q(\mathbf{x}_{0})} \\
=& \frac{q\left(\mathbf{x}_t, \mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) }{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}
\end{aligned}
\quad \Rightarrow \quad
\begin{aligned}
&\sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)} \\
=& \sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t, \mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) } \cdot {q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)} \\
=& \sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)} \cdot \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}
\end{aligned}
\]
接着对\((5)\)进行改写得到最终形式\((6)\):
L &=\mathbb{E}_q\left[-\log \frac{p\left(\mathbf{x}_T\right)}{q\left(\mathbf{x}_T \mid \mathbf{x}_0\right)}-\sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)\right] \\
&=\mathbb{E}_q[\underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right) \| p\left(\mathbf{x}_T\right)\right)}_{L_T}+\sum_{t>1} \underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\right)}_{L_{t-1}} \underbrace{-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}_{L_0}]
\end{aligned}
\tag{6}
\]
Summary
太好了,对于\((6)\)来说,它最起码是个可以优化的目标函数了,因为论文中定义马尔可夫链相邻状态的转变是服从高斯分布的。当然在论文中,\((6)\)还会进一步被改写,得到更加精简的\(\text{loss function}\)形式。
DDPM是应用\(\text{variational inference}\)进行优化求解的典型例子,很值得借鉴学习。
Reference
Part2: DDPM as Example of Variational Inference的更多相关文章
- [Bayesian] “我是bayesian我怕谁”系列 - Variational Inference
涉及的领域可能有些生僻,骗不了大家点赞.但毕竟是人工智能的主流技术,在园子却成了非主流. 不可否认的是:乃值钱的技术,提高身价的技术,改变世界观的技术. 关于变分,通常的课本思路是: GMM --&g ...
- [Bayes] Variational Inference for Bayesian GMMs
为了世界和平,为了心知肚明,决定手算一次 Variational Inference for Bayesian GMMs 目的就是达到如下的智能效果,扔进去六个高斯,最后拟合结果成了两个高斯,当然,其 ...
- 变分推断(Variational Inference)
(学习这部分内容大约需要花费1.1小时) 摘要 在我们感兴趣的大多数概率模型中, 计算后验边际或准确计算归一化常数都是很困难的. 变分推断(variational inference)是一个近似计算这 ...
- Improved Variational Inference with Inverse Autoregressive Flow
目录 概 主要内容 代码 Kingma D., Salimans T., Jozefowicz R., Chen X., Sutskever I. and Welling M. Improved Va ...
- Variational Inference with Normalizing Flow
目录 概 主要内容 一些合适的可逆变换 代码 Rezende D., Mohamed S. Variational Inference with Normalizing Flow. ICML, 201 ...
- Variational Inference
作者:孙九爷链接:https://www.zhihu.com/question/41765860/answer/101915528来源:知乎著作权归作者所有.商业转载请联系作者获得授权,非商业转载请注 ...
- 变分推断(Variational Inference)
变分 对于普通的函数f(x),我们可以认为f是一个关于x的一个实数算子,其作用是将实数x映射到实数f(x).那么类比这种模式,假设存在函数算子F,它是关于f(x)的函数算子,可以将f(x)映射成实数F ...
- 一文详解扩散模型:DDPM
作者:京东零售 刘岩 扩散模型讲解 前沿 人工智能生成内容(AI Generated Content,AIGC)近年来成为了非常前沿的一个研究方向,生成模型目前有四个流派,分别是生成对抗网络(Gene ...
- PRML读书会第十章 Approximate Inference(近似推断,变分推断,KL散度,平均场, Mean Field )
主讲人 戴玮 (新浪微博: @戴玮_CASIA) Wilbur_中博(1954123) 20:02:04 我们在前面看到,概率推断的核心任务就是计算某分布下的某个函数的期望.或者计算边缘概率分布.条件 ...
- Variational Bayes
一.前言 变分贝叶斯方法最早由Matthew J.Beal在他的博士论文<Variational Algorithms for Approximate Bayesian Inference> ...
随机推荐
- ISCTF 2022
Re SigninReverse ida 64 位 打开程序,即可获得flag ISCTF{27413241-9eab-41e2-aca1-88fe8b525956} ezbase # coding= ...
- 用户地址管理---新增、设置默认地址、根据id查询所有地址、查询默认地址、查询指定用户的全部地址
导入用户地址簿相关功能代码 需求分析: 地址簿,指的是移动端消费者用户的地址信息,用户登录成功后可以维护自己的地址信息.同一个用户可以有多个地址信息,但是只能有一个默认地址. 用户的地址信息会存储在a ...
- Python——高级数据类型(七)
1. 列表数据类型的声明与访问 # coding=utf-8 #列表数据类型的声明与访问 my_list =[1,2,3,4,5] # 列表中的元素 print (my_list) # 0 1 2 3 ...
- CSS 基础属性篇组成及作用
#### 学习目标- css属性和属性值的定义- css文本属性- css列表属性- css背景属性- css边框属性- css浮动属性##### 一.css属性和属性值的定义>属性:属性是指定 ...
- 【Visual Leak Detector】QT 中 VLD 输出解析(四)
说明 使用 VLD 内存泄漏检测工具辅助开发时整理的学习笔记. 目录 说明 1. 使用方式 2. 测试代码 3. 使用 32 bit 编译器时的输出 4. 使用 64 bit 编译器时的输出 5. 输 ...
- 拒绝“爆雷”!GaussDB(for MySQL)新上线了这个功能
摘要:智能把控大数据量查询,防患系统奔溃于未然. 本文分享自华为云社区<拒绝"爆雷"!GaussDB(for MySQL)新上线了这个功能>,作者:GaussDB 数据 ...
- 基于OCR进行Bert独立语义纠错实践
摘要:本案例我们利用视频字幕识别中的文字检测与识别模型,增加预训练Bert进行纠错 本文分享自华为云社区<Bert特调OCR>,作者:杜甫盖房子. 做这个项目的初衷是发现图比较糊/检测框比 ...
- 教程 - 在 Vue3+Ts 中引入 CesiumJS 的最佳实践@2023
目录 1. 本篇适用范围与目的 1.1. 适用范围 1.2. 目的 2. 牛刀小试 - 先看到地球 2.1. 创建 Vue3 - TypeScript 工程并安装 cesium 2.2. 清理不必要的 ...
- Rust中的derive属性详解
1. Rust中的derive是什么? 在Rust语言中,derive是一个属性,它可以让编译器为一些特性提供基本的实现.这些特性仍然可以手动实现,以获得更复杂的行为. 2. derive的出现解决了 ...
- C#模拟C++模板特化对类型的值的支持
概述 C++的模板相比于C#,有很多地方都更加的灵活(虽然代价是降低了编译速度),比如C++支持变长参数模板.支持枚举.int等类型的值作为模板参数. C++支持枚举.int等类型的值作为模板参数,为 ...