生成式对抗网络(GAN,generative adversarial network)由Goodfellow等人于2014年提出,它可以替代VAE来学习图像的潜在空间。它能够迫使生成图像与真实图像在统计上几乎无法区别,从而生成相当逼真的合成图像。

1.GAN是什么?

简单来说就是由两部分组成,生成器generator网络和判别器discriminator网络。一部分不断进化,使其对立部分也不断进化,实现共同进化的过程。

对GAN的一种直观理解是,想象我们想要试图生成一个二次元头像。一开始,我们并不擅长这项任务,就将自己的一些噪音二次元头像和真的二次元头像混在一起,并将其展示给discriminator。discriminator对每个头像进行真实性评估,并向我们给出反馈,告诉我们是什么让二次元头像看起来像真的二次元头像,我们回到自己的工作室,并准备一些新的二次元头像。随着时间的推移,我们变得越来越擅长模仿二次元头像的风格,discriminator也变得越来越擅长找出假的二次元头像。最后,我们手上拥有了一些优秀的二次元头像。

2.为什么?

【1】为什么我们有真的二次元头像和假的二次元头像,为什么不自己用监督学习生成新的二次元头像呢?

  generator无法自己独立学习的原因是,以vae为例,输出layer层输出的是各像素点,而他们在输出时是独立的,没有相互作用的,因此无法判断总体的效果进行自主学习。对于discriminator,其输入是生成的整张图像,因此可以从总体上进行判断。

  需要注意的是,discriminator对于输入的真实图像都应是高分,那么如果训练时只给它真实图像的话,他就无法实现正确的判断,会将所有输入都判为高分。所以需要一些差的图像送给discriminator进行训练,并且这些差的图像不应是简单的加些噪声之类的能让它轻易分辨的。因此,训练它的方法是,除真实图像外先给它一些随机生成的差的例子,然后对discriminator解argmaxD(x)做generation生成出一些他觉得好的图像,然后将原本极差的图像换为这些图像再进行训练,如此往复,discriminator会不断产生更好的图像,将这些作为negative examples给其学习,达到训练的目的。

【2】discriminator对真的二次元头像这么了解,为什么他不自己做,而是要来指导我们做呢?

那既然如此,为什么还需要generator呢?discriminator自己也可以生成图像啊?

这是因为discriminator生成图像需要解argmaxD(x), 难度较大,一般需要假设一些条件才会好解,比如网络假设为线性时,但这样会限制图像的生成效果。而generator生成非常快,因此将二者结合起来共同学习实现输出好的结果。二者优缺点如下所示:

总而言之,因为generator没有全局观,所以需要结合discriminator学习,对于discriminator,使用generator生成图像比自己解方程生成更简单高效,这二者的优缺点相互补充。

GAN的目的是为了生成,而VAE目的是为了压缩,目的不同效果自然不同。比如,由于二范数的原因,VAE的生成是模糊的。而GAN的生成是犀利的。

数据集为CIFAR10,包含50000张32*32的RGB图像,这些图像属于10个类别(每个类别5000张图像),这里我们只使用属于“frog”(青蛙)类别的图像

import keras
from keras import layers
import numpy as np

  

生成器网络:将一个向量(来自潜在空间,训练过程中对其随机采样)转换为一张候选图像

生成器从未直接见过训练集中的图像,它所知道的关于数据的信息都来自于判别器。

latent_dim = 32
height = 32
width = 32
channels = 3 generator_input = keras.Input(shape=(latent_dim,)) #将输入转换为大小为16*16的128个通道的特征图
x = layers.Dense(128*16*16)(generator_input)
x = layers.LeakyReLU()(x)
x = layers.Reshape((16,16,128))(x) x = layers.Conv2D(256,5,padding='same')(x)
x = layers.LeakyReLU()(x) #上采样为32*32
x = layers.Conv2DTranspose(256,4,strides=2,padding='same')(x)
x = layers.LeakyReLU()(x) x = layers.Conv2D(256,5,padding='same')(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(256,5,padding='same')(x)
x = layers.LeakyReLU()(x) #生成一个大小为32*32的单通道特征图(即CIFAR10图像的形状)
x = layers.Conv2D(channels,7,activation='tanh',padding='same')(x)
#将生成器模型实例化,它将形状为(latent_dim,)的输入映射到形状为(32,32,3)的图像
generator = keras.models.Model(generator_input,x)
generator.summary()

  

判别器网络:它接收一张候选图像(真实的或合成的)作为输入,并将其划分到这两个类别之一:"生成图像"或"来自训练集的真实图像"

#GAN判别器网络
discriminator_input = layers.Input(shape=(height,width,channels))
x = layers.Conv2D(128,3)(discriminator_input)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128,4,strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128,4,strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128,4,strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Flatten()(x) x = layers.Dropout(0.4)(x) x = layers.Dense(1,activation='sigmoid')(x)#分类层 #将判别器模型实例化,它将形状为(32,32,3)的输入转换为一个二进制分类决策(真/假)
discriminator = keras.models.Model(discriminator_input,x)
discriminator.summary()

  

discriminator_optimizer = keras.optimizers.RMSprop(
lr=0.0008,
clipvalue = 1.0, #在优化器中使用梯度裁剪(限制梯度值的范围)
decay = 1e-8,#为了稳定训练过程,使用学习率衰减
) discriminator.compile(optimizer=discriminator_optimizer,
loss='binary_crossentropy')

  

设置GAN,将生成器和判别器连接在一起 训练时,这个模型将让生成器向某个方向移动,从而提高它欺骗判别器的能力。这个模型将潜在空间的点转换为一个分类决策(即"真"或"假") 它训练的标签都是"真实图像"。因此,训练gan将会更新generator得到权重,使得discriminator在观测假图像时更有可能预测为"真"。

对抗网络

discriminator.trainable = True #将判别器权重设置为不可训练(仅应用于gan模型)

gan_input = keras.Input(shape=(latent_dim,))
gan_output = discriminator(generator(gan_input))
gan = keras.models.Model(gan_input,gan_output)
gan_optimizer = keras.optimizers.RMSprop(lr=0.0004,clipvalue=1.0,decay=1e-8)
gan.compile(optimizer=gan_optimizer,loss='binary_crossentropy')

  

注意:在训练过程中需要将判别器设置为冻结(即不可训练),这样在训练gan时它的权重才不会更新。 如果在此过程中可以对判别器的权重进行更新,那么我们就是在训练判别器始终预测"真",但这并不是我们想要的。

实现GAN的训练

import os
from keras.preprocessing import image (x_train,y_train),(_,_) = keras.datasets.cifar10.load_data() #cifar数据集
x_train = x_train[y_train.flatten() == 6]#选择青蛙的图像
x_train = x_train.reshape((x_train.shape[0],) + (height,width,channels)).astype('float32')/255.

  

iterations = 1000
batch_size = 2
save_dir = 'frog_dir' start = 0 for step in range(iterations):
random_latent_vectors = np.random.normal(size=(batch_size,latent_dim))
generated_images = generator.predict(random_latent_vectors)#点-->虚假图像 stop = start + batch_size #混淆真实图像和虚假图像
real_images = x_train[start:stop]
combined_images = np.concatenate([generated_images,
real_images])
labels = np.concatenate([np.ones((batch_size,1)),
np.zeros((batch_size,1))])
labels += 0.05 * np.random.random(labels.shape) #向标签中添加噪声 #训练判别器
d_loss = discriminator.train_on_batch(combined_images,labels) #在潜在空间中采样随机点
random_latent_vectors = np.random.normal(size=(batch_size,latent_dim))
#合并标签,全都是“真实图像”(这是在撒谎)
misleading_targets = np.zeros((batch_size,1)) #通过gan模型来训练生成器(此时冻结判别器模型)
a_loss = gan.train_on_batch(random_latent_vectors,misleading_targets) start += batch_size
if start > len(x_train) - batch_size:
start = 0
if step % 2 == 0:
gan.save_weights('gan.h5') print('discriminator loss:',d_loss)
print('adversarial loss:',a_loss) img = image.array_to_img(generated_images[0] * 255.,scale=False)
img.save(os.path.join(save_dir,'generated_frog'+str(step)+'.png')) img = image.array_to_img(real_images[0]*255.,scale=False)
img.save(os.path.join(save_dir,'real_frog'+str(step)+'.png'))

  

判别器损失:d_loss=(生成的图像和真实图像->标签)

gan损失:a_loss=(随机采样的点->全是'真'的标签)

第一次  

最后一次 

4.keras实现-->生成式深度学习之用GAN生成图像的更多相关文章

  1. 4.keras实现-->生成式深度学习之用变分自编码器VAE生成图像(mnist数据集和名人头像数据集)

    变分自编码器(VAE,variatinal autoencoder)   VS    生成式对抗网络(GAN,generative adversarial network) 两者不仅适用于图像,还可以 ...

  2. 4.keras实现-->生成式深度学习之DeepDream

    DeepDream是一种艺术性的图像修改技术,它用到了卷积神经网络学到的表示,DeepDream由Google于2015年发布.这个算法与卷积神经网络过滤器可视化技术几乎相同,都是反向运行一个卷积神经 ...

  3. 从零开始学会GAN 0:第一部分 介绍生成式深度学习(连载中)

    本书的前四章旨在介绍开始构建生成式深度学习模型所需的核心技术.在第1章中,我们将首先对生成式建模领域进行广泛的研究,并从概率的角度考虑我们试图解决的问题类型.然后,我们将探讨我们的基本概率生成模型的第 ...

  4. 深度学习新星:GAN的基本原理、应用和走向

    深度学习新星:GAN的基本原理.应用和走向 (本文转自雷锋网,转载已获取授权,未经允许禁止转载)原文链接:http://www.leiphone.com/news/201701/Kq6FvnjgbKK ...

  5. 深度学习之 rnn 台词生成

    深度学习之 rnn 台词生成 写一个台词生成的程序,用 pytorch 写的. import os def load_data(path): with open(path, 'r', encoding ...

  6. 转:TensorFlow和Caffe、MXNet、Keras等其他深度学习框架的对比

    http://geek.csdn.net/news/detail/138968 Google近日发布了TensorFlow 1.0候选版,这第一个稳定版将是深度学习框架发展中的里程碑的一步.自Tens ...

  7. 深度学习----现今主流GAN原理总结及对比

    原文地址:https://blog.csdn.net/Sakura55/article/details/81514828 1.GAN 先来看看公式:             GAN网络主要由两个网络构 ...

  8. 惊不惊喜, 用深度学习 把设计图 自动生成HTML代码 !

    如何用前端页面原型生成对应的代码一直是我们关注的问题,本文作者根据 pix2code 等论文构建了一个强大的前端代码生成模型,并详细解释了如何利用 LSTM 与 CNN 将设计原型编写为 HTML 和 ...

  9. 深度学习在gilt应用——用图像相似性搜索引擎来商品推荐和服务属性分类

    机器学习起源于神经网络,而深度学习是机器学习的一个快速发展的子领域.最近的一些算法的进步和GPU并行计算的使用,使得基于深度学习的算法可以在围棋和其他的一些实际应用里取得很好的成绩. 时尚产业是深度学 ...

随机推荐

  1. 分布式实时日志系统(四) 环境搭建之centos 6.4下hbase 1.0.1 分布式集群搭建

    一.hbase简介 HBase是一个开源的非关系型分布式数据库(NoSQL),它参考了谷歌的BigTable建模,实现的编程语言为 Java.它是Apache软件基金会的Hadoop项目的一部分,运行 ...

  2. 【WEB前端系列之CSS】CSS3动画之Tranition

    前言 css中的transition允许css的属性值在一定的时间区间内平滑的过渡.这种效果可以在鼠标点击.获得焦点.被点击或对元素任何改变中触发,并圆滑的以动画效果改变CSS的属性值.语法: tra ...

  3. Makefile Demo案例

    # Comments can be written like this. # File should be named Makefile and then can be run as `make &l ...

  4. 异构GoldenGate 12c 单向复制配置(支持DDL复制)

    1.开始配置OGG支持DDL复制(在source端操作) 1.1 赋予权限 SQL> conn /as sysdba 已连接. SQL> grant execute on utl_file ...

  5. 题目1458:汉诺塔III(不一样的汉诺塔递归算法)

    题目链接:http://ac.jobdu.com/problem.php?pid=1458 详解链接:https://github.com/zpfbuaa/JobduInCPlusPlus 参考代码: ...

  6. SharePoint 2013部署自定义HttpModule访问SPContext.Current的一个问题

    如果文档库post提交文档时,自定义HttpModule正好有代码访问SPContext.Current属性则会导致上传文档失败.

  7. redmine创建新闻,自动发邮件给项目组所有成员

    redmine创建新闻,自动发邮件给项目组所有成员: 1.添加用户至公共项目内 2.配置系统邮件推送配置 3.检查用户接受推送配置 3.使用管理员账户发布新闻(不能自己发送自己) 4.查看邮件接受邮件

  8. Django---项目如何创建

    首先是安装好Django,找到 Scripts 目录配置环境变量: 只要添加到环境变量,在任何目录执行 django-admin startproject mysite 就可以创建 Django 程序 ...

  9. hdu2586(LCA最近公共祖先)

    How far away ? Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 32768/32768 K (Java/Others) T ...

  10. php base64转图片

    1.解析base64数据成图片 The problem is that data:image/bmp;base64, is included in the encoded contents. This ...