更加好的排版: https://www.big-yellow-j.top/posts/2025/06/17/CM.html

前面已经介绍了扩散模型,在最后的结论里面提到一点:扩散模型往往需要多步才能生成较为满意的图像。不过现在有一种新的方式来加速(旨在通过少数迭代步骤)生成图像:一致性模型(consistency model),因此这里主要是介绍一致性模型(consistency model)基本原理以及代码实践,值得注意的是本文不会过多解释数学原理,数学原理推导可以参考:

介绍一致性模型之前需要了解几个知识:在传统的扩散模型中无论是加噪还是解噪过程都是随机的,在论文[1]中(也就是CM作者宋博士的另外一篇论文)将这个随机过程(也就是随机微分方程SDE)转化成“固定的”过程(也就是常微分方程ODE),只有过程可控才能保证下面公式成立。

一致性模型(Consistency Model)

其中ODE(常微分方程),在传统的扩散模型(Diffusion Models, DM)中,前向过程是从原始图像 \(x_0\)开始,不断添加噪声,经过 \(T\)步得到高斯噪声图像 \(x_T\)。反向过程(如 DDPM)通常通过训练一个逐步去噪的模型,将 \(x_T\)逐步还原为 \(x_0\) ,每一步估计一个中间状态,因此推理成本高(需迭代 T 步)。而在 Consistency Models(CM) 中,模型训练时引入了 Consistency Regularization,使得模型在不同的时间步 \(t\)都能一致地预测干净图像。这样在推理时,无需迭代多步,而是可以通过一个单一函数\(f(x ,t)\) 直接将任意噪声图像\(x_t\) 还原为目标图像\(x_0\) 。这大大减少了推理时间,实现了一步(或少数几步)生成。

一致性模型(consistency model)在论文[2]里面主要是通过使用常微分方程角度出发进行解释的。Consistency Model 在 Diffusion Model 的基础上,新增了一个约束:从某个样本到某个噪声的加噪轨迹上的每一个点,都可以经过一个函数 \(f\) 映射为这条轨迹的起点(也就是通过扩散处理的图像在不同的时间 \(t\)都可以直接转化为最开始的图像 \(x_0\)),用数学描述就是:\(f:(x_t, t)\rightarrow x_\epsilon\),换言之就是需要满足: \(f(x_t,t)=f(x_{t^\prime},t^\prime)\) 其中 \(t,t^\prime \in [\epsilon,T]\),正如论文里面的图片描述:

要满足上面的计算关系,作者在论文里面定义如下的等式关系:

\[f_\theta(x,t)=c_{skip}(t)x+ c_{out}(t)F_\theta(x,t)
\]

其中等式需要满足:\(c_{skip}(\epsilon)=1,c_{out}(\epsilon)=0\) (\(c_{skip}(t)=\frac{\sigma_{data}^2}{(t- \epsilon)^2+ \sigma_{data}^2}\), \(c_{out}(t)=\frac{\sigma_{data}(t-\epsilon)}{\sqrt{\sigma_{data}^2+ t^2}}\)),随着解噪过程(时间从:\(T \rightarrow \epsilon\) 其中 \(c_{skip}\) 的值逐渐增大,也就是当前的解噪图像占比权重增加),其中我的 \(F_\theta\) 就是我们的神经网络模型(比如Unet)。既然使用了神经网络那么必定就需要设计一个损失函数,在论文里面作者设计的损失函数为:两个时间步之间生成得到的图像距离通过最小化这个值(比如说 \(\Vert x_{t+1} - x_t \Vert_2\))来优化模型参数。作者对于模型训练给出两种训练方式

直接通过蒸馏模型进行优化

通过直接蒸馏的方式对模型参数进行优化,其中设计的损失函数为:

\[\mathcal{L}_{CD}^N(\boldsymbol{\theta},\boldsymbol{\theta}^-;\phi) = \mathbb{E}[\lambda(t_n)d(\boldsymbol{f}_{\boldsymbol{\theta}}(\mathbf{x}_{t_{n+1}},t_{n+1}),\boldsymbol{f}_{\boldsymbol{\theta}^-}(\hat{\mathbf{x}}_{t_n}^{\boldsymbol{\phi}},t_n))]
\]

其中 \(d\)代表距离(比如\(l_1\)或者 \(l_2\))对于上面公式中几个参数:\(\theta, \theta^-\),其中 \(\hat{x}_{t_n}^\phi\) 代表的是一个预训练的 score model。虽然在CM中损失函数设计上一下子又3个模型,但是实际训练过程中更新的只有一个参数:\(\theta\)。另外一个参数是直接通过:$\theta^- \leftarrow \mu \theta^-+ (1-\mu \theta) $ 通过指数滑动平均方式进行训练。而另外一个参数 \(\phi\)是一个确定的函数直接通过ODE solver来进行计算得到,比如在论文[3]的使用的欧拉求解法:

\[\hat{x}_{t_n}^\phi= x_{t_{n+1}}- (t_n- t_{n+1})t_{n+1}\nabla_{x_{t_{n+1}}}\log p_{t_{n+1}}(x_{t_{n+1}})
\]

欧拉法: \(y_{n+1}= y_n+h*f(t_n, y_n)\) 其中h代表时间步长,f代表当前导数估计。不过值得进一步了解的是,在DL中大部分函数都是直接通过神经网络进行“估算的”,也就是说对于上面的 \(\nabla_{x_{t_{n+1}}}\log p_{t_{n+1}} \textcolor{red}{≈} s_\theta(x_{t_{n+1}},t_{n+1})\) 其中 \(s_\theta\)代表的是训练好的去噪网络。

那么这样一来整个过程就变成了:

回顾整个过程(直接借鉴上面的流程图),算法比较简单(只不过背后的数学原理蛮复杂),简单描述上面过程就是:对于输入图片通过加噪处理之后得到加噪的图像,损失函数设计就是直接通过计算相邻的两步之间的“距离”最小,对于 \(x_{t_{n+1}}\) 我们是已经知道的,但是对于当前时间 \(t_n\) 是未知的,因此可以直接通过ODE solver的方式去进行估计,而后再去计算loss并且更新参数,对于students模型参数 \(\theta\)就可以直接通过计算梯度而后进行更新参数,而对于教师模型参数 \(\theta^-\) 可以直接可通过EMA进行更新

直接训练模型进行优化

直接训练模型进行优化,其中具体的过程为:

LCM/LCM-Lora

潜在一致性模型(Latent Consistency Model)[4]以及LCM-Lora[5](LCM的Lora优微调)通过再latent space中使用一致性模型(stable diffusion model通过VAE将图像进行压缩到latent sapce而后通过DF模型训练并且最后再通过VAE decoder输出),在LCM中主要提出两点:

1、Skipping-Step:因为在最开始的CM中计算两个相邻的时间步之间的loss由于时间步过于接近,就会导致loss很小,因此通过跳步解决这个问题,这样loss就会变成:\(d(f(x_{t_{n+\textcolor{red}{k}}}, t_{n+\textcolor{red}{k}}), f(x_{t_n}, t_n))\)。

2、引入Classifier-free guidance (CFG) 那么整个loss计算就会变成:\(d(f(x_{t_{n+\textcolor{red}{k}}}, \textcolor{red}{w}+ \textcolor{red}{c}, t_{n+\textcolor{red}{k}}), f(x_{t_n}, \textcolor{red}{w}+ \textcolor{red}{c}+ t_n))\),公式中c代表文本,对于CFG而言其实就是一个改进的ODE solver(见下面算法流程中的蓝色部分)

对于LCD算法流程,其中蓝色部分为LCM所修改的内容:

对于最后得到的实验结果分析:

  • 不同的k对结果的影响

在DPM-solver++和DPM-Solver中基本只需要 2000 步迭代,LCM 4 步采样的 FID 就已经基本收敛了

  • 不同的Guidance Scale对结果的影响

LCM 作者用不同 LCM 的迭代次数与不同 Guidance Scale 做了对比。发现 \(w\) 增加有助于提升 CLIP Score,但是损失了 FID 指标(即多样性)的表现。另外,LCM 迭代次数为 2、4、8 时,CLIP Score 和 FID 相差都不大,说明了 LCM 的蒸馏性能确实非常强悍,两步前向的效果可能都足够好了,只是一步前向的结果还差些。

总得来说,在LCM中主要是做了如下几点改进:1、使用skipping-step来“拉大”相邻点之间的距离计算;2、改进了ODE solver。

代码操作[6]

直接阅读最后总结

直接使用diffuser里面给的案例使用LCM/LCM-lora,代码分析如下:

1、时间步处理以及 \(c_{skip}\) 和 \(c_{out}\)

在代码中实现Skipping-Step因此在代码中处理方法为:首先通过DDPM(因为在代码中使用的教师模型是:stable-diffusion-v1-5/stable-diffusion-v1-5 而他使用的就是DDPM方式)而后计算得到topk(1000//30=33),而后通过构建随机索引(通过DDIM采样步:50)从DDIM的time steps((np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1)中进行索引。而后计算边界缩放:\(c_{skip}(t)=\frac{\sigma_{data}^2}{t^2+ \sigma_{data}^2}\), \(c_{out}(t)=\frac{t}{\sqrt{\sigma_{data}^2+ t^2}}\)

noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision)
...
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
start_timesteps = solver.ddim_timesteps[index]
timesteps = start_timesteps - topk
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) # 3. Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(
start_timesteps, timestep_scaling=args.timestep_scaling_factor
)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(
timesteps, timestep_scaling=args.timestep_scaling_factor
)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]

2、加噪、模型处理

图像通过VAE处理之后然后会直接通过noise_scheduler.add_noise处理,最后通过模型预测得到noise_pred,因为加噪过程是 \(z_t=\sqrt{\alpha_t}x_0+ \sqrt{1-\alpha_t}\epsilon\) 通过模型得到了 \(z_t\)因此反推(get_predicted_original_sample)得到加噪前图像 \(x_0\),最后的模型预测就是:\(f_\theta(x,t)=c_{skip}(t)x+ c_{out}(t)F_\theta(x,t)\)

noise = torch.randn_like(latents)
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps) # 5. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w_embedding = guidance_scale_embedding(w, embedding_dim=time_cond_proj_dim)
w = w.reshape(bsz, 1, 1, 1)
# Move to U-Net device and dtype
w = w.to(device=latents.device, dtype=latents.dtype)
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype) # 6. Prepare prompt embeds and unet_added_conditions
prompt_embeds = encoded_text.pop("prompt_embeds") # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
noise_pred = unet(
noisy_model_input,
start_timesteps,
timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds.float(),
added_cond_kwargs=encoded_text,
).sample pred_x_0 = get_predicted_original_sample(
noise_pred,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
) model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0

3、计算CFG、计算loss

因为LCM是一个文本引导的模型,因此在CFG计算中:

其中就存在计算有条件文本\(c\) 和无条件文本 \(\emptyset\)(直接用空文本uncond_input_ids = tokenizer([""] * args.train_batch_size, return_tensors="pt", padding="max_length", max_length=77).input_ids.to(accelerator.device)uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]进行表示即可)代码。最后计算loss值,更新参数并且通过EMA更新教师模型参数:

with torch.no_grad():
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype) with autocast_ctx:
target_noise_pred = target_unet(
x_prev.float(),
timesteps,
timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds.float(),
).sample
pred_x_0 = get_predicted_original_sample(
target_noise_pred,
timesteps,
x_prev,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
target = c_skip * x_prev + c_out * pred_x_0 # 10. Calculate loss
if args.loss_type == "l2":
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
elif args.loss_type == "huber":
loss = torch.mean(
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
)

总结

总的来说consistency model作为一种diffusion model生成(区别与DDPM/DDIM)加速操作,在理论上首先将随机生成过程变成“确定”过程,这样一来生成就是确定的,从 \(T\rightarrow t_0\) 所有的点都在“一条线”上等式 \(f(x_t,t)=f(x_{t^\prime},t^\prime)\) 其中 \(t,t^\prime \in [\epsilon,T]\) 成立那么就保证了模型不需要再去不断依靠 \(t+1\) 生成内容去推断 \(t\)时刻内容(具体可以参考算法流程图)。而后续的LCM/LCM-Lora/TCD[7]则是基于CM的原理进行改进,回顾一下LCM的过程,理解代码(参考Huggingface)操作:

(LCM蒸馏)训练过程中主要使用了3个模型:1、teacher_model;2、unet;3、student_model。其中后面两个模型是相同的,第一个模型可以直接使用训练好的SD模型。

1、首先是构建跳步迭代过程,而后去计算(公式-1) \(c_{skip}\)和 \(c_{out}\)以及:\(f_\theta(x,t)=c_{skip}(t)x+ c_{out}(t)F_\theta(x,t)\) 对于其中的\(x\) 以及 \(F_\theta\) (代码中)分别表示的是\(t\)时刻 模型预测的输出和 \(t-1\)时刻的噪声图像(加噪反推可以直接计算出来)这样一来就得到( unet模型计算 ):model_pred(对应公式中的:\(f_\theta(z_{t_{n+k}},w,c,t_{n+k})\))

2、对于ODE solver计算就和公式-1计算过程相似,只不过需要区分有文本编码和没有文本编码两种,最后得到(这个过程通过教师模型 teacher_model 处理)pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) 以及 pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise) 而后再通过扩散解噪过程得到\(t\)时刻的(ODE Solver结果): x_prev(对应公式中的:\(z_{t_n}^{\Psi,w}\))

3、计算loss通过student_model处理 x_prev而后反推得到 pred_x_0再去计算\(f_\theta(x,t)=c_{skip}(t)x+ c_{out}(t)F_\theta(x,t)\)(c_skip * x_prev + c_out * pred_x_0)得到target。最后去计算model_pred和 target的loss值,而后去通过EMA更新 student_model参数(通过unet来更新他的参数)

参考


  1. https://arxiv.org/pdf/2011.13456

  2. https://arxiv.org/abs/2303.01469

  3. https://arxiv.org/pdf/2406.14548v2

  4. https://arxiv.org/abs/2310.04378

  5. https://arxiv.org/abs/2311.05556

  6. https://github.com/huggingface/diffusers/tree/main/examples/consistency_distillation

  7. https://arxiv.org/abs/2402.19159

深入浅出了解生成模型-4:一致性模型(consistency model)的更多相关文章

  1. 一致性模型(consistency model)

    比如下面的例子: 一行X值在节点M和节点N上有副本 客户端A在节点M上写入行X的值 一段时间后,客户端B在节点N上读取行X的值 一致性模型所要做的就是决定客户端B能否看到客户端A写的值.一致性模型分为 ...

  2. [翻译]内存一致性模型 --- memory consistency model

    I will just give the analogy with which I understand memory consistency models (or memory models, fo ...

  3. CAP原理、一致性模型、BASE理论和ACID特性

    CAP原理 在理论计算机科学中,CAP定理(CAP theorem),又被称作布鲁尔定理(Brewer's theorem),它指出对于一个分布式计算系统来说,不可能同时满足以下三点: 一致性(Con ...

  4. Flink流处理(五)- 状态与一致性模型

    状态(State)与一致性模型 接下来我们转向另一个在流处理中十分重要的点:状态(state).状态在数据处理中是无处不在的.为了产生一个结果,函数一般会聚合某个时间段内(或是一定数量的)events ...

  5. java内存模型-顺序一致性

    数据竞争与顺序一致性保证 当程序未正确同步时,就会存在数据竞争.java 内存模型规范对数据竞争的定义如下: 在一个线程中写一个变量, 在另一个线程读同一个变量, 而且写和读没有通过同步来排序. 当代 ...

  6. Actor模型浅析 一致性和隔离性

    一.Actor模型介绍 在单核 CPU 发展已经达到一个瓶颈的今天,要增加硬件的速度更多的是增加 CPU 核的数目.而针对这种情况,要使我们的程序运行效率提高,那么也应该从并发方面入手.传统的多线程方 ...

  7. TensorFlow笔记四:从生成和保存模型 -> 调用使用模型

    TensorFlow常用的示例一般都是生成模型和测试模型写在一起,每次更换测试数据都要重新训练,过于麻烦, 以下采用先生成并保存本地模型,然后后续程序调用测试. 示例一:线性回归预测 make.py ...

  8. .net架构设计读书笔记--第三章 第8节 域模型简介(Introducing Domain Model)

    一.数据--行为转变     很长的时间,典型的分析方法或多或少是以下两种,第一,收集需求并做一些分析,找出有关实体 (例如,客户. 订单. 产品) 和进程来实现. 第二,手持这种理解你尝试推断一个物 ...

  9. CAP理论下对比ACID模型与BASE模型

    CAP介绍 Consistency(一致性), 数据一致更新,所有数据变动都是同步的.比如网购,库存减少的同时资金增多.Availability(可用性), 好的响应性能.比如支付操作10ms内响应用 ...

  10. ER图/模型转换为关系模型

    ER图中的主要成分是实体类型和联系类型,转换规则就是如何把实体类型.联系类型转换成关系模式. 1. 二元联系转换 规则1.1(实体类型的转换):将每个实体类型转换成一个关系模式,实体的属性即为关系模式 ...

随机推荐

  1. 【Linux】5.7 Shell test命令

    Shell test 命令 Shell中的 test 命令用于检查某个条件是否成立,它可以进行数值.字符和文件三个方面的测试. 1. 数值测试 参数 说明 -eq 等于则为真 -ne 不等于则为真 - ...

  2. SQL 和 PL/SQL 的区别

    不经意看到2个ORA错误,一个提示PL/SQL ORA-错误,另一个提示SQL ORA-错误,好奇这2货啥区别?留爪. PL/SQL也是一种程序语言,叫做过程化SQL语言(Procedural Lan ...

  3. 二叉树专题学习(C++版) 基础的上机题

    前言: 由于二叉树这一章的题型比较多,涉及到的递归程序也较多,所以单开一个随笔来记录这个学习过程,希望对读者有帮助. 理论知识基础 在二叉树的选择题中,常常会涉及到对于最多或最少结点.最大或最小高度. ...

  4. Tomcat知识点整理

    从学习起就开始接触tomcat, 解压, 点击运行, 然后放文件夹里面会自动部署, 可以通过ip访问. 在这里主要记录一些tomcat相关的知识点 配置文件解析(留位置) server.xml/web ...

  5. php 常用bc函数

    bcadd - 加法,2个任意精度数字的加法计算bcsub - 减法bcmul - 乘法bcdiv - 除法bcpow - 乘方bcmod - 取模bcsqrt - 求二次方根bccomp - 比较两 ...

  6. SQL 强化练习 (八)

    继续练习写sql, 不能停下来. 今天还额外对 Excel 拼接 sql 语句做了一个代码实现, 逻辑是蛮简单的, 发现其实很多东西都是蛮简单的, 只要一点点去做, 明白逻辑过后, 慢慢去调试, 都是 ...

  7. 实现高质量视频通话的javascript技巧与方法

    @charset "UTF-8"; .markdown-body { line-height: 1.75; font-weight: 400; font-size: 15px; o ...

  8. vue之sync

    在 Vue 中,.sync 是一个用于实现双向数据绑定的特殊修饰符.它允许父组件通过一种简洁的方式向子组件传递一个 prop,并在子组件中修改这个 prop 的值,然后将修改后的值反馈回父组件,实现双 ...

  9. Add Two Numbers--LeetCode进阶路②

    题目描述: You are given two non-empty linked lists representing two non-negative integers. The digits ar ...

  10. IPO——LeetCode⑫

    //原题链接https://leetcode.com/problems/ipo/ 题目描述 Suppose LeetCode will start its IPO soon. In order to ...