Conditional Generative Adversarial Nets
Mirza M, Osindero S. Conditional Generative Adversarial Nets.[J]. arXiv: Learning, 2014.
@article{mirza2014conditional,
title={Conditional Generative Adversarial Nets.},
author={Mirza, Mehdi and Osindero, Simon},
journal={arXiv: Learning},
year={2014}}
引
GAN (Generative Adversarial Nets) 能够通过隐变量\(z\)来生成一些数据, 但是我们没有办法去控制, 因为隐变量\(z\)是完全随机的. 这篇文章便很自然地提出了条件GAN,增加一个输入\(y\)(比如类别标签)去控制输出. 比如在MNIST数据集上, 我们随机采样一个\(z\), 并给定
\]
结果应当是数字2.
主要内容
文章的优化函数如下:

网络"结构"如下:



代码
"""
这个几乎就是照搬别人的代码
lr=0.0001,
epochs=50
但是10轮就差不多收敛了
"""
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import os
import matplotlib.pyplot as plt
class Generator(nn.Module):
"""
生成器
"""
def __init__(self, input_size=(100, 10), output_size=784):
super().__init__()
self.fc1 = nn.Sequential(
nn.Linear(input_size[0], 256),
nn.BatchNorm1d(256),
nn.ReLU()
)
self.fc2 = nn.Sequential(
nn.Linear(input_size[1], 256),
nn.BatchNorm1d(256),
nn.ReLU()
)
self.dense = nn.Sequential(
nn.Linear(512, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Linear(1024, output_size),
nn.Tanh()
)
def forward(self, z, y):
"""
:param z: 随机隐变量
:param y: 条件隐变量
:return:
"""
z = self.fc1(z)
y = self.fc2(y)
out = self.dense(
torch.cat((z, y), 1)
)
return out
class Discriminator(nn.Module):
def __init__(self, input_size=(784, 10)):
super().__init__()
self.fc1 = nn.Sequential(
nn.Linear(input_size[0], 1024),
nn.LeakyReLU(0.2)
)
self.fc2 = nn.Sequential(
nn.Linear(input_size[1], 1024),
nn.LeakyReLU(0.2)
)
self.dense = nn.Sequential(
nn.Linear(2048, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x, y):
x = self.fc1(x)
y = self.fc2(y)
out = self.dense(
torch.cat((x, y), 1)
)
return out
class Train:
def __init__(self, z_size=100, y_size=10, x_size=784,
criterion=nn.BCELoss(), lr=1e-4):
self.generator = Generator(input_size=(z_size, y_size), output_size=x_size)
self.discriminator = Discriminator(input_size=(x_size, y_size))
self.criterion = criterion
self.opti1 = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(0.5, 0.999))
self.opti2 = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
self.z_size = z_size
self.y_size = y_size
self.x_size = x_size
self.lr = lr
cpath = os.path.abspath('.')
self.gen_path = os.path.join(cpath, 'generator3.pt')
self.dis_path = os.path.join(cpath, 'discriminator3.pt')
self.imgspath = lambda i: os.path.join(cpath, 'image3', 'fig{0}'.format(i))
#self.loading()
def transform_y(self, labels):
return torch.eye(self.y_size)[labels]
def sampling_z(self, size):
return torch.randn(size)
def showimgs(self, imgs, order):
n = imgs.size(0)
imgs = imgs.data.view(n, 28, 28)
fig, axs = plt.subplots(10, 10)
for i in range(10):
for j in range(10):
axs[i, j].get_xaxis().set_visible(False)
axs[i, j].get_yaxis().set_visible(False)
for i in range(10):
for j in range(10):
t = i * 10 + j
img = imgs[t]
axs[i, j].cla()
axs[i, j].imshow(img.data.view(28, 28).numpy(), cmap='gray')
fig.savefig(self.imgspath(order))
for i in range(10):
for j in range(10):
t = i * 10 + j
img = imgs[t]
axs[i, j].cla()
axs[i, j].imshow(img.data.view(28, 28).numpy() / 2 + 0.5, cmap='gray')
fig.savefig(self.imgspath(order+1))
#plt.show()
#plt.cla()
def train(self, trainloader, epochs=50, classes=10):
order = 2
for epoch in range(epochs):
running_loss_d = 0.
running_loss_g = 0.
if (epoch + 1) % 5 is 0.:
self.opti1.param_groups[0]['lr'] /= 10
self.opti2.param_groups[0]['lr'] /= 10
print("learning rate change!")
if (epoch + 1) % order is 0.:
self.showimgs(fake_imgs, order=order)
self.showimgs(real_imgs, order=order+2)
order += 4
for i, data in enumerate(trainloader):
real_imgs, labels = data
real_imgs = real_imgs.view(real_imgs.size(0), -1)
y = self.transform_y(labels)
d_out = self.discriminator(real_imgs, y).squeeze()
z = self.sampling_z((y.size(0), self.z_size))
fake_y = self.transform_y(torch.randint(classes, size=(y.size(0),)))
fake_imgs = self.generator(z, fake_y).squeeze()
g_out = self.discriminator(fake_imgs, fake_y).squeeze()
# 训练判别器
loss1 = self.criterion(d_out, torch.ones_like(d_out))
loss2 = self.criterion(g_out, torch.zeros_like(g_out))
d_loss = loss1 + loss2
self.opti2.zero_grad()
d_loss.backward()
self.opti2.step()
# 训练生成器
z = self.sampling_z((y.size(0), self.z_size))
fake_y = self.transform_y(torch.randint(classes, size=(y.size(0),)))
fake_imgs = self.generator(z, fake_y).squeeze()
g_out = self.discriminator(fake_imgs, fake_y).squeeze()
g_loss = self.criterion(g_out, torch.ones_like(g_out))
self.opti1.zero_grad()
g_loss.backward()
self.opti1.step()
running_loss_d += d_loss
running_loss_g += g_loss
if i % 10 is 0 and i != 0:
print("[epoch {0:<d}: d_loss: {1:<5f} g_loss: {2:<5f}]".format(
epoch, running_loss_d / 10, running_loss_g / 10
))
running_loss_d = 0.
running_loss_g = 0.
torch.save(self.generator.state_dict(), self.gen_path)
torch.save(self.discriminator.state_dict(), self.dis_path)
def loading(self):
self.generator.load_state_dict(torch.load(self.gen_path))
self.generator.eval()
self.discriminator.load_state_dict(torch.load(self.dis_path))
self.discriminator.eval()
结果

此时判别器对这些图片进行判别, 但部分都是0.5以下, 也就是说这些基本上都被认为是伪造的图片.
"""
lr=0.001,
SGD,
网络结构简化了
"""
class Generator(nn.Module):
"""
生成器
"""
def __init__(self, input_size=(100, 10), output_size=784):
super().__init__()
self.fc1 = nn.Sequential(
nn.Linear(input_size[0], 128),
nn.BatchNorm1d(128),
nn.ReLU()
)
self.fc2 = nn.Sequential(
nn.Linear(input_size[1], 128),
nn.BatchNorm1d(128),
nn.ReLU()
)
self.dense = nn.Sequential(
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, output_size),
nn.BatchNorm1d(output_size),
nn.Tanh()
)
def forward(self, z, y):
"""
:param z: 随机隐变量
:param y: 条件隐变量
:return:
"""
z = self.fc1(z)
y = self.fc2(y)
out = self.dense(
torch.cat((z, y), 1)
)
return out
class Discriminator(nn.Module):
def __init__(self, input_size=(784, 10)):
super().__init__()
self.fc1 = nn.Sequential(
nn.Linear(input_size[0], 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2)
)
self.fc2 = nn.Sequential(
nn.Linear(input_size[1], 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2)
)
self.dense = nn.Sequential(
nn.Linear(2048, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1),
nn.Sigmoid()
)
def forward(self, x, y):
x = self.fc1(x)
y = self.fc2(y)
out = self.dense(
torch.cat((x, y), 1)
)
return out
class Train:
def __init__(self, z_size=100, y_size=10, x_size=784,
criterion=nn.BCELoss(), lr=1e-3, momentum=0.9):
self.generator = Generator(input_size=(z_size, y_size), output_size=x_size)
self.discriminator = Discriminator(input_size=(x_size, y_size))
self.criterion = criterion
self.opti1 = torch.optim.SGD(self.generator.parameters(), lr=lr, momentum=momentum)
self.opti2 = torch.optim.SGD(self.discriminator.parameters(), lr=lr, momentum=momentum)
self.z_size = z_size
self.y_size = y_size
self.x_size = x_size
self.lr = lr
cpath = os.path.abspath('.')
self.gen_path = os.path.join(cpath, 'generator2.pt')
self.dis_path = os.path.join(cpath, 'discriminator2.pt')
self.imgspath = lambda i: os.path.join(cpath, 'image', 'fig{0}'.format(i))
#self.loading()
def transform_y(self, labels):
return torch.eye(self.y_size)[labels]
def sampling_z(self, size):
return torch.randn(size)
def showimgs(self, imgs, order):
n = imgs.size(0)
imgs = imgs.data.view(n, 28, 28)
fig, axs = plt.subplots(10, 10)
for i in range(10):
for j in range(10):
axs[i, j].get_xaxis().set_visible(False)
axs[i, j].get_yaxis().set_visible(False)
for i in range(10):
for j in range(10):
t = i * 10 + j
img = imgs[t]
axs[i, j].cla()
axs[i, j].imshow(img.data.view(28, 28).numpy(), cmap='gray')
fig.savefig(self.imgspath(order))
def train(self, trainloader, epochs=5, classes=10):
order = 0
for epoch in range(epochs):
running_loss_d = 0.
running_loss_g = 0.
if (epoch + 1) % 5 is 0.:
self.opti1.param_groups[0]['lr'] /= 10
self.opti2.param_groups[0]['lr'] /= 10
print("learning rate change!")
for i, data in enumerate(trainloader):
real_imgs, labels = data
real_imgs = real_imgs.view(real_imgs.size(0), -1)
y = self.transform_y(labels)
d_out = self.discriminator(real_imgs, y).squeeze()
z = self.sampling_z((y.size(0), self.z_size))
fake_y = self.transform_y(torch.randint(classes, size=(y.size(0),)))
fake_imgs = self.generator(z, fake_y).squeeze()
g_out = self.discriminator(fake_imgs.detach(), fake_y).squeeze()
# 训练判别器
loss1 = self.criterion(d_out, torch.ones_like(d_out))
loss2 = self.criterion(g_out, torch.zeros_like(g_out))
d_loss = loss1 + loss2
self.opti2.zero_grad()
d_loss.backward()
self.opti2.step()
# 训练生成器
z = self.sampling_z((y.size(0), self.z_size))
fake_y = self.transform_y(torch.randint(classes, size=(y.size(0),)))
fake_imgs = self.generator(z, fake_y).squeeze()
g_out = self.discriminator(fake_imgs, fake_y).squeeze()
g_loss = self.criterion(g_out, torch.ones_like(g_out))
self.opti1.zero_grad()
g_loss.backward()
self.opti1.step()
running_loss_d += d_loss
running_loss_g += g_loss
if i % 10 is 0 and i != 0:
print("[epoch {0:<d}: d_loss: {1:<5f} g_loss: {2:<5f}]".format(
epoch, running_loss_d / 10, running_loss_g / 10
))
running_loss_d = 0.
running_loss_g = 0.
if (epoch + 1) % 2 is 0:
self.showimgs(fake_imgs, order=order)
order += 1
torch.save(self.generator.state_dict(), self.gen_path)
torch.save(self.discriminator.state_dict(), self.dis_path)
def loading(self):
self.generator.load_state_dict(torch.load(self.gen_path))
self.generator.eval()
self.discriminator.load_state_dict(torch.load(self.dis_path))
self.discriminator.eval()
结果, 不是特别好

SGD改成Adam之后的结果(50个epochs都训练完了, 结果居然有点好).

Conditional Generative Adversarial Nets的更多相关文章
- 论文笔记之:Conditional Generative Adversarial Nets
Conditional Generative Adversarial Nets arXiv 2014 本文是 GANs 的拓展,在产生 和 判别时,考虑到额外的条件 y,以进行更加"激烈 ...
- Generative Adversarial Nets[content]
0. Introduction 基于纳什平衡,零和游戏,最大最小策略等角度来作为GAN的引言 1. GAN GAN开山之作 图1.1 GAN的判别器和生成器的结构图及loss 2. Condition ...
- Generative Adversarial Nets[CAAE]
本文来自<Age Progression/Regression by Conditional Adversarial Autoencoder>,时间线为2017年2月. 该文很有意思,是如 ...
- Generative Adversarial Nets[pix2pix]
本文来自<Image-to-Image Translation with Conditional Adversarial Networks>,是Phillip Isola与朱俊彦等人的作品 ...
- GAN(Generative Adversarial Nets)的发展
GAN(Generative Adversarial Nets),产生式对抗网络 存在问题: 1.无法表示数据分布 2.速度慢 3.resolution太小,大了无语义信息 4.无reference ...
- (转)Deep Learning Research Review Week 1: Generative Adversarial Nets
Adit Deshpande CS Undergrad at UCLA ('19) Blog About Resume Deep Learning Research Review Week 1: Ge ...
- 论文笔记之:Generative Adversarial Nets
Generative Adversarial Nets NIPS 2014 摘要:本文通过对抗过程,提出了一种新的框架来预测产生式模型,我们同时训练两个模型:一个产生式模型 G,该模型可以抓住数据分 ...
- Generative Adversarial Nets[BEGAN]
本文来自<BEGAN: Boundary Equilibrium Generative Adversarial Networks>,时间线为2017年3月.是google的工作. 作者提出 ...
- Generative Adversarial Nets[CycleGAN]
本文来自<Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks>,时间线为2017 ...
随机推荐
- SQLITE_BTREE_H
sqlite3.c, 237436行 = 全部源文件,找东西比多文件查找方便多了:-),字符串查找一点都不慢. 不要太害怕,SQLite说它的代码里有非常多是用来做数据完整性检查和测试的.但愿B树,虚 ...
- 【Go语言学习笔记】包
包其实是每个大型工程都会使用的模块化工具. 将相关的代码封装成一个包,给其他项目调用,提供不同的功能. GO的设计是将一个文件夹看成一个包,虽然不一定非要用文件夹的名字,但是比较建议. 同一个文件夹下 ...
- windows磁盘扩容
要邻近的磁盘,才可以扩展.所以必须要先删除恢复分区. 删除恢复分区,参考如下: https://jingyan.baidu.com/article/574c5219598d5e6c8c9dc15e.h ...
- echarts饼图样式
1.中间标题字体大小不一致(可分为一个title一个graphic) 2.labelLine与饼图分离(两个饼图,其中一个显示一个隐藏) function setmyChartJsgxzq(arr,d ...
- 事务(@Transactional注解)的用法和实例
参数 @Transactional可以配制那些参数及以其所代表的意义: 参数 意义 isolation 事务隔离级别 propagation 事务传播机制 readOnly 事务读写性 noRollb ...
- 【Linux】【Services】【SaaS】Docker+kubernetes(5. 安装和配置ETCD集群)
1. 简介: 1.1. ETCD是kubernetes和openstack都用到的组件,需要首先装好 1.2. 官方网站:https://coreos.com/etcd/ 1.3. ETCD的作用: ...
- AOP中环绕通知的书写和配置
package com.hope.utils;import org.aspectj.lang.ProceedingJoinPoint;import org.aspectj.lang.annotatio ...
- 使用递归方法,遍历输出以.java结尾的文件
package cn.itcast.demo01;import java.io.File;/** * @author newcityman * @date 2019/7/27 - 19:17 * 题目 ...
- 带你揭开WebSerivce的面纱
最近在工作中遇到这样的一个项目(暂且定为项目A),项目A本身是用PHP开发的,但是其数据是来自于另一个使用java开发的项目(暂且定为项目B),项目A不能操作项目B的数据库,它有其自己的一套数据库,只 ...
- C++STL标准库学习笔记(一)sort
前言: 近来在学习STL标准库,做一份笔记并整理好,方便自己梳理知识.以后查找,也方便他人学习,两全其美,快哉快哉! 这里我会以中国大学慕课上北京大学郭炜老师的<程序设计与算法(一)C语言程序设 ...