Smiling & Weeping

                ---- 一生拥有自由和爱,是我全部的野心

1. 环境准备

%pip install diffusers
from huggingface_hub import notebook_login

# 登录huggingface
notebook_login()
import numpy as np
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
import torchvision
from PIL import Image

def show_images(x):
"""给定一批图像,创建一个网格并将其转换成PIL"""
x = x*0.5 + 0.5
grid = torchvision.utils.make_grid(x)
grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1)*255
grad_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
return grad_im

def make_grid(images, size=64):
"""给定一个PIL图像列表,将他们叠加成一行以便查看"""
output_im = Image.new("RGB", (size*len(images), size))
for i, im in enumerate(images):
out_im.paste(im.resize((size, size)), (i*size, 0))
return output_im

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
from diffusers import DDPMPipeline, StableDiffusionPipeline

model_id = "sd-dreambooth-library/mr-potato-head"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
prompt = "a cute anime characters using 8K resolution"
image = pipe(prompt, num_inference_steps=50, guidance_scale=5.5).images[0]
image

Diffusers核心API:

  • 管线:从高层次设计的多种类函数,便于部署的方式实现,能够快速利用预训练的主流扩散模型来生成样本。
  • 模型:在训练新的扩散模型时需要用到的网络结构。
  • 调度器:在推理过程中使用多种不同的技巧来从噪声中生成图像,同时可以生成训练过程中所需的“带噪”图像。
import torchvision
from datasets import load_dataset
from torchvision import transforms
from diffusers import DDPMScheduler
from diffusers import DDPMPipeline, StableDiffusionPipeline dataset = load_dataset('lowres/anime', split="train") image_size = 256
batch_size = 8 preprocess = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.RandomHorizontalFlip(),
transforms.Normalize([0.5], [0.5]),
]) def transform(examples):
images = [preprocess(image.convert("RGB")) for image in examples["image"]]
return {"images": images} dataset.set_transform(transform)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

  

xb = next(iter(train_dataloader))['images'].to(device)[:8]
print("X shape:", xb.shape)
show_images(xb).resize((8*256, 256), resample=Image.NEAREST)

# 定义调度器
from diffusers import DDPMScheduler noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_start=0.001, beta_end=0.004)
timesteps = torch.linspace(0, 999, 8).long().to(device)
noise = torch.rand_like(xb)
noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)
print("Noise X Shape", noisy_xb.shape)
show_images(noisy_xb).resize((8*64, 64), resample=Image.NEAREST)

  

from diffusers import UNet2DModel

model = UNet2DModel(
sample_size=image_size, # 目标图像的分辨率
in_channels=3,
out_channels=3,
layers_per_block=2, # 每一个UNet块中的ResNet层数
block_out_channels=(64, 128, 128, 256),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D", # 带有空域维度的self-att的ResNet下采样模块
"AttnDownBlock2D",
),
up_block_types=(
"AttnUpBlock2D",
"AttnUpBlock2D", # 带有空域维度的self-att的ResNet上采样模块
"UpBlock2D",
"UpBlock2D",
),
) model = model.to(device)
with torch.no_grad():
model_pred = model(noisy_xb, timesteps).sample model_pred.shape

 训练



# 设定噪声调度器
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2") # 训练循环
optimizer = torch.optim.Adam(model.parameters(), lr=4e-4) losses = [] # 定义损失函数
loss_fn = torch.nn.MSELoss() for epoch in range(45):
for step, batch in enumerate(train_dataloader):
# 未添加噪声的数据(clean data)
clean_data = batch['images'].to(device) # 生成噪声
noise = torch.randn(clean_data.shape).to(device)
bs = clean_data.shape[0] # 为每张图片随机采样一个时间步
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bs, ), device=device).long() # 噪声数据
# 根据每个时间步的噪声幅度(迭代次数),向清晰的图片中添加噪声
noisy_data = noise_scheduler.add_noise(clean_data, noise, timesteps) # 获得预测模型
pred_data = model(noisy_data, timesteps, return_dict=False)[0] # 计算损失
loss = loss_fn(pred_data, clean_data)
loss.backward()
losses.append(loss.item()) # 迭代模型参数
optimizer.step()
optimizer.zero_grad() if (epoch+1) % 5 == 0:
loss_last_epoch = sum(losses[-len(train_dataloader):]) / len(train_dataloader)
print(f"Epoch: {epoch+1}, loss: {loss_last_epoch}")
torch.save(model.state_dict(), 'save.pt')

 绘制损失图线

fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].plot(losses)
axs[1].plot(np.log(losses))
plt.show()

Diffusers实战的更多相关文章

  1. SSH实战 · 唯唯乐购项目(上)

    前台需求分析 一:用户模块 注册 前台JS校验 使用AJAX完成对用户名(邮箱)的异步校验 后台Struts2校验 验证码 发送激活邮件 将用户信息存入到数据库 激活 点击激活邮件中的链接完成激活 根 ...

  2. GitHub实战系列汇总篇

    基础: 1.GitHub实战系列~1.环境部署+创建第一个文件 2015-12-9 http://www.cnblogs.com/dunitian/p/5034624.html 2.GitHub实战系 ...

  3. MySQL 系列(四)主从复制、备份恢复方案生产环境实战

    第一篇:MySQL 系列(一) 生产标准线上环境安装配置案例及棘手问题解决 第二篇:MySQL 系列(二) 你不知道的数据库操作 第三篇:MySQL 系列(三)你不知道的 视图.触发器.存储过程.函数 ...

  4. Asp.Net Core 项目实战之权限管理系统(4) 依赖注入、仓储、服务的多项目分层实现

    0 Asp.Net Core 项目实战之权限管理系统(0) 无中生有 1 Asp.Net Core 项目实战之权限管理系统(1) 使用AdminLTE搭建前端 2 Asp.Net Core 项目实战之 ...

  5. 给缺少Python项目实战经验的人

    我们在学习过程中最容易犯的一个错误就是:看的多动手的少,特别是对于一些项目的开发学习就更少了! 没有一个完整的项目开发过程,是不会对整个开发流程以及理论知识有牢固的认知的,对于怎样将所学的理论知识应用 ...

  6. asp.net core 实战之 redis 负载均衡和"高可用"实现

    1.概述 分布式系统缓存已经变得不可或缺,本文主要阐述如何实现redis主从复制集群的负载均衡,以及 redis的"高可用"实现, 呵呵双引号的"高可用"并不是 ...

  7. Linux实战教学笔记08:Linux 文件的属性(上半部分)

    第八节 Linux 文件的属性(上半部分) 标签(空格分隔):Linux实战教学笔记 第1章 Linux中的文件 1.1 文件属性概述(ls -lhi) linux里一切皆文件 Linux系统中的文件 ...

  8. Linux实战教学笔记07:Linux系统目录结构介绍

    第七节 Linux系统目录结构介绍 标签(空格分隔):Linux实战教学笔记 第1章 前言 windows目录结构 C:\windows D:\Program Files E:\你懂的\精品 F:\你 ...

  9. Linux实战教学笔记06:Linux系统基础优化

    第六节 Linux系统基础优化 标签(空格分隔):Linux实战教学笔记-陈思齐 第1章 基础环境 第2章 使用网易163镜像做yum源 默认国外的yum源速度很慢,所以换成国内的. 第一步:先备份 ...

  10. Linux实战教学笔记05:远程SSH连接服务与基本排错(新手扫盲篇)

    第五节 远程SSH连接服务与基本排错 标签(空格分隔):Linux实战教学笔记-陈思齐 第1章 远程连接LInux系统管理 1.1 为什么要远程连接Linux系统 在实际的工作场景中,虚拟机界面或物理 ...

随机推荐

  1. Vineyard 加入 CNCF Sandbox,将继续瞄准云原生大数据分析领域

    简介: Vineyard 是一个专为云原生环境下大数据分析场景中端到端工作流提供内存数据共享的分布式引擎,我们很高兴宣布 Vineyard 在 2021 年 4 月 27 日被云原生基金会(CNCF) ...

  2. [FE] uni-app 动态改变 navigationBarTitleText 导航标题

    改导航文字: uni.setNavigationBarTitle({ title: 'xx' }); 改 tabBar 文字: uni.setTabBarItem({ index: 0, text: ...

  3. Ubuntu WSL 下编译并使用OpenJDK12

    一,安装Ubuntu WSL 1.Windows中设置WSL并安装Ubuntu wsl "控制面板"-->"程序"-->"启用或关闭Win ...

  4. 指定Task任务顺序执行

    经常听到说线程池这个东西,凭印象写了个这么简单的例子. CusTRun方法要不要await,取决于要不要作为后台任务.任务可指定数量,线程参数可共享全,顺序可控,可继续改进. using System ...

  5. DB2查找最耗时SQL

    两种方法:db2top和snapshot for dynamic sql 1. db2top -d <dbname>

  6. kali 的 vim 中不能粘贴复制

    kali 的 vim 中不能粘贴复制 进入 vim 命令行模式,输入 :set mouse=c 之后可以正常粘贴复制

  7. grads 同时读取多个ctl文件方法

    1.不同的文件进行不同的设置:'set dfile 2' 2.读取不同文件的变量:qv.2 实例如下:'reinit''open e:\tskt.CTL''open e:\uwnd.CTL''open ...

  8. .NET实现获取NTP服务器时间并同步(附带Windows系统启用NTP服务功能)

    对某个远程服务器启用和设置NTP服务(Windows系统) 打开注册表 HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Services\W32Time\Tim ...

  9. Excel功能学习

    字符串和单元格内容拼接函数CONCATENATE a@马踏星空:=CONCATENATE(D2,E2,F2)拼接指定单元格内字符串,无分隔符 a@马踏星空:=CONCATENATE(I4," ...

  10. 关于.net Core在华为云的鲲鹏服务器上部署的细节纪要

    由于鲲鹏使用的是ARM的cpu,,非x86的,我们公司买的是Centos,,由于需要在上面部署.net core 3.0/3.1的应用,,在按照官方的文章进行部署之后,会提示 FailFast: Co ...