很多次翻看DDPM,始终不太能理解论文中提到的\(\text{Variational Inference}\)到底是如何在这个工作中起到作用。五一假期在家,无意间又刷到徐亦达老师早些年录制的理论视频,没想到其中也有介绍这部分的内容。老师的上课方式总是娓娓道来,把每一步都讲解得很仔细。本文记录一下个人对开头问题的思考。

Background

如果需要简略地介绍一下DDPM这个工作,可能会用以下几句话简单地描述:DDPMMarkov的形式对数据(图片)“扩散过程”建模,使用神经网络进行训练拟合,学习数据的概率分布。

所以对于生成任务来说,希望从给定数据中学习到的是数据的潜在信息。比如图片生成,在给定一些图片后,模型学习到的是“正常图片长什么样子”,如:

  1. 一张包含手机正面的图片会有【手机屏幕】;
  2. 一张包含猫咪的图片会有人们观察到的猫咪模样;
  3. ...

对于图片中每个像素点和附近的像素点,进行“合理”布局,才能生成“符合人们认知的图片”。

图片生成能像常见的机器学习任务如分类任务、回归任务,能基于maximize likelihood的形式来训练么?

结论是很难,先回顾如何做maximum likelihood。给定一批数据,首先需要假定数据服从的分布,接着写出似然函数,之后直接通过解析解的形式或是梯度下降的形式,求出分布。

问题就出在假定分布这一步,没有人知道图片客观上服从什么分布。那如果使用神经网络直接拟合可以么?这好像也不现实,拿一张512*512*3的图片来说,网络输出层共有约75w的数值。

对于图片生成还有另外一个问题,世界上的图片太多了,目之所及稍做处理,皆为图片。即便使用神经网络能拟合,最后生成的图片很难存在多样性。

那目前图片生成模型都是怎么做的,比如VAE或是本文即将要介绍的Diffusion Model,它们学习的都是数据分布\(p(x)\),但直接求\(p(x)\)这么麻烦,需要怎么做?这其实也是\(\text{Variational Inference}\)的核心思想,“曲线救国”,通过引入其它分布,将原本难以优化的问题转变为可优化问题。

ELOB

先把上述提到的所有背景先抛开,研究一下\(p(x)\),看看能得到什么有意思的结论。

a. 基于条件概率分布,引入新的随机变量\(z\):\(p(x) = \frac{p(x, z)}{p(z\mid x)}\);

b. 对于两边同时取\(\ln\),等式依然成立,因此有:\(\ln{p(x)} = \ln{\frac{p(x, z)}{p(z \mid x)}}\);

c. 右边分子分母同乘以\(q(z)\):\(\ln{p(x)} = \ln{\frac{p(x, z) * q(z)}{p(z \mid x) * q(z)}} = \ln{\left(\frac{p(x, z)}{q(z)} * \frac{q(z)}{p(z \mid x)}\right)} = \ln{\frac{p(x, z)}{q(z)}} + \ln{\frac{q(z)}{p(z \mid x)}}\)

d. 再次,对于上式左右两边求关于\(q(z)\)的期望,等式依然成立:

\[\begin{aligned}
&\mathbb{E}_{z\sim q(z)}{[\ln{p(x)}]} = \mathbb{E}_{z\sim q(z)}{(\ln{\frac{p(x, z)}{q(z)}} + \ln{\frac{q(z)}{p(z \mid x)}})} \\
\iff & \int_z q(z)\ln{p(x)}dz = \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz + \int_z q(z)\ln{\frac{q(z)}{p(z \mid x)}}dz \\
\iff & \ln{p(x)} = \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz + \int_z q(z)\ln{\frac{q(z)}{p(z \mid x)}}dz
\end{aligned}
\tag{1}
\]

一系列变换后,\((1)\)式是最后的推导结果,等式右边由两个项组成。第二个项\(\int_z q(z)\ln{\frac{q(z)}{p(z \mid x)}}dz\),叫做KL散度,它被用来衡量两个分布之间的“距离”,性质是值不小于0

这样一来,通过\((1)\)可以得到不等式\((2)\):

\[\begin{equation*}
\ln{p(x)} \geq \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz
\end{equation*}
\tag{2}
\]

\((1)\)式右边的第一项,同时也是\((2)\)式的右边项,被学者们叫做\(\text{ELBO(Evidence Lower Bound)}\)。

Objective Function

上述推导的\((2)\)式可以被视作“定理”一般的存在,即对于某个分布的对数形式,总可以找到它的下界。

那\((2)\)式可以用来做什么?在Background中提到,图片生成任务中的\(p(x)\)想要对它做maximum likelihood根本无法做起。目标依然是最大化\(p(x)\),但有了\((2)\)式,求解的目标可以转移到最大化它的下界\(\text{ELBO}\)。

这也是论文中提到的:

This paper presents progress in diffusion probabilistic models. A diffusion probabilistic model (which we will call a “diffusion model” for brevity) is a parameterized Markov chain trained using variational inference to produce samples matching the data after finite time.

接下来,回到论文中,看看是如何一步步推导出DDPM的优化目标。\((3)\)式直接摘录于论文:

\[\begin{equation*}
\ln{p(x)} \geq \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz = \mathbb{E}_{z \sim q(z)}\left[\ln{\frac{p(x,z)}{q(z)}}\right]
\end{equation*}
\tag{2}
\]
\[\begin{equation*}
\mathbb{E}\left[-\log p_\theta\left(\mathbf{x}_0\right)\right] \leq \mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right]=\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t \geq 1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right]=: L
\end{equation*}
\tag{3}
\]

下面一项项地对\((3)\) 进行拆解,并且将它与\((2)\)比对,能帮助更好地理解:

  1. \((3)\)不等号左边的\(\mathbb{E}\left[-\log p_\theta\left(\mathbf{x}_0\right)\right]\)进一步化简就是\(-\log p_\theta\left(\mathbf{x}_0\right)\)。其中,\(p_\theta\left(\mathbf{x}_0\right)\)便是模型要学习的最终目标:图像的分布,\(\theta\)是模型的参数,\(\mathbf{x}_0\)是图片;

  2. \((2)\)式的左右两边同时加上符号,\(\geq\)变为\(\leq\);

  3. 看\((3)\)不等式右边部分,\(\mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right]\)

    1. 很明显,\(q(\mathbf{x}_{1:T} \mid \mathbf{x}_0)\)相当于\((2)\)中引入的额外分布\(q(z)\)。对于\(z\),在生成模型中会给它一个称呼:隐变量\((\text{latent})\)。实际上,在diffusion models里,对\(\mathbf{x}_0\)加噪后的\(\mathbf{x}_1,\mathbf{x}_2,\ldots, \mathbf{x}_T\)就可以看作隐变量,那不妨记作\(z := \{\mathbf{x}_1,\mathbf{x}_2,\ldots, \mathbf{x}_T\}\);

    2. \(p_\theta\left(\mathbf{x}_{0: T}\right) = p_\theta\left(\mathbf{x}_{0}, \mathbf{x}_{1}, \ldots, \mathbf{x}_{T}\right)\),是关于\(\mathbf{x}_0, z\)的联合概率分布,因为选用马尔代夫链建模,那么依据马尔可夫链的性质,论文定义:

\[\begin{equation*}
\begin{aligned}
q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)&:=\prod_{t=1}^T q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right) \\
p_\theta\left(\mathbf{x}_{0: T}\right)&:=p\left(\mathbf{x}_T\right) \prod_{t=1}^T p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)
\end{aligned}
\end{equation*}
\tag{4}
\]
  1. 将\((4)\)带入\((3)\)不等式右边的第一项,得到\(L\):
\[\begin{equation*}
\begin{aligned}
&\mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right] \\
=&\mathbb{E}_q\left[-\log \frac{p\left(\mathbf{x}_T\right) \prod_{t=1}^T p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{\prod_{t=1}^T q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right] \\
=&\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t \geq 1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right] := L
\end{aligned}
\end{equation*}
\]

到目前为止,经过了很多轮的变换以及数学公式,先捋一遍,再往下。\(L\)是一个替代的优化目标,

\[\mathop{\arg\min}{(L)} \iff \mathop{\arg\min}{(-\ln{p}_{\theta}(\mathbf{x}_0))} \iff \mathop{\arg\max}{(\ln{p}_{\theta}(\mathbf{x}_0))}
\]

接下来,论文中对\(L\)进行了重写,以下步骤直接摘录自论文\(\text{Appendix A}\)

\[\begin{equation*}
\begin{aligned}
L & =\mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right] \\ & =\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t \geq 1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right] \\ & =\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}-\log \frac{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}\right] \\
&=\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t>1} \log \left[\frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)} \cdot \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}\right]-\log \frac{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}\right]
\end{aligned}
\end{equation*}
\tag{5}
\]

倒数两步的变换发生在第二项,具体依据为:

\[\begin{aligned}
q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)
=& \frac{q\left(\mathbf{x}_t, \mathbf{x}_{t-1}\right)}{q\left(\mathbf{x}_{t-1}\right)} \\
=& \frac{q\left(\mathbf{x}_t, \mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) *q(\mathbf{x}_{0})}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) * q(\mathbf{x}_{0})} \\
=& \frac{q\left(\mathbf{x}_t, \mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) }{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}
\end{aligned}
\quad \Rightarrow \quad
\begin{aligned}
&\sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)} \\
=& \sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t, \mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) } \cdot {q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)} \\
=& \sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)} \cdot \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}
\end{aligned}
\]

接着对\((5)\)进行改写得到最终形式\((6)\):

\[\begin{aligned}
L &=\mathbb{E}_q\left[-\log \frac{p\left(\mathbf{x}_T\right)}{q\left(\mathbf{x}_T \mid \mathbf{x}_0\right)}-\sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)\right] \\
&=\mathbb{E}_q[\underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right) \| p\left(\mathbf{x}_T\right)\right)}_{L_T}+\sum_{t>1} \underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\right)}_{L_{t-1}} \underbrace{-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}_{L_0}]
\end{aligned}
\tag{6}
\]

Summary

太好了,对于\((6)\)来说,它最起码是个可以优化的目标函数了,因为论文中定义马尔可夫链相邻状态的转变是服从高斯分布的。当然在论文中,\((6)\)还会进一步被改写,得到更加精简的\(\text{loss function}\)形式。

DDPM是应用\(\text{variational inference}\)进行优化求解的典型例子,很值得借鉴学习。

Reference

Part2: DDPM as Example of Variational Inference的更多相关文章

  1. [Bayesian] “我是bayesian我怕谁”系列 - Variational Inference

    涉及的领域可能有些生僻,骗不了大家点赞.但毕竟是人工智能的主流技术,在园子却成了非主流. 不可否认的是:乃值钱的技术,提高身价的技术,改变世界观的技术. 关于变分,通常的课本思路是: GMM --&g ...

  2. [Bayes] Variational Inference for Bayesian GMMs

    为了世界和平,为了心知肚明,决定手算一次 Variational Inference for Bayesian GMMs 目的就是达到如下的智能效果,扔进去六个高斯,最后拟合结果成了两个高斯,当然,其 ...

  3. 变分推断(Variational Inference)

    (学习这部分内容大约需要花费1.1小时) 摘要 在我们感兴趣的大多数概率模型中, 计算后验边际或准确计算归一化常数都是很困难的. 变分推断(variational inference)是一个近似计算这 ...

  4. Improved Variational Inference with Inverse Autoregressive Flow

    目录 概 主要内容 代码 Kingma D., Salimans T., Jozefowicz R., Chen X., Sutskever I. and Welling M. Improved Va ...

  5. Variational Inference with Normalizing Flow

    目录 概 主要内容 一些合适的可逆变换 代码 Rezende D., Mohamed S. Variational Inference with Normalizing Flow. ICML, 201 ...

  6. Variational Inference

    作者:孙九爷链接:https://www.zhihu.com/question/41765860/answer/101915528来源:知乎著作权归作者所有.商业转载请联系作者获得授权,非商业转载请注 ...

  7. 变分推断(Variational Inference)

    变分 对于普通的函数f(x),我们可以认为f是一个关于x的一个实数算子,其作用是将实数x映射到实数f(x).那么类比这种模式,假设存在函数算子F,它是关于f(x)的函数算子,可以将f(x)映射成实数F ...

  8. 一文详解扩散模型:DDPM

    作者:京东零售 刘岩 扩散模型讲解 前沿 人工智能生成内容(AI Generated Content,AIGC)近年来成为了非常前沿的一个研究方向,生成模型目前有四个流派,分别是生成对抗网络(Gene ...

  9. PRML读书会第十章 Approximate Inference(近似推断,变分推断,KL散度,平均场, Mean Field )

    主讲人 戴玮 (新浪微博: @戴玮_CASIA) Wilbur_中博(1954123) 20:02:04 我们在前面看到,概率推断的核心任务就是计算某分布下的某个函数的期望.或者计算边缘概率分布.条件 ...

  10. Variational Bayes

    一.前言 变分贝叶斯方法最早由Matthew J.Beal在他的博士论文<Variational Algorithms for Approximate Bayesian Inference> ...

随机推荐

  1. Agora Flat:在线教室的开源初体验

    开发者其实很多时候都非常向往开源,开源领域的大佬也特别多,我们谈不上有多资深,也是一边探索一边做.同时,也希望可以借这次机会把我们摸索到的一些经验分享给大家. 01 Flat 是什么 Flat 是一个 ...

  2. 关于MySQL建立库表时大写自动转换为小写的解决方案

    mysql 5.6以上windows对大小写敏感要在my.ini中的[mysqld]下面设置lower_case_table_names = 2 网上有的要改成0 亲测报错 [○・`Д´・ ○]

  3. CAS乐观锁(原子操作)

    更多内容,前往 IT-BLOG 锁主要分为两种:乐观锁和悲观锁,而 synchronized 就属于一种悲观锁,每次在操作数据前都会加锁.乐观锁是指:乐观的认为自己在操作数据时,别人不会对当前数据进行 ...

  4. 记一次 .NET 某企业 ERP网站系统 崩溃分析

    一:背景 1. 讲故事 前段时间收到了一个朋友的求助,说他的ERP网站系统会出现偶发性崩溃,找了好久也没找到是什么原因,让我帮忙看下,其实崩溃好说,用 procdump 自动抓一个就好,拿到 dump ...

  5. vue 之 computed方法自带缓存踩坑1

    使用场景:ant-vue 穿梭框使用 页面使用computed方法处理组织结构数据,退出页面时,对加载数据做了set null 操作,再次进入页面时,穿梭框只显示数据,无法做左右穿梭功能. 原因:co ...

  6. [JavaScript]Base64 ←→ 图像

    1 Base64 → 图像 [demo1] document.getElementById('img').setAttribute( 'src', ' ...

  7. 互联网常用API收集

    百度车联网API:http://lbsyun.baidu.com/index.php?title=car

  8. Android Studio 样式和主题背景

    样式和主题背景 转载自   Styles and Themes  |  Android Developers 借助 Android 中的样式和主题背景,您可以将应用设计的细节与界面的结构和行为分开,其 ...

  9. Yolov8离谱报错

    YoloV8离谱报错 ​ 今天下午给一个研究生小姐姐跑数据集,用的是yolov8在恒源云上租的4070的GPU服务器,跑垃圾分类数据集(https://blog.csdn.net/m0_5488250 ...

  10. Nacos Prometheus Grafana

    目录 运维篇:springboot与微服务组件nacos Linux服务器部署springboot项目 Springboot启动服务指定参数 Linux & Win 监控运行中的服务 Prom ...