High-Resolution Image Synthesis with Latent Diffusion Models

论文背景

LDM是Stable Diffusion模型的奠基性论文

于2022年6月在CVPR上发表



传统生成模型具有局限性:

  • 扩散模型(DM)通过逐步去噪生成图像,质量优于GAN,但直接在像素空间操作导致高计算开销。
  • 随着分辨率提升,扩散模型的优化和推理成本呈指数级增长,限制了实际应用

如DDPM生成的图像分辨率普遍不超过256×256,而LDM生成的图像分辨率可以超过1024×1024.

而LDM通过将扩散过程迁移至潜在空间,解决了传统模型的计算瓶颈,同时保持生成质量与灵活性

论文框架方法

论文中框架示意图如图所示:



在训练阶段:

  • 预训练自动编码器(AE)和条件生成编码器(如clip)
  • 输入图片x,经过自动编码器压缩到隐空间ε(x)=z
  • 随机采样时间步T,对Z进行加噪到\(Z_{T}\)
  • 对右边框里的条件进行条件编码\(\tau_\theta(y)\)和\(Z_{T}\)一起输入UNet网络中
  • 进行交叉注意力计算,其中\(Z_{T}\)作为Q向量,\(\tau_\theta(y)\)作为K,V向量计算注意力,这样做是让图像的每个位置根据文本的语义来决定关注哪些部分
  • 最后Unet输出两个向量,一个是无条件预测噪声,一个是文本预测噪声。

无条件预测噪声输入是空字符串

  • 使用CFG计算最终预测噪声\(\epsilon_{\text{guided}}(z_t, t, \tau_\theta(y)) = \epsilon_\theta(z_t, t,\tau_\theta(y)) + s \cdot (\epsilon_\theta(z_t, t, \tau_\theta(y)) - \epsilon_\theta(z_t, t, \varnothing))\)
  • 使用损失函数进行反向传播计算

在生成阶段:

  • 以随机噪声\(Z_{T}\)作为起点
  • 输入文本作为条件,编码后一起进入Unet进行交叉注意力计算
  • 输出预测噪声\(\epsilon_{\text{guided}}(z_t, t, \tau_\theta(y))\)
  • 使用调度器进行逐步去噪计算(如DDPM,DDIM)成为\(Z_{T-1}\)
  • 重复以上过程,直到Z
  • 通过自动编码器的解码器部分把Z迁移到像素空间,D(z),即生成图像

交叉注意力机制中的维度变换

图像编码后变成 C=4, H'=64, W'=64,展平后作为Q(\(z_t \Rightarrow Q \in \mathbb{R}^{(H'W') \times d}\)),文本通过编码器的编码表示为\(c = [t_1, t_2, ..., t_L] \Rightarrow \text{Embedding} \in \mathbb{R}^{L \times d}\),K和V表示为\(K, V \in \mathbb{R}^{L \times d}\),计算注意力权重\(A = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d}} \right) \in \mathbb{R}^{(H'W') \times L}\),输出为\(\text{Attention}(Q, K, V) = A \cdot V \in \mathbb{R}^{(H'W') \times d}\)

        # 潜在空间输入(prepare_latents生成)
latents.shape = (batch_size * num_images_per_prompt, 4, H//8, W//8) # 文本嵌入处理(encode_prompt输出)
prompt_embeds.shape = (batch_size, max_sequence_length, embedding_dim) # IP适配器图像嵌入处理
image_embeds[0].shape = (batch_size * num_images_per_prompt, num_images, emb_dim) # UNet输入/输出维度
latent_model_input.shape = [batch*2, 4, H//8, W//8] # 当启用CFG时
noise_pred.shape = [batch*2, 4, H//8, W//8] # UNet输出噪声预测
假设参数设置
prompt = "一只坐在月球上的猫"
height = 512
width = 512
num_images_per_prompt = 1
guidance_scale = 7.5
batch_size = 1 # 根据prompt长度自动确定 # 关键计算步骤演示
# ---------------------------
# 步骤1:潜在空间(latents)维度计算
latents_shape = (
batch_size * num_images_per_prompt, # 1*1=1
4, # UNet输入通道数
height // 8, # 512/8=64
width // 8 # 512/8=64
)
print(f"潜在空间维度: {latents_shape}") # -> (1, 4, 64, 64) # 步骤2:文本编码维度(假设使用CLIP模型)
prompt_embeds_shape = (
batch_size,
77, # CLIP最大序列长度
768 # CLIP文本编码维度
)
print(f"文本嵌入维度: {prompt_embeds_shape}") # -> (1, 77, 768) # 步骤3:CFG处理后的嵌入
if guidance_scale > 1:
prompt_embeds = torch.cat([negative_embeds, positive_embeds])
print(f"CFG嵌入维度: {prompt_embeds.shape}") # -> (2, 77, 768) # 步骤4:UNet输入维度(假设启用CFG)
latent_model_input = torch.cat([latents] * 2)
print(f"UNet输入维度: {latent_model_input.shape}") # -> (2, 4, 64, 64) # 步骤5:噪声预测输出
noise_pred = unet(latent_model_input, ...)[0]
print(f"噪声预测维度: {noise_pred.shape}") # -> (2, 4, 64, 64) # 步骤6:CFG调整后的噪声
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
print(f"调整后噪声维度: {noise_pred.shape}") # -> (1, 4, 64, 64) # 最终输出图像
image = vae.decode(latents / vae.config.scaling_factor)[0]
print(f"输出图像维度: {image.shape}") # -> (1, 3, 512, 512)
def cross_attention(query, key, value):
# 输入维度说明
# query: 来自潜在噪声 [batch=2, 4*64*64=16384] → 投影为 [2, 16384, 768]
# key/value: 来自文本嵌入 [2, 77, 768] # 步骤1:计算注意力分数
attention_scores = torch.matmul(
query, # [2, 16384, 768]
key.transpose(-1, -2) # [2, 768, 77] → 转置后维度
) # 矩阵乘法结果 → [2, 16384, 77] # 步骤2:计算注意力权重
attention_probs = torch.softmax(
attention_scores, # [2, 16384, 77]
dim=-1 # 对最后一个维度(文本标记维度)做归一化
) # 保持维度 [2, 16384, 77] # 步骤3:应用注意力到value
output = torch.matmul(
attention_probs, # [2, 16384, 77]
value # [2, 77, 768]
) # 结果维度 → [2, 16384, 768] # 步骤4:重塑为潜在空间维度
output = output.view(2, 4, 64, 64, 768) # 恢复空间结构
output = output.permute(0, 4, 1, 2, 3) # [2, 768, 4, 64, 64]
output = self.to_out(output) # 通过最后的线性层投影回4通道
return output # [2, 4, 64, 64]

数据集以及指标介绍

数据集介绍

CelebA-HQ 256 × 256数据集,是一个大规模的人脸属性数据集,拥有超过200K张名人图片,每张图片都有40个属性注释(如身份,年龄、表情、发型等)。

从Flickr网站爬取的人脸数据集集合,涵盖多样化的年龄、种族、表情、配饰(如眼镜、帽子)等属性

这两个数据集都是LSUN(大规模场景理解)数据集的子集,两个数据集分别表示教堂和卧室场景的数据,包含教堂建筑的不同视角、结构和环境条件,覆盖多样化的卧室场景,包括不同装修风格、家具布局和光照条件

指标介绍

IS分数介绍

Inception Score 的定义为:

\(IS(G) = \exp \left( \mathbb{E}_{x \sim p_g} \left[ D_{KL} ( p(y|x) \| p(y) ) \right] \right)\)

x~pg:生成图像样本来自生成模型的分布 。

p(y|x):通过预训练分类器(如Inception v3)对生成图像的类别预测概率分布。

p(y):预测类别的边缘分布。类别可以是猫,狗,猪等诸如此类的动物。

其中如果生成图像明确、质量高,则p(y|x)的熵就会比较低,如果生成图像比较多样,则p(y)的熵就会较高,体现在公式中则IS分数会较高。

FID分数介绍

主要是计算生成图像分布和真实图像分布在特征空间中的距离

公式\(\text{FID} = \| \mu_r - \mu_g \|_2^2 + \text{Tr}(\Sigma_r + \Sigma_g - 2 (\Sigma_r \Sigma_g)^{\frac{1}{2}})\)

\(\mu_r,\Sigma_r\):真实图像分布的均值和协方差矩阵。

\(\mu_g,\Sigma_g\):生成图像分布的均值和协方差矩阵。

\(\| \mu_r - \mu_g \|_2^2\):欧几里得距离的平方。

\(\text{Tr}\):矩阵的迹。

\((\Sigma_r \Sigma_g)^{\frac{1}{2}}\):协方差矩阵的乘积的平方根。

两个分布的均值和协方差越低,FID越低,生成图像质量越接近生成的图像

prec和recall

这里的指标和一般理解的不一样。

会先用Inception网络分别提取真实图像和生成图像的特征点

用集合的角度解释:

Precision ≈ 生成图像中,有多少落在真实图像分布的“支持区域”里(真实性)

Recall ≈ 真实图像中,有多少被生成图像的“支持区域”覆盖(多样性)

实验分析

研究不同下采样因子f对生成图像质量和训练效率的影响

下采样因子:指的是自动编码器中的参数。

可以看到下采样因子为4或8时表现最好。因为如果因子过小,会导致维度高,计算缓慢,因子过大,会损失很多信息,导致最后生成图像生成质量较差

后续的实验将基于此展开

在这个实验里可以看到,LDM在CelebA-HQ中取得了最优的FID分数,在其他数据集上的表现也是中规中矩。

这个实验里展示了LDM在类别生成任务中的表现,可以看到使用cfg引导的LDM展现出了非常优秀的性能,在FID和IS分数上表现优异,虽然recall略低,但是使用的参数量也大幅减少了。

stable diffusion论文解读的更多相关文章

  1. 使用 LoRA 进行 Stable Diffusion 的高效参数微调

    LoRA: Low-Rank Adaptation of Large Language Models 是微软研究员引入的一项新技术,主要用于处理大模型微调的问题.目前超过数十亿以上参数的具有强能力的大 ...

  2. 论文解读(KP-GNN)《How Powerful are K-hop Message Passing Graph Neural Networks》

    论文信息 论文标题:How Powerful are K-hop Message Passing Graph Neural Networks论文作者:Jiarui Feng, Yixin Chen, ...

  3. itemKNN发展史----推荐系统的三篇重要的论文解读

    itemKNN发展史----推荐系统的三篇重要的论文解读 本文用到的符号标识 1.Item-based CF 基本过程: 计算相似度矩阵 Cosine相似度 皮尔逊相似系数 参数聚合进行推荐 根据用户 ...

  4. CVPR2019 | Mask Scoring R-CNN 论文解读

    Mask Scoring R-CNN CVPR2019 | Mask Scoring R-CNN 论文解读 作者 | 文永亮 研究方向 | 目标检测.GAN 推荐理由: 本文解读的是一篇发表于CVPR ...

  5. AAAI2019 | 基于区域分解集成的目标检测 论文解读

    Object Detection based on Region Decomposition and Assembly AAAI2019 | 基于区域分解集成的目标检测 论文解读 作者 | 文永亮 学 ...

  6. Gaussian field consensus论文解读及MATLAB实现

    Gaussian field consensus论文解读及MATLAB实现 作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailugaji/ 一.Introduction ...

  7. zz扔掉anchor!真正的CenterNet——Objects as Points论文解读

    首发于深度学习那些事 已关注写文章   扔掉anchor!真正的CenterNet——Objects as Points论文解读 OLDPAN 不明觉厉的人工智障程序员 ​关注他 JustDoIT 等 ...

  8. NIPS2018最佳论文解读:Neural Ordinary Differential Equations

    NIPS2018最佳论文解读:Neural Ordinary Differential Equations 雷锋网2019-01-10 23:32     雷锋网 AI 科技评论按,不久前,NeurI ...

  9. [论文解读] 阿里DIEN整体代码结构

    [论文解读] 阿里DIEN整体代码结构 目录 [论文解读] 阿里DIEN整体代码结构 0x00 摘要 0x01 文件简介 0x02 总体架构 0x03 总体代码 0x04 模型基类 4.1 基本逻辑 ...

  10. 【抓取】6-DOF GraspNet 论文解读

    [抓取]6-DOF GraspNet 论文解读 [注]:本文地址:[抓取]6-DOF GraspNet 论文解读 若转载请于明显处标明出处. 前言 这篇关于生成抓取姿态的论文出自英伟达.我在读完该篇论 ...

随机推荐

  1. redis bind protected-mode

    概要 redis bind.protected-mode 配置 安装并启动 yum install -y redis systemctl enable --now redis # 使用 redis-s ...

  2. Open-Sora 2.0 重磅开源!

    潞晨科技正式推出 Open-Sora 2.0 -- 一款全新开源的 SOTA 视频生成模型,仅 20 万美元(224 张 GPU)成功训练商业级 11B 参数视频生成大模型.开发高性能的视频生成模型通 ...

  3. angular双向数据绑定踩坑记:

    在angular中使用ngModel时出现了一个报错error NG8002: Can't bind to 'ngModel' since it isn't a known property of ' ...

  4. linux部署go项目

    直接部署: 1.将程序所需要的文件如配置文件和生成的可执行文件拷贝到linux中 2.直接执行./main命令,启动程序 (main是go编译生成的可执行文件) 如果报Permission denie ...

  5. mysql8导入myslq5 报错

    打开sql文件替换 我的数据库编码是utf8mb4,如果你的数据库编码是别的,替换成你自己的编码. utf8mb4_0900_ai_ci替换为utf8mb4_general_ci

  6. HTTP 特性

    HTTP 常见到版本有 HTTP/1.1,HTTP/2.0,HTTP/3.0,不同版本的 HTTP 特性是不一样的. 这一章主要针对 HTTP/1.1 展开,最突出的优点是「简单.灵活和易于扩展.应用 ...

  7. MFC编程中与编码方式有关的宏定义的使用

    1 多字节字符集:char *strcpy(char *strDestination, const char *strSource); Unicode字符集:wchar_t *wcscpy(wchar ...

  8. Delphi 多线程使用

    1. 定义线程类 type TMyThread = class(TThread) private { Private declarations } fPos:Integer; // 变量 protec ...

  9. 【JVM之内存与垃圾回收篇】垃圾回收概述

    垃圾回收概述 概念 这次我们主要关注的是黄色部分,内存的分配与回收 垃圾收集 垃圾收集,不是 Java 语言的伴生产物.早在 1960 年,第一门开始使用内存动态分配和垃圾收集技术的 Lisp 语言诞 ...

  10. ASP.NET Core 全球化和本地化

    留备后观... Globalization and localization in ASP.NET Core 体验 ASP.NET Core 中的多语言支持(Localization)