第一个GAN模型—生成手写数字

一、GAN的基础:对抗训练

形式上,生成器和判别器由可微函数表示如神经网络,他们都有自己的代价函数。这两个网络是利用判别器的损失记性反向传播训练。判别器努力使真实样本输入和伪样本输入带来的损失最小化,而生成器努力使它生成的为样本造成的判别器损失最大化。

训练数据集决定了生成器要学习模拟的样本类型,例如,目标是生成猫的逼真图像,我们就会给GAN提供一组猫的图像。

用更专业的术语来说,生成器的目标是生成符合训练数据集数据分布的样本。对计算机来说,图像只是矩阵:灰度图是二维的,彩色图是三维的。当在屏幕上呈现时,这些矩阵中的像素值将显示为图像线条、边缘、轮廓等的所有视觉元素。这些值在数据集中的每个图像上遵循复杂的分布,如果没有分布规律,图像将不过是些随机噪声。目标识别模型学习图像中的模式以识别图像的内容,生成器所做的可以认为是相反的过程:它学习合成这些模式,而不是识别这些模式。

1. 代价函数

遵循标准的表示形式,用\(\text{J}^{(G)}\)表示生成器的代价函数,用\(\text{J}^{(D)}\)表示判别器的代价函数。两个网络的训练参数(权重和偏置)用希腊字母表示:\(\theta^{(G)}\)表示生成器,\(\theta^{(D)}\)表示判别器。

GAN在两个关键方面不同于传统的神经网络。第一,代价函数\(J\),传统神经网络的代价函数仅根据其自身可训练的参数定义,数学表示为\(\text{J}^{(\theta)}\)。相比之下,GAN由两个网络组成,其代价函数依赖于两个网络的参数。也就是说,生成器的代价函数是\(\text{J}^{(G)}({\theta}^{(G)}, {\theta}^{(D)})\),而判别器的代价函数是\(\text{J}^{(D)}({\theta}^{(G)}, {\theta}^{(D)})\)​。

第二,在训练过程中,传统的神经网络可以调整它的所有参数\(\theta\)。在GAN中,每个网络只能调整自己的权重和偏置。也就是说,在训练过程中,生成器只能调整\({\theta}^{(G)}\),判别器只能调整\({\theta}^{(D)}\)​​。因此,每个网络只控制了决定损失的部分参量。

为了使上述内容不那么抽象,考虑下面这个类比。想象一下我们正在选择下班开车回家的路线,如果交通不堵塞,最快的选择是高速公路,但在交通高峰期,优选是走一条小路。尽管小路更长更曲折,但当高速公路上交通堵塞时,走小路可能会更快地回家。

让我们把它当作一道数学题——\(J\)作为代价函数,并定义为回家所需的时间。我们的目标是尽量减小\(J\)。为简单起见,假设离开办公室的时间是固定的,既不能提前离开,也不能为了避开高峰时间而晚走。所以唯一能改变的参数是路线\(\theta\)。

如果我们所拥有的是路上唯一的车,代价将类似于一个常规的神经网络:它将只取决于路线,且优化\(\text{J}{(\theta)}\)​完全在我们的能力范围内。然而,一旦将其他驾驶员引入方程式,情况就会变得更加复杂。突然之间,我们回家的时间不仅取决于自己的决定,还取决于其他驾驶员的行路方案,即\(\text{J}({\theta}^{\text{我们}}, {\theta}^{\text{其他驾驶员}})\)。就像生成器网络和判别器网络一样,“代价函数”将取决于各种因素的相互作用,其中一些因素在我们的掌控之下,而另一些因素则不在。

2. 训练过程

上面所描述的两个差异对GAN的训练过程有着深远的影响。传统神经网络的训练是个优化问题,通过寻找一组参数来最小化代价函数,移动到参数空间中的任何相邻点都会增加代价。这可能是参数空间中的局部或全局最小值,由寻求最小化使用的代价函数所决定。

最小化代价函数的优化过程如下图所示:

碗形网格表示参数空间\(\theta_1\)和\(\theta_2\)中的损失\(J\)。黑色点线表示通过优化使参数空间中的损失最小化

因为生成器和判别器智能调整自己的参数而不能相互调整对方的参数,所以GAN训练可以用一个博弈过程来更好的描述,而非优化。该博弈中的对手是GAN所包含的两个网络。

当两个网络达到纳什均衡时GAN训练结束,在纳什均衡点上,双方都不能通过改变策略来改善自己的情况。从数学角度来说,发上在这样的情况下——生成器的可训练参数\(\theta^{(G)}\)对应的生成器的代价函数\(\text{J}^{(G)}({\theta}^{(G)}, {\theta}^{(D)})\)最小化;同时,对应该网络参数\(\theta^{(D)}\)下的判别器的代价函数\(\text{J}^{(D)}({\theta}^{(G)}, {\theta}^{(D)})\)也得到最小化。下图说明了二者零和博弈的建立和达到纳什均衡的过程。

玩家1(左)试图通过调整\(\theta_1\)来最小化V。玩家2(中间)试图通过调整\(\theta_2\)来最大化V(最小化-V)。鞍形网络(右)显示了参数空间\(V({\theta_1}, {\theta_2})\)中的组合损失。虚线表示在鞍形网络中心收敛达到纳什均衡

回到我们的类比,对于我们和可能在路上遇到的所有其他驾驶员来说,当每一条回家的路线所花费的时间都完全相同时,纳什均衡将会发生。任何更快的路线都会被交通拥堵量的成比例增长所抵消,从而减缓了每个人的速度。而这种状态在现实生活中几乎是无法实现的,即使使用像谷歌地图这样提供实时流量更新的工具,也不可能完美地评估出回家的最佳路径。

这同样适用于训练GAN网络时的高维、非凸情况。即使是像 MNIST数据集中的那些小到只有28×28像素的灰度图像,也有28×28=784维。如果它们被着色(RGB),它们的维数将增加到3倍变成2352。在训练数据集中的所有图像上捕获这种分布非常困难,特别是当最好的学习方法是从对手(判别器)那里学习时。

成功地训练GAN需要反复试验,尽管有最优方法,但它是一门科学的同时也是一门艺术。

二、生成器和判别器

现在通过引入更多的表示概括所学的内容。生成器(G)接收随机噪声向量z并生成为样本\(x^*\)。数学上来说,\(G(z) = x^*\)。判别器(D)的输入要么是真实样本x,要么是伪样本x*;对于每个输入,它输出一个介于0和1之间的值,表示输入是真实样本的概率。下图用刚才介绍的术语和符号描述了GAN架构。

生成器网络G将随机向量z转换为伪样本\(x^*: G(z) = x^*\)。判别器网络D对输入样本是各真实进行分类并输出。对于真实样本x,判别器力求输出尽可能接近1的值;对于伪样本x,判别器力求输出尽可能接近0的值。相反,生成器希望D(x*)尽可能接近1,这表明判别器被欺骗,将伪样本分类为真实样本

1. 对抗的目标

判别器的目标是尽可能精确。对于真实样本x,D(x)力求尽可能接近1(正的标签);对于伪样本x*,\(D\left( x^{\ast }\right)\)​力求尽可能接近0(负的标签)。

生成器的目标正好相反,它试图通过生成与训练数据集中的真实数据别无二致的伪样本x*来欺骗判别器。从数学角度来讲,即生成器试图生成假样本\(D\left( x^{\ast }\right)\),使得\(D\left( x^{\ast }\right)\)尽可能接近1。

2. 混淆矩阵

判别器的分类可以使用混淆矩阵来表示,混淆矩阵是二元分类中所有可能结果的表格表示。如下表:

判别器的分类结果如下:

  1. 真阳性(true positive)——真实样本正确分类为真D(x)≈1;
  2. 假阴性( false negative)——真实样本错误分类为假D(x)≈0;
  3. 真阴性( true negative)——伪样本正确分类为假D(x*)≈0;
  4. 假阳性( false positive)——伪样本错误分类为真D(x*)≈1。

使用混淆矩阵的术语,判别器试图最大化真阳性和真阴性分类,这等同于最小化假阳性和假阴性分类。相反,生成器的目标是最大化判别器的假阳性分类,这样生成器才能成功地欺骗判别器,使其相信伪样本是真的。生成器不关心判别器对真实样本的分类效果如何,只关心对伪样本的分类。

三、GAN训练算法

这里介绍的算法使用小批量(mini-batch)而不是一次使用一个样本。

GAN训练算法

对于每次训练迭代,执行

(1)训练判别器

a. 取随机的小批量的真实样本x

b. 取随机的小批量的随机噪声z,并生成一个小批量伪样本:G(z) = x*

c. 计算D(x)和D(x*)的分类损失,并反向传播总误差以更新\({\theta}^{(D)}\)来最小化分类损失

​ (2)训练生成器

​ a. 取随机的小批量的随机噪声z生成一小批量伪样本:G(z) = x*

​ b. 用判别器网络对x*进行分类

​ c. 计算D(x*)的分类损失,并反向传播总误差以更新\({\theta}^{(G)}\)来最大化分类损失。

结束

在步骤1中训练判别器时,生成器的参数保持不变;同样,在步骤2中,在训练生成器时保持判别器的参数不变。之所以只允许更新被训练网络的权重和偏置,是因为要将所有更改隔离到仅受该网络控制的参数中。

四、生成手写数字

本节将实现一个GAN,它将学习生成外观逼真的手写数字,用的是带有TensorFlow后端的Python神经网络库Keras。

1. 导入模块并指定模型输入维度

#import statements
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import os os.environ['CUDA_VISIBLE_DEVICES'] = '0' from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, Flatten, GlobalAveragePooling2D, Reshape
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from keras.layers.advanced_activations import LeakyReLU #模型输入维度
img_rows = 28
img_cols = 28
channels = 1 img_shape = (img_rows, img_cols, channels)#输入图片的维度 z_dim = 100#噪声向量的大小用作生成器的输入
img_shape

2. 构造生成器

简而言之,生成器是一个只有一个隐藏层的神经网络。生成器以z为输入,生成28×28×1的图像。在隐藏层中使用 LeakyReLU激活函数,与将任何负输入映射

到0的常ReLU函数不同, LeakyReLU函数允许存在一个小的正梯度,这样可以防止梯度在训练过程中消失,从而产生更好的训练效果。

在输出层使用tanh激活函数,它将输出值缩放到范围[-1, 1]。之所以使用tanh(与sigmoid同, sigmoid会输出更为典型的0到1范围内的值),是因为它有助于生成更清断的图像。

def build_generator(img_shape, z_dim):
model = Sequential([
Dense(128, input_dim=z_dim),#全连接层
LeakyReLU(alpha=0.01),
Dense(28*28*1, activation='tanh'),
Reshape(img_shape)#生成器的输出改变为图像尺寸
])
return model build_generator(img_shape, z_dim).summary()

3. 构造判别器

判别器接收28×28×1的图像,并输出表示输入是否被视为真而不是假的概率。判别器由一个两层神经网络表示,其隐藏层有128个隐藏单元及激活函数为 LeakyReLU。

为简单起见,我们构造的判别器网络看起来与生成器几乎相同,但并非必须如此。实际上,在大多数GAN的实现中,生成器和判别器网络体系结构的大小和复杂性都相差很大。

注意,与生成器不同的是,判别器的输出层应用了 sigmoid激活函数。这确保了输出值将介于0和1之间,可以将其解释为生成器将输入认定为真的概率。

def build_discrimination(img_shape):
model = Sequential([
Flatten(input_shape=img_shape),#输入图像展平
Dense(128),
LeakyReLU(alpha=0.01),
Dense(1, activation='sigmoid')
])
return model build_discrimination(img_shape).summary()

4. 搭建整个模型

构建并编译先前实现的生成器模型和判别器模型。注意:在用于训练生成器的组合模型中,通过将discriminator.trainable设置为False来固定判别器参数。还要注意的是,组合模型(其中判别器设置为不可训练)仅用于训练生成器。判别器将用单独编译的模型训练。(当回顾训练循环时,这一点会变得很明显。)

使用二元交又熵作为在训练中寻求最小化的损失函数。二元交叉熵( binary cross-entropy)用于度量二分类预测计算的概率和实际概率之间的差异;交叉损失越大,预测离真值就越远。

优化每个网络使用的是Adam优化算法。该算法名字源于adaptive moment estimation。这是一种先进的基于梯度下降的优化算法,Adam凭借其通常优异的性能已经成为大多数GAN的首选优化器。

def build_gan(generator, discriminator):
model = Sequential() #生成器模型和判别器模型结合到一起
model.add(generator)
model.add(discriminator) return model discriminator = build_discrimination(img_shape)#构建并编译判别器
discriminator.compile (
loss='binary_crossentropy',
optimizer=Adam(),
metrics=['accuracy']
)
generator = build_generator(img_shape, z_dim)#构建生成器
discriminator.trainable = False#训练生成器时保持判别器的参数固定 #构建并编译判别器固定的GAN模型,以生成训练器
gan = build_gan(generator, discriminator)
gan.compile(
loss='binary_crossentropy',
optimizer=Adam()
)

5. 训练

首先,取随机小批量的MNIST图像为真实样本,从随机噪声向量z中生成小批量伪样本,然后在保持生成器参数不变的情况下,利用这些伪样本训练判别器网络。其次,生成一小批伪样本,使用这些图像训练生成器网络,同时保持判别器的参数不变。算法在每次送代中都重复这个过程。

我们使用独热编码(one-hot-encoded)标签:1代表真实图像,0代表伪图像。z从标准正态分布(平均值为0、标准差为1的钟形曲线)中取样得到。训练判别器使得假标签分配给伪图像,真标签分配给真图像。对生成器进行训练时,生成器要使判别器能将真实的标签分配给它生成的伪样本。

注意:训练数据集中的真实图像被重新缩放到了-1到1。如前例所示,生成器在输出层使用tanh激活函数,因此伪样本同样将在范围(-1,1)内。相应地,就得将判别器的所有输入重新缩放到同一范围。

losses = []
accuracies = []
iteration_checkpoints = [] def train(iterations, batch_size, sample_interval):
(x_train, _), (_, _) = mnist.load_data()#加载mnist数据集
x_train = x_train/127.5 - 1.0#灰度像素值[0,255]缩放到[-1,1]
x_train = np.expand_dims(x_train, axis=3)
real = np.ones((batch_size, 1))#真实图像的标签都是1
fake = np.zeros((batch_size, 1))#伪图像的标签都是0
for iteration in range(iterations):
idx = np.random.randint(0, x_train.shape[0], batch_size)#随机噪声采样
imgs = x_train[idx] z = np.random.normal(0, 1, (batch_size, 100))#获取随机的一批真实图像
gen_imgs = generator.predict(z) #图像像素缩放到[0,1]
d_loss_real = discriminator.train_on_batch(imgs, real)
d_loss_fake = discriminator.train_on_batch(gen_imgs, fake) d_loss, accuracy = 0.5 * np.add(d_loss_real, d_loss_fake) z = np.random.normal(0, 1, (batch_size, 100))#生成一批伪图像
gen_imgs = generator.predict(z) g_loss = gan.train_on_batch(z, real)#训练判别器 if(iteration + 1) % sample_interval == 0:
losses.append((d_loss, g_loss))
accuracies.append(100.0 * accuracy)
iteration_checkpoints.append(iteration + 1)
print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (iteration + 1, d_loss, 100.0 * accuracy, g_loss))#输出训练过程 sample_images(generator)#输出生成图像的采样

6. 输出样本图像

在生成器训练代码中,你可能注意到调用了 sample_images()函数。该函数在每次sample_ interval选代中调用,并输出由生成器在给定迭代中合成的含有4x4幅合成图像的网格。运行模型后,你可以使用这些图像检查临时和最终的输出情况。

def sample_images(generator, image_grid_rows=4, image_grid_columns=4):
z = np.random.normal(0, 1, (image_grid_rows * image_grid_columns, z_dim))#样本随机噪声 gen_imgs = generator.predict(z)#从随机噪声生成图像 gen_imgs = 0.5 * gen_imgs + 0.5#将图像像素重置缩放至[0, 1]内 #设置图像网格
fig, axs = plt.subplots(
image_grid_rows,
image_grid_columns,
figsize=(4, 4),
sharex=True,
sharey=True
) cnt = 0
for i in range(image_grid_rows):
for j in range(image_grid_columns):
axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')#输出一个图像网格
axs[i, j].axis('off')
cnt += 1

7. 运行模型

这是最后一步,设置训练超参数——迭代次数和批量大小,然后训练模型。目前没有一种行之有效的方法来确定正确的迭代次数或正确的批量大小,只能观察训

练进度,通过反复试验来确定。

也就是说,对这些数有一些重要的实际限制:每个小批量必须足够小,以适合内存器处理(典型使用的批量大小是2的幂:32、64、128、256和512)。迭代次数也有一个实际的限制:拥有的迭代次数越多,训练过程花费的时间就越长。像GAN这样复杂的深度学习模型,即使有了强大的计算能力,训练时长也很容易变得难以控制。

为了确定合适的迭代次数,你需要监控训练损失,并在损失达到平稳状态(这意味着我们从进一步的训练中得到的改进增量很少,甚至没有)的次数附近设置送代次数。(因为这是个生成模型,像有监督的学习算法一样,也需要担心过拟合问题。)

#设置训练超参数
iterations = 20000
batch_size = 128
sample_interval = 1000 train(iterations, batch_size, sample_interval)

8. 检查结果

经过训练迭代后由生成器生成的样本图像,按照时间先后排列,如下图所示。可以看到,生成器起初只会产生随机噪声。在训练迭代的过程中,它越来越擅长模拟训练数据的特性,每次判别器判断生成的图像为假或判断生成的图像为真时,生成器都会稍有改进。

生成器经过充分训练后可以合成的图像样本,如下图所示。虽远非完美,但简单的双层生成器生成了逼真的数字,如数字1和7。

为了进行比较,我们给出从MNIST数据集中随机选择的真实图像样本。如下图所示。

五、小结

  1. GAN是由两个网络组成的:生成器(G)和判别器(D)。它们各自有自己的损失函数:\(\text{J}^{(G)}({\theta}^{(G)}, {\theta}^{(D)})\)和\(\text{J}^{(D)}({\theta}^{(G)}, {\theta}^{(D)})\)
  2. 在训练过程中,生成器和判别器只能调整自己的参数,即\(\theta^{(G)}\)和\(\theta^{(D)}\)。
  3. 两个网路通过一个类似博弈的动态过程同时训练:生成器试图最大化判别器的假阳性分类(将生成的图像分类为真图像),而判别器试图最小化它的假阳性和假阴性分类。

GAN实战笔记——第三章第一个GAN模型:生成手写数字的更多相关文章

  1. GAN实战笔记——第四章深度卷积生成对抗网络(DCGAN)

    深度卷积生成对抗网络(DCGAN) 我们在第3章实现了一个GAN,其生成器和判别器是具有单个隐藏层的简单前馈神经网络.尽管很简单,但GAN的生成器充分训练后得到的手写数字图像的真实性有些还是很具说服力 ...

  2. GAN实战笔记——第七章半监督生成对抗网络(SGAN)

    半监督生成对抗网络 一.SGAN简介 半监督学习(semi-supervised learning)是GAN在实际应用中最有前途的领域之一,与监督学习(数据集中的每个样本有一个标签)和无监督学习(不使 ...

  3. GAN——生成手写数字

    <Generative Adversarial Nets>是 GAN 系列的鼻祖.在这里通过 PyTorch 实现 GAN ,并且用于手写数字生成. 摘要: 我们提出了一个新的框架,通过对 ...

  4. 吴裕雄--天生自然python机器学习实战:K-NN算法约会网站好友喜好预测以及手写数字预测分类实验

    实验设备与软件环境 硬件环境:内存ddr3 4G及以上的x86架构主机一部 系统环境:windows 软件环境:Anaconda2(64位),python3.5,jupyter 内核版本:window ...

  5. GAN实战笔记——第五章训练与普遍挑战:为成功而GAN

    训练与普遍挑战:为成功而GAN 一.评估 回顾一下第1章中伪造达・芬奇画作的类比.假设一个伪造者(生成器)正在试图模仿达・芬奇,想使这幅伪造的画被展览接收.伪造者要与艺术评论家(判别器)竞争,后者试图 ...

  6. GAN实战笔记——第六章渐进式增长生成对抗网络(PGGAN)

    渐进式增长生成对抗网络(PGGAN) 使用 TensorFlow和 TensorFlow Hub( TFHUB)构建渐进式增长生成对抗网络( Progressive GAN, PGGAN或 PROGA ...

  7. SPRING IN ACTION 第4版笔记-第三章ADVANCING WIRING-003-@Conditional根据条件生成bean及处理profile

    一.用@Conditional根据条件决定是否要注入bean 1. package com.habuma.restfun; public class MagicBean { } 2. package ...

  8. TensorFlow实战之Softmax Regression识别手写数字

         关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...

  9. Tensorflow实战 手写数字识别(Tensorboard可视化)

    一.前言 为了更好的理解Neural Network,本文使用Tensorflow实现一个最简单的神经网络,然后使用MNIST数据集进行测试.同时使用Tensorboard对训练过程进行可视化,算是打 ...

随机推荐

  1. POJ2891Strange Way to Express Integers

    http://poj.org/problem?id=2891 实际上就是一个一元线性同余方程组.按照合并的方式来解即可. 有一个注意点,调用函数是会慢的. #include<iostream&g ...

  2. 我的2021年度总结-回忆录|附旅行Vlog

    今天是农历腊月初十,还有20天就是2022年了.这一年,些许遗憾,些许期盼.时间久了,很多事已经慢慢模糊了,只记得,这最后几个月的闲碎小事. 不止多久,很久没有码字了.有些事,记不清,忆不得.时至今年 ...

  3. JavaScript 中BOM的常用操作

    JavaScript BOM操作 1.获取浏览器窗口尺寸 var width=window,innerWidth //获取可视窗口宽度 var height=window.innerHeight // ...

  4. 《手把手教你》系列技巧篇(六十)-java+ selenium自动化测试 - 截图三剑客 -中篇(详细教程)

    1.简介 前面我们介绍了Selenium中TakeScreenshot类来截图,得到的图片是浏览器窗口内的截图.有时候,只截浏览器窗口内的图是不够的,而且TakeScreenshot截图只针对浏览器的 ...

  5. Servlet-概念及实现一个Servlet程序

    Servlet技术 一,Servlet概念 1,Servlet是JavaEE规范之一.规范就是接口 2,Servlet就是JavaWeb三大组件之一.三大组件分别是:Servlet程序.Filter过 ...

  6. IDEA设置Maven

    1,在idea中设置maven,让idea和maven结合使用 idea中内置了maven,一般不使用内置,因为用内置修改maven的设置不方便 使用自己安装的maven,需要覆盖idea中默认的设置 ...

  7. Servlet中的Filter 过滤器的简单使用!

    package com.aaa.filter; import java.io.IOException; import javax.servlet.Filter; import javax.servle ...

  8. ApacheCN 数据库译文集 20211112 更新

    创建你的 Mysql 数据库 零.前言 一.介绍 MySQL 设计 二.数据采集 三.数据命名 四.数据分组 五.数据结构调整 六.补充案例研究 Redis 学习手册 零.序言 一.NoSQL 简介 ...

  9. 【源码】Redis exists命令bug分析

    本文基于社区版Redis 4.0.8 1.复现条件 版本:社区版Redis 4.0.10以下版本 使用场景:开启读写分离的主从架构或者集群架构(master只负责写流量,slave负责读流量) 案例: ...

  10. Linux安装MySQL详细步骤(CentOS6、CentOS7)

    1.查看mysql的依赖(centos7 要把mysql改成mariadb) rpm -qa | grep mysql 2.删除mysql的依赖,可以两个都执行(centos7 要把mysql改成ma ...