很多次翻看DDPM,始终不太能理解论文中提到的\(\text{Variational Inference}\)到底是如何在这个工作中起到作用。五一假期在家,无意间又刷到徐亦达老师早些年录制的理论视频,没想到其中也有介绍这部分的内容。老师的上课方式总是娓娓道来,把每一步都讲解得很仔细。本文记录一下个人对开头问题的思考。

Background

如果需要简略地介绍一下DDPM这个工作,可能会用以下几句话简单地描述:DDPMMarkov的形式对数据(图片)“扩散过程”建模,使用神经网络进行训练拟合,学习数据的概率分布。

所以对于生成任务来说,希望从给定数据中学习到的是数据的潜在信息。比如图片生成,在给定一些图片后,模型学习到的是“正常图片长什么样子”,如:

  1. 一张包含手机正面的图片会有【手机屏幕】;
  2. 一张包含猫咪的图片会有人们观察到的猫咪模样;
  3. ...

对于图片中每个像素点和附近的像素点,进行“合理”布局,才能生成“符合人们认知的图片”。

图片生成能像常见的机器学习任务如分类任务、回归任务,能基于maximize likelihood的形式来训练么?

结论是很难,先回顾如何做maximum likelihood。给定一批数据,首先需要假定数据服从的分布,接着写出似然函数,之后直接通过解析解的形式或是梯度下降的形式,求出分布。

问题就出在假定分布这一步,没有人知道图片客观上服从什么分布。那如果使用神经网络直接拟合可以么?这好像也不现实,拿一张512*512*3的图片来说,网络输出层共有约75w的数值。

对于图片生成还有另外一个问题,世界上的图片太多了,目之所及稍做处理,皆为图片。即便使用神经网络能拟合,最后生成的图片很难存在多样性。

那目前图片生成模型都是怎么做的,比如VAE或是本文即将要介绍的Diffusion Model,它们学习的都是数据分布\(p(x)\),但直接求\(p(x)\)这么麻烦,需要怎么做?这其实也是\(\text{Variational Inference}\)的核心思想,“曲线救国”,通过引入其它分布,将原本难以优化的问题转变为可优化问题。

ELOB

先把上述提到的所有背景先抛开,研究一下\(p(x)\),看看能得到什么有意思的结论。

a. 基于条件概率分布,引入新的随机变量\(z\):\(p(x) = \frac{p(x, z)}{p(z\mid x)}\);

b. 对于两边同时取\(\ln\),等式依然成立,因此有:\(\ln{p(x)} = \ln{\frac{p(x, z)}{p(z \mid x)}}\);

c. 右边分子分母同乘以\(q(z)\):\(\ln{p(x)} = \ln{\frac{p(x, z) * q(z)}{p(z \mid x) * q(z)}} = \ln{\left(\frac{p(x, z)}{q(z)} * \frac{q(z)}{p(z \mid x)}\right)} = \ln{\frac{p(x, z)}{q(z)}} + \ln{\frac{q(z)}{p(z \mid x)}}\)

d. 再次,对于上式左右两边求关于\(q(z)\)的期望,等式依然成立:

\[\begin{aligned}
&\mathbb{E}_{z\sim q(z)}{[\ln{p(x)}]} = \mathbb{E}_{z\sim q(z)}{(\ln{\frac{p(x, z)}{q(z)}} + \ln{\frac{q(z)}{p(z \mid x)}})} \\
\iff & \int_z q(z)\ln{p(x)}dz = \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz + \int_z q(z)\ln{\frac{q(z)}{p(z \mid x)}}dz \\
\iff & \ln{p(x)} = \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz + \int_z q(z)\ln{\frac{q(z)}{p(z \mid x)}}dz
\end{aligned}
\tag{1}
\]

一系列变换后,\((1)\)式是最后的推导结果,等式右边由两个项组成。第二个项\(\int_z q(z)\ln{\frac{q(z)}{p(z \mid x)}}dz\),叫做KL散度,它被用来衡量两个分布之间的“距离”,性质是值不小于0

这样一来,通过\((1)\)可以得到不等式\((2)\):

\[\begin{equation*}
\ln{p(x)} \geq \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz
\end{equation*}
\tag{2}
\]

\((1)\)式右边的第一项,同时也是\((2)\)式的右边项,被学者们叫做\(\text{ELBO(Evidence Lower Bound)}\)。

Objective Function

上述推导的\((2)\)式可以被视作“定理”一般的存在,即对于某个分布的对数形式,总可以找到它的下界。

那\((2)\)式可以用来做什么?在Background中提到,图片生成任务中的\(p(x)\)想要对它做maximum likelihood根本无法做起。目标依然是最大化\(p(x)\),但有了\((2)\)式,求解的目标可以转移到最大化它的下界\(\text{ELBO}\)。

这也是论文中提到的:

This paper presents progress in diffusion probabilistic models. A diffusion probabilistic model (which we will call a “diffusion model” for brevity) is a parameterized Markov chain trained using variational inference to produce samples matching the data after finite time.

接下来,回到论文中,看看是如何一步步推导出DDPM的优化目标。\((3)\)式直接摘录于论文:

\[\begin{equation*}
\ln{p(x)} \geq \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz = \mathbb{E}_{z \sim q(z)}\left[\ln{\frac{p(x,z)}{q(z)}}\right]
\end{equation*}
\tag{2}
\]
\[\begin{equation*}
\mathbb{E}\left[-\log p_\theta\left(\mathbf{x}_0\right)\right] \leq \mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right]=\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t \geq 1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right]=: L
\end{equation*}
\tag{3}
\]

下面一项项地对\((3)\) 进行拆解,并且将它与\((2)\)比对,能帮助更好地理解:

  1. \((3)\)不等号左边的\(\mathbb{E}\left[-\log p_\theta\left(\mathbf{x}_0\right)\right]\)进一步化简就是\(-\log p_\theta\left(\mathbf{x}_0\right)\)。其中,\(p_\theta\left(\mathbf{x}_0\right)\)便是模型要学习的最终目标:图像的分布,\(\theta\)是模型的参数,\(\mathbf{x}_0\)是图片;

  2. \((2)\)式的左右两边同时加上符号,\(\geq\)变为\(\leq\);

  3. 看\((3)\)不等式右边部分,\(\mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right]\)

    1. 很明显,\(q(\mathbf{x}_{1:T} \mid \mathbf{x}_0)\)相当于\((2)\)中引入的额外分布\(q(z)\)。对于\(z\),在生成模型中会给它一个称呼:隐变量\((\text{latent})\)。实际上,在diffusion models里,对\(\mathbf{x}_0\)加噪后的\(\mathbf{x}_1,\mathbf{x}_2,\ldots, \mathbf{x}_T\)就可以看作隐变量,那不妨记作\(z := \{\mathbf{x}_1,\mathbf{x}_2,\ldots, \mathbf{x}_T\}\);

    2. \(p_\theta\left(\mathbf{x}_{0: T}\right) = p_\theta\left(\mathbf{x}_{0}, \mathbf{x}_{1}, \ldots, \mathbf{x}_{T}\right)\),是关于\(\mathbf{x}_0, z\)的联合概率分布,因为选用马尔代夫链建模,那么依据马尔可夫链的性质,论文定义:

\[\begin{equation*}
\begin{aligned}
q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)&:=\prod_{t=1}^T q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right) \\
p_\theta\left(\mathbf{x}_{0: T}\right)&:=p\left(\mathbf{x}_T\right) \prod_{t=1}^T p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)
\end{aligned}
\end{equation*}
\tag{4}
\]
  1. 将\((4)\)带入\((3)\)不等式右边的第一项,得到\(L\):
\[\begin{equation*}
\begin{aligned}
&\mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right] \\
=&\mathbb{E}_q\left[-\log \frac{p\left(\mathbf{x}_T\right) \prod_{t=1}^T p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{\prod_{t=1}^T q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right] \\
=&\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t \geq 1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right] := L
\end{aligned}
\end{equation*}
\]

到目前为止,经过了很多轮的变换以及数学公式,先捋一遍,再往下。\(L\)是一个替代的优化目标,

\[\mathop{\arg\min}{(L)} \iff \mathop{\arg\min}{(-\ln{p}_{\theta}(\mathbf{x}_0))} \iff \mathop{\arg\max}{(\ln{p}_{\theta}(\mathbf{x}_0))}
\]

接下来,论文中对\(L\)进行了重写,以下步骤直接摘录自论文\(\text{Appendix A}\)

\[\begin{equation*}
\begin{aligned}
L & =\mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right] \\ & =\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t \geq 1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right] \\ & =\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}-\log \frac{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}\right] \\
&=\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t>1} \log \left[\frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)} \cdot \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}\right]-\log \frac{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}\right]
\end{aligned}
\end{equation*}
\tag{5}
\]

倒数两步的变换发生在第二项,具体依据为:

\[\begin{aligned}
q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)
=& \frac{q\left(\mathbf{x}_t, \mathbf{x}_{t-1}\right)}{q\left(\mathbf{x}_{t-1}\right)} \\
=& \frac{q\left(\mathbf{x}_t, \mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) *q(\mathbf{x}_{0})}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) * q(\mathbf{x}_{0})} \\
=& \frac{q\left(\mathbf{x}_t, \mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) }{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}
\end{aligned}
\quad \Rightarrow \quad
\begin{aligned}
&\sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)} \\
=& \sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t, \mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) } \cdot {q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)} \\
=& \sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)} \cdot \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}
\end{aligned}
\]

接着对\((5)\)进行改写得到最终形式\((6)\):

\[\begin{aligned}
L &=\mathbb{E}_q\left[-\log \frac{p\left(\mathbf{x}_T\right)}{q\left(\mathbf{x}_T \mid \mathbf{x}_0\right)}-\sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)\right] \\
&=\mathbb{E}_q[\underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right) \| p\left(\mathbf{x}_T\right)\right)}_{L_T}+\sum_{t>1} \underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\right)}_{L_{t-1}} \underbrace{-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}_{L_0}]
\end{aligned}
\tag{6}
\]

Summary

太好了,对于\((6)\)来说,它最起码是个可以优化的目标函数了,因为论文中定义马尔可夫链相邻状态的转变是服从高斯分布的。当然在论文中,\((6)\)还会进一步被改写,得到更加精简的\(\text{loss function}\)形式。

DDPM是应用\(\text{variational inference}\)进行优化求解的典型例子,很值得借鉴学习。

Reference

Part2: DDPM as Example of Variational Inference的更多相关文章

  1. [Bayesian] “我是bayesian我怕谁”系列 - Variational Inference

    涉及的领域可能有些生僻,骗不了大家点赞.但毕竟是人工智能的主流技术,在园子却成了非主流. 不可否认的是:乃值钱的技术,提高身价的技术,改变世界观的技术. 关于变分,通常的课本思路是: GMM --&g ...

  2. [Bayes] Variational Inference for Bayesian GMMs

    为了世界和平,为了心知肚明,决定手算一次 Variational Inference for Bayesian GMMs 目的就是达到如下的智能效果,扔进去六个高斯,最后拟合结果成了两个高斯,当然,其 ...

  3. 变分推断(Variational Inference)

    (学习这部分内容大约需要花费1.1小时) 摘要 在我们感兴趣的大多数概率模型中, 计算后验边际或准确计算归一化常数都是很困难的. 变分推断(variational inference)是一个近似计算这 ...

  4. Improved Variational Inference with Inverse Autoregressive Flow

    目录 概 主要内容 代码 Kingma D., Salimans T., Jozefowicz R., Chen X., Sutskever I. and Welling M. Improved Va ...

  5. Variational Inference with Normalizing Flow

    目录 概 主要内容 一些合适的可逆变换 代码 Rezende D., Mohamed S. Variational Inference with Normalizing Flow. ICML, 201 ...

  6. Variational Inference

    作者:孙九爷链接:https://www.zhihu.com/question/41765860/answer/101915528来源:知乎著作权归作者所有.商业转载请联系作者获得授权,非商业转载请注 ...

  7. 变分推断(Variational Inference)

    变分 对于普通的函数f(x),我们可以认为f是一个关于x的一个实数算子,其作用是将实数x映射到实数f(x).那么类比这种模式,假设存在函数算子F,它是关于f(x)的函数算子,可以将f(x)映射成实数F ...

  8. 一文详解扩散模型:DDPM

    作者:京东零售 刘岩 扩散模型讲解 前沿 人工智能生成内容(AI Generated Content,AIGC)近年来成为了非常前沿的一个研究方向,生成模型目前有四个流派,分别是生成对抗网络(Gene ...

  9. PRML读书会第十章 Approximate Inference(近似推断,变分推断,KL散度,平均场, Mean Field )

    主讲人 戴玮 (新浪微博: @戴玮_CASIA) Wilbur_中博(1954123) 20:02:04 我们在前面看到,概率推断的核心任务就是计算某分布下的某个函数的期望.或者计算边缘概率分布.条件 ...

  10. Variational Bayes

    一.前言 变分贝叶斯方法最早由Matthew J.Beal在他的博士论文<Variational Algorithms for Approximate Bayesian Inference> ...

随机推荐

  1. 关于EasyExcel的数据导入和单sheet和多sheet导出

    读写Excel基本代码 直接复制不一定能用 实体类 @ExcelIgnore 在导出操作中不会被导出 @ExcelProperty 在导入过程中 可以根据导入模板自动匹配字段, 在导出过程中可用于设置 ...

  2. 这样封装echarts简单好用

    为什么要去封装echarts? 在我们的项目中,有很多的地方都使用了echarts图表展示数据. 在有些场景,一个页面有十多个的echarts图. 这些echarts只是展示的指标不一样. 如果我们每 ...

  3. python3各数据类型的常用方法

    python3数据类型包括: 数字.字符串str.列表list.元组tuple.字典dict.集合set.布尔bool 1.字符串(str)-可变-用"".''定义 (1)uppe ...

  4. SwitchHosts operation not permitted 解决方案--亲测有效

    SwitchHost!是帮助我们管理Hosts的工具,可以帮助我们做域名解析, 弥补了如果要修改域名还要改计算机C:\Windows\System32\drivers\etc位置下的hosts文件的弊 ...

  5. 第一推动|2023年VSCode插件最新推荐(54款)

    本文介绍前端开发领域常用的一些VSCode插件,插件是VSCode最重要的组成部分之一,本文列出了我自己在以往工作经验中积累的54款插件,个人觉得这些插件是有用或有趣的,根据它们的作用,我粗略的把它们 ...

  6. 华为 A800-9000 服务器 离线安装MindX DL

    MindX DL(昇腾深度学习组件)是支持 Atlas 800 训练服务器.Atlas 800 推理服务器的深度学习组件参考设计,提供昇腾 AI 处理器资源管理和监控.昇腾 AI 处理器优化调度.分布 ...

  7. pandas之聚合函数

    在<Python Pandas窗口函数>一节,我们重点介绍了窗口函数.我们知道,窗口函数可以与聚合函数一起使用,聚合函数指的是对一组数据求总和.最大值.最小值以及平均值的操作,本节重点讲解 ...

  8. SLBR通过自校准的定位和背景细化来去除可见的水印

    一.简要介绍   本文简要介绍了论文"Visible Watermark Removal via Self-calibrated Localization and Background Re ...

  9. 【Vue项目 + 自写java后端】尚品汇(七)后台项目 ElementUI 表单验证 + 三级联动

    ElementUI 表单验证 1 标准验证规则 Form 组件提供了表单验证的功能,只需要通过 rules 属性传入约定的验证规则,并将 Form-Item 的 prop 属性设置为需校验的字段名即可 ...

  10. Java关键字以及标识符

    Java中有许多关键字,关键字是什么意思呢? 我用自己的分析来表达一下吧. Java就是源自于生活的,我们都有自己的名字.所以它也会有许多的名字,每个名字都有各自不同的特性(作用),都是系统定义好的. ...