Jittor实现Conditional GAN
Jittor实现Conditional GAN
Generative Adversarial Nets(GAN)提出了一种新的方法来训练生成模型。然而,GAN对于要生成的图片缺少控制。Conditional GAN(CGAN)通过添加显式的条件或标签,来控制生成的图像。本文讲解了CGAN的网络结构、损失函数设计、使用CGAN生成一串数字、从头训练CGAN、以及在mnist手写数字数据集上的训练结果。
CGAN网络架构
通过在生成器generator和判别器discriminator中添加相同的额外信息y,GAN就可以扩展为一个conditional模型。y可以是任何形式的辅助信息,例如类别标签或者其他形式的数据。可以通过将y作为额外输入层,添加到生成器和判别器来完成条件控制。
在生成器generator中,除了y之外,还额外输入随机一维噪声z,为结果生成提供更多灵活性。

损失函数
GAN的损失函数
在解释CGAN的损失函数之前,首先介绍GAN的损失函数。下面是GAN的损失函数设计。

对于判别器D,要训练最大化这个loss。如果D的输入是来自真实样本的数据x,则D的输出D(x)要尽可能地大,log(D(x))也会尽可能大。如果D的输入是来自G生成的假图片G(z),则D的输出D(G(z))应尽可能地小,从而log(1-D(G(z))会尽可能地大。这样可以达到max D的目的。
对于生成器G,要训练最小化这个loss。对于G生成的假图片G(z),希望尽可能地骗过D,让它觉得生成的图片就是真的图片,这样就达到了G“以假乱真”的目的。那么D的输出D(G(z))应尽可能地大,从而log(1-D(G(z))会尽可能地小。这样可以达到min G的目的。
D和G以这样的方式联合训练,最终达到G的生成能力越来越强,D的判别能力越来越强的目的。
CGAN的损失函数
下面是CGAN的损失函数设计。

很明显,CGAN的loss跟GAN的loss的区别就是多了条件限定y。D(x/y)代表在条件y下,x为真的概率。D(G(z/y))表示在条件y下,G生成的图片被D判别为真的概率。
Jittor代码数字生成
首先,导入需要的包,并且设置好所需的超参数:
import jittor as jt
from jittor import nn
import numpy as np
import pylab as pl
%matplotlib inline
# 隐空间向量长度
latent_dim = 100
# 类别数量
n_classes = 10
# 图片大小
img_size = 32
# 图片通道数量
channels = 1
# 图片张量的形状
img_shape = (channels, img_size, img_size)
第一步,定义生成器G。该生成器输入两个一维向量y和noise,生成一张图片。
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.label_emb = nn.Embedding(n_classes, n_classes)
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2))
return layers
self.model = nn.Sequential(
*block((latent_dim + n_classes), 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh())
def execute(self, noise, labels):
gen_input = jt.contrib.concat((self.label_emb(labels), noise), dim=1)
img = self.model(gen_input)
img = img.view((img.shape[0], *img_shape))
return img
第二步,定义判别器D。D输入一张图片和对应的y,输出是真图片的概率。
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.label_embedding = nn.Embedding(n_classes, n_classes)
self.model = nn.Sequential(
nn.Linear((n_classes + int(np.prod(img_shape))), 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2),
nn.Linear(512, 1))
def execute(self, img, labels):
d_in = jt.contrib.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
validity = self.model(d_in)
return validity
第三步,使用CGAN生成一串数字。
代码如下。可以使用训练好的模型来生成图片,也可以使用提供的预训练参数: 模型预训练参数下载:https://cloud.tsinghua.edu.cn/d/fbe30ae0967942f6991c/。
# 下载提供的预训练参数
!wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/generator_last.pkl
!wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/discriminator_last.pkl
生成自定义的数字:
# 定义模型
generator = Generator()
discriminator = Discriminator()
generator.eval()
discriminator.eval()
# 加载参数
generator.load('./generator_last.pkl')
discriminator.load('./discriminator_last.pkl')
# 定义一串数字
number = "201962517"
n_row = len(number)
z = jt.array(np.random.normal(0, 1, (n_row, latent_dim))).float32().stop_grad()
labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad()
gen_imgs = generator(z,labels)
pl.imshow(gen_imgs.data.transpose((1,2,0,3))[0].reshape((gen_imgs.shape[2], -1)))
生成结果如下,测试的完整代码在https://github.com/Jittor/gan-jittor/blob/master/models/cgan/test.py。

从头训练Condition GAN
从头训练 Condition GAN 的完整代码在https://github.com/Jittor/gan-jittor/blob/master/models/cgan/cgan.py,下载下来看看!
!wget https://raw.githubusercontent.com/Jittor/gan-jittor/master/models/cgan/cgan.py
!python3.7 ./cgan.py --help
# 选择合适的batch size,运行试试
# 运行命令: !python3.7 ./cgan.py --batch_size 8
下载下来的代码里面定义损失函数、数据集、优化器。损失函数采用MSELoss、数据集采用MNIST、优化器采用Adam 如下(此段代码仅仅用于解释意图,不能运行,需要运行请运行完整文件cgan.py):
# 此段代码仅仅用于解释意图,不能运行,需要运行请运行完整文件cgan.py
# Define Loss
adversarial_loss = nn.MSELoss()
# Define Model
generator = Generator()
discriminator = Discriminator()
# Define Dataloader
from jittor.dataset.mnist import MNIST
import jittor.transform as transform
transform = transform.Compose([
transform.Resize(opt.img_size),
transform.Gray(),
transform.ImageNormalize(mean=[0.5], std=[0.5]),
])
dataloader = MNIST(train=True, transform=transform).set_attrs(batch_size=opt.batch_size, shuffle=True)
optimizer_G = nn.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = nn.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
模型训练的代码如下(此段代码仅仅用于解释意图,不能运行,需要运行请运行完整文件cgan.py):
# 此段代码仅仅用于解释意图,不能运行,需要运行请运行完整文件cgan.py
# valid表示真,fake表示假
valid = jt.ones([batch_size, 1]).float32().stop_grad()
fake = jt.zeros([batch_size, 1]).float32().stop_grad()
# 真实图像和对应的标签
real_imgs = jt.array(imgs)
labels = jt.array(labels)
#########################################################
# 训练生成器G
# - 希望生成的图片尽可能地让D觉得是valid
#########################################################
# 随机向量z和随机生成的标签
z = jt.array(np.random.normal(0, 1, (batch_size, opt.latent_dim))).float32()
gen_labels = jt.array(np.random.randint(0, opt.n_classes, batch_size)).float32()
# 随机向量z和随机生成的标签经过生成器G生成的图片,希望判别器能够认为生成的图片和生成的标签是一致的,以此优化生成器G的生成能力。
gen_imgs = generator(z, gen_labels)
validity = discriminator(gen_imgs, gen_labels)
g_loss = adversarial_loss(validity, valid)
g_loss.sync()
optimizer_G.step(g_loss)
#########################################################
# 训练判别器D
# - 尽可能识别real_imgs为valid
# - 尽可能识别gen_imgs为fake
#########################################################
# 真实的图片和标签经过判别器的结果,要尽可能接近valid。
validity_real = discriminator(real_imgs, labels)
d_real_loss = adversarial_loss(validity_real, valid)
# G生成的图片和对应的标签经过判别器的结果,要尽可能接近fake。
validity_fake = discriminator(gen_imgs.stop_grad(), gen_labels)
d_fake_loss = adversarial_loss(validity_fake, fake)
d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.sync()
optimizer_D.step(d_loss)
MNIST数据集训练结果
下面展示了Jittor版CGAN在MNIST数据集的训练结果。下面分别是训练0 epoch和90 epoches的结果。


Jittor实现Conditional GAN的更多相关文章
- AI Conditional GAN
Conditional GAN 参考链接: https://arxiv.org/pdf/1611.07004v1.pdf
- 论文笔记:Towards Diverse and Natural Image Descriptions via a Conditional GAN
论文笔记:Towards Diverse and Natural Image Descriptions via a Conditional GAN ICCV 2017 Paper: http://op ...
- pytorch conditional GAN 调试笔记
推荐的几个开源实现 znxlwm 使用InfoGAN的结构,卷积反卷积 eriklindernoren 把mnist转成1维,label用了embedding wiseodd 直接从tensorflo ...
- 用GAN生成二维样本的小例子
同步自我的知乎专栏:https://zhuanlan.zhihu.com/p/27343585 本文完整代码地址:Generative Adversarial Networks (GANs) with ...
- 提高驾驶技术:用GAN去除(爱情)动作片中的马赛克和衣服
同步自我的知乎专栏:https://zhuanlan.zhihu.com/p/27199954 作为一名久经片场的老司机,早就想写一些探讨驾驶技术的文章.这篇就介绍利用生成式对抗网络(GAN)的两个基 ...
- GAN︱生成模型学习笔记(运行机制、NLP结合难点、应用案例、相关Paper)
我对GAN"生成对抗网络"(Generative Adversarial Networks)的看法: 前几天在公开课听了新加坡国立大学[机器学习与视觉实验室]负责人冯佳时博士在[硬 ...
- [ZZ] Valse 2017 | 生成对抗网络(GAN)研究年度进展评述
Valse 2017 | 生成对抗网络(GAN)研究年度进展评述 https://www.leiphone.com/news/201704/fcG0rTSZWqgI31eY.html?viewType ...
- Improved GAN
https://www.bilibili.com/video/av9770302/?p=16 从之前讲的basic gan延伸到unified framework,到WGAN 再到通过WGAN进行Ge ...
- GAN (Generative Adversarial Network)
https://www.bilibili.com/video/av9770302/?p=15 前面说了auto-encoder,VAE可以用于生成 VAE的问题, AE的训练是让输入输出尽可能的接近, ...
随机推荐
- POJ2771最大独立集元素个数
题意: 女生和男生之间只要满足四个条件中的一个,那么两个人就不会在一起!然后给出一些男生和女生,问最多多少人一起做活动彼此不会产生暧昧关系. 思路: 这样的问题还是比较裸的问法 ...
- 插入排序——Python实现
插入排序Python实现 # -*- coding: utf-8 -*- # @Time : 2019/10/28 20:47 # @Author : yuzhou_1shu # @Email : y ...
- 【python】Leetcode每日一题-设计停车系统
[python]Leetcode每日一题-设计停车系统 [题目描述] 请你给一个停车场设计一个停车系统.停车场总共有三种不同大小的车位:大,中和小,每种尺寸分别有固定数目的车位. 请你实现 Parki ...
- 联想R720Y空间问题
由于之前Y空间在启动项中,所以将他关闭,这次想找到他却找不到 备注:因为在解决问题前,没有把图片保存下来,所以下面用一个颜色框挡住,表示之前的效果 第一个问题 在电脑上找到Y空间 百度上很多说在开始中 ...
- lombok,Invalid byte tag in constant pool: 19
今天偶到一个奇怪的问题: 三台生产服务器部署同样的代码,同样的tomcat ,jdk等环境. 其中有一台服务器启动时报lombok-1.18.6.jar! Invalid byte tag in ...
- 数据人必读!玩转数据可视化用这个就够了——高德LOCA API 2.0升级来袭!
引言 "一图胜千言",大数据时代来临,数据与人们生活密切相关.复杂难懂且体量庞大的数据给人的感觉总是冷冰冰的,让人难以获取到重点信息,也找不出规律和特征,数据价值发挥不出来.空间数 ...
- HashSet添加操作底层判读(Object类型)
Object类型添加操作判读 第一步:程序首先创建一个Object泛型的Set数组,这里用到了上转型: 第二步:执行object里面的add添加方法,传进的值为"JAVA": 首先 ...
- 【odoo】[经验分享]数据迁移注意事项
[odoo14]经典好书学习没有烂尾,主体已完成,可移步了解.https://www.cnblogs.com/xushuotec/p/14428210.html 背景 近期,有朋友打算上odoo系统. ...
- 精选Hive高频面试题11道,附答案详细解析(好文收藏)
1. hive内部表和外部表的区别 未被external修饰的是内部表,被external修饰的为外部表. 区别: 内部表数据由Hive自身管理,外部表数据由HDFS管理: 内部表数据存储的位置是hi ...
- QQ账号登录测试用例