论文原文:Auto-Encoding Variational Bayes [OpenReview (ICLR 2014) | arXiv]

本文记录了我在学习 VAE 过程中的一些公式推导和思考。如果你希望从头开始学习 VAE,建议先看一下苏剑林的博客(本文末尾有链接)。

VAE 的整体框架

VAE 认为,随机变量 \(\boldsymbol{x} \sim p(\boldsymbol{x})\) 由两个随机过程得到:

  1. 根据先验分布 \(p(\boldsymbol{z})\) 生成隐变量 \(\boldsymbol{z}\)。
  2. 根据条件分布 \(p(\boldsymbol{x} | \boldsymbol{z})\) 由 \(\boldsymbol{z}\) 得到 \(\boldsymbol{x}\)。

于是 \(p(\boldsymbol{x}, \boldsymbol{z}) = p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})\) 就是我们所需要的生成模型。

一种朴素的想法是:先用随机数生成器生成隐变量 \(\boldsymbol{z}\),然后用 \(p(\boldsymbol{x} | \boldsymbol{z})\) 从 \(\boldsymbol{z}\) 中生成出(或者说重构出) \(\boldsymbol{x}\),通过最小化重构损失来训练模型。这个想法的问题在于:我们无法找到生成的样本与原始样本之间的对应关系,重构损失算不了,无法训练。

VAE 的做法是引入后验分布 \(p(\boldsymbol{z} | \boldsymbol{x})\),训练过程变为:

  1. 采样一批原始样本 \(\boldsymbol{x}\)。
  2. 用 \(p(\boldsymbol{z} | \boldsymbol{x})\) 获得每个样本 \(\boldsymbol{x}\) 对应的隐变量 \(\boldsymbol{z}\)。
  3. 用 \(p(\boldsymbol{x} | \boldsymbol{z})\) 从隐变量 \(\boldsymbol{z}\) 中重构出 \(\boldsymbol{x}\),通过最小化重构损失来训练模型。

从这个角度来看,\(p(\boldsymbol{z} | \boldsymbol{x})\) 相当于编码器,\(p(\boldsymbol{x} | \boldsymbol{z})\) 相当于解码器,训练结束后只需要保留解码器 \(p(\boldsymbol{x} | \boldsymbol{z})\) 即可。

除了重构损失以外,VAE 还有一项 KL 散度损失,希望近似的后验分布 \(q(\boldsymbol{z} | \boldsymbol{x})\) 尽量接近先验分布 \(p(\boldsymbol{z})\),即最小化二者的 KL 散度。

变分下界的推导

现有 \(N\) 个由分布 \(P(\boldsymbol{x}; \boldsymbol{\theta})\) 生成的样本 \(\boldsymbol{x}^{(1)}, \ldots, \boldsymbol{x}^{(N)}\),我们可以使用极大似然估计从这些样本中估计出分布的参数 \(\boldsymbol{\theta}\),即

\[\begin{aligned}
\boldsymbol{\theta}
& = \operatorname*{argmax}_{\boldsymbol{\theta}} p(\boldsymbol{x}^{(1)}; \boldsymbol{\theta}) \cdots p(\boldsymbol{x}^{(N)}; \boldsymbol{\theta}) \\
& = \operatorname*{argmax}_{\boldsymbol{\theta}} \ln(p(\boldsymbol{x}^{(1)}; \boldsymbol{\theta}) \cdots p(\boldsymbol{x}^{(N)}; \boldsymbol{\theta})) \\
& = \operatorname*{argmax}_{\boldsymbol{\theta}} \sum_{i=1}^n \ln p(\boldsymbol{x}^{(i)}; \boldsymbol{\theta}).
\end{aligned}
\]

后验分布 \(p(\boldsymbol{z} | \boldsymbol{x}) = \frac{p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})}{p(\boldsymbol{x})} = \frac{p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})}{\int_{\boldsymbol{z}} p(\boldsymbol{x}, \boldsymbol{z}) \mathrm{d}\boldsymbol{z}}\) 是 intractable 的,因为分母处的边缘分布 \(p(\boldsymbol{x})\) 积不出来。具体来说,联合分布 \(p(\boldsymbol{x}, \boldsymbol{z}) = p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})\) 的表达式非常复杂,\(\int_{\boldsymbol{z}} p(\boldsymbol{x}, \boldsymbol{z}) \mathrm{d}\boldsymbol{z}\) 这个积分找不到解析解。

需要使用变分推断解决后验分布无法计算的问题。我们使用一个形式已知的分布 \(q(\boldsymbol{z}|\boldsymbol{x}^{(i)}; \boldsymbol{\phi})\) 来近似后验分布 \(p(\boldsymbol{z}|\boldsymbol{x}^{(i)}; \boldsymbol{\theta})\),于是有

\[\begin{aligned}
\log p(\boldsymbol{x}^{(i)})
& = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) - \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \log p(\boldsymbol{x}^{(i)}) \cdot 1 \\
& = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\log\frac{q(\boldsymbol{z}|\boldsymbol{x}^{(i)})}{p(\boldsymbol{z}|\boldsymbol{x}^{(i)})} \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \log p(\boldsymbol{x}^{(i)}) \cdot \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\mathrm{d}\boldsymbol{z} \\
& = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\log p(\boldsymbol{x}^{(i)}) \mathrm{d}\boldsymbol{z} \\
& = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} \\
& = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log (p(\boldsymbol{z}|\boldsymbol{x}^{(i)})p(\boldsymbol{x}^{(i)}))] \mathrm{d}\boldsymbol{z} \\
& = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)}, \boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\
& = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \mathbb{E}_{\boldsymbol{z} \sim q(\boldsymbol{z}|\boldsymbol{x}^{(i)})}[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)}, \boldsymbol{z})] \\
& = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)}) \\
& \geq L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)}).
\end{aligned}
\]

利用 KL 散度大于等于 0 这一特性,我们得到了对数似然 \(\log p(\boldsymbol{x}^{(i)})\) 的一个下界 \(L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)})\),于是可以将最大化对数似然改为最大化这个下界。

这个下界可以进一步写成

\[\begin{aligned}
L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)})
& = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)}, \boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\
& = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log (p(\boldsymbol{z})p(\boldsymbol{x}^{(i)}|\boldsymbol{z}))] \mathrm{d}\boldsymbol{z} \\
& = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}) + \log p(\boldsymbol{x}^{(i)}|\boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\
& = -\int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) - \log p(\boldsymbol{z})] \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\log p(\boldsymbol{x}^{(i)}|\boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\
& = -\mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z})] + \mathbb{E}_{\boldsymbol{z} \sim q(\boldsymbol{z}|\boldsymbol{x}^{(i)})}[\log p(\boldsymbol{x}^{(i)}|\boldsymbol{z})]. \\
\end{aligned}
\]

其中的第一项是 KL 散度损失,第二项是重构损失。

KL 散度损失

使用标准正态分布作为先验分布,即 \(p(\boldsymbol{z}) = N(\boldsymbol{z}; \boldsymbol{0}, \boldsymbol{I})\)。

使用一个由 MLP 的输出来参数化的正态分布作为近似后验分布,即 \(q(\boldsymbol{z}|\boldsymbol{x}^{(i)}; \boldsymbol{\phi}) = N(\boldsymbol{z}; \boldsymbol{\mu}(\boldsymbol{x}^{(i)}; \boldsymbol{\phi}), \boldsymbol{\sigma}^2(\boldsymbol{x}^{(i)}; \boldsymbol{\phi})\boldsymbol{I})\)。

选择正态分布的好处在于 KL 散度的这个积分可以写出解析解,训练时直接按照公式计算即可,无需通过采样的方式来算积分。

由于我们选择的是各分量独立的多元正态分布,因此只需要推导一元正态分布的情形即可:

\[\begin{aligned}
\mathrm{KL}[N(z; \mu, \sigma^2), N(z; 0, 1)]
& = \int_z N(z; \mu, \sigma^2)\log\frac{N(z; \mu, \sigma^2)}{N(z; 0, 1)} \mathrm{d}z \\
& = \int_z N(z; \mu, \sigma^2) \log\frac{\frac{1}{\sqrt{2\pi}\sigma}\exp\left(-\frac{(z - \mu)^2}{2\sigma^2}\right)}{\frac{1}{\sqrt{2\pi}}\exp\left(-\frac{z^2}{2}\right)} \mathrm{d}z \\
& = \int_z N(z; \mu, \sigma^2) \log\left(\frac{1}{\sqrt{\sigma^2}}\exp\left(\frac{1}{2}\left(-\frac{(z - \mu^2)^2}{\sigma^2} + z^2\right)\right)\right) \mathrm{d}z \\
& = \frac{1}{2}\int_z N(z; \mu, \sigma^2) \left(-\log\sigma^2 - \frac{(z - \mu)^2}{\sigma^2} + z^2\right)\mathrm{d}z \\
& = \frac{1}{2}\left(-\log\sigma^2\int_z N(z; \mu, \sigma^2) \mathrm{d}z - \frac{1}{\sigma^2}\int_z N(z; \mu, \sigma^2)(z - \mu)^2\mathrm{d}z + \int_z N(z; \mu, \sigma^2)z^2\mathrm{d}z\right) \\
& = \frac{1}{2}\left(-\log\sigma^2 \cdot 1 - \frac{1}{\sigma^2} \cdot \sigma^2 + \mu^2 + \sigma^2\right) \\
& = \frac{1}{2}(-\log\sigma^2 - 1 + \mu^2 + \sigma^2).
\end{aligned}
\]

解释一下倒数第三行的三个积分:

  1. \(\int_z N(z; \mu, \sigma^2) \mathrm{d}z\) 是概率密度函数的积分,也就是 1。
  2. \(\int_z N(z; \mu, \sigma^2)(z - \mu)^2\mathrm{d}z\) 是方差的定义,也就是 \(\sigma^2\)。
  3. \(\int_z N(z; \mu, \sigma^2)z^2\mathrm{d}z\) 是正态分布的二阶矩,结果为 \(\mu^2 + \sigma^2\)。

重构损失

伯努利分布模型

当 \(\boldsymbol{x}\) 是二值向量时,可以用伯努利分布(两点分布)来建模 \(p(\boldsymbol{x}|\boldsymbol{z})\),即认为向量 \(\boldsymbol{x}\) 的每个维度都服从对应的相互独立的伯努利分布。使用一个 MLP 来计算各维度所对应的伯努利分布的参数,第 \(i\) 维伯努利分布的参数为 \(y_i = \boldsymbol{y}(\boldsymbol{z})_i\),于是有

\[p(\boldsymbol{x}|\boldsymbol{z}) = \prod_{i=1}^D y_i^{x_i}(1 - y_i)^{1 - x_i},
\]
\[\log p(\boldsymbol{x}|\boldsymbol{z}) = \sum_{i=1}^D x_i\log y_i + (1 - x_i)\log(1 - y_i).
\]

其中 \(D\) 表示向量 \(\boldsymbol{x}\) 的维度。可见此时最大化 \(\log p(\boldsymbol{x}|\boldsymbol{z})\) 等价于最小化交叉熵损失。

正态分布模型

当 \(\boldsymbol{x}\) 是实值向量时,可以用正态分布来建模 \(p(\boldsymbol{x}|\boldsymbol{z})\)。使用一个 MLP 来计算正态分布的参数,于是有

\[\begin{aligned}
p(\boldsymbol{x}|\boldsymbol{z})
& = N(\boldsymbol{x}; \boldsymbol{\mu}, \boldsymbol{\sigma}^2\boldsymbol{I}) \\
& = \prod_{i=1}^D N(x_i; \mu_i, \sigma_i^2) \\
& = \left(\prod_{i=1}^D\frac{1}{\sqrt{2\pi}\sigma_i}\right)\exp\left(\sum_{i=1}^D-\frac{(x_i - \mu_i)^2}{2\sigma_i^2}\right),
\end{aligned}
\]
\[\log p(\boldsymbol{x}|\boldsymbol{z}) = -\frac{D}{2}\log 2\pi - \frac{1}{2}\sum_{i=1}^D\log\sigma_i^2 - \frac{1}{2}\sum_{i=1}^D\frac{(x_i - \mu_i)^2}{\sigma_i^2}.
\]

很多时候我们会假设 \(\sigma_i^2\) 是一个常数,于是 MLP 只需要输出均值参数 \(\boldsymbol{\mu}\) 即可。此时有

\[\log p(\boldsymbol{x}|\boldsymbol{z}) \sim -\frac{1}{2}\sum_{i=1}^D(x_i - \mu_i)^2 = -\frac{1}{2}\|\boldsymbol{x} - \boldsymbol{\mu}(\boldsymbol{z})\|^2.
\]

可见此时最大化 \(\log p(\boldsymbol{x}|\boldsymbol{z})\) 等价于最小化 MSE 损失。

重参数化技巧

需要使用重参数化技巧解决采样 \(z\) 时不可导的问题。解决的思路是先从无参数分布中采样一个 \(\varepsilon\),再通过变换得到 \(z\)。

从 \(N(\mu, \sigma^2)\) 中采样一个 \(z\),相当于先从 \(N(0, 1)\) 中采样一个 \(\varepsilon\),然后令 \(z = \mu + \varepsilon\cdot\sigma\)。

相关知识

技巧,通过取对数把乘除变成加减:

\[\ln ab = \ln a + \ln b,\ \ln\frac{a}{b} = \ln a - \ln b.
\]

随机变量的函数的期望:

\[\mathbb{E}_{x \sim P(x)} g(x) = \int_x p(x)g(x) \mathrm{d}x,
\]

利用此公式可以将积分改写成期望的形式,这样就可以用采样的方式计算积分了(蒙特卡罗积分法)。

条件概率密度的定义:

\[p_{Y|X}(y|x) = \frac{p(x, y)}{p_X(x)},
\]

此处的 \(p\) 并不是概率而是概率密度函数,但是这个公式在形式上跟条件概率公式是一样的。

参考资料

苏剑林的 VAE 系列博客:

15 分钟了解变分推理:

变分自编码器(VAE)公式推导的更多相关文章

  1. 4.keras实现-->生成式深度学习之用变分自编码器VAE生成图像(mnist数据集和名人头像数据集)

    变分自编码器(VAE,variatinal autoencoder)   VS    生成式对抗网络(GAN,generative adversarial network) 两者不仅适用于图像,还可以 ...

  2. 变分推断到变分自编码器(VAE)

    EM算法 EM算法是含隐变量图模型的常用参数估计方法,通过迭代的方法来最大化边际似然. 带隐变量的贝叶斯网络 给定N 个训练样本D={x(n)},其对数似然函数为: 通过最大化整个训练集的对数边际似然 ...

  3. 再谈变分自编码器VAE:从贝叶斯观点出发

    链接:https://kexue.fm/archives/5343

  4. (转) 变分自编码器(Variational Autoencoder, VAE)通俗教程

    变分自编码器(Variational Autoencoder, VAE)通俗教程 转载自: http://www.dengfanxin.cn/?p=334&sukey=72885186ae5c ...

  5. 变分自编码器(Variational Autoencoder, VAE)通俗教程

    原文地址:http://www.dengfanxin.cn/?p=334 1. 神秘变量与数据集 现在有一个数据集DX(dataset, 也可以叫datapoints),每个数据也称为数据点.我们假定 ...

  6. VAE变分自编码器

    我在学习VAE的时候遇到了很多问题,很多博客写的不太好理解,因此将很多内容重新进行了整合. 我自己的学习路线是先学EM算法再看的变分推断,最后学VAE,自我感觉这个线路比较好理解. 一.首先我们来宏观 ...

  7. 基于图嵌入的高斯混合变分自编码器的深度聚类(Deep Clustering by Gaussian Mixture Variational Autoencoders with Graph Embedding, DGG)

    基于图嵌入的高斯混合变分自编码器的深度聚类 Deep Clustering by Gaussian Mixture Variational Autoencoders with Graph Embedd ...

  8. VAE变分自编码器公式推导

    VAE变分推导依赖数学公式 (1)贝叶斯公式:\(p(z|x) = \frac{p(x|z)p(z)}{p(x)}\) (2)边缘概率公式:\(p(x) =\int{p(x,z)}dz\) (3)KL ...

  9. 变分自编码器(Variational auto-encoder,VAE)

    参考: https://www.cnblogs.com/huangshiyu13/p/6209016.html https://zhuanlan.zhihu.com/p/25401928 https: ...

  10. 基于变分自编码器(VAE)利用重建概率的异常检测

    本文为博主翻译自:Jinwon的Variational Autoencoder based Anomaly Detection using Reconstruction Probability,如侵立 ...

随机推荐

  1. python类详解

    类封装 继承 多态一静态属性1.静态变量和静态方法都属于类的静态成员,它们与普通的成员变量和成员方法不同,静态变量和静态方法只属于定义它们的类,而不属于某一个对象.2.静态变量和静态方法都可以通过类名 ...

  2. 【Vue】Vuex

    Vuex简介 概念: 专门在Vue中实现集中式状态(数据)管理的一个Vue插件,对vue应用中多个组件的共享状态进行集中管理(读.写),也是一种适用于任意组件间的通信方式. 什么时候用Vuex ①多个 ...

  3. Kurator v0.3.0版本发布

    摘要:2023年4月8日,Kurator正式发布v0.3.0版本. 本文分享自华为云社区<华为云 Kurator v0.3.0 版本发布!集群舰队助力分布式云统一管理>,作者:云容器大未来 ...

  4. C# 一个List 分成多个List

    /// <summary>        /// 一个List拆分多个List        /// </summary>        /// <param name= ...

  5. css实现水平垂直居中的几种方法

    一,已知宽高 1 <style> 2 #box { 3 height: 400px; 4 width: 400px; 5 border: 1px solid grey; 6 positio ...

  6. Linux 内存管理 pt.1

    哈喽大家好,我是咸鱼 今天我们来学习一下 Linux 操作系统核心之一:内存 跟 CPU 一样,内存也是操作系统最核心的功能之一,内存主要用来存储系统和程序的指令.数据.缓存等 关于内存的学习,我会尽 ...

  7. Ffmpeg 视频压缩的几个关键参数

    Ffmpeg的视频操作官网文档:https://ffmpeg.org/ffmpeg-filters.html#Video-Filters 视频压缩用到的参数主要为以下几个: 文件路径:-i 输入文件的 ...

  8. Python 函数及参数的使用

    函数 带名字的代码块,用于完成具体的工作 关键字def定义一个函数,定义函数名,括号内是需要完成任务所需要的信息,最后定义冒号结尾 缩进构成函数体 函数调用,依次指定函数名以及冒号括起来的必要信息 d ...

  9. CF1477F Nezzar and Chocolate Bars 题解

    题意: 有一根长为 \(1\) 的巧克力,已经被切了 \(m-1\) 刀被分成 \(m\) 分,接下来每次在整根长度为 \(1\) 的巧克力上均匀随机一个点切一刀,求每一小段巧克力长度均小于一个给定值 ...

  10. Azure Devops上模版化K8s部署

    在2022年我们终于完成了主要业务系统上K8s的计划,在这里总结下我们上K8s时候的模版工程. 前提条件 本文不讨论K8s是什么,什么是容器化,为什么需要容器化,什么是微服务等这些基础内容,这些到处说 ...