最近,人工智能社区在开发更大、更高性能的语言模型方面取得了显著的进展,例如 Falcon 40B、LLaMa-2 70B、Falcon 40B、MPT 30B; 以及在图像领域的模型,如 SD2.1 和 SDXL 。这些进步无疑推动了人工智能的发展,使其具有高度多功能和最先进的图像生成和语言理解能力。然而,在我们惊叹于这些模型的强大和复杂性之余,必须认识到一个日益增长的需求: 使人工智能模型体量更小、运行更高效、更易于访问,特别是通过开源它们来共建生态。

Segmind,我们一直致力于如何使生成式 AI 更快、更便宜。去年,我们开源了我们加速的 SD-WebUI 库 voltaML,它是一个基于 AITemplate/TensorRT 的推理加速库,推理速度提高了 4-6 倍。为了继续实现使生成模型更快、更小、更便宜的目标,我们正在开源我们压缩的 SD 模型:SD-Small 和 SD-Tiny 的权重和训练代码。预训练的检查点可在 Hugging Face 上获取。

知识蒸馏

我们的新压缩模型已经经过知识蒸馏 (KD) 技术的训练,这项工作主要基于 这篇论文。作者描述了一种块移除知识蒸馏方法,其中一些 UNet 层被移除,学生模型权重被训练。使用论文中描述的 KD 方法,我们能够使用 diffusers 库训练两个压缩模型; Small(微小版本)Tiny(极小版本),分别比基础模型少 35% 和 55% 的参数,同时实现与基础模型相当的图像保真度。我们已经在这个 repo 中开源了我们的蒸馏代码,并将预训练检查点上传到了 Hugging Face

知识蒸馏训练神经网络类似于老师一步一步指导学生。一个大的老师模型 (teacher model) 预先在大量数据上训练,然后一个较小的模型在较小的数据集上训练,以模仿大模型的输出并在数据集上进行经典训练。

在这种特殊类型的知识蒸馏中,学生模型被训练来完成从纯噪声恢复图像的正常扩散任务,但同时,模型被迫与更大的老师模型的输出匹配。输出匹配发生在 U-nets 的每个块,因此模型质量基本保持不变。所以,使用前面的类比,我们可以说,在这种蒸馏过程中,学生不仅会试图从问题和答案中学习,还会从老师的答案以及逐步得到答案的方法中学习。我们在损失函数中有 3 个组成部分来实现这一点,首先是目标图像隐变量和生成图像隐变量之间的传统损失。其次是老师生成的图像隐变量和学生生成的图像隐变量之间的损失。最后,也是最重要的组成部分,是特征级损失,即老师和学生每个块输出之间的损失。

结合所有这些构成了知识蒸馏训练。下面是论文中描述的用于 KD 的块移除 UNet 架构。

图片来自 Shinkook 等人的 论文 “On Architectural Compression of Text-to-Image Diffusion Models”。

我们以 Realistic-Vision 4.0 为基础老师模型,并在LAION Art Aesthetic 数据集 上训练,图像分数高于 7.5,因为它们具有高质量的图像描述。与论文不同,我们选择分别为 Small 和 Tiny 模式训练两个模型,分别在 1M 张图像上进行 100K 步和 125K 步的训练。蒸馏训练的代码可以在 这里 找到。

模型使用

模型可以通过 diffusers 中的 DiffusionPipeline 来使用。

from diffusers import DiffusionPipeline
import torch pipeline = DiffusionPipeline.from_pretrained("segmind/small-sd", torch_dtype=torch.float16)
prompt = "Portrait of a pretty girl"
negative_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
image = pipeline(prompt, negative_prompt = negative_prompt).images[0]
image.save("my_image.png")

推理延迟方面的速度表现

我们观察到,蒸馏模型比原始基础模型快了一倍。基准测试代码可以在 这里 找到。

潜在的局限性

蒸馏模型处于早期阶段,输出可能还不具备生产水平的质量。这些模型可能不是最好的通用模型,它们最好用作针对特定概念/风格进行微调或 LoRA 训练。蒸馏模型目前还不太擅长组合性或多概念。

在人像数据集上微调 SD-tiny 模型

我们已经在 Realistic Vision v4.0 模型生成的人像图像上微调了我们的 sd-tiny 模型。下面是使用的微调参数。

原版参数 中文释义
Steps: 131000 步数: 131000
Learning rate: 1e-4 学习率: 1e-4
Batch size: 32 批量大小: 32
Gradient accumulation steps: 4 梯度累积步数: 4
Image resolution: 768 图像分辨率: 768
Dataset size: 7k images 数据集大小: 7 千张图像
Mixed precision: fp16 混合精度: fp16

我们能够产生接近原始模型产生的图像质量,参数减少了近 40%,下面的样本结果不言自明:

微调基础模型的代码可以在 这里 找到。

LoRA 训练

在蒸馏模型上进行 LoRA 训练的一个优点是训练更快。下面是我们在蒸馏模型上对一些抽象概念进行的第一个 LoRA 训练的一些图像。LoRA 训练的代码可以在 这里 找到。

结论

我们邀请开源社区帮助我们改进并实现这些蒸馏 SD 模型的更广泛采用。用户可以加入我们的 Discord 服务器,在那里我们将宣布这些模型的最新更新,发布更多的检查点和一些令人兴奋的新 LoRAs。如果你喜欢我们的工作,请在我们的 Github 上点一下 star。


英文原文: https://hf.co/blog/sd_distillation

原文作者: Yatharth Gupta

译者: innovation64

审校/排版: zhongdongy (阿东)

开源 SD-Small 和 SD-Tiny 知识蒸馏代码与权重的更多相关文章

  1. 知识蒸馏(Distillation)

    蒸馏神经网络取名为蒸馏(Distill),其实是一个非常形象的过程. 我们把数据结构信息和数据本身当作一个混合物,分布信息通过概率分布被分离出来.首先,T值很大,相当于用很高的温度将关键的分布信息从原 ...

  2. Deeplearning知识蒸馏

    Deeplearning知识蒸馏 merge paddleslim.dist.merge(teacher_program, student_program, data_name_map, place, ...

  3. 【论文考古】知识蒸馏 Distilling the Knowledge in a Neural Network

    论文内容 G. Hinton, O. Vinyals, and J. Dean, "Distilling the Knowledge in a Neural Network." 2 ...

  4. 开源囧事4:你们这些卖代码的能不能留自己的QQ号?留我QQ号干嘛?

    缘起于开源项目 从 2017 年开始,陆陆续续写了一些开源项目放到开源网站里,都是一些实战项目,给大家练练手.有基础整合的demo,有 Spring Boot 博客项目,有 Spring Boot 商 ...

  5. 【DKNN】Distilling the Knowledge in a Neural Network 第一次提出神经网络的知识蒸馏概念

    原文链接 小样本学习与智能前沿 . 在这个公众号后台回复"DKNN",即可获得课件电子资源. 文章已经表明,对于将知识从整体模型或高度正则化的大型模型转换为较小的蒸馏模型,蒸馏非常 ...

  6. Android 存储到SD卡,获取SD的大小及可用空间

    使用Sdcard注意事项:     1.权限问题:             <uses-permission android:name="android.permission.WRIT ...

  7. eclipse中如何向开源中国(码云)上传代码

    摘要 本文将介绍如何将本地的项目提交到开源中国上去,过程比较详细,实现起来很简单.由于自己也算是一个新手,所以没有做过多的解释,只是单纯的描述了该如何去做.   1.在开源中国上面新建一个空项目 到这 ...

  8. 分享一个开源的JavaScript统计图表库,40行代码实现专业统计图表

    提升程序员工作效率的工具/技巧推荐系列 推荐一个功能强大的文件搜索工具SearchMyFiles 介绍一个好用的免费流程图和UML绘制软件-Diagram Designer 介绍Windows任务管理 ...

  9. 95)PHP,文件上传知识和代码

    首先是知识总结: 上传: 从浏览器端传输的到服务器端. 请求时: 数据从浏览器端传输到服务器端. 可见: 上传,发生在浏览器向服务器发出请求过程中. 文件,对于浏览器来讲,就是表单中的一个特殊类型的数 ...

  10. wndows程序设计之书籍知识与代码摘录-封装一个类似printf的messagebox

    //----------------------------------------- //本程序展示了如何实现MessageBoxPrintf函数 //本函数能像printf那样格式化输出 //摘录 ...

随机推荐

  1. pg_enterprise_views偶然发现的PG神仙插件!

    一直从事数据库相关的工作,对于PG而言最大的问题其实是在运维管理方面,其缺乏有效且直观成体系的系统表,苦觅良久,今日在PG官网中发现了一款新收录的免费插件,其提供了数十张系统表,内容涵盖了从操作系统到 ...

  2. Rocky 9 Linux 平台 vim 9.0 源码包编译安装踩坑记录

    目录 vim 9.0 部署准备环境 vim 9.0 源码包正式部署 vim 9.0 初体验 plug-vim 安装插件 在上一篇 <vim入门实战> 篇,我并没有介绍 Linux 平台源码 ...

  3. c#构建具有用户认证与管理的socks5代理服务端

    Socks 协议是一种代理 (Proxy) 协议, 例如我们所熟知的 Shdowsocks 便是 Socks 协议的一个典型应用程序, Socks 协议有多个版本, 目前最新的版本为 5, 其协议标准 ...

  4. 7-8 估值一亿的AI核心代码

    题目描述: 以上图片来自新浪微博. 本题要求你实现一个稍微更值钱一点的 AI 英文问答程序,规则是: 无论用户说什么,首先把对方说的话在一行中原样打印出来: 消除原文中多余空格:把相邻单词间的多个空格 ...

  5. Kali下压缩解压缩命令大全zip,tar,tar.gz,tar.bz2(转)

    转自http://blog.csdn.net/yangjin_unique/article/details/7824852 tar 解包:tar xvf FileName.tar 打包:tar cvf ...

  6. 为什么会出现 setTimeout 倒计时误差

    setTimeout 倒计时误差的出现主要与 JavaScript 的事件循环机制和计时器的执行方式有关. 在 JavaScript 中,事件循环是用于管理和调度代码执行的机制.setTimeout ...

  7. 如何卸载 python setup.py install 安装的包?

    当我们半自动安装某些 python 包时,总是存在很多依赖关系的问题,而这些问题还是很难避免的,所以,当我们安装一个不确定的包的时候,最好提前收集一些相关资料,或者请教他人,同时最好把安装过程都记录下 ...

  8. bugku_EasyMath

    bugku_EasyMath 题目描述 简单的数学题 from Crypto.Util.number import getPrime, bytes_to_long from secret import ...

  9. Python生成指定大小的文件

    转载请注明出处️ 作者:测试蔡坨坨 原文链接:caituotuo.top/400bd75c.html 你好,我是测试蔡坨坨. 在日常测试工作中,我们经常需要对上传的文件大小进行测试,例如:一个文件上传 ...

  10. @Inherited元注解的使用

    @Inherited注解标记其他的注解用于指明标记的注解是可以被自动继承的. 注意:此注解只对注解标记的超类有效,对接口是无效的. 示例: 先声明两个用@Inherited标记的注解,@Name和@A ...