Smiling & Weeping

                ---- 一生拥有自由和爱,是我全部的野心

1. 环境准备

%pip install diffusers
from huggingface_hub import notebook_login

# 登录huggingface
notebook_login()
import numpy as np
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
import torchvision
from PIL import Image

def show_images(x):
"""给定一批图像,创建一个网格并将其转换成PIL"""
x = x*0.5 + 0.5
grid = torchvision.utils.make_grid(x)
grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1)*255
grad_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
return grad_im

def make_grid(images, size=64):
"""给定一个PIL图像列表,将他们叠加成一行以便查看"""
output_im = Image.new("RGB", (size*len(images), size))
for i, im in enumerate(images):
out_im.paste(im.resize((size, size)), (i*size, 0))
return output_im

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
from diffusers import DDPMPipeline, StableDiffusionPipeline

model_id = "sd-dreambooth-library/mr-potato-head"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
prompt = "a cute anime characters using 8K resolution"
image = pipe(prompt, num_inference_steps=50, guidance_scale=5.5).images[0]
image

Diffusers核心API:

  • 管线:从高层次设计的多种类函数,便于部署的方式实现,能够快速利用预训练的主流扩散模型来生成样本。
  • 模型:在训练新的扩散模型时需要用到的网络结构。
  • 调度器:在推理过程中使用多种不同的技巧来从噪声中生成图像,同时可以生成训练过程中所需的“带噪”图像。
import torchvision
from datasets import load_dataset
from torchvision import transforms
from diffusers import DDPMScheduler
from diffusers import DDPMPipeline, StableDiffusionPipeline dataset = load_dataset('lowres/anime', split="train") image_size = 256
batch_size = 8 preprocess = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.RandomHorizontalFlip(),
transforms.Normalize([0.5], [0.5]),
]) def transform(examples):
images = [preprocess(image.convert("RGB")) for image in examples["image"]]
return {"images": images} dataset.set_transform(transform)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

  

xb = next(iter(train_dataloader))['images'].to(device)[:8]
print("X shape:", xb.shape)
show_images(xb).resize((8*256, 256), resample=Image.NEAREST)

# 定义调度器
from diffusers import DDPMScheduler noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_start=0.001, beta_end=0.004)
timesteps = torch.linspace(0, 999, 8).long().to(device)
noise = torch.rand_like(xb)
noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)
print("Noise X Shape", noisy_xb.shape)
show_images(noisy_xb).resize((8*64, 64), resample=Image.NEAREST)

  

from diffusers import UNet2DModel

model = UNet2DModel(
sample_size=image_size, # 目标图像的分辨率
in_channels=3,
out_channels=3,
layers_per_block=2, # 每一个UNet块中的ResNet层数
block_out_channels=(64, 128, 128, 256),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D", # 带有空域维度的self-att的ResNet下采样模块
"AttnDownBlock2D",
),
up_block_types=(
"AttnUpBlock2D",
"AttnUpBlock2D", # 带有空域维度的self-att的ResNet上采样模块
"UpBlock2D",
"UpBlock2D",
),
) model = model.to(device)
with torch.no_grad():
model_pred = model(noisy_xb, timesteps).sample model_pred.shape

 训练



# 设定噪声调度器
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2") # 训练循环
optimizer = torch.optim.Adam(model.parameters(), lr=4e-4) losses = [] # 定义损失函数
loss_fn = torch.nn.MSELoss() for epoch in range(45):
for step, batch in enumerate(train_dataloader):
# 未添加噪声的数据(clean data)
clean_data = batch['images'].to(device) # 生成噪声
noise = torch.randn(clean_data.shape).to(device)
bs = clean_data.shape[0] # 为每张图片随机采样一个时间步
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bs, ), device=device).long() # 噪声数据
# 根据每个时间步的噪声幅度(迭代次数),向清晰的图片中添加噪声
noisy_data = noise_scheduler.add_noise(clean_data, noise, timesteps) # 获得预测模型
pred_data = model(noisy_data, timesteps, return_dict=False)[0] # 计算损失
loss = loss_fn(pred_data, clean_data)
loss.backward()
losses.append(loss.item()) # 迭代模型参数
optimizer.step()
optimizer.zero_grad() if (epoch+1) % 5 == 0:
loss_last_epoch = sum(losses[-len(train_dataloader):]) / len(train_dataloader)
print(f"Epoch: {epoch+1}, loss: {loss_last_epoch}")
torch.save(model.state_dict(), 'save.pt')

 绘制损失图线

fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].plot(losses)
axs[1].plot(np.log(losses))
plt.show()

Diffusers实战的更多相关文章

  1. SSH实战 · 唯唯乐购项目(上)

    前台需求分析 一:用户模块 注册 前台JS校验 使用AJAX完成对用户名(邮箱)的异步校验 后台Struts2校验 验证码 发送激活邮件 将用户信息存入到数据库 激活 点击激活邮件中的链接完成激活 根 ...

  2. GitHub实战系列汇总篇

    基础: 1.GitHub实战系列~1.环境部署+创建第一个文件 2015-12-9 http://www.cnblogs.com/dunitian/p/5034624.html 2.GitHub实战系 ...

  3. MySQL 系列(四)主从复制、备份恢复方案生产环境实战

    第一篇:MySQL 系列(一) 生产标准线上环境安装配置案例及棘手问题解决 第二篇:MySQL 系列(二) 你不知道的数据库操作 第三篇:MySQL 系列(三)你不知道的 视图.触发器.存储过程.函数 ...

  4. Asp.Net Core 项目实战之权限管理系统(4) 依赖注入、仓储、服务的多项目分层实现

    0 Asp.Net Core 项目实战之权限管理系统(0) 无中生有 1 Asp.Net Core 项目实战之权限管理系统(1) 使用AdminLTE搭建前端 2 Asp.Net Core 项目实战之 ...

  5. 给缺少Python项目实战经验的人

    我们在学习过程中最容易犯的一个错误就是:看的多动手的少,特别是对于一些项目的开发学习就更少了! 没有一个完整的项目开发过程,是不会对整个开发流程以及理论知识有牢固的认知的,对于怎样将所学的理论知识应用 ...

  6. asp.net core 实战之 redis 负载均衡和"高可用"实现

    1.概述 分布式系统缓存已经变得不可或缺,本文主要阐述如何实现redis主从复制集群的负载均衡,以及 redis的"高可用"实现, 呵呵双引号的"高可用"并不是 ...

  7. Linux实战教学笔记08:Linux 文件的属性(上半部分)

    第八节 Linux 文件的属性(上半部分) 标签(空格分隔):Linux实战教学笔记 第1章 Linux中的文件 1.1 文件属性概述(ls -lhi) linux里一切皆文件 Linux系统中的文件 ...

  8. Linux实战教学笔记07:Linux系统目录结构介绍

    第七节 Linux系统目录结构介绍 标签(空格分隔):Linux实战教学笔记 第1章 前言 windows目录结构 C:\windows D:\Program Files E:\你懂的\精品 F:\你 ...

  9. Linux实战教学笔记06:Linux系统基础优化

    第六节 Linux系统基础优化 标签(空格分隔):Linux实战教学笔记-陈思齐 第1章 基础环境 第2章 使用网易163镜像做yum源 默认国外的yum源速度很慢,所以换成国内的. 第一步:先备份 ...

  10. Linux实战教学笔记05:远程SSH连接服务与基本排错(新手扫盲篇)

    第五节 远程SSH连接服务与基本排错 标签(空格分隔):Linux实战教学笔记-陈思齐 第1章 远程连接LInux系统管理 1.1 为什么要远程连接Linux系统 在实际的工作场景中,虚拟机界面或物理 ...

随机推荐

  1. Gartner APM 魔力象限技术解读——全量存储? No! 按需存储?YES!

    简介: 在云原生时代,充分利用边缘节点的计算和存储能力,结合冷热数据分离实现高性价比的数据价值探索已经逐渐成为 APM 领域的主流. 作者:夏明(涯海) 调用链记录了完整的请求状态及流转信息,是一座巨 ...

  2. 深入解读 Flink SQL 1.13

    简介: Apache Flink 社区 5 月 22 日北京站 Meetup 分享内容整理,深入解读 Flink SQL 1.13 中 5 个 FLIP 的实用更新和重要改进. 本文由社区志愿者陈政羽 ...

  3. [FAQ] Goland 始终没有包代码的提示 ?

    表现:import 引入的包始终是红色的,表示没有找到引入的包. 注意,在这里开启Go Modules: 然后在 Exteneral Libraries 里看到 Go Modules 即可. Refe ...

  4. [FAQ] MetaMask ALERT: 交易出错. 合约代码执行异常.

    首先确认载入的合约地址是否是最新的,比如 web3 载入的 abi 格式的 json 文件名 正不正确. 其次需要检查合约逻辑是否都正确,以及是否是合约抛出的错误,这两点最好是通过写测试用例来保证. ...

  5. 关于Web的欢迎页面的开发设置

    关于Web的欢迎页面的开发设置 每博一文案 命运总是不如人愿.但往往是在无数的痛苦中,在重重的矛盾和艰辛中,才是人成熟起来. 一次邂逅,一次目光的交融,就是永远的合二为一,就是与上帝的契约:总是风暴雷 ...

  6. Linux系统命令-目录命令

    1.ls命令:主要作用是显示目录下的内容 基本格式 [root@localhost ~]# ls [选项] [参数是文件名或目录名] 常用选项 -a:显示所有文件 --color=when:支持颜色输 ...

  7. Ubuntu下MPICH的安装与配置

    原创直达链接 一.MPICH的下载与安装 MPI安装文件下载地址: 博客下载地址 或 官网地址 可以下载3.4.2版本的,本文就是3.4.2版本 1.解压: sudo tar - zxvf mpich ...

  8. mac更新nodejs

    查看本机node.js版本: node -v 清除node.js的cache:sudo npm cache clean -f 安装 n 工具:sudo npm install -g n 安装最新版本的 ...

  9. 使用 Python 旋转PDF页面、或调整PDF页面顺序

    在将纸质文档扫描成PDF电子文档时,有时可能会出现页面方向翻转或者页面顺序混乱的情况.为了确保更好地浏览和查看PDF文件,本文将分享一个使用Python来旋转PDF页面或者调整PDF页面顺序的解决方案 ...

  10. AIRIOT答疑第4期|如何使用数据分析引擎?

    灵活报表曲线,满足各类分析需求! AIRIOT物联网低代码平台的数据分析引擎满足各类型数据分类及分析需求,毫秒级数据反馈速度,快速响应客户分析条件变换查询需求.通过机器学习.融合各种计算模型.人工智能 ...