基于Jittor框架实现LSGAN图像生成对抗网络
基于Jittor框架实现LSGAN图像生成对抗网络
生成对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。GAN模型由生成器(Generator)和判别器(Discriminator)两个部分组成。在训练过程中,生成器的目标就是尽量生成真实的图片去欺骗判别器。而判别器的目标就是尽量把生成器生成的图片和真实的图片分别开来。这样,生成器和判别器构成了一个动态的“博弈过程”。许多相关的研究工作表明GAN能够产生效果非常真实的生成效果。
使用Jittor框架实现了一种经典GAN模型LSGAN。LSGAN将GAN的目标函数由交叉熵损失替换成最小二乘损失,以此拒绝了标准GAN生成的图片质量不高以及训练过程不稳定这两个缺陷。通过LSGAN的实现介绍了Jittor数据加载、模型定义、模型训练的使用方法。
LSGAN论文:https://arxiv.org/abs/1611.04076
1.数据集准备
使用两种数据集进行LSGAN的训练,分别是Jittor自带的数据集MNIST,和用户构建的数据集CelebA。您可以通过以下链接下载CelebA数据集。
- CelebA 数据集: http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
使用Jittor自带的MNIST数据加载器方法如下。使用jittor.transform可以进行数据归一化及数据增强,这里通过transform将图片归一化到[0,1]区间,并resize到标准大小112*112。。通过set_attrs函数可以修改数据集的相关参数,如batch_size、shuffle及transform等。
fromjittor.dataset.mnistimportMNIST
importjittor.transformastransform
transform=transform.Compose([
transform.Resize(size=img_size),
transform.ImageNormalize(mean=[0.5],std=[0.5])
])
train_loader=MNIST(train=True,transform=transform)
.set_attrs(batch_size=batch_size,shuffle=True)
val_loader=MNIST(train=False,transform=transform)
.set_attrs(batch_size=1,shuffle=True)
使用用户构建的CelebA数据集方法如下,通过通用数据加载器jittor.dataset.dataset.ImageFolder,输入数据集路径即可构建用户数据集。
fromjittor.dataset.datasetimportImageFolder
importjittor.transformastransform
transform=transform.Compose([
transform.Resize(size=img_size),
transform.ImageNormalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])
])
train_dir='./data/celebA_train'
train_loader=ImageFolder(train_dir)
.set_attrs(transform=transform,batch_size=batch_size,shuffle=True)
val_dir='./data/celebA_eval'
val_loader=ImageFolder(val_dir)
.set_attrs(transform=transform,batch_size=1,shuffle=True)
2.模型定义
2.1.网络结构
使用LSGAN进行图像生成,下图为LSGAN论文给出的网络架构图,其中(a)为生成器,(b)为判别器。生成器网络输入一个1024维的向量,生成分辨率为112*112的图像;判别器网络输入112*112的图像,输出一个数字表示输入图像为真实图像的可信程度。
受到VGG模型的启发,生成器在与DCGAN的结构基础上在前两个反卷积层之后增加了两个步长=1的反卷积层。除使用最小二乘损失函数外判别器的结构与DCGAN中的结构相同。与DCGAN相同,生成器和判别器分别使用了ReLU激活函数和LeakyReLU激活函数。

下面将介绍如何使用Jittor定义一个网络模型。定义模型需要继承基类jittor.Module,并实现__init__和execute函数。__init__函数在模型声明时会被调用,用于进行模型内部op或其他模型的声明及参数的初始化。该模型初始化时输入参数dim表示训练图像的通道数,对于MNIST数据集dim为1,对于CelebA数据集dim为3。
execute函数在网络前向传播时会被调用,用于定义前向传播的计算图,通过autograd机制在训练时Jittor会自动构建反向计算图。
importjittorasjt
fromjittorimportnn,Module
classgenerator(Module):
def__init__(self,dim=3):
super(generator,self).__init__()
self.fc=nn.Linear(1024,7*7*256)
self.fc_bn=nn.BatchNorm(256)
self.deconv1=nn.ConvTranspose(256,256,3,2,1,1)
self.deconv1_bn=nn.BatchNorm(256)
self.deconv2=nn.ConvTranspose(256,256,3,1,1)
self.deconv2_bn=nn.BatchNorm(256)
self.deconv3=nn.ConvTranspose(256,256,3,2,1,1)
self.deconv3_bn=nn.BatchNorm(256)
self.deconv4=nn.ConvTranspose(256,256,3,1,1)
self.deconv4_bn=nn.BatchNorm(256)
self.deconv5=nn.ConvTranspose(256,128,3,2,1,1)
self.deconv5_bn=nn.BatchNorm(128)
self.deconv6=nn.ConvTranspose(128,64,3,2,1,1)
self.deconv6_bn=nn.BatchNorm(64)
self.deconv7=nn.ConvTranspose(64,dim,3,1,1)
self.relu=nn.ReLU()
self.tanh=nn.Tanh()
defexecute(self,input):
x=self.fc_bn(self.fc(input).reshape((input.shape[0],256,7,7)))
x=self.relu(self.deconv1_bn(self.deconv1(x)))
x=self.relu(self.deconv2_bn(self.deconv2(x)))
x=self.relu(self.deconv3_bn(self.deconv3(x)))
x=self.relu(self.deconv4_bn(self.deconv4(x)))
x=self.relu(self.deconv5_bn(self.deconv5(x)))
x=self.relu(self.deconv6_bn(self.deconv6(x)))
x=self.tanh(self.deconv7(x))
returnx
classdiscriminator(nn.Module):
def__init__(self,dim=3):
super(discriminator,self).__init__()
self.conv1=nn.Conv(dim,64,5,2,2)
self.conv2=nn.Conv(64,128,5,2,2)
self.conv2_bn=nn.BatchNorm(128)
self.conv3=nn.Conv(128,256,5,2,2)
self.conv3_bn=nn.BatchNorm(256)
self.conv4=nn.Conv(256,512,5,2,2)
self.conv4_bn=nn.BatchNorm(512)
self.fc=nn.Linear(512*7*7,1)
self.leaky_relu=nn.Leaky_relu()
defexecute(self,input):
x=self.leaky_relu(self.conv1(input),0.2)
x=self.leaky_relu(self.conv2_bn(self.conv2(x)),0.2)
x=self.leaky_relu(self.conv3_bn(self.conv3(x)),0.2)
x=self.leaky_relu(self.conv4_bn(self.conv4(x)),0.2)
x=x.reshape((x.shape[0],512*7*7))
x=self.fc(x)
returnx
2.2.损失函数
损失函数采用最小二乘损失函数,其中判别器损失函数如下。其中x为真实图像,z为服从正态分布的1024维向量,a取值为1,b取值为0。

生成器损失函数如下。其中z为服从正态分布的1024维向量,c取值为1。

具体实现如下,x为生成器的输出值,b表示该图像是否希望被判别为真。
defls_loss(x,b):
mini_batch=x.shape[0]
y_real_=jt.ones((mini_batch,))
y_fake_=jt.zeros((mini_batch,))
ifb:
return(x-y_real_).sqr().mean()
else:
return(x-y_fake_).sqr().mean()
3.模型训练
3.1.参数设定
参数设定如下。
# 通过use_cuda设置在GPU上进行训练
jt.flags.use_cuda=1
# 批大小
batch_size=128
# 学习率
lr=0.0002
# 训练轮数
train_epoch=50
# 训练图像标准大小
img_size=112
# Adam优化器参数
betas=(0.5,0.999)
# 数据集图像通道数,MNIST为1,CelebA为3
dim=1iftask=="MNIST"else3
3.2.模型、优化器声明
分别声明生成器和判别器,并使用Adam作为优化器。
# 生成器
G=generator(dim)
# 判别器
D=discriminator(dim)
# 生成器优化器
G_optim=nn.Adam(G.parameters(),lr,betas=betas)
# 判别器优化器
D_optim=nn.Adam(D.parameters(),lr,betas=betas)
3.3.训练
forepochinrange(train_epoch):
forbatch_idx,(x_,target)inenumerate(train_loader):
mini_batch=x_.shape[0]
# 判别器训练
D_result=D(sx)
D_real_loss=ls_loss(D_result,True)
z_=init.gauss((mini_batch,1024),'float')
G_result=G(z_)
D_result_=D(G_result)
D_fake_loss=ls_loss(D_result_,False)
D_train_loss=D_real_loss+D_fake_loss
D_train_loss.sync()
D_optim.step(D_train_loss)
# 生成器训练
z_=init.gauss((mini_batch,1024),'float')
G_result=G(z_)
D_result=D(G_result)
G_train_loss=ls_loss(D_result,True)
G_train_loss.sync()
G_optim.step(G_train_loss)
if(batch_idx%100==0):
print('D training loss =',D_train_loss.data.mean())
print('G training loss =',G_train_loss.data.mean())
4.结果与测试
4.1.生成结果
分别使用MNIST和CelebA数据集进行了50个epoch的训练。训练完成后各随机采样了25张图像,结果如下。


4.2.速度对比
使用Jittor与主流的深度学习框架PyTorch进行了训练速度的对比,下表为PyTorch(是/否打开benchmark)及Jittor在两种数据集上进行1次训练迭带的使用时间。得益于Jittor特有的元算子融合技术,其训练速度比PyTorch快了40%~55%。

基于Jittor框架实现LSGAN图像生成对抗网络的更多相关文章
- 深度学习框架PyTorch一书的学习-第七章-生成对抗网络(GAN)
参考:https://github.com/chenyuntc/pytorch-book/tree/v1.0/chapter7-GAN生成动漫头像 GAN解决了非监督学习中的著名问题:给定一批样本,训 ...
- AI佳作解读系列(六) - 生成对抗网络(GAN)综述精华
注:本文来自机器之心的PaperWeekly系列:万字综述之生成对抗网络(GAN),如有侵权,请联系删除,谢谢! 前阵子学习 GAN 的过程发现现在的 GAN 综述文章大都是 2016 年 Ian G ...
- [ZZ] Valse 2017 | 生成对抗网络(GAN)研究年度进展评述
Valse 2017 | 生成对抗网络(GAN)研究年度进展评述 https://www.leiphone.com/news/201704/fcG0rTSZWqgI31eY.html?viewType ...
- 不到 200 行代码,教你如何用 Keras 搭建生成对抗网络(GAN)【转】
本文转载自:https://www.leiphone.com/news/201703/Y5vnDSV9uIJIQzQm.html 生成对抗网络(Generative Adversarial Netwo ...
- 渐进结构—条件生成对抗网络(PSGAN)
Full-body High-resolution Anime Generation with Progressive Structure-conditional Generative Adversa ...
- 知物由学 | AI网络安全实战:生成对抗网络
本文由 网易云发布. “知物由学”是网易云易盾打造的一个品牌栏目,词语出自汉·王充<论衡·实知>.人,能力有高下之分,学习才知道事物的道理,而后才有智慧,不去求问就不会知道.“知物由学” ...
- 生成对抗网络(Generative Adversarial Networks,GAN)初探
1. 从纳什均衡(Nash equilibrium)说起 我们先来看看纳什均衡的经济学定义: 所谓纳什均衡,指的是参与人的这样一种策略组合,在该策略组合上,任何参与人单独改变策略都不会得到好处.换句话 ...
- 【超分辨率】—(ESRGAN)增强型超分辨率生成对抗网络-解读与实现
一.文献解读 我们知道GAN 在图像修复时更容易得到符合视觉上效果更好的图像,今天要介绍的这篇文章——ESRGAN: Enhanced Super-Resolution Generative Adve ...
- 人工智能中小样本问题相关的系列模型演变及学习笔记(二):生成对抗网络 GAN
[说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![握手][握手] [再啰嗦一下]本文衔接上一个随笔:人工智能中小样本问题相关的系列模型演变及学习 ...
随机推荐
- sed高级指令
N命令 n命令 n命令简单来说就是提前读取下一行,覆盖模型空间前一行,然后执行后续命令.然后再读取新行,对新读取的内容重头执行sed //从test文件中取出偶数行 [root@localhost ~ ...
- POJ2553 强连通出度为0的应用
题意: 给你一个有向图,然后问你有多少个满足要求的点,要求是: 这个点能走到的所有点都能走回这个点,找到所有的这样的点,然后排序输出. 思路: 可以直接一遍强连通缩点,所点之后 ...
- Windows中动态磁盘管理
目录 动态磁盘 基本磁盘和动态磁盘的转换 简单卷 跨区卷 带区卷 镜像卷 RAID-5卷 相关文章:硬盘分区形式(MBR.GPT).系统引导.文件系统.Inode和Block 动态磁盘 Windows ...
- 子域名探测工具Aquatone的使用
目录 Aquatone Aquatone的安装 Aquatone的使用 子域名爆破 端口扫描
- 神经网络与机器学习 笔记—卷积神经网络(CNN)
卷积神经网络 之前的一些都是考虑多层感知器算法设计相关的问题,这次是说一个多层感知器结构布局相关的问题.来总结卷积神经网络.对于模式分类非常合适.网络的提出所隐含的思想收到了神经生物学的启发. 第一个 ...
- IDEA 这样设置,好看到爆炸!!!
Hello,大家好,我是楼下小黑哥. 今天这篇文章是次条视频的文案,这里推荐大家直接看视频学习. IDEA 这样设置,好看到爆炸!!!#01 今天这期我们来分享几个美化 IDEA 设置技巧,让你的 I ...
- Linux安装Redis报错`cc:命令未找到`
缺少gcc和gcc-c++的编译环境,安装即可. 可以联网情况下使用命令 yum install gcc yum install gcc-c++ 然后清理原来的残余文件 make distclean ...
- 2020 ICPC EC Final西安现场赛游记
也不知道从何说起,也不知道会说些什么,最想表达的就是很累很累. 从第一天去的时候满怀希望,没什么感觉甚至还有一些兴奋.到后来一直在赶路,感觉很疲惫,热身赛的时候觉得马马虎虎,导致热身赛被咕.然后教练就 ...
- 03.28,周六,12:00-17:00,ICPC训练联盟周赛,选用试题:UCF Local Programming Contest 2016正式赛。
A. Majestic 10 题意:三个数均大于10则输出"triple-double",如果两个数大于10则输出"double-double",如果一个大于1 ...
- Linux的三剑客
首先,需要介绍一下管道和正则表达式,因为它经常和Linux三剑客一起使用. 一.管道Linux 提供管道符"|",将两个命令隔开,管道符左边命令的输出作为管道符右边命令的输入. c ...