GAN的定义

  GAN是一个评估和学习生成模型的框架。生成模型的目标是学习到输入样本的分布,用来生成样本。GAN和传统的生成模型不同,使用两个内置模型以“对抗”的方式来使学习分布不断接近输入样本分布。两个模型一个是生成模型(Generative model),用来生成样本;另一个是判别模型(Discriminative model),产生判断样本是真实而不是来自生成模型的概率。生成模型并不直接学习输入样本的分布,而是通过“欺骗”判别模型的方式提高输入分布的逼近程度;判别模型则是使用生成样本和真实样本来提高判别准确率。

  对于生成模型$G$和判别模型$D$,GAN的优化式的如下:

$\min\limits_{G}\max\limits_{D} V(D,G)$

$ V(D,G) = E_{x\sim p_{data}}[\log_{}D(x)] + E_{z\sim p_z}[\log_{}(1-D(G(z)))]$

  其中$p_{data}$是样本的真实分布。比如对于某个分辨率的图片来说,这个分布基于这个分辨率上的所有图片。注意!即使是乱码图片,它也是有概率密度的,只不过很小很小而已。$p_z$是随机数$z$的分布,通常用高斯分布(文章用的是均匀分布,这是最早的文章);$G(z)$就是生成器基于这个随机数生成的样本。$D(x)$是判别器判断样本$x$为真实样本的概率。

  使用梯度下降法进行优化的过程如下:

  每次分别随机拿到$m$个真实和生成样本用来对函数($\theta_d$、$\theta_g$分别包含在$D$和$G$中)

$\displaystyle f(\theta_d) = \frac{1}{m}\sum\limits_{i=1}^{m}[\log_{}D(x^{(i)})+\log_{}(1-D(G(z^{(i)})))]$

  梯度上升,也就是优化判别模型;再生成$m$个样本用来对函数

$\displaystyle g(\theta_g) = \frac{1}{m}\sum\limits_{i=1}^{m}[\log_{}(1-D(G(z^{(i)})))]$

  梯度下降也就是优化生成模型。最终二者都达到最优。

  以下是拟合的过程图:

  黑点线是样本$x$的真实分布,绿线是样本$x$的生成模型分布,蓝虚线是判别模型判断$x$属于真实的概率,下方的$z$是均匀分布随机数$z$到生成样本$x$的映射。

  a图是初始化时,判别模型$D$和生成模型$G$都很差。

  b图是取样本来更新$D$,$D$在此刻变为最优。也就是说,在当前的$G$下,对于每个$x$,都能正确得出它是真实样本的概率:

$\displaystyle D(x) = \frac{p_{data}(x)}{p_{data}(x)+p_g(x)}$,

  证明在后面,不过想想也是这么一回事。比如看绿线和黑点线中间的交叉点,此时$x$的真实概率为0.5。

  c图是更新$G$,$G$在此刻$D$的基础上变得不错了。

  d图是一直迭代到最后,$G$和真实分布一模一样,而$D$的判断概率全是0.5。但是,一模一样也不是很好。因为样本集总是有限的,并不能完全契合样本全体的分布,所以如果生成分布和样本集分布一模一样的话可能会过拟合。

全局最优

  对任意给定的$G$,最优的$D$对每个样本$x$,都有:

$D_G^*(x) = \displaystyle  \frac{p_{data}(x)}{p_{data}(x)+p_g(x)}$

  这是因为最优的$D$最大化关于$\theta_d$的函数:

$\displaystyle V(G,D) = \int_x p_{data}(x)\log_{}(D(x)) + p_g(x)\log_{}(1-D(x))dx$  

  也就是对于每个$x$,这个积分内部函数都取最大值。对于函数

$h(y) = a\log_{}(y)+b\log_{}(1-y),a\ge 0,b\ge 0$

  在$0< y < y^*$时,$h'(y)$大于零;$y^*< y < 1$,$h'(y)$小于零。所以$h(y)$在

$\displaystyle y^*=\frac{a}{a+b}$

  时最大。因此得证。

  假如$G$训练到了最优,也就是输出分布与输入样本分布相同,即$p_{data}=p_g$,而$D$也最优时,有:

$\displaystyle V(D,G) = E_{x\sim p_{data}}\left[\log_{}\frac{p_{data}(x)}{p_{data}(x)+p_g(x)}\right] + E_{x\sim p_g}\left[\log_{}\frac{p_{g}(x)}{p_{data}(x)+p_g(x)}\right]$

$\displaystyle = E_{x\sim p_{data}}\left[\log_{}\frac{1}{2}\right] + E_{x\sim p_g}\left[\log_{}\frac{1}{2}\right]=-\log_{}4$

CGAN

  CGAN(Conditional GAN)是GAN的一种基本变通。相对于基本GAN的生成器和判别器,输入分别只有随机抽样和样本,CGAN的输入则可以附带条件。CGAN生成器的输入除了随机抽样外,还可以附加样本的一些特征,从而可以更加精确地生成我们期望的生成样本。判别器则是输入样本和对应的特征,联合这两者进行判断样本的“真实性”。

  比如用CGAN训练MNIST时,我们想要让生成器能生成我们期望的数字。生成器的输入就是随机抽样+对应数字的one-hot编码,而判别器的输入就是生成的样本或真实样本+对应数字的one-hot编码。所以CGAN的优化函数就在GAN的基础上改改:

$\max\limits_G\min\limits_D V(D,G) = E_{x\sim p_{data}}[\log_{}D(x|y)] + E_{z\sim p_z}[\log_{}(1-D(G(z|y)|y))]$

  其中$y$是$x$的标签。上面$D$中表示的好像是条件概率,我觉得也可以直接理解为联合概率。生成器和判别器只需将它们的两个输入concatenate,后面的层就和GAN类似了。另外,上式没有对样本和标签不匹配的情况进行限制,论文中也没有写。这样的话,模型就可能生成比较真实但与标签不符的样本。所以训练时判别器还应该惩罚真实但标签错误的输入。

  下面用MNIST训练CGAN来生成数字,模型结构是用CGAN论文中的。我原本是想用卷积网络来搭建,然而迭代了几万次都生成不出有点像数字的图,最终放弃。而仿照论文用全连接层搭建的模型,虽然也不是特别“真”,至少比我原来的模型效果好多了。下面是生成的数字图:

  一共迭代了1100次,每次迭代使用100个样本对生成器和判别器进行训练。随着迭代次数的增加,生成图片的效果逐渐变好,又逐渐崩坏,然后又逐渐变好,如此反复循环,所以要把握迭代停止的时机。理论上来讲,如果一直迭代下去,最终是会平稳下来的。但是我迭代到几千次甚至上万次,生成的图片效果依旧没有变得很好,具体原因不清楚,还有待发掘。

  以下是训练代码:

#%%生成器
from keras import layers,Input,Model,utils,activations
import numpy as np sample_num = 200
Input_sampling = Input(shape=[sample_num])
Input_label = Input(shape=[10]) x1 = layers.Dense(sample_num,activation='relu')(Input_sampling)
x2 = layers.Dense(1000,activation='relu')(Input_label)
x = layers.concatenate([x1,x2])
x = layers.Dropout(0.5)(x)
x = layers.Dense(28*28,activation='sigmoid')(x)
x = layers.Reshape([28,28,1])(x) generator = Model([Input_label,Input_sampling],x)
generator.summary()
utils.plot_model(generator)
#%%判别器
Input_img = Input(shape=[28,28,1]) x1 = layers.Reshape([28*28])(Input_img)
x1 = layers.MaxoutDense(240,5)(x1)
x2 = layers.MaxoutDense(50,5)(Input_label)
x = layers.concatenate([x1,x2])
x = layers.MaxoutDense(240,4)(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(1,activation='sigmoid')(x) discriminator = Model([Input_label,Input_img],x)
discriminator.summary()
utils.plot_model(discriminator)
#%%合并模型GAN
x = generator([Input_label,Input_sampling])
x = discriminator([Input_label,x])
gan = Model([Input_label,Input_sampling],x)
#%%数据预处理
from keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt
(train_data,train_labels),(test_data,test_labels) = mnist.load_data()
def label_to_one_hot(labels):
l = np.zeros([len(labels),10])
for i in range(len(labels)):
l[i,labels[i]]=1
return l
train_data = train_data[:,:,:,np.newaxis].astype('float')/255
test_data = test_data[:,:,:,np.newaxis].astype('float')/255
train_labels = label_to_one_hot(train_labels)
test_labels = label_to_one_hot(test_labels)
plt.imshow(train_data[0,:,:,0])
#%%编译模型
from tensorflow.keras import optimizers,losses
import matplotlib.pyplot as plt generator.trainable = True
discriminator.trainable = False
gan_optimizer = optimizers.Adam()
gan.compile(
optimizer=gan_optimizer,
loss='binary_crossentropy')
discriminator.trainable = True
d_optimizer = optimizers.Adam()
discriminator.compile(
optimizer=d_optimizer,
loss='binary_crossentropy')
#%%训练
def get_samples():
return np.random.random([batch_size,sample_num])*2-1
def train_generator(batch_size,if_show_loss):
samples = get_samples()
labels = np.zeros([batch_size,10])
judges = np.ones(batch_size) - np.abs(np.random.normal(scale=0.05,loc = 0,size = batch_size))
for i in labels:
i[np.random.randint(10)] = 1.
gan.fit([labels,samples],judges,verbose=if_show_loss)
def train_discriminator(data,labels_true_right,batch_size,if_show_loss):
#生成器生成图像
samples = get_samples()
labels_fake = np.zeros([batch_size,10])
for i in labels_fake:
i[np.random.randint(10)] = 1.
fake_imgs = generator.predict([labels_fake,samples])
#获取错误标签真图像
s = np.linspace(0,9,10).astype('int')
lebals_true_wrong = np.zeros_like(labels_true_right)
for i in range(batch_size):
p = np.ones(10)/9
p[np.argmax(labels_true_right[i])] = 0
lebals_true_wrong[i,np.random.choice(s,1, p=p)] = 1
#将输入拼接
in_imgs = np.concatenate([fake_imgs,data,data],axis = 0)
in_labels = np.concatenate(
[labels_fake,lebals_true_wrong,labels_true_right],
axis = 0)
judges_wrong = np.zeros(batch_size*2) + np.random.normal(scale=0.05,loc = 0,size = batch_size*2)
judges_right = np.ones(batch_size) - np.random.normal(scale=0.05,loc = 0,size = batch_size)
train_judges = np.concatenate([judges_wrong,judges_right],axis=0) discriminator.fit([in_labels,in_imgs],train_judges,verbose=if_show_loss)
def save_img_and_model(num,i):
label = np.zeros([1,10])
label[0,num] = 1
img = generator.predict([label,get_samples()])
plt.imshow( img[0,:,:,0],cmap='bone')
plt.show( ) generator.save('generator.h5')
discriminator.save('discriminator.h5') epochs = 10000
batch_size = 500
train_size = 20000
now_train = 0
for i in range(epochs):
print(i)
if_show_loss = False
if i % 20 == 0:
if_show_loss = True
save_img_and_model(np.random.randint(10),i)
train_generator(batch_size,if_show_loss)
train_discriminator(
train_data[now_train:now_train+batch_size],
train_labels[now_train:now_train+batch_size],
batch_size,if_show_loss)
now_train = (now_train + batch_size)%train_size

参考文献

  Generative Adversarial Networks

  Conditional Generative Adversarial Nets

GAN和CGAN——生成式对抗网络和条件生成式对抗网络的更多相关文章

  1. 渐进结构—条件生成对抗网络(PSGAN)

    Full-body High-resolution Anime Generation with Progressive Structure-conditional Generative Adversa ...

  2. OpenStack网络指导手册 -基本网络概念

    转自:http://blog.csdn.net/zztflyer/article/details/50441200 目录(?)[-] 以太网Ethernet 虚拟局域网VLANs 子网和地址解析协议S ...

  3. 无废话Android之smartimageview使用、android多线程下载、显式意图激活另外一个activity,检查网络是否可用定位到网络的位置、隐式意图激活另外一个activity、隐式意图的配置,自定义隐式意图、在不同activity之间数据传递(5)

    1.smartimageview使用 <LinearLayout xmlns:android="http://schemas.android.com/apk/res/android&q ...

  4. iOS网络-06-监听Iphone的网络状态

    使用系统的方法来监听网络状态 系统的方法是通过通知机制来实现网络状态的监听 实现网络状态监听的步骤 定义Reachability类型的成员变量来保存网络的状态 @property (nonatomic ...

  5. 连接SQLServer2005失败--[Microsoft][ODBC SQL Server Driver][DBNETLIB]一般性网络错误。请检查网络文档

    连接SQLServer2005失败,错误信息: 错误类型:Microsoft OLE DB Provider for ODBC Drivers (0x80004005)[Microsoft][ODBC ...

  6. Java 网络编程(一) 网络基础知识

    链接地址:http://www.cnblogs.com/mengdd/archive/2013/03/09/2951826.html 网络基础知识 网络编程的目的:直接或间接地通过网络协议与其他计算机 ...

  7. Android编程 获取网络连接状态 及调用网络配置界面

    获取网络连接状态 随着3G和Wifi的推广,越来越多的Android应用程序需要调用网络资源,检测网络连接状态也就成为网络应用程序所必备的功能. Android平台提供了ConnectivityMan ...

  8. Android编程获取网络连接状态及调用网络配置界面

    获取网络连接状态 随着3G和Wifi的推广,越来越多的Android应用程序需要调用网络资源,检测网络连接状态也就成为网络应用程序所必备的功能. Android平台提供了ConnectivityMan ...

  9. 网络知识--OSI七层网络与TCP/IP五层网络架构及二层/三层网络

    作为一个合格的运维人员,一定要熟悉掌握OSI七层网络和TCP/IP五层网络结构知识. 废话不多说!下面就逐一展开对这两个网络架构知识的说明:一.OSI七层网络协议OSI是Open System Int ...

  10. 深度残差网络(DRN)ResNet网络原理

    一说起“深度学习”,自然就联想到它非常显著的特点“深.深.深”(重要的事说三遍),通过很深层次的网络实现准确率非常高的图像识别.语音识别等能力.因此,我们自然很容易就想到:深的网络一般会比浅的网络效果 ...

随机推荐

  1. Chrome 浏览器远程调试 【转】

    Chrome 浏览器按F12,可以调试JS,分析HTTP包等.但是有时候需要远程调试. 比如,某个EXE它内部嵌套了浏览器的话,可以想办法打开它的远程调试功能,然后在外部连到这个地址,就能分析它的ht ...

  2. C#/.NET/.NET Core优质学习资料,干货收藏!

    前言 今天大姚给大家分享一些C#/.NET/.NET Core优质学习资料,希望可以帮助到有需要的小伙伴. 什么是 .NET? .NET 是一个免费的.跨平台的.开源开发人员平台,用于构建许多不同类型 ...

  3. 超轻量级、支持插件的 .NET 网络通信框架

    前言 给大家推荐一个轻量级的.支持插件的综合网络通信库:TouchSocket. TouchSocket 的基础通信功能包括 TCP.UDP.SSL.RPC 和 HTTP.其中,HTTP 服务器支持 ...

  4. fluent python-chap2

    1. 内置序列类型 容器序列: list tuple collections.deque 可以存放不同类型的数据. 存放的是它们所包含的任意类型的对象的引用. 扁平序列: str bytes byte ...

  5. servlet一些笔记、详解

    一.什么是servlet? 处理请求和发送响应的过程是由一种叫做Servlet的程序来完成的,并且Servlet是为了解决实现动态页面而衍生的东西.理解这个的前提是了解一些http协议的东西,并且知道 ...

  6. 全网最适合入门的面向对象编程教程:50 Python函数方法与接口-接口和抽象基类

    全网最适合入门的面向对象编程教程:50 Python 函数方法与接口-接口和抽象基类 摘要: 在 Python 中,接口和抽象基类(Abstract Base Classes, ABCs)都用于定义类 ...

  7. Windows系统无法打开‘’网络发现‘’功能

    Windows10无法开启网络发现 解决办法: 1. services.msc 2. 开启 SSDP Discovery ,设置 启动类型为 自动 ,服务状态为 启动 Windows7 无法开启网络发 ...

  8. WPF下使用FreeRedis操作RedisStream实现简单的消息队列

    Redis Stream简介 Redis Stream是随着5.0版本发布的一种新的Redis数据类型: 高效消费者组:允许多个消费者组从同一数据流的不同部分消费数据,每个消费者组都能独立地处理消息, ...

  9. 《Vue.js 设计与实现》读书笔记 - 第7章、渲染器的设计

    第7章.渲染器的设计 7.1 渲染器与响应系统的结合 渲染器需要有跨平台的能力. 在浏览器端会渲染为真实的 DOM 元素. const { effect, ref } = VueReactivity ...

  10. python安装sklearn

    安装sklearn这个包,首先要安装三个依赖包,如图划红线的部分. 要找这三个包,我们都可以登录:https://www.lfd.uci.edu/~gohlke/pythonlibs/#scipy 这 ...