DENOISING DIFFUSION IMPLICIT MODELS (DDIM)

从DDPM中我们知道,其扩散过程(前向过程、或加噪过程)被定义为一个马尔可夫过程,其去噪过程(也有叫逆向过程)也是一个马尔可夫过程。对马尔可夫假设的依赖,导致重建每一步都需要依赖上一步的状态,所以推理需要较多的步长。

\[q(x_t|x_{t-1}) := \mathcal{N}(x_t;\sqrt{\alpha_t}x_{t-1},{1-\alpha_t}I) \\
q(x_t|x_{0}) := \mathcal{N}(x_t;\sqrt{\bar{\alpha}_t}x_{0},{(1-\bar{\alpha}_t})I)
\]
\[\begin{align*}
q(x_{t-1}|x_t,x_0)
&\overset{Bayes}{=} \dfrac{q(x_t|x_{t-1},x_0)q(x_{t-1}|x_0)}{q(x_t|x_0)} \\
&\overset{Markov}{=} \dfrac{q(x_t|x_{t-1})q(x_{t-1}|x_0)}{q(x_t|x_0)}
\end{align*}
\]

DDPM中对于其逆向分布的建模使用马尔可夫假设,这样做的目的是将式子中的未知项 \(q(x_t|x_{t-1},x_0)\),转化成了已知项 \(q(x_t|x_{t-1})\),最后求出 \(q(x_{t-1}|x_t,x_0)\) 的分布也是一个高斯分布 \(\mathcal{N}(x_{t-1};\mu_q(x_t,x_0),\Sigma_q(t))\)。

从DDPM的结论出发,我们不妨直接假设 \(q(x_{t-1}|x_t,x_0)\) 的分布为高斯分布,在不使用马尔可夫假设的情况下,尝试求解 \(q(x_{t-1}|x_t,x_0)\) 。

由 DDPM 中 \(q(x_{t-1}|x_t,x_0)\) 的分布 \(\mathcal{N}(x_{t-1};\mu_q(x_t,x_0),\Sigma_q(t))\) 可知,均值为 一个关于 \(x_t,x_0\) 的函数,方差为一个关于 \(t\) 的函数。

我们可以把 \(q(x_{t-1}|x_t,x_0)\) 设计成如下分布:

\[q(x_{t-1}|x_t,x_0) := \mathcal{N}(x_{t-1}; a x_0 + b x_t,\sigma_t^2 I)
\]

这样,只要求解出 \(a,b,\sigma_t\) 这三个待定系数,即可确定 \(q(x_{t-1}|x_t,x_0)\) 的分布。

重参数化 \(q(x_{t-1}|x_t,x_0)\) :

\[x_{t-1}=a x_0 + b x_t + \sigma_t \varepsilon^{\prime}_{t-1}
\]

假设训练模型时输入噪声图片的加噪参数与DDPM完全一致

由 \(q(x_t|x_{0}) := \mathcal{N}(x_t;\sqrt{\bar{\alpha}_t}x_{0},(1-\bar{\alpha}_t)I)\) :

\[x_t=\sqrt{\bar{\alpha}_t}x_{0}+\sqrt{1-\bar{\alpha}_t}\varepsilon^{\prime}_{t}
\]

代入 \(x_t\) 有:

\[\begin{align*}
x_{t-1} &=a x_0 + b(\sqrt{\bar{\alpha}_t}x_{0}+\sqrt{1-\bar{\alpha}_t}\varepsilon^{\prime}_{t}) + \sigma_t \varepsilon^{\prime}_{t-1} \\
&= (a + b\sqrt{\bar{\alpha}_t}) x_0 + (b\sqrt{1-\bar{\alpha}_t}\varepsilon^{\prime}_{t} + \sigma_t \varepsilon^{\prime}_{t-1}) \\
&= (a + b\sqrt{\bar{\alpha}_t}) x_0 + (\sqrt{b^2(1-\bar{\alpha}_t)+ \sigma_t^2}) \bar{\varepsilon}_{t-1}
\end{align*}

\]

又:

\[x_{t-1}=\sqrt{\bar{\alpha}_{t-1}} x_0 + \sqrt{1-\bar{\alpha}_{t-1}} \varepsilon^{\prime}_{t-1}
\]

观察系数可以得到方程组:

\[\begin{cases}
a + b\sqrt{\bar{\alpha}_t} = \sqrt{\bar{\alpha}_{t-1}} \\
\sqrt{b^2(1-\bar{\alpha}_t)+ \sigma_t^2} = \sqrt{1-\bar{\alpha}_{t-1}}
\end{cases}
\]

三个未知数 两个方程,可以用 \(\sigma_t\) 表示 \(a,b\):

\[\begin{cases}
a = \sqrt{\bar{\alpha}_{t-1}} - \sqrt{\bar{\alpha}_t} \sqrt{\dfrac{1-\bar{\alpha}_{t-1}-\sigma_t^2}{1-\bar{\alpha}_t}} \\
b = \sqrt{\dfrac{1-\bar{\alpha}_{t-1}-\sigma_t^2}{1-\bar{\alpha}_t}}
\end{cases}
\]

\(a, b\) 代入 \(q(x_{t-1}|x_t,x_0) := \mathcal{N}(x_{t-1}; a x_0 + b x_t,\sigma_t^2 I)\)

\[q(x_{t-1}|x_t,x_0) := \mathcal{N}(x_{t-1}; \underbrace{ \left( \sqrt{\bar{\alpha}_{t-1}} - \sqrt{\bar{\alpha}_t} \sqrt{\dfrac{1-\bar{\alpha}_{t-1}-\sigma_t^2}{1-\bar{\alpha}_t}}\right ) x_0 + (\sqrt{\dfrac{1-\bar{\alpha}_{t-1}-\sigma_t^2}{1-\bar{\alpha}_t}}) x_t}_{\mu_q(x_t,x_0,t)},\sigma_t^2 I)
\]

\[x_t=\sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \bar{\varepsilon}_0 \\
x_0 = \dfrac{1}{\sqrt{\bar{\alpha}_t}}x_t - \dfrac{\sqrt{1-\bar{\alpha}_t}}{\sqrt{\bar{\alpha}_t}} \bar{\varepsilon}_0 \\
\]

代入 \(x_0\) 有:

\[\mu_q(x_t,x_0,t) = \sqrt{\bar{\alpha}_{t-1}} \dfrac{x_t-\sqrt{1-\bar{\alpha}_t} \bar{\varepsilon}_0}{\sqrt{\bar{\alpha}_{t}}} + \sqrt{1-\bar{\alpha}_{t-1}-\sigma_t^2} \bar{\varepsilon}_0 \\
\]
\[\begin{align*}
x_{t-1} &= \mu_q(x_t,x_0,t) + \sigma_t \varepsilon_0 \\
&= \sqrt{\bar{\alpha}_{t-1}} \underbrace{\dfrac{x_t-\sqrt{1-\bar{\alpha}_t} \bar{\varepsilon}_0}{\sqrt{\bar{\alpha}_{t}}}}_{预测的x_0} + \underbrace{\sqrt{1-\bar{\alpha}_{t-1}-\sigma_t^2} \bar{\varepsilon}_0}_{x_t的方向} + \underbrace{\sigma_t \varepsilon_0}_{随机噪声扰动}
\end{align*}

\]

通过观察 \(x_{t-1}\) 的分布,我们建模采样分布为高斯分布:

\[p_\theta(x_{t-1}|x_t):=\mathcal{N}(x_{t-1};\mu_\theta(x_t,t), \Sigma_\theta(x_t,t)I)
\]

并且均值和方差也采用相似的形式:

\[\begin{align*}
\mu_\theta(x_t,t) &= \sqrt{\bar{\alpha}_{t-1}} \dfrac{x_t-\sqrt{1-\bar{\alpha}_t} \epsilon_\theta(x_t,t) }{\sqrt{\bar{\alpha}_{t}}} + \sqrt{1-\bar{\alpha}_{t-1}-\sigma_t^2} \epsilon_\theta(x_t,t) \\
\Sigma_\theta(x_t,t) &= \sigma_t^2
\end{align*}
\]

其中 \(\epsilon_\theta(x_t,t)\) 为预测的噪声。

此时,确定优化目标只需要 \(q(x_{t-1}|x_t,x_0)\) 和 \(p_\theta(x_{t-1}|x_t)\) 两个分布尽可能相似,使用KL散度来度量,则有:

\[\begin{align*}
&\quad \ \underset{\theta}{argmin} D_{KL}(q(x_{t-1}|x_t,x_0)||p_\theta(x_{t-1}|x_t)) \\
&=\underset{\theta}{argmin} D_{KL}(\mathcal{N}(x_{t-1};\mu_q, \Sigma_q(t))||\mathcal{N}(x_{t-1};\mu_\theta, \Sigma_q(t))) \\
&=\underset{\theta}{argmin} \dfrac{1}{2} \left[ log\dfrac{|\Sigma_q(t)|}{|\Sigma_q(t)|} - k + tr(\Sigma_q(t)^{-1}\Sigma_q(t)) + (\mu_q-\mu_\theta)^T \Sigma_q(t)^{-1} (\mu_q-\mu_\theta) \right] \\
&=\underset{\theta}{argmin} \dfrac{1}{2} \left[ 0 - k + k + (\mu_q-\mu_\theta)^T (\sigma_t^2I)^{-1} (\mu_q-\mu_\theta) \right] \\
&\overset{内积公式A^TA}{=} \underset{\theta}{argmin} \dfrac{1}{2\sigma_t^2} \left[ ||\mu_q-\mu_\theta||_2^2 \right] \\
&\overset{代入\mu_q,\mu_\theta}{=} \underset{\theta}{argmin} \dfrac{1}{2\sigma_t^2} (\sqrt{1-\bar{\alpha}_{t-1}-\sigma_t^2} - \dfrac{\sqrt{\bar{\alpha}_{t-1}} \sqrt{1-\bar{\alpha}_t}}{\sqrt{\bar{\alpha}_t}}) \left[ ||\bar{\varepsilon}_0-\epsilon_\theta(x_t,t)||_2^2 \right]
\end{align*}
\]

恰好与DDPM的优化目标一致,所以我们可以直接复用DDPM训练好的模型。

\(p_{\theta}\) 的采样步骤则为:

\[x_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \underbrace{\dfrac{x_t-\sqrt{1-\bar{\alpha}_t} \epsilon_\theta(x_t,t)}{\sqrt{\bar{\alpha}_{t}}}}_{预测的x_0} + \underbrace{\sqrt{1-\bar{\alpha}_{t-1}-\sigma_t^2} \epsilon_\theta(x_t,t)}_{x_t的方向} + \underbrace{\sigma_t \varepsilon}_{随机噪声扰动}
\]

令 \(\sigma_t=\eta \sqrt{\dfrac{(1-{\alpha}_{t})(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_{t}}}\)

当 \(\eta =1\) 时,前向过程为 Markovian ,采样过程变为 DDPM 。

当 \(\eta =0\) 时,采样过程为确定过程,此时的模型 称为 隐概率模型(implicit probabilstic model)。

DDIM如何加速采样:

在 DDPM 中,基于马尔可夫链 \(t\) 与 \(t-1\) 是相邻关系,例如 \(t=100\) 则 \(t-1=99\);

在 DDIM 中,\(t\) 与 \(t-1\) 只表示前后关系,例如 \(t=100\) 时,\(t-1\) 可以是 90 也可以是 80、70,只需保证 \(t-1 < t\) 即可。

此时构建的采样子序列 \(\tau=[\tau_i,\tau_{i-1},\cdots,\tau_{1}] \ll [t,t-1,\cdots,1]\) 。

例如,原序列 \(\Tau=[100,99,98,\cdots,1]\),采样子序列为 \(\tau=[100,90,80,\cdots,1]\) 。

DDIM 采样公式为:

\[x_{\tau_{i-1}} = \sqrt{\bar{\alpha}_{\tau_{i-1}}} {\dfrac{x_{\tau_{i}}-\sqrt{1-\bar{\alpha}_{\tau_{i}}} \epsilon_\theta(x_{\tau_{i}},{\tau_{i}})}{\sqrt{\bar{\alpha}_{\tau_{i}}}}} + {\sqrt{1-\bar{\alpha}_{\tau_{i-1}}-\sigma_{\tau_{i}}^2} \epsilon_\theta(x_{\tau_{i}},{\tau_{i}})} + {\sigma_{\tau_{i}} \varepsilon}
\]

当 \(\eta= 0\) 时,DDIM 采样公式为:

\[ x_{\tau_{i-1}} = \dfrac{\sqrt{\bar{\alpha}_{\tau_{i-1}}}}{\sqrt{\bar{\alpha}_{\tau_{i}}}} x_{\tau_{i}} + \left( \sqrt{1-\bar{\alpha}_{\tau_{i-1}}} - \dfrac{\sqrt{\bar{\alpha}_{\tau_{i-1}}}}{\sqrt{\bar{\alpha}_{\tau_{i}}}} \sqrt{1-\bar{\alpha}_{\tau_{i}}} \right) \epsilon_\theta(x_{\tau_i},\tau_i)
\]

代码实现

训练过程与 DDPM 一致,代码参考上一篇文章。采样代码如下:

device = 'cuda'
torch.cuda.empty_cache()
model = Unet().to(device)
model.load_state_dict(torch.load('ddpm_T1000_l2_epochs_300.pth'))
model.eval() image_size=96
epochs = 500
batch_size = 128
T=1000
betas = torch.linspace(0.0001, 0.02, T).to('cuda') # torch.Size([1000]) # 每隔20采样一次
tau_index = list(reversed(range(0, T, 20))) #[980, 960, ..., 20, 0]
eta = 0.003 # train
alphas = 1 - betas # 0.9999 -> 0.98
alphas_cumprod = torch.cumprod(alphas, axis=0) # 0.9999 -> 0.0000
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1-alphas_cumprod) def get_val_by_index(val, t, x_shape):
batch_t = t.shape[0]
out = val.gather(-1, t)
return out.reshape(batch_t, *((1,) * (len(x_shape) - 1))) # torch.Size([batch_t, 1, 1, 1]) def p_sample_ddim(model):
def step_denoise(model, x_tau_i, tau_i, tau_i_1):
sqrt_alphas_bar_tau_i = get_val_by_index(sqrt_alphas_cumprod, tau_i, x_tau_i.shape)
sqrt_alphas_bar_tau_i_1 = get_val_by_index(sqrt_alphas_cumprod, tau_i_1, x_tau_i.shape) denoise = model(x_tau_i, tau_i) if eta == 0:
sqrt_1_minus_alphas_bar_tau_i = get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i, x_tau_i.shape)
sqrt_1_minus_alphas_bar_tau_i_1 = get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i_1, x_tau_i.shape)
x_tau_i_1 = sqrt_alphas_bar_tau_i_1 / sqrt_alphas_bar_tau_i * x_tau_i \
+ (sqrt_1_minus_alphas_bar_tau_i_1 - sqrt_alphas_bar_tau_i_1 / sqrt_alphas_bar_tau_i * sqrt_1_minus_alphas_bar_tau_i) \
* denoise
return x_tau_i_1 sigma = eta * torch.sqrt((1-get_val_by_index(alphas, tau_i, x_tau_i.shape)) * \
(1-get_val_by_index(sqrt_alphas_cumprod, tau_i_1, x_tau_i.shape)) / get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i, x_tau_i.shape)) noise_z = torch.randn_like(x_tau_i, device=x_tau_i.device) # 整个式子由三部分组成
c1 = sqrt_alphas_bar_tau_i_1 / sqrt_alphas_bar_tau_i * (x_tau_i - get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i, x_tau_i.shape) * denoise)
c2 = torch.sqrt(1 - get_val_by_index(alphas_cumprod, tau_i_1, x_tau_i.shape) - sigma) * denoise
c3 = sigma * noise_z
x_tau_i_1 = c1 + c2 + c3 return x_tau_i_1 img_pred = torch.randn((4, 3, image_size, image_size), device=device) for k in range(0, len(tau_index)):
# print(tau_index)
# 因为 tau_index 是倒序的,tau_i = k, tau_i_1 = k+1,这里不能弄反
tau_i_1 = torch.tensor([tau_index[k+1]], device=device, dtype=torch.long)
tau_i = torch.tensor([tau_index[k]], device=device, dtype=torch.long)
img_pred = step_denoise(model, img_pred, tau_i, tau_i_1) torch.cuda.empty_cache()
if tau_index[k+1] == 0: return img_pred return img_pred with torch.no_grad():
img = p_sample_ddim(model)
img = torch.clamp(img, -1.0, 1.0) show_img_batch(img.detach().cpu())

DDIM

https://arxiv.org/pdf/2010.02502

https://github.com/ermongroup/ddim

Diffusion系列 - DDIM 公式推导 + 代码 -(三)的更多相关文章

  1. Android系列之网络(三)----使用HttpClient发送HTTP请求(分别通过GET和POST方法发送数据)

    ​[声明] 欢迎转载,但请保留文章原始出处→_→ 生命壹号:http://www.cnblogs.com/smyhvae/ 文章来源:http://www.cnblogs.com/smyhvae/p/ ...

  2. Android系列之Fragment(三)----Fragment和Activity之间的通信(含接口回调)

    ​[声明] 欢迎转载,但请保留文章原始出处→_→ 生命壹号:http://www.cnblogs.com/smyhvae/ 文章来源:http://www.cnblogs.com/smyhvae/p/ ...

  3. ReactiveSwift源码解析(九) SignalProducerProtocol延展中的Start、Lift系列方法的代码实现

    上篇博客我们聊完SignalProducer结构体的基本实现后,我们接下来就聊一下SignalProducerProtocol延展中的start和lift系列方法.SignalProducer结构体的 ...

  4. JavaScript 系列博客(三)

    JavaScript 系列博客(三) 前言 本篇介绍 JavaScript 中的函数知识. 函数的三种声明方法 function 命令 可以类比为 python 中的 def 关键词. functio ...

  5. 【原创 深度学习与TensorFlow 动手实践系列 - 3】第三课:卷积神经网络 - 基础篇

    [原创 深度学习与TensorFlow 动手实践系列 - 3]第三课:卷积神经网络 - 基础篇 提纲: 1. 链式反向梯度传到 2. 卷积神经网络 - 卷积层 3. 卷积神经网络 - 功能层 4. 实 ...

  6. Linux Shell系列教程之(三)Shell变量

    本文是Linux Shell系列教程的第(三)篇,更多shell教程请看:Linux Shell系列教程 Shell作为一种高级的脚本类语言,也是支持自定义变量的.今天就为大家介绍下Shell中的变量 ...

  7. 【HANA系列】【第三篇】SAP HANA XS的JavaScript安全事项

    公众号:SAP Technical 本文作者:matinal 原文出处:http://www.cnblogs.com/SAPmatinal/ 原文链接:[HANA系列][第三篇]SAP HANA XS ...

  8. 孟老板 ListAdapter封装, 告别Adapter代码 (三)

    BaseAdapter系列 ListAdapter封装, 告别Adapter代码 (一) ListAdapter封装, 告别Adapter代码 (二) ListAdapter封装, 告别Adapter ...

  9. 《手把手教你》系列技巧篇(三十)-java+ selenium自动化测试- Actions的相关操作下篇(详解教程)

    1.简介 本文主要介绍两个在测试过程中可能会用到的功能:Actions类中的拖拽操作和Actions类中的划取字段操作.例如:需要在一堆log字符中随机划取一段文字,然后右键选择摘取功能. 2.拖拽操 ...

  10. 《手把手教你》系列技巧篇(三十一)-java+ selenium自动化测试- Actions的相关操作-番外篇(详解教程)

    1.简介 上一篇中,宏哥说的宏哥在最后提到网站的反爬虫机制,那么宏哥在自己本地做一个网页,没有那个反爬虫的机制,谷歌浏览器是不是就可以验证成功了,宏哥就想验证一下自己想法,于是写了这一篇文章,另外也是 ...

随机推荐

  1. 【Tycoon City New York】城市梦想家: 纽约 作弊键说明

    这游戏是自带快捷键作弊功能的 [Ctrl] + [Alt] + A 加10,000人口 [Ctrl] + [Alt] + C 加$1,000,000资金 [Ctrl] + [Alt] + B 加100 ...

  2. 【转载】 传统PID算法解决不了的情况,应该怎么办?

    原文地址: http://www.51hei.com/bbs/dpj-152844-1.html --------------------------------------------------- ...

  3. QT基础-弹出框(信息框,模态框,操作框)

    学习前端知识的时候就了解到让用户使用的界面一定要足够清晰,因为你永远不知道用户会以何种方式打开你开发的软件,所以莫泰提示框就很重要了.下面将会介绍几本的集中模态对话框,用来提升用户体验! 1.模态框 ...

  4. 2023 年上海市大学生程序设计竞赛 - 五月赛A,B,C

    A. 选择 多造几组数据可以发现 ​ \(dp[n] = dp[n / 2] + 1\). 假如一个序列为\(\{1,2,\cdots,n\}\),那我们从\(n/2\)后都减去\(n/2\),序列就 ...

  5. 2023 CCPC 桂林题解

    gym H. Sweet Sugar 一个经典贪心是从下到上,如果子树 \(u\) 剩下的部分(一定包含 \(u\))包含合法连通块,那么这个连通块给答案贡献 \(1\),切断 \(u\) 与 \(f ...

  6. kali常用配置

    用户须知 1.免责声明:本教程作者及相关参与人员对于任何直接或间接使用本教程内容而导致的任何形式的损失或损害,包括但不限于数据丢失.系统损坏.个人隐私泄露或经济损失等,不承担任何责任.所有使用本教程内 ...

  7. Win32 sdk 下树形控件响应鼠标单击与双击,获得选中项的名称

    //窗口过程函数INT_PTR CALLBACK myWin::myWinDlgProc(HWND dlgHwnd, UINT uMsg, WPARAM wParam, LPARAM lParam) ...

  8. equals与hashCode关系梳理

    目录 equals用法 hashCode用法 总结 为什么一个类中需要两个比较方法 为什么重写 equals 方法时必须同时重写 hashCode 方法? Reference 这个并不是一个通用性编程 ...

  9. ChatGPT 客户端推荐

    通过按量计费的 Token 使用 ChatGPT 可以获得比免费 ChatGPT 更快的响应速度,但又不必支付昂贵的每月 20 美金订阅费用.下面是一些我个人喜欢的支持 Token 的 ChatGPT ...

  10. CSS – Font Family

    前言 font-family 虽然只是一个 CSS 属性, 但是牵连许多东西, 所以独立一篇来讲. 网站一般上会使用 Google Fonts 作为 font-family, 下面会以一个 Googl ...