Generative Adversarial Nets

这篇文章,引领了对抗学习的思想,更加可贵的是其中的理论证明,证明很少却直击要害.

目标

GAN,译名生成对抗网络,目的就是训练一个网络来拟合数据的分布,以前的方法,类似高斯核,Parzen窗等都可以用来估计(虽然不是很熟).

GAN有俩个网络,一个是G(z)生成网络,和D(x)判别网络, 其中\(z\)服从一个随机分布,而\(x\)是原始数据, \(z\)服从一个随机分布,是很重要的一点,假设\(\hat{x}=G(x)\), 则:

\[p(\hat{x})=\int p(z)I(G(z)=\hat{x})\mathrm{d}z
\]

其中\(I\)表示指示函数,这意味着,网络\(G\)也是一个分布,而我们所希望的,就是这个分布能够尽可能取拟合原始数据\(x\)的分布.

框架



GAN需要训练上面的俩个网络,D的输出是一个0~1的标量,其含义是输入的x是否为真实数据(真实为1), 故其损失函数为(V(D,G)部分):



在实际操作中,固定网络G更新网络D,再固定网络D更新网络G,反复迭代:

理论

至于为什么可以这么做,作者给出了精炼的证明.



上面的证明唯一令人困惑的点在于\(p_z \rightarrow p_g\)的变化,我一开始觉得这个是利用换元,但是从别的博客中看到,似乎是用了测度论的导数的知识,最后用到了变分的知识.



其中:



其证明思路是,当\(p_g=p_{data}\)的时候,\(C(G)=-\log 4\), 所以只需证明这个值为最小值,且仅再\(p_g=p_{data}\)的时候成立那么证明就结束了,为了证明这一点,作者凑了一个JSD, 而其正好满足我们要求(实际上只需KL散度即可Gibb不等式).

数值实验

在MNIST数据集上做实验(代码是仿别人的写的), 我们的目标自然是给一个z, G能够给出一些数字.

用不带卷积层的网络:



带卷积层的网络,不过不论\(z\)怎么变,结果都一样,感觉有点怪,但是实际上,如果\(G\)一直生成的都是比方说是1, 那也的确能够骗过\(D\), 这个问题算是什么呢?有悖啊...

代码

代码需要注意的一点是,用BCELoss, 但是更新G网络的时候,并不是传入fake_label, 而是real_label,因为G需要骗过D, 不知道该怎么说,应该明白的.


import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt class Generator(nn.Module):
def __init__(self, input_size):
super(Generator, self).__init__()
self.dense = nn.Sequential(
nn.Linear(input_size, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 784)
) def forward(self, x):
out = self.dense(x)
return out class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.dense = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 1),
nn.Sigmoid()
) def forward(self, x):
x = x.view(x.size(0), -1)
out = self.dense(x)
return out class Train:
def __init__(self, trainset, batch_size, z_size=100, criterion=nn.BCELoss(), lr=1e-3):
self.generator = Generator(z_size)
self.discriminator = Discriminator()
self.opt1 = torch.optim.SGD(self.generator.parameters(), lr=lr, momentum=0.9)
self.opt2 = torch.optim.SGD(self.discriminator.parameters(), lr=lr, momentum=0.9)
self.trainset = trainset
self.batch_size = batch_size
self.real_label = torch.ones(batch_size)
self.fake_label = torch.zeros(batch_size)
self.criterion = criterion
self.z_size = z_size def train(self, epoch_size, path):
running_loss1 = 0.0
running_loss2 = 0.0
for epoch in range(epoch_size):
for i, data in enumerate(self.trainset, 0):
try:
real_img, _ = data out1 = self.discriminator(real_img)
real_loss = self.criterion(out1, self.real_label) z = torch.randn(self.batch_size, self.z_size)
fake_img = self.generator(z)
out2 = self.discriminator(fake_img)
fake_loss = self.criterion(out2, self.fake_label) loss = real_loss + fake_loss
self.opt2.zero_grad()
loss.backward()
self.opt2.step() z = torch.randn(self.batch_size, self.z_size)
fake_img = self.generator(z)
out2 = self.discriminator(fake_img)
fake_loss = self.criterion(out2, self.real_label) #real_label!!!! self.opt1.zero_grad()
fake_loss.backward()
self.opt1.step() running_loss1 += fake_loss
running_loss2 += real_loss
if i % 10 == 9:
print("[epoch:{} loss1: {:.7f} loss2: {:.7f}]".format(
epoch,
running_loss1 / 10,
running_loss2 / 10
))
running_loss1 = 0.0
running_loss2 = 0.0
except ValueError as err:
print(err) #最后一批的数据可能不是batch_size
continue
torch.save(self.generator.state_dict(), path) def loading(self, path):
self.generator.load_state_dict(torch.load(path))
self.generator.eval()
"""
加了点卷积
"""
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt class Generator(nn.Module):
def __init__(self, input_size):
super(Generator, self).__init__()
self.dense = nn.Sequential(
nn.Linear(input_size, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 784)
) def forward(self, x):
out = self.dense(x)
return out class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, 5, 3, 2), # 1x28x28 --> 32x10x10
nn.ReLU(),
nn.MaxPool2d(2, 2), # 32 x 10 x 10 --> 32x5x5
nn.Conv2d(32, 64, 3, 1, 1), # 32x5x5-->32x5x5
nn.ReLU()
)
self.dense = nn.Sequential(
nn.Linear(1600, 512),
nn.ReLU(),
nn.Linear(512, 1),
nn.Sigmoid()
) def forward(self, x):
x = x.view(x.size(0), 1, 28, 28)
x = self.conv(x)
x = x.view(x.size(0), -1)
out = self.dense(x)
return out class Train:
def __init__(self, trainset, batch_size, z_size=100, criterion=nn.BCELoss(), lr=1e-3):
self.generator = Generator(z_size)
self.discriminator = Discriminator()
self.opt1 = torch.optim.SGD(self.generator.parameters(), lr=lr, momentum=0.9)
self.opt2 = torch.optim.SGD(self.discriminator.parameters(), lr=lr, momentum=0.9)
self.trainset = trainset
self.batch_size = batch_size
self.real_label = torch.ones(batch_size)
self.fake_label = torch.zeros(batch_size)
self.criterion = criterion
self.z_size = z_size def train(self, epoch_size, path):
running_loss1 = 0.0
running_loss2 = 0.0
for epoch in range(epoch_size):
for i, data in enumerate(self.trainset, 0):
try:
real_img, _ = data out1 = self.discriminator(real_img)
real_loss = self.criterion(out1, self.real_label) z = torch.randn(self.batch_size, self.z_size)
fake_img = self.generator(z)
out2 = self.discriminator(fake_img)
fake_loss = self.criterion(out2, self.fake_label) loss = real_loss + fake_loss
self.opt2.zero_grad()
loss.backward()
self.opt2.step() z = torch.randn(self.batch_size, self.z_size)
fake_img = self.generator(z)
out2 = self.discriminator(fake_img)
fake_loss = self.criterion(out2, self.real_label) #real_label!!!! self.opt1.zero_grad()
fake_loss.backward()
self.opt1.step() running_loss1 += fake_loss
running_loss2 += real_loss
if i % 10 == 9:
print("[epoch:{} loss1: {:.7f} loss2: {:.7f}]".format(
epoch,
running_loss1 / 10,
running_loss2 / 10
))
running_loss1 = 0.0
running_loss2 = 0.0
except ValueError as err:
print(err) #最后一批的数据可能不是batch_size
continue
torch.save(self.generator.state_dict(), path) def loading(self, path):
self.generator.load_state_dict(torch.load(path))
self.generator.eval()

Generative Adversarial Nets (GAN)的更多相关文章

  1. 一文读懂对抗生成学习(Generative Adversarial Nets)[GAN]

    一文读懂对抗生成学习(Generative Adversarial Nets)[GAN] 0x00 推荐论文 https://arxiv.org/pdf/1406.2661.pdf 0x01什么是ga ...

  2. Generative Adversarial Nets(GAN Tensorflow)

    Generative Adversarial Nets(简称GAN)是一种非常流行的神经网络. 它最初是由Ian Goodfellow等人在NIPS 2014论文中介绍的. 这篇论文引发了很多关于神经 ...

  3. Generative Adversarial Nets[Wasserstein GAN]

    本文来自<Wasserstein GAN>,时间线为2017年1月,本文可以算得上是GAN发展的一个里程碑文献了,其解决了以往GAN训练困难,结果不稳定等问题. 1 引言 本文主要思考的是 ...

  4. Generative Adversarial Nets(原生GAN学习)

    学习总结于国立台湾大学 :李宏毅老师 Author: Ian Goodfellow • Paper: https://arxiv.org/abs/1701.00160 • Video: https:/ ...

  5. GAN(Generative Adversarial Nets)的发展

    GAN(Generative Adversarial Nets),产生式对抗网络 存在问题: 1.无法表示数据分布 2.速度慢 3.resolution太小,大了无语义信息 4.无reference ...

  6. (转)Deep Learning Research Review Week 1: Generative Adversarial Nets

    Adit Deshpande CS Undergrad at UCLA ('19) Blog About Resume Deep Learning Research Review Week 1: Ge ...

  7. Generative Adversarial Nets[BEGAN]

    本文来自<BEGAN: Boundary Equilibrium Generative Adversarial Networks>,时间线为2017年3月.是google的工作. 作者提出 ...

  8. Generative Adversarial Nets[content]

    0. Introduction 基于纳什平衡,零和游戏,最大最小策略等角度来作为GAN的引言 1. GAN GAN开山之作 图1.1 GAN的判别器和生成器的结构图及loss 2. Condition ...

  9. Generative Adversarial Nets[CycleGAN]

    本文来自<Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks>,时间线为2017 ...

随机推荐

  1. 学习java 7.12

    学习内容: File是文件和目录路径名的抽象表示,File封装的不是一个真正存在的文件,仅仅是一个路径名 File类的方法 绝对目录和相对目录的区别 字节流 使用字节输出流写数据的步骤 : 创建字节输 ...

  2. 大数据学习day23-----spark06--------1. Spark执行流程(知识补充:RDD的依赖关系)2. Repartition和coalesce算子的区别 3.触发多次actions时,速度不一样 4. RDD的深入理解(错误例子,RDD数据是如何获取的)5 购物的相关计算

    1. Spark执行流程 知识补充:RDD的依赖关系 RDD的依赖关系分为两类:窄依赖(Narrow Dependency)和宽依赖(Shuffle Dependency) (1)窄依赖 窄依赖指的是 ...

  3. Linux下删除的文件如何恢复

    Linux下删除的文件如何恢复 参考自: [1]linux下误操作删除文件如何恢复 [2]Linux实现删除撤回的方法 以/home/test.txt为例 1.df -T 文件夹 找到当前文件所在磁盘 ...

  4. Output of C++ Program | Set 5

    Difficulty Level: Rookie Predict the output of below C++ programs. Question 1 1 #include<iostream ...

  5. Xcode中匹配的配置包的存放目录

    /Applications/Xcode.app/Contents/Developer/Platforms/iPhoneOS.platform/DeviceSupport

  6. OC中的结构体

    一.结构体 结构体只能在定义的时候进行初始化 给结构体属性赋值    + 强制转换: 系统并不清楚是数组还是结构体,需要在值前面加上(结构体名称)    +定义一个新的结构体,进行直接赋值    + ...

  7. JVM堆空间结构及常用的jvm内存分析命令和工具

    jdk8之前的运行时数据区域 程序计数器 是一块较小的内存空间,它可以看做是当前线程所执行的字节码的行号指示器.每个线程都有一个独立的程序计数器,这类内存区域为"线程私有",此内存 ...

  8. 初步接触Linux命令

    目录 虚拟机快照 1.首先将已经运行的系统关机 2.找到快照 拍摄快照 3.找到克隆 下一步 有几个快照会显示几个 4.克隆完成后 要修改一下IP 不然无法同时运行两个虚拟机系统 系统介绍 1.pin ...

  9. 【Matlab】imagesc的使用

    imagesc(A) 将矩阵A中的元素数值按大小转化为不同颜色,并在坐标轴对应位置处以这种颜色染色 imagesc(x,y,A) x,y决定坐标范围 x,y应是两个二维向量,即x=[x1 x2],y= ...

  10. Markdown 语法粗学

    Markdown 语法粗学 Typora下载 Typora官网 下拉点击右上角 选择下载即可 里面选择自己想要的系统下载即可 如果下载缓慢,推荐使用各自的下载工具或者使用软件管家等 亲测迅雷速度尚可 ...