生成式对抗网络(GAN,generative adversarial network)由Goodfellow等人于2014年提出,它可以替代VAE来学习图像的潜在空间。它能够迫使生成图像与真实图像在统计上几乎无法区别,从而生成相当逼真的合成图像。

1.GAN是什么?

简单来说就是由两部分组成,生成器generator网络和判别器discriminator网络。一部分不断进化,使其对立部分也不断进化,实现共同进化的过程。

对GAN的一种直观理解是,想象我们想要试图生成一个二次元头像。一开始,我们并不擅长这项任务,就将自己的一些噪音二次元头像和真的二次元头像混在一起,并将其展示给discriminator。discriminator对每个头像进行真实性评估,并向我们给出反馈,告诉我们是什么让二次元头像看起来像真的二次元头像,我们回到自己的工作室,并准备一些新的二次元头像。随着时间的推移,我们变得越来越擅长模仿二次元头像的风格,discriminator也变得越来越擅长找出假的二次元头像。最后,我们手上拥有了一些优秀的二次元头像。

2.为什么?

【1】为什么我们有真的二次元头像和假的二次元头像,为什么不自己用监督学习生成新的二次元头像呢?

  generator无法自己独立学习的原因是,以vae为例,输出layer层输出的是各像素点,而他们在输出时是独立的,没有相互作用的,因此无法判断总体的效果进行自主学习。对于discriminator,其输入是生成的整张图像,因此可以从总体上进行判断。

  需要注意的是,discriminator对于输入的真实图像都应是高分,那么如果训练时只给它真实图像的话,他就无法实现正确的判断,会将所有输入都判为高分。所以需要一些差的图像送给discriminator进行训练,并且这些差的图像不应是简单的加些噪声之类的能让它轻易分辨的。因此,训练它的方法是,除真实图像外先给它一些随机生成的差的例子,然后对discriminator解argmaxD(x)做generation生成出一些他觉得好的图像,然后将原本极差的图像换为这些图像再进行训练,如此往复,discriminator会不断产生更好的图像,将这些作为negative examples给其学习,达到训练的目的。

【2】discriminator对真的二次元头像这么了解,为什么他不自己做,而是要来指导我们做呢?

那既然如此,为什么还需要generator呢?discriminator自己也可以生成图像啊?

这是因为discriminator生成图像需要解argmaxD(x), 难度较大,一般需要假设一些条件才会好解,比如网络假设为线性时,但这样会限制图像的生成效果。而generator生成非常快,因此将二者结合起来共同学习实现输出好的结果。二者优缺点如下所示:

总而言之,因为generator没有全局观,所以需要结合discriminator学习,对于discriminator,使用generator生成图像比自己解方程生成更简单高效,这二者的优缺点相互补充。

GAN的目的是为了生成,而VAE目的是为了压缩,目的不同效果自然不同。比如,由于二范数的原因,VAE的生成是模糊的。而GAN的生成是犀利的。

数据集为CIFAR10,包含50000张32*32的RGB图像,这些图像属于10个类别(每个类别5000张图像),这里我们只使用属于“frog”(青蛙)类别的图像

import keras
from keras import layers
import numpy as np

  

生成器网络:将一个向量(来自潜在空间,训练过程中对其随机采样)转换为一张候选图像

生成器从未直接见过训练集中的图像,它所知道的关于数据的信息都来自于判别器。

latent_dim = 32
height = 32
width = 32
channels = 3 generator_input = keras.Input(shape=(latent_dim,)) #将输入转换为大小为16*16的128个通道的特征图
x = layers.Dense(128*16*16)(generator_input)
x = layers.LeakyReLU()(x)
x = layers.Reshape((16,16,128))(x) x = layers.Conv2D(256,5,padding='same')(x)
x = layers.LeakyReLU()(x) #上采样为32*32
x = layers.Conv2DTranspose(256,4,strides=2,padding='same')(x)
x = layers.LeakyReLU()(x) x = layers.Conv2D(256,5,padding='same')(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(256,5,padding='same')(x)
x = layers.LeakyReLU()(x) #生成一个大小为32*32的单通道特征图(即CIFAR10图像的形状)
x = layers.Conv2D(channels,7,activation='tanh',padding='same')(x)
#将生成器模型实例化,它将形状为(latent_dim,)的输入映射到形状为(32,32,3)的图像
generator = keras.models.Model(generator_input,x)
generator.summary()

  

判别器网络:它接收一张候选图像(真实的或合成的)作为输入,并将其划分到这两个类别之一:"生成图像"或"来自训练集的真实图像"

#GAN判别器网络
discriminator_input = layers.Input(shape=(height,width,channels))
x = layers.Conv2D(128,3)(discriminator_input)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128,4,strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128,4,strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128,4,strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Flatten()(x) x = layers.Dropout(0.4)(x) x = layers.Dense(1,activation='sigmoid')(x)#分类层 #将判别器模型实例化,它将形状为(32,32,3)的输入转换为一个二进制分类决策(真/假)
discriminator = keras.models.Model(discriminator_input,x)
discriminator.summary()

  

discriminator_optimizer = keras.optimizers.RMSprop(
lr=0.0008,
clipvalue = 1.0, #在优化器中使用梯度裁剪(限制梯度值的范围)
decay = 1e-8,#为了稳定训练过程,使用学习率衰减
) discriminator.compile(optimizer=discriminator_optimizer,
loss='binary_crossentropy')

  

设置GAN,将生成器和判别器连接在一起 训练时,这个模型将让生成器向某个方向移动,从而提高它欺骗判别器的能力。这个模型将潜在空间的点转换为一个分类决策(即"真"或"假") 它训练的标签都是"真实图像"。因此,训练gan将会更新generator得到权重,使得discriminator在观测假图像时更有可能预测为"真"。

对抗网络

discriminator.trainable = True #将判别器权重设置为不可训练(仅应用于gan模型)

gan_input = keras.Input(shape=(latent_dim,))
gan_output = discriminator(generator(gan_input))
gan = keras.models.Model(gan_input,gan_output)
gan_optimizer = keras.optimizers.RMSprop(lr=0.0004,clipvalue=1.0,decay=1e-8)
gan.compile(optimizer=gan_optimizer,loss='binary_crossentropy')

  

注意:在训练过程中需要将判别器设置为冻结(即不可训练),这样在训练gan时它的权重才不会更新。 如果在此过程中可以对判别器的权重进行更新,那么我们就是在训练判别器始终预测"真",但这并不是我们想要的。

实现GAN的训练

import os
from keras.preprocessing import image (x_train,y_train),(_,_) = keras.datasets.cifar10.load_data() #cifar数据集
x_train = x_train[y_train.flatten() == 6]#选择青蛙的图像
x_train = x_train.reshape((x_train.shape[0],) + (height,width,channels)).astype('float32')/255.

  

iterations = 1000
batch_size = 2
save_dir = 'frog_dir' start = 0 for step in range(iterations):
random_latent_vectors = np.random.normal(size=(batch_size,latent_dim))
generated_images = generator.predict(random_latent_vectors)#点-->虚假图像 stop = start + batch_size #混淆真实图像和虚假图像
real_images = x_train[start:stop]
combined_images = np.concatenate([generated_images,
real_images])
labels = np.concatenate([np.ones((batch_size,1)),
np.zeros((batch_size,1))])
labels += 0.05 * np.random.random(labels.shape) #向标签中添加噪声 #训练判别器
d_loss = discriminator.train_on_batch(combined_images,labels) #在潜在空间中采样随机点
random_latent_vectors = np.random.normal(size=(batch_size,latent_dim))
#合并标签,全都是“真实图像”(这是在撒谎)
misleading_targets = np.zeros((batch_size,1)) #通过gan模型来训练生成器(此时冻结判别器模型)
a_loss = gan.train_on_batch(random_latent_vectors,misleading_targets) start += batch_size
if start > len(x_train) - batch_size:
start = 0
if step % 2 == 0:
gan.save_weights('gan.h5') print('discriminator loss:',d_loss)
print('adversarial loss:',a_loss) img = image.array_to_img(generated_images[0] * 255.,scale=False)
img.save(os.path.join(save_dir,'generated_frog'+str(step)+'.png')) img = image.array_to_img(real_images[0]*255.,scale=False)
img.save(os.path.join(save_dir,'real_frog'+str(step)+'.png'))

  

判别器损失:d_loss=(生成的图像和真实图像->标签)

gan损失:a_loss=(随机采样的点->全是'真'的标签)

第一次  

最后一次 

4.keras实现-->生成式深度学习之用GAN生成图像的更多相关文章

  1. 4.keras实现-->生成式深度学习之用变分自编码器VAE生成图像(mnist数据集和名人头像数据集)

    变分自编码器(VAE,variatinal autoencoder)   VS    生成式对抗网络(GAN,generative adversarial network) 两者不仅适用于图像,还可以 ...

  2. 4.keras实现-->生成式深度学习之DeepDream

    DeepDream是一种艺术性的图像修改技术,它用到了卷积神经网络学到的表示,DeepDream由Google于2015年发布.这个算法与卷积神经网络过滤器可视化技术几乎相同,都是反向运行一个卷积神经 ...

  3. 从零开始学会GAN 0:第一部分 介绍生成式深度学习(连载中)

    本书的前四章旨在介绍开始构建生成式深度学习模型所需的核心技术.在第1章中,我们将首先对生成式建模领域进行广泛的研究,并从概率的角度考虑我们试图解决的问题类型.然后,我们将探讨我们的基本概率生成模型的第 ...

  4. 深度学习新星:GAN的基本原理、应用和走向

    深度学习新星:GAN的基本原理.应用和走向 (本文转自雷锋网,转载已获取授权,未经允许禁止转载)原文链接:http://www.leiphone.com/news/201701/Kq6FvnjgbKK ...

  5. 深度学习之 rnn 台词生成

    深度学习之 rnn 台词生成 写一个台词生成的程序,用 pytorch 写的. import os def load_data(path): with open(path, 'r', encoding ...

  6. 转:TensorFlow和Caffe、MXNet、Keras等其他深度学习框架的对比

    http://geek.csdn.net/news/detail/138968 Google近日发布了TensorFlow 1.0候选版,这第一个稳定版将是深度学习框架发展中的里程碑的一步.自Tens ...

  7. 深度学习----现今主流GAN原理总结及对比

    原文地址:https://blog.csdn.net/Sakura55/article/details/81514828 1.GAN 先来看看公式:             GAN网络主要由两个网络构 ...

  8. 惊不惊喜, 用深度学习 把设计图 自动生成HTML代码 !

    如何用前端页面原型生成对应的代码一直是我们关注的问题,本文作者根据 pix2code 等论文构建了一个强大的前端代码生成模型,并详细解释了如何利用 LSTM 与 CNN 将设计原型编写为 HTML 和 ...

  9. 深度学习在gilt应用——用图像相似性搜索引擎来商品推荐和服务属性分类

    机器学习起源于神经网络,而深度学习是机器学习的一个快速发展的子领域.最近的一些算法的进步和GPU并行计算的使用,使得基于深度学习的算法可以在围棋和其他的一些实际应用里取得很好的成绩. 时尚产业是深度学 ...

随机推荐

  1. 【Java基础系列】Java IO系统

    前言 创建好的输入/输出系统不仅要考虑三种不同种类的IO系统(文件,控制台,网络连接)还需要通过大量不同的方式与他们通信(顺序,随机访问,二进制,字符,按行,按字等等). 一.输入和输出 Java的I ...

  2. VS 2008 头文件库文件设置

    在程序开发中,很多时候需要用到别人开发的工具包,如OpenCV和itk.一般而言,在vs2008中,很少使用源文件,大部分是使用对类进行声明的头文件和封装了类的链接库(静态lib或动态dll). 如果 ...

  3. 题目1460:Oil Deposit(递归遍历图)

    题目链接:http://ac.jobdu.com/problem.php?pid=1460 详解链接:https://github.com/zpfbuaa/JobduInCPlusPlus 参考代码: ...

  4. sencha touch 2.3.1 list emptyText不显示

    如图所示,有时候没有取到任何的数据. 那么我们就需要显示没有获取到内容这一类提示,显示内容通常通过emptyText这个属性来配置. 但是在sencha touch 2.3.1之中有可能会出问题,所以 ...

  5. JavaScript 异步进化史

    前言 JS 中最基础的异步调用方式是 callback,它将回调函数 callback 传给异步 API,由浏览器或 Node 在异步完成后,通知 JS 引擎调用 callback.对于简单的异步操作 ...

  6. 移动端前端框架UI库

    移动端前端框架UI库(Frozen UI.WeUI.SUI Mobile) Frozen UI 自述:简单易用,轻量快捷,为移动端服务的前端框架. 主页:http://frozenui.github. ...

  7. linux下的一些操作命令

    1.切换到root账号下: su root    输入密码: 2.修改root账号密码: sudo passwd root   输入密码: 3.cat用法: 查看文件内容   cat 文件名 创建文件 ...

  8. Linux(CentOS)安装JDK(.tar.gz)并配置

    本文思路转自http://blog.sina.com.cn/s/blog_81631744010137iy.html 点击(此处)折叠或打开 1.到 甲骨文(oracle)下载jdk不用多说 tar ...

  9. Python - 3MySQL 数据库连接

    Python3 MySQL 数据库连接 本文我们为大家介绍 Python3 使用 PyMySQL 连接数据库,并实现简单的增删改查. 什么是 PyMySQL? PyMySQL 是在 Python3.x ...

  10. Centos7 安装hive

    安装hive 配置hive 在hdfs中新建目录/user/hive/warehouse 首先启动hadoop任务 hdfs dfs -mkdir /tmp hdfs dfs -mkdir /user ...