从DDPM到DDIM(四) 预测噪声与后处理

前情回顾

下图展示了DDPM的双向马尔可夫模型。

训练目标。最大化证据下界等价于最小化以下损失函数:

\[\boldsymbol{\theta}^*=\underset{\boldsymbol{\theta}}{\operatorname{argmin}} \sum_{t=1}^T \frac{1}{2 \sigma^2(t)} \frac{\left(1-\alpha_t\right)^2 \overline{\alpha}_{t-1}}{\left(1-\overline{\alpha}_t\right)^2} \mathbb{E}_{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}\left[\Vert\tilde{\mathbf{x}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)-\mathbf{x}_0\Vert_2^2\right] \tag{1}
\]

推理过程。推理过程利用马尔可夫链蒙特卡罗方法。

\[\begin{aligned}
\mathbf{x}_{t-1} &\sim p_{\theta}\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}\right) = \mathcal{N}(\mathbf{x}_{t-1}; \tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right) , \sigma^2 \left(t\right) \mathbf{I}) \\
\mathbf{x}_{t-1} &= \tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right) + \sigma \left(t\right) \bm{\epsilon} \\
&= \frac{\left( 1 - \overline{\alpha}_{t-1} \right) \sqrt{\alpha_t}}{\left( 1 - \overline{\alpha}_{t} \right)} \mathbf{x}_{t} + \frac{\left(1 - \alpha_t\right) \sqrt{\overline{\alpha}_{t-1}}}{\left( 1 - \overline{\alpha}_{t} \right)} \tilde{\mathbf{x}}_{\theta} \left(\mathbf{x}_{t}, t\right) + \sigma \left(t\right) \bm{\epsilon}
\end{aligned} \tag{2}
\]

1、预测噪声

  上一篇文章我们提到,扩散模型的神经网络用于预测 \(\mathbf{x}_{0}\),然而DDPM并不是这样做的,而是用神经网络预测噪声。这也是DDPM 第一个字母 D(Denoising)的含义。为什么采用预测噪声的参数化方法?DDPM作者在原文中提到去噪分数匹配(denoising score matching, DSM),并说这样训练和DSM是等价的。可见应该是收了DSM的启发。另外一个解释我们一会来讲。

  按照上一篇文章的化简技巧,对于神经网络的预测输出 \(\tilde{\mathbf{x}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)\),也可以进行进一步参数化(parameterization):

已知:

\[\begin{aligned}
\mathbf{x}_{t} = \sqrt{\overline{\alpha}_t} \mathbf{x}_{0} + \sqrt{1 - \overline{\alpha}_t} \bm{\epsilon}
\end{aligned} \tag{3}
\]

于是:

\[\begin{aligned}
\mathbf{x}_{0} = \frac{1}{\sqrt{\overline{\alpha}_t}} \mathbf{x}_{t} + \frac{\sqrt{1 - \overline{\alpha}_t}}{\sqrt{\overline{\alpha}_t}} \bm{\epsilon}
\end{aligned} \tag{4}
\]
\[\begin{aligned}
\tilde{\mathbf{x}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right) = \frac{1}{\sqrt{\overline{\alpha}_t}} \mathbf{x}_{t} + \frac{\sqrt{1 - \overline{\alpha}_t}}{\sqrt{\overline{\alpha}_t}} \tilde{\bm{\epsilon}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)
\end{aligned} \tag{5}
\]

这里我们解释以下为什么采用预测噪声的方式的第二个原因。从(4)(5)两式可见,噪声项可以看作是 \(\mathbf{x}_{0}\) 与 \(\mathbf{x}_{t}\) 的残差项。回顾经典的Resnet结构:

\[\left[\mathbf{y}=\mathbf{x}+\mathcal{F}\left(\mathbf{x}, W_i\right)\right]
\]

Resnet也是用神经网络学习的残差项。DDPM采用预测噪声的方法和Resnet残差学习由异曲同工之妙。

  下面我们将(3)(4)两式代入(1)式,继续化简,有:

\[\begin{aligned}
\Vert\tilde{\mathbf{x}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)-\mathbf{x}_0\Vert_2^2 &= \frac{1 - \overline{\alpha}_t}{\overline{\alpha}_t} \Vert\tilde{\bm{\epsilon}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)-\bm{\epsilon}\Vert_2^2
\end{aligned}
\]

注意 \(\overline{\alpha}_t\) = \(\overline{\alpha}_{t-1} \alpha_t\)于是可以得出新的优化方程:

\[\boldsymbol{\theta}^*=\underset{\boldsymbol{\theta}}{\operatorname{argmin}} \sum_{t=1}^T \frac{1}{2 \sigma^2(t)} \frac{\left(1-\alpha_t\right)^2}{\left(1-\overline{\alpha}_t\right) \alpha}_t \mathbb{E}_{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}\left[\Vert\tilde{\bm{\epsilon}}_{\boldsymbol{\theta}}\left(\sqrt{\overline{\alpha}_t} \mathbf{x}_{0} + \sqrt{1 - \overline{\alpha}_t} \bm{\epsilon}, t\right)-\bm{\epsilon}\Vert_2^2\right] \tag{6}
\]

(6) 式表示,我们的神经网络 \(\tilde{\bm{\epsilon}}_{\boldsymbol{\theta}}\left(\sqrt{\overline{\alpha}_t} \mathbf{x}_{0} + \sqrt{1 - \overline{\alpha}_t} \bm{\epsilon}, t\right)\) 被用于预测最初始的噪声 \(\bm{\epsilon}\)。忽略掉前面的系数,对应的训练算法如下:


Algorithm 3 . Training a Deniosing Diffusion Probabilistic Model. (Version: Predict noise)

Repeat the following steps until convergence.

  • For every image \(\mathbf{x}_0\) in your training dataset \(\mathbf{x}_0 \sim q\left(\mathbf{x}_0\right)\)
  • Pick a random time step \(t \sim \text{Uniform}[1, T]\).
  • Generate normalized Gaussian random noise \(\bm{\epsilon} \sim \mathcal{N} \left(\mathbf{0}, \mathbf{I}\right)\)
  • Take gradient descent step on
\[\nabla_{\boldsymbol{\theta}} \Vert\tilde{\bm{\epsilon}}_{\boldsymbol{\theta}}\left(\sqrt{\overline{\alpha}_t} \mathbf{x}_{0} + \sqrt{1 - \overline{\alpha}_t} \bm{\epsilon}, t\right)-\bm{\epsilon}\Vert_2^2
\]

You can do this in batches, just like how you train any other neural networks. Note that, here, you are training one denoising network \(\tilde{\bm{\epsilon}}_{\boldsymbol{\theta}}\) for all noisy conditions.


推理的过程依然从马尔可夫链蒙特卡洛(MCMC)开始,因为这里是预测噪声,而推理的过程中也需要加噪声,为了区分,我们将推理过程中添加的噪声用 \(\mathbf{z} \sim \mathcal{N} \left(\mathbf{0}, \mathbf{I}\right)\) 来表示。推理过程中每次推理的噪声 \(\mathbf{z}\) 都是不同的,但训练过程中要拟合的最初的目标噪声 \(\bm{\epsilon}\) 是相同的

\[\begin{aligned}
\mathbf{x}_{t-1} &\sim p_{\theta}\left(\mathbf{x}_{t-1} | \mathbf{x}_{t}\right) = \mathcal{N}(\mathbf{x}_{t-1}; \tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right) , \sigma^2 \left(t\right) \mathbf{I}) \\
\mathbf{x}_{t-1} &= \tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right) + \sigma \left(t\right) \mathbf{z} \\
&= \frac{\left( 1 - \overline{\alpha}_{t-1} \right) \sqrt{\alpha_t}}{\left( 1 - \overline{\alpha}_{t} \right)} \mathbf{x}_{t} + \frac{\left(1 - \alpha_t\right) \sqrt{\overline{\alpha}_{t-1}}}{\left( 1 - \overline{\alpha}_{t} \right)} \tilde{\mathbf{x}}_{\theta} \left(\mathbf{x}_{t}, t\right) + \sigma \left(t\right) \mathbf{z}
\end{aligned} \tag{7}
\]

将(5)式代入:

\[\begin{aligned}
\tilde{\bm{\mu}}_{\theta}\left(\mathbf{x}_{t}, t\right) &= \frac{\left( 1 - \overline{\alpha}_{t-1} \right) \sqrt{\alpha_t}}{\left( 1 - \overline{\alpha}_{t} \right)} \mathbf{x}_{t} + \frac{\left(1 - \alpha_t\right) \sqrt{\overline{\alpha}_{t-1}}}{\left( 1 - \overline{\alpha}_{t} \right)} \tilde{\mathbf{x}}_{\theta} \left(\mathbf{x}_{t}, t\right) \\
&= \frac{\left( 1 - \overline{\alpha}_{t-1} \right) \sqrt{\alpha_t}}{\left( 1 - \overline{\alpha}_{t} \right)} \mathbf{x}_{t} + \frac{\left(1 - \alpha_t\right) \sqrt{\overline{\alpha}_{t-1}}}{\left( 1 - \overline{\alpha}_{t} \right)} \left( \frac{1}{\sqrt{\overline{\alpha}_t}} \mathbf{x}_{t} + \frac{\sqrt{1 - \overline{\alpha}_t}}{\sqrt{\overline{\alpha}_t}} \tilde{\bm{\epsilon}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right) \right) \\
&= \text{some algebra calculation} \\
&= \frac{1}{\sqrt{\overline{\alpha}_t}} \mathbf{x}_{t} + \frac{1 - \alpha_t}{ \sqrt{ \left( 1 - \overline{\alpha}_{t} \right)\alpha}_t} \tilde{\bm{\epsilon}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)
\end{aligned}
\]

所以推理的表达式为:

\[\begin{aligned}
\mathbf{x}_{t-1} &= \frac{1}{\sqrt{\overline{\alpha}_t}} \mathbf{x}_{t} + \frac{1 - \alpha_t}{ \sqrt{ \left( 1 - \overline{\alpha}_{t} \right)\alpha}_t} \tilde{\bm{\epsilon}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right) + \sigma \left(t\right) \mathbf{z}
\end{aligned} \tag{7}
\]

下面可以写出采用拟合噪声策略的推理算法:


Algorithm 4 . Inference on a Deniosing Diffusion Probabilistic Model. (Version: Predict noise)

You give us a white noise vector \(\mathbf{x}_T \sim \mathcal{N} \left(\mathbf{0}, \mathbf{I}\right)\)

Repeat the following for \(t = T, T − 1, ... , 1\).

  • Generate \(\mathbf{z} \sim \mathcal{N} \left(\mathbf{0}, \mathbf{I}\right)\) if \(t > 1\) else \(\mathbf{z} = \mathbf{0}\)
\[\mathbf{x}_{t-1} = \frac{1}{\sqrt{\overline{\alpha}_t}} \mathbf{x}_{t} + \frac{1 - \alpha_t}{ \sqrt{ \left( 1 - \overline{\alpha}_{t} \right)\alpha}_t} \tilde{\bm{\epsilon}}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right) + \sigma \left(t\right) \mathbf{z}
\]

Return \(\mathbf{x}_{0}\)


2、后处理

首先要注意到,在推理算法的最后一步,生成图像的时候,并没有添加噪声,而是直接采用预测的均值作为 \(\mathcal{x}_0\) 的估计值。

另外,生成的图像原本是归一化到 \([-1, 1]\) 之间的,所以要反归一化到 \([0, 255]\)。这里比较简单,直接看 diffusers 库中的代码:


image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image) if not return_dict:
return (image,) def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images] return pil_images

3、总结

  我们最初的目标是估计图像的概率分布,采用极大似然估计法,求 \(\log p\left(\mathbf{x}_0\right)\)。但是直接求解,很难求:

\[\begin{aligned}
p\left(\mathbf{x}_0\right) = \int p\left(\mathbf{x}_{0:T}\right) d \mathbf{x}_{1:T} \\
\end{aligned} \\
\]

  而且 \(p\left(\mathbf{x}_{0:T}\right)\) 也不知道。于是我们选择估计它的证据下界。在计算证据下界的过程中,我们解析了双向马尔可夫链中的很多分布和变量,最终推导出证据下界的表达式,以KL散度的方式来表示。这样做本质上是用已知的分布 \(q\left(\mathbf{x}_{1:T} | \mathbf{x}_{0}\right)\) 来对未知的分布做逼近。这其实是 变分推断 的思想。变分法是寻找一个函数使得这个函数最能满足条件,而变分推断是寻找一个分布使之更加逼近已知的分布。

  于是我们而在高斯分布的假设下,KL散度恰好等价于二范数的平方。最大似然估计等价于最小化二范数loss。之后就顺理成章地推导出了训练方法,并根据马尔可夫链蒙特卡洛推导出推理算法。关于变分推断和马尔可夫链蒙特卡洛相关的知识,读者可以自行查找,有时间我也会写篇文章来介绍。

  以上就是DDPM的全部内容了,我用了四篇文章对DDPM进行了详细推导,写文章的过程中也弄懂了自己之前不懂的一些细节。我的最大的感受是,初学者千万不要相信诸如《一文读懂DDPM》之类的文章,如果要真正搞懂DDPM,只有自己把所有公式手推一边才是正道。

下一篇我们开始介绍DDPM的一个经典的推理加速方法:DDIM

从DDPM到DDIM(四) 预测噪声与后处理的更多相关文章

  1. 一文详解扩散模型:DDPM

    作者:京东零售 刘岩 扩散模型讲解 前沿 人工智能生成内容(AI Generated Content,AIGC)近年来成为了非常前沿的一个研究方向,生成模型目前有四个流派,分别是生成对抗网络(Gene ...

  2. 目标跟踪之卡尔曼滤波---理解Kalman滤波的使用预测

    Kalman滤波简介 Kalman滤波是一种线性滤波与预测方法,原文为:A New Approach to Linear Filtering and Prediction Problems.文章推导很 ...

  3. 【Deep Learning】DDPM

    DDPM 1. 大致流程 1.1 宏观流程 1.2 训练过程 1.3 推理过程 2. 对比GAN 2.1 GAN流程 2.2 相比GAN优点 训练过程更稳定,损失函数指向性更强(loss数值大小指示训 ...

  4. 【滤波】标量Kalman滤波的过程分析和证明及C实现

    摘要: 标量Kalman滤波的过程分析和证明及C实现,希望能够帮助入门的小白,同时得到各位高手的指教.并不涉及其他Kalman滤波方法. 本文主要参考自<A Introduction to th ...

  5. 理解Kalman滤波的使用

    Kalman滤波简介 Kalman滤波是一种线性滤波与预测方法,原文为:A New Approach to Linear Filtering and Prediction Problems.文章推导很 ...

  6. paper 2:图像处理常用的Matlab函数汇总

    一 图像的读写 1 imread imread函数用于读入各种图像文件,如:a=imread('e:\w01.tif') 注:计算机E盘上要有w01相应的.tif文件. 2 imwrite imwri ...

  7. matlab中图像处理常见用法

    一. 读写图像文件 1. imread imread函数用于读入各种图像文件,如:a=imread('e:/w01.tif') 注:计算机E盘上要有w01相应的.tif文件. 2. imwrite i ...

  8. matlab 对图像操作的函数概览

    转自博客:http://blog.163.com/fei_lai_feng/blog/static/9289962200991713415422/ 一. 读写图像文件 1. imread imread ...

  9. 通俗理解kalman filter原理

    [哲学思想]即使我们对真相(真值)一无所知,我们任然可以通过研究事物规律,历史信息,当前观测而能尽可能靠近真相(真值). [线性预测模型]温度的变化是线性规律的,已知房间温度真值每小时上升1度左右(用 ...

  10. 【Unity Shader】2D动态云彩

    写在前面 赶在年前写一篇文章.之前翻看2015年的SIGGRAPH Course(关于渲染的可以去selfshadow的博客里找到,很全)的时候看到了关于体积云的渲染.这个课程讲述了开发者为游戏< ...

随机推荐

  1. linux系统重要文件和目录说明

    系统信息相关文件 /etc/issue 记录操作系统版本 head /etc/issue /proc/cpuinfo 记录cpu信息 cat /proc/cpuinfo /proc/meminfo 记 ...

  2. navicat 如何调整查询区域字体大小

    Navicat是一套快速.可靠和全面的数据库管理工具,专门用于简化数据库管理和降低管理成本.Navicat图形界面直观,提供简便的管理方法,设计和操作MySQL.MariaDB.SQL Server. ...

  3. foxy rviz2 "rviz_common/Time"报错问题

    报错内容 The class required for this panel, 'rviz_common/Time', could not be loaded. Error: According to ...

  4. 鸿蒙HarmonyOS实战-ArkTS语言基础类库(容器类库)

    前言 容器类库是指一组用于存储和管理数据的数据结构和算法.它们提供了各种不同类型的容器,如数组.链表.树.图等,以及相关的操作和功能,如查找.插入.删除.排序等. 容器类库还可以包含其他数据结构和算法 ...

  5. lxl学长讲课笔记

    lxl 学长讲课笔记 常数种可能性的状态 通过预先处理多种状态的信息,从而快速的转换状态. 经典操作:flip. 分析信息的思路 利用线段树 利用线段树的时候,如何合并两个分支区间的信息,我们需要有如 ...

  6. oppo、一加 android14 chrome116内核 input @click不触发

    // 兼容Chrome内核116及以上版本中配置disabled的input组件无法触发并冒泡click事件 .uni-input-input:disabled { pointer-events: n ...

  7. (四)Redis 缓存应用、淘汰机制

    1.缓存应用 一个系统中不同层面数据访问速度不一样,以计算机为例,CPU.内存和磁盘这三层的访问速度从几十 ns 到 100ns,再到几 ms,性能的差异很大,如果每次 CPU 处理数据时都要到磁盘读 ...

  8. 《Javscript实用教程》目录

    图书购买地址: 京东:<Javscript实用教程> 当当:<Javscript实用教程> 天猫:<Javscript实用教程> 注:本书提供源码和ppt课件,下载 ...

  9. CentOS8 安装ansible

    # 安装epel扩展源 yum install https://dl.fedoraproject.org/pub/epel/epel-release-latest-8.noarch.rpm -y # ...

  10. poj1338 ugly number 题解 打表

    类似的题目有HDU1058 humble number(翻译下来都是丑陋的数字). Description Ugly numbers are numbers whose only prime fact ...