论文原文:Auto-Encoding Variational Bayes [OpenReview (ICLR 2014) | arXiv]

本文记录了我在学习 VAE 过程中的一些公式推导和思考。如果你希望从头开始学习 VAE,建议先看一下苏剑林的博客(本文末尾有链接)。

VAE 的整体框架

VAE 认为,随机变量 \(\boldsymbol{x} \sim p(\boldsymbol{x})\) 由两个随机过程得到:

  1. 根据先验分布 \(p(\boldsymbol{z})\) 生成隐变量 \(\boldsymbol{z}\)。
  2. 根据条件分布 \(p(\boldsymbol{x} | \boldsymbol{z})\) 由 \(\boldsymbol{z}\) 得到 \(\boldsymbol{x}\)。

于是 \(p(\boldsymbol{x}, \boldsymbol{z}) = p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})\) 就是我们所需要的生成模型。

一种朴素的想法是:先用随机数生成器生成隐变量 \(\boldsymbol{z}\),然后用 \(p(\boldsymbol{x} | \boldsymbol{z})\) 从 \(\boldsymbol{z}\) 中生成出(或者说重构出) \(\boldsymbol{x}\),通过最小化重构损失来训练模型。这个想法的问题在于:我们无法找到生成的样本与原始样本之间的对应关系,重构损失算不了,无法训练。

VAE 的做法是引入后验分布 \(p(\boldsymbol{z} | \boldsymbol{x})\),训练过程变为:

  1. 采样一批原始样本 \(\boldsymbol{x}\)。
  2. 用 \(p(\boldsymbol{z} | \boldsymbol{x})\) 获得每个样本 \(\boldsymbol{x}\) 对应的隐变量 \(\boldsymbol{z}\)。
  3. 用 \(p(\boldsymbol{x} | \boldsymbol{z})\) 从隐变量 \(\boldsymbol{z}\) 中重构出 \(\boldsymbol{x}\),通过最小化重构损失来训练模型。

从这个角度来看,\(p(\boldsymbol{z} | \boldsymbol{x})\) 相当于编码器,\(p(\boldsymbol{x} | \boldsymbol{z})\) 相当于解码器,训练结束后只需要保留解码器 \(p(\boldsymbol{x} | \boldsymbol{z})\) 即可。

除了重构损失以外,VAE 还有一项 KL 散度损失,希望近似的后验分布 \(q(\boldsymbol{z} | \boldsymbol{x})\) 尽量接近先验分布 \(p(\boldsymbol{z})\),即最小化二者的 KL 散度。

变分下界的推导

现有 \(N\) 个由分布 \(P(\boldsymbol{x}; \boldsymbol{\theta})\) 生成的样本 \(\boldsymbol{x}^{(1)}, \ldots, \boldsymbol{x}^{(N)}\),我们可以使用极大似然估计从这些样本中估计出分布的参数 \(\boldsymbol{\theta}\),即

\[\begin{aligned}
\boldsymbol{\theta}
& = \operatorname*{argmax}_{\boldsymbol{\theta}} p(\boldsymbol{x}^{(1)}; \boldsymbol{\theta}) \cdots p(\boldsymbol{x}^{(N)}; \boldsymbol{\theta}) \\
& = \operatorname*{argmax}_{\boldsymbol{\theta}} \ln(p(\boldsymbol{x}^{(1)}; \boldsymbol{\theta}) \cdots p(\boldsymbol{x}^{(N)}; \boldsymbol{\theta})) \\
& = \operatorname*{argmax}_{\boldsymbol{\theta}} \sum_{i=1}^n \ln p(\boldsymbol{x}^{(i)}; \boldsymbol{\theta}).
\end{aligned}
\]

后验分布 \(p(\boldsymbol{z} | \boldsymbol{x}) = \frac{p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})}{p(\boldsymbol{x})} = \frac{p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})}{\int_{\boldsymbol{z}} p(\boldsymbol{x}, \boldsymbol{z}) \mathrm{d}\boldsymbol{z}}\) 是 intractable 的,因为分母处的边缘分布 \(p(\boldsymbol{x})\) 积不出来。具体来说,联合分布 \(p(\boldsymbol{x}, \boldsymbol{z}) = p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})\) 的表达式非常复杂,\(\int_{\boldsymbol{z}} p(\boldsymbol{x}, \boldsymbol{z}) \mathrm{d}\boldsymbol{z}\) 这个积分找不到解析解。

需要使用变分推断解决后验分布无法计算的问题。我们使用一个形式已知的分布 \(q(\boldsymbol{z}|\boldsymbol{x}^{(i)}; \boldsymbol{\phi})\) 来近似后验分布 \(p(\boldsymbol{z}|\boldsymbol{x}^{(i)}; \boldsymbol{\theta})\),于是有

\[\begin{aligned}
\log p(\boldsymbol{x}^{(i)})
& = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) - \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \log p(\boldsymbol{x}^{(i)}) \cdot 1 \\
& = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\log\frac{q(\boldsymbol{z}|\boldsymbol{x}^{(i)})}{p(\boldsymbol{z}|\boldsymbol{x}^{(i)})} \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \log p(\boldsymbol{x}^{(i)}) \cdot \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\mathrm{d}\boldsymbol{z} \\
& = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\log p(\boldsymbol{x}^{(i)}) \mathrm{d}\boldsymbol{z} \\
& = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} \\
& = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log (p(\boldsymbol{z}|\boldsymbol{x}^{(i)})p(\boldsymbol{x}^{(i)}))] \mathrm{d}\boldsymbol{z} \\
& = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)}, \boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\
& = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \mathbb{E}_{\boldsymbol{z} \sim q(\boldsymbol{z}|\boldsymbol{x}^{(i)})}[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)}, \boldsymbol{z})] \\
& = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)}) \\
& \geq L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)}).
\end{aligned}
\]

利用 KL 散度大于等于 0 这一特性,我们得到了对数似然 \(\log p(\boldsymbol{x}^{(i)})\) 的一个下界 \(L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)})\),于是可以将最大化对数似然改为最大化这个下界。

这个下界可以进一步写成

\[\begin{aligned}
L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)})
& = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)}, \boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\
& = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log (p(\boldsymbol{z})p(\boldsymbol{x}^{(i)}|\boldsymbol{z}))] \mathrm{d}\boldsymbol{z} \\
& = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}) + \log p(\boldsymbol{x}^{(i)}|\boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\
& = -\int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) - \log p(\boldsymbol{z})] \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\log p(\boldsymbol{x}^{(i)}|\boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\
& = -\mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z})] + \mathbb{E}_{\boldsymbol{z} \sim q(\boldsymbol{z}|\boldsymbol{x}^{(i)})}[\log p(\boldsymbol{x}^{(i)}|\boldsymbol{z})]. \\
\end{aligned}
\]

其中的第一项是 KL 散度损失,第二项是重构损失。

KL 散度损失

使用标准正态分布作为先验分布,即 \(p(\boldsymbol{z}) = N(\boldsymbol{z}; \boldsymbol{0}, \boldsymbol{I})\)。

使用一个由 MLP 的输出来参数化的正态分布作为近似后验分布,即 \(q(\boldsymbol{z}|\boldsymbol{x}^{(i)}; \boldsymbol{\phi}) = N(\boldsymbol{z}; \boldsymbol{\mu}(\boldsymbol{x}^{(i)}; \boldsymbol{\phi}), \boldsymbol{\sigma}^2(\boldsymbol{x}^{(i)}; \boldsymbol{\phi})\boldsymbol{I})\)。

选择正态分布的好处在于 KL 散度的这个积分可以写出解析解,训练时直接按照公式计算即可,无需通过采样的方式来算积分。

由于我们选择的是各分量独立的多元正态分布,因此只需要推导一元正态分布的情形即可:

\[\begin{aligned}
\mathrm{KL}[N(z; \mu, \sigma^2), N(z; 0, 1)]
& = \int_z N(z; \mu, \sigma^2)\log\frac{N(z; \mu, \sigma^2)}{N(z; 0, 1)} \mathrm{d}z \\
& = \int_z N(z; \mu, \sigma^2) \log\frac{\frac{1}{\sqrt{2\pi}\sigma}\exp\left(-\frac{(z - \mu)^2}{2\sigma^2}\right)}{\frac{1}{\sqrt{2\pi}}\exp\left(-\frac{z^2}{2}\right)} \mathrm{d}z \\
& = \int_z N(z; \mu, \sigma^2) \log\left(\frac{1}{\sqrt{\sigma^2}}\exp\left(\frac{1}{2}\left(-\frac{(z - \mu^2)^2}{\sigma^2} + z^2\right)\right)\right) \mathrm{d}z \\
& = \frac{1}{2}\int_z N(z; \mu, \sigma^2) \left(-\log\sigma^2 - \frac{(z - \mu)^2}{\sigma^2} + z^2\right)\mathrm{d}z \\
& = \frac{1}{2}\left(-\log\sigma^2\int_z N(z; \mu, \sigma^2) \mathrm{d}z - \frac{1}{\sigma^2}\int_z N(z; \mu, \sigma^2)(z - \mu)^2\mathrm{d}z + \int_z N(z; \mu, \sigma^2)z^2\mathrm{d}z\right) \\
& = \frac{1}{2}\left(-\log\sigma^2 \cdot 1 - \frac{1}{\sigma^2} \cdot \sigma^2 + \mu^2 + \sigma^2\right) \\
& = \frac{1}{2}(-\log\sigma^2 - 1 + \mu^2 + \sigma^2).
\end{aligned}
\]

解释一下倒数第三行的三个积分:

  1. \(\int_z N(z; \mu, \sigma^2) \mathrm{d}z\) 是概率密度函数的积分,也就是 1。
  2. \(\int_z N(z; \mu, \sigma^2)(z - \mu)^2\mathrm{d}z\) 是方差的定义,也就是 \(\sigma^2\)。
  3. \(\int_z N(z; \mu, \sigma^2)z^2\mathrm{d}z\) 是正态分布的二阶矩,结果为 \(\mu^2 + \sigma^2\)。

重构损失

伯努利分布模型

当 \(\boldsymbol{x}\) 是二值向量时,可以用伯努利分布(两点分布)来建模 \(p(\boldsymbol{x}|\boldsymbol{z})\),即认为向量 \(\boldsymbol{x}\) 的每个维度都服从对应的相互独立的伯努利分布。使用一个 MLP 来计算各维度所对应的伯努利分布的参数,第 \(i\) 维伯努利分布的参数为 \(y_i = \boldsymbol{y}(\boldsymbol{z})_i\),于是有

\[p(\boldsymbol{x}|\boldsymbol{z}) = \prod_{i=1}^D y_i^{x_i}(1 - y_i)^{1 - x_i},
\]
\[\log p(\boldsymbol{x}|\boldsymbol{z}) = \sum_{i=1}^D x_i\log y_i + (1 - x_i)\log(1 - y_i).
\]

其中 \(D\) 表示向量 \(\boldsymbol{x}\) 的维度。可见此时最大化 \(\log p(\boldsymbol{x}|\boldsymbol{z})\) 等价于最小化交叉熵损失。

正态分布模型

当 \(\boldsymbol{x}\) 是实值向量时,可以用正态分布来建模 \(p(\boldsymbol{x}|\boldsymbol{z})\)。使用一个 MLP 来计算正态分布的参数,于是有

\[\begin{aligned}
p(\boldsymbol{x}|\boldsymbol{z})
& = N(\boldsymbol{x}; \boldsymbol{\mu}, \boldsymbol{\sigma}^2\boldsymbol{I}) \\
& = \prod_{i=1}^D N(x_i; \mu_i, \sigma_i^2) \\
& = \left(\prod_{i=1}^D\frac{1}{\sqrt{2\pi}\sigma_i}\right)\exp\left(\sum_{i=1}^D-\frac{(x_i - \mu_i)^2}{2\sigma_i^2}\right),
\end{aligned}
\]
\[\log p(\boldsymbol{x}|\boldsymbol{z}) = -\frac{D}{2}\log 2\pi - \frac{1}{2}\sum_{i=1}^D\log\sigma_i^2 - \frac{1}{2}\sum_{i=1}^D\frac{(x_i - \mu_i)^2}{\sigma_i^2}.
\]

很多时候我们会假设 \(\sigma_i^2\) 是一个常数,于是 MLP 只需要输出均值参数 \(\boldsymbol{\mu}\) 即可。此时有

\[\log p(\boldsymbol{x}|\boldsymbol{z}) \sim -\frac{1}{2}\sum_{i=1}^D(x_i - \mu_i)^2 = -\frac{1}{2}\|\boldsymbol{x} - \boldsymbol{\mu}(\boldsymbol{z})\|^2.
\]

可见此时最大化 \(\log p(\boldsymbol{x}|\boldsymbol{z})\) 等价于最小化 MSE 损失。

重参数化技巧

需要使用重参数化技巧解决采样 \(z\) 时不可导的问题。解决的思路是先从无参数分布中采样一个 \(\varepsilon\),再通过变换得到 \(z\)。

从 \(N(\mu, \sigma^2)\) 中采样一个 \(z\),相当于先从 \(N(0, 1)\) 中采样一个 \(\varepsilon\),然后令 \(z = \mu + \varepsilon\cdot\sigma\)。

相关知识

技巧,通过取对数把乘除变成加减:

\[\ln ab = \ln a + \ln b,\ \ln\frac{a}{b} = \ln a - \ln b.
\]

随机变量的函数的期望:

\[\mathbb{E}_{x \sim P(x)} g(x) = \int_x p(x)g(x) \mathrm{d}x,
\]

利用此公式可以将积分改写成期望的形式,这样就可以用采样的方式计算积分了(蒙特卡罗积分法)。

条件概率密度的定义:

\[p_{Y|X}(y|x) = \frac{p(x, y)}{p_X(x)},
\]

此处的 \(p\) 并不是概率而是概率密度函数,但是这个公式在形式上跟条件概率公式是一样的。

参考资料

苏剑林的 VAE 系列博客:

15 分钟了解变分推理:

变分自编码器(VAE)公式推导的更多相关文章

  1. 4.keras实现-->生成式深度学习之用变分自编码器VAE生成图像(mnist数据集和名人头像数据集)

    变分自编码器(VAE,variatinal autoencoder)   VS    生成式对抗网络(GAN,generative adversarial network) 两者不仅适用于图像,还可以 ...

  2. 变分推断到变分自编码器(VAE)

    EM算法 EM算法是含隐变量图模型的常用参数估计方法,通过迭代的方法来最大化边际似然. 带隐变量的贝叶斯网络 给定N 个训练样本D={x(n)},其对数似然函数为: 通过最大化整个训练集的对数边际似然 ...

  3. 再谈变分自编码器VAE:从贝叶斯观点出发

    链接:https://kexue.fm/archives/5343

  4. (转) 变分自编码器(Variational Autoencoder, VAE)通俗教程

    变分自编码器(Variational Autoencoder, VAE)通俗教程 转载自: http://www.dengfanxin.cn/?p=334&sukey=72885186ae5c ...

  5. 变分自编码器(Variational Autoencoder, VAE)通俗教程

    原文地址:http://www.dengfanxin.cn/?p=334 1. 神秘变量与数据集 现在有一个数据集DX(dataset, 也可以叫datapoints),每个数据也称为数据点.我们假定 ...

  6. VAE变分自编码器

    我在学习VAE的时候遇到了很多问题,很多博客写的不太好理解,因此将很多内容重新进行了整合. 我自己的学习路线是先学EM算法再看的变分推断,最后学VAE,自我感觉这个线路比较好理解. 一.首先我们来宏观 ...

  7. 基于图嵌入的高斯混合变分自编码器的深度聚类(Deep Clustering by Gaussian Mixture Variational Autoencoders with Graph Embedding, DGG)

    基于图嵌入的高斯混合变分自编码器的深度聚类 Deep Clustering by Gaussian Mixture Variational Autoencoders with Graph Embedd ...

  8. VAE变分自编码器公式推导

    VAE变分推导依赖数学公式 (1)贝叶斯公式:\(p(z|x) = \frac{p(x|z)p(z)}{p(x)}\) (2)边缘概率公式:\(p(x) =\int{p(x,z)}dz\) (3)KL ...

  9. 变分自编码器(Variational auto-encoder,VAE)

    参考: https://www.cnblogs.com/huangshiyu13/p/6209016.html https://zhuanlan.zhihu.com/p/25401928 https: ...

  10. 基于变分自编码器(VAE)利用重建概率的异常检测

    本文为博主翻译自:Jinwon的Variational Autoencoder based Anomaly Detection using Reconstruction Probability,如侵立 ...

随机推荐

  1. [网络/Linux]网络嗅探工具——nmap

    1 nmap 简介 Nmap 即 Network Mapper,最早是Linux下的网络扫描和嗅探工具包. nmap是网络扫描和主机检测的工具,用nmap进行信息收集和检测漏洞,功能有: 检测存活主机 ...

  2. IIS 部署.NET CORE 项目 出现 HTTP 错误 500.19 - Internal Server Error

    当出现这个错误时是因为服务器上没有.NET CORE对应的SDK以及运行时文件,我的.NET CORE版本是2.2,下载的就是2.2对应的文件. 附上.NET CORE2.2版本的下载链接 下载 .N ...

  3. 人群定向SQL表

    SET FOREIGN_KEY_CHECKS=0; -- ---------------------------- -- Table structure for rc_throng -- ------ ...

  4. 全新跨平台版本.NET敏捷开发框架-RDIFramework.NET5.0震撼发布

    RDIFramework.NET,基于全新.NET Framework与.NET Core的快速信息化系统敏捷开发.整合框架,给用户和开发者最佳的.Net框架部署方案.为企业快速构建跨平台.企业级的应 ...

  5. 2023成都.NET线下技术沙龙圆满结束

    2023年4月15日周六,由MASA技术团队和成都.NET俱乐部共同主办的2023年成都.NET线下技术沙龙活动在成都市世纪城新会展中心知域空间举行,共计报名人数90多人,实际到场60多人,13:30 ...

  6. 轻量化3D文件格式转换HOOPS Exchange新特性

    BIM与AEC市场发展现状 近年来BIM(建筑信息模型)和AEC(建筑.工程和施工)市场一直保持着持续增长.2014 年全球 BIM 软件市场价值 27.6 亿美元,而到 2022年,预期到达115. ...

  7. for of 和 for in 的区别

    1 var arr = ["f", "6", 3, "a", 7]; 2 var obj = { name: "shun" ...

  8. 机器视觉基本理论(opencv)

    1. 什么是图像采样 采样是按照某种时间间隔或空间间隔,将空间上连续的图像变换成离散点的操作称为图像采样 2. 什么是图像分变率 采样 得到的离散图像的尺寸称为图像分辨率.分辨率是数字图像可辨别的最小 ...

  9. MySQL事务和锁实战篇

    文章目录 MySQL事务和锁 事务 事务的控制语句 事务隔离级别设置 脏读 不可重复读 幻读 锁机制 InnoDB的行级锁 锁实战 死锁 总结 MySQL事务和锁 事务 说到关系型的数据库的事务,相信 ...

  10. values_list() 元组形式显示查询结果

    values_list() 元组形式显示查询结果 name,age为数据库的两个列 Student.objects.values_list('name','age') values_list() 元组 ...