GAN由论文《Ian Goodfellow et al., “Generative Adversarial Networks,” arXiv (2014)》提出。

GAN与VAEs的区别

GANs require differentiation through the visible units, and thus cannot model discrete data,

while VAEs require differentiation through the hidden units, and thus cannot have discrete latent variables.

即GAN不能处理离散数据,VAEs不能处理离散隐空间变量。

训练过程

常见模型是最小化一个loss,GAN里的生成器和鉴别器则是一个minmax操作,即

同时,生成器更新一次后,鉴别器应该更新多次,这样保证鉴别器可以维持在最优解附近。

如果生成器连续多次更新,而鉴别器不更新,则生成器倾向于生成那些“为难”鉴别器的同一批样本,这样生成器就缺乏多样性。

论文中给出的算法流程(简单的一次生成器更新对应多次鉴别器更新):

一些细节:

生成器使用relu和sigmoid激活函数,鉴别器使用maxout激活函数,Dropout只添加于鉴别器。


本文代码使用的一些trick:

  • 生成器最后的激活函数使用tanh代替sigmoid
  • 隐空间中使用正态分布去采样
  • 添加随机性因素。GAN是非常难以训练的,添加一些噪音可以让训练不会轻易卡主。除了Dropout外,此处对鉴别器判断的标签也添加随机噪音。
  • 稀疏梯度(Sparse gradients)在一些网络中通常是渴求的目标。但在GAN中,它会妨碍训练过程。所以将maxpool替换为带stride的卷积层,并使用leakyRELU代替relu激活函数。
  • 为了避免产生的图像如棋盘状(即一个个正方形像素块,而非连续流畅的像素),设定卷积窗口大小为步长的整数倍。
  • 优化器使用的是RMSprop,并使用梯度裁剪和梯度衰减。

训练过程为:

数据集为cifar10

定义生成器网络,输入为隐空间中一个矢量,输出为一个图片。

定义鉴别器网络,输入为生成器网络采样所得的图片和真实图片(以及标签),输出为sigmoid激活函数的标量值,即判断图片为真实还是伪造。

定义生成对抗网络,为D(G(x))即生成网络与鉴别网络的嵌套形式。输入为生成网络的输入,输出为鉴别器网络的输出。

训练时,使用高斯分布从隐空间中采样,经过生成网络得到生成的图片,与真实图片混合后(以及标签)作为鉴别器网络的输入。

先训练鉴别器。然后重新采样生成图片,此时需将这些图片的标签置为真实图片的标签(固定标签后,训练生成器,即让其参数调整到鉴别器都以为确实是真实图片)。再训练GAN(此时冻结鉴别器参数,训练的只是生成器)

可以看到,定义了3个模型,只是因为生成器网络的训练要基于鉴别器网络进行。


代码如下

import numpy as np
from keras.datasets import cifar10
from keras.models import Model
from keras.layers import Input,Dense,LeakyReLU,Reshape,Conv2D,Conv2DTranspose,Flatten,Dropout
from keras.optimizers import RMSprop
from keras.preprocessing import image
import os latent_dim=32
# Cifar10图片尺寸
height,width=(32,32)
channels=3

3个网络定义

# 生成网络:将隐空间中矢量生成图片,使用Conv2DTranspose
generator_input=Input((latent_dim,))
x=Dense(128*16*16)(generator_input)
# 只添加了一个alpha参数,其他地方跟书上一致,alpha默认0.3
x=LeakyReLU(alpha=0.1)(x)
x=Reshape((16,16,128))(x)
x=Conv2D(256,5,padding='same')(x)
x=LeakyReLU(alpha=0.1)(x)
# 结果为32*32*256,为避免生成图片呈现棋盘的点阵格式,凡是使用strides的地方,窗口大小为strides的整数倍
x=Conv2DTranspose(256,4,strides=2,padding='same')(x)
x=LeakyReLU(alpha=0.1)(x) x=Conv2D(256,5,padding='same')(x)
x=LeakyReLU(alpha=0.1)(x)
x=Conv2D(256,5,padding='same')(x)
x=LeakyReLU(alpha=0.1)(x) # 结果为32*32*3,即一个图片正确格式。使用tanh代替sigmoid
x=Conv2D(channels,7,activation='tanh',padding='same')(x)
generator=Model(generator_input,x)#它在包含在GAN里训练的,所以这里不用编译
# generator.summary() # 鉴别网络
discriminator_input=Input((height,width,channels))
x=Conv2D(128,3)(discriminator_input)
x=LeakyReLU(alpha=0.1)(x) x=Conv2D(128,4,strides=2)(x)
x=LeakyReLU(alpha=0.1)(x)
x=Conv2D(128,4,strides=2)(x)
x=LeakyReLU(alpha=0.1)(x)
# 2*2*128
x=Conv2D(128,4,strides=2)(x)
x=LeakyReLU(alpha=0.1)(x)
x=Flatten()(x)
# Dropout和给标签添加噪声,可以避免GAN卡住
x=Dropout(0.4)(x)
x=Dense(1,activation='sigmoid')(x) discriminator=Model(discriminator_input,x)
# discriminator.summary() # clipvalue,梯度超过这个值就截断,decay,衰减,使得训练稳定
discriminator_optimizer=RMSprop(lr=0.0003,clipvalue=1.0,decay=1e-8)
discriminator.compile(optimizer=discriminator_optimizer,loss='binary_crossentropy') # 最后的生成对抗网络,由生成网络与对抗网络组合而成,此时冻结鉴别网络,训练的只是生成网络
discriminator.trainable=False
# 组成整个生成对抗网络
gan_input=Input((latent_dim,))
# 最终网络形式为鉴别网络作用于生成网络,故生成器也不用compile
gan_output=discriminator(generator(gan_input))
gan_optimizer=RMSprop(lr=0.0004,clipvalue=1.0,decay=1e-8)
gan=Model(gan_input,gan_output)
gan.compile(optimizer=gan_optimizer,loss='binary_crossentropy')

训练过程,此处并未使用多次鉴别器更新一次生成器更新,你可以自己调整(即循环里面开头添加个循环,训练鉴别器)。

(x_train,y_train),(x_test,y_test)=cifar10.load_data()
# 选择frog类别,总共10个类
x_train=x_train[y_train.flatten()==6]
# reshape到输入格式 nums*height*width*channels,像素归一化
x_train=x_train.reshape((x_train.shape[0],)+(height,width,channels)).astype('float32')/255.
iters=10000
batch_size=20
save_dir='frog' start=0
for step in range(iters):
# 选取潜空间中随机矢量(正态分布)
random_latent_vec=np.random.normal(size=(batch_size,latent_dim))
# 生成网络产生图片
generated_images=generator.predict(random_latent_vec)
stop=start+batch_size
# 真实原始图片
real_images=x_train[start:stop]
# mix生成和真实图片
combined_images=np.concatenate([generated_images,real_images])
# mix labels
labels=np.concatenate([np.ones((batch_size,1)),np.zeros((batch_size,1))])
# trick:标签添加随机噪声
labels+=0.05*np.random.random(labels.shape)
# 鉴别loss,可能为负,因为使用的是LeakyReLU
d_loss=discriminator.train_on_batch(combined_images,labels)
# 重新生成随机矢量
random_latent_vec=np.random.normal(size=(batch_size,latent_dim))
# 故意设置标签为真实
misleading_targets=np.zeros((batch_size,1))
a_loss=gan.train_on_batch(random_latent_vec,misleading_targets)
start+=batch_size
if start>len(x_train)-batch_size:
start=0
if step%100==0:
# gan.save_weights('gan.h5')
print('discriminator loss:',d_loss)
print('adversarial loss:',a_loss)
# 保存一个batch里的第一个图片,之前像素归一化了,这里乘以255还原
img=image.array_to_img(generated_images[0]*255.,scale=False)
img.save(os.path.join(save_dir,'generated_frog'+str(step)+'.png'))
# 保存一个对比图片
img=image.array_to_img(real_images[0]*255.,scale=False)
img.save(os.path.join(save_dir,'real_frog'+str(step)+'.png'))

loss变化趋势,可以看到是不稳定的

看真实图和生成图片对比,上下2行图片只是同一批保存的,没有相关性。这是训练4000步,也即80000个训练样本后的结果。看起来比较丑陋吧。

GAN(生成对抗网络)之keras实践的更多相关文章

  1. 不到 200 行代码,教你如何用 Keras 搭建生成对抗网络(GAN)【转】

    本文转载自:https://www.leiphone.com/news/201703/Y5vnDSV9uIJIQzQm.html 生成对抗网络(Generative Adversarial Netwo ...

  2. 生成对抗网络(Generative Adversarial Networks,GAN)初探

    1. 从纳什均衡(Nash equilibrium)说起 我们先来看看纳什均衡的经济学定义: 所谓纳什均衡,指的是参与人的这样一种策略组合,在该策略组合上,任何参与人单独改变策略都不会得到好处.换句话 ...

  3. GAN实战笔记——第四章深度卷积生成对抗网络(DCGAN)

    深度卷积生成对抗网络(DCGAN) 我们在第3章实现了一个GAN,其生成器和判别器是具有单个隐藏层的简单前馈神经网络.尽管很简单,但GAN的生成器充分训练后得到的手写数字图像的真实性有些还是很具说服力 ...

  4. GAN实战笔记——第六章渐进式增长生成对抗网络(PGGAN)

    渐进式增长生成对抗网络(PGGAN) 使用 TensorFlow和 TensorFlow Hub( TFHUB)构建渐进式增长生成对抗网络( Progressive GAN, PGGAN或 PROGA ...

  5. [ZZ] Valse 2017 | 生成对抗网络(GAN)研究年度进展评述

    Valse 2017 | 生成对抗网络(GAN)研究年度进展评述 https://www.leiphone.com/news/201704/fcG0rTSZWqgI31eY.html?viewType ...

  6. 生成对抗网络(GAN)

    基本思想 GAN全称生成对抗网络,是生成模型的一种,而他的训练则是处于一种对抗博弈状态中的. 譬如:我要升职加薪,你领导力还不行,我现在领导力有了要升职加薪,你执行力还不行,我现在执行力有了要升职加薪 ...

  7. TensorFlow从1到2(十二)生成对抗网络GAN和图片自动生成

    生成对抗网络的概念 上一篇中介绍的VAE自动编码器具备了一定程度的创造特征,能够"无中生有"的由一组随机数向量生成手写字符的图片. 这个"创造能力"我们在模型中 ...

  8. AI佳作解读系列(六) - 生成对抗网络(GAN)综述精华

    注:本文来自机器之心的PaperWeekly系列:万字综述之生成对抗网络(GAN),如有侵权,请联系删除,谢谢! 前阵子学习 GAN 的过程发现现在的 GAN 综述文章大都是 2016 年 Ian G ...

  9. 生成对抗网络GAN介绍

    GAN原理 生成对抗网络GAN由生成器和判别器两部分组成: 判别器是常规的神经网络分类器,一半时间判别器接收来自训练数据中的真实图像,另一半时间收到来自生成器中的虚假图像.训练判别器使得对于真实图像, ...

随机推荐

  1. C语言例题

    1.连接两个字符串 将两个字符串连接,不要用stract函数 2.求矩阵外围元素之和 求3行3列矩阵的外围元素之和. 3.求矩阵主对角线和副对角线元素之和 求5行5列矩阵的主对角线和副对角线元素之和. ...

  2. 【python爬虫】 爬云音乐我和xxx共同听过的歌曲

    闲聊的时候,觉得,想写个爬虫,爬下2个人共同听过的歌曲有哪些,然后一鼓作气,花了一个多小时,写了一个.支持最近一周和所有时间,需要用户没有关闭听歌排行显示 How to start 使用到的工具是Se ...

  3. codeforces#562 Div2 C---Increasing by modulo【二分】

    题目:http://codeforces.com/contest/1169/problem/C 题意: 有n个数,每次可以选择k个,将他们+1并对m取模.问最少进行多少次操作,使得序列是非递减的. 思 ...

  4. Codeforces Round #587 (Div. 3) D. Swords

    链接: https://codeforces.com/contest/1216/problem/D 题意: There were n types of swords in the theater ba ...

  5. 在maven项目中如何引入另外一个项目(转)

    原文链接:https://blog.csdn.net/jianfpeng241241/article/details/52654352 1  在Myeclipse中准备两个maven demo. , ...

  6. HDU 6048 - Puzzle | 2017 Multi-University Training Contest 2

    /* HDU 6048 - Puzzle [ 思维,结论 ] | 2017 Multi-University Training Contest 2 题意: 类似华容道的问题, N*M 的矩阵中N*M- ...

  7. scrapy+selenium 爬取淘宝商城商品数据存入到mongo中

    1.配置信息 # 设置mongo参数 MONGO_URI = 'localhost' MONGO_DB = 'taobao' # 设置搜索关键字 KEYWORDS=['小米手机','华为手机'] # ...

  8. 多线程 submit与execute区别

    (1)可以接受的任务类型 submit: execute: 可以看出: execute只能接受Runnable类型的任务 submit不管是Runnable还是Callable类型的任务都可以接受,但 ...

  9. java实现上传文件夹

    我们平时经常做的是上传文件,上传文件夹与上传文件类似,但也有一些不同之处,这次做了上传文件夹就记录下以备后用. 首先我们需要了解的是上传文件三要素: 1.表单提交方式:post (get方式提交有大小 ...

  10. 透彻tarjan

    tarjan 求强连通分量: #include<cstdio> #include<iostream> #include<cstdlib> #define N 100 ...