Tensorflow2.0实战之GAN
本文主要带领读者了解生成对抗神经网络(GAN),并使用提供的face数据集训练网络
GAN 入门
自 2014 年 Ian Goodfellow 的《生成对抗网络(Generative Adversarial Networks)》论文发表以来,GAN 的进展突飞猛进,生成结果也越来越具有照片真实感。
就在三年前,Ian Goodfellow 在 reddit 上回答 GAN 是否可以应用在文本领域的问题时,还认为 GAN 不能扩展到文本领域。

“由于 GAN 定义在实值数据上,因此 GAN 不能应用于 NLP。
GAN 的工作原理是训练一个生成网络,输出合成数据,然后利用判别网络判别合成数据。判别网络根据合成数据输出的梯度告诉你该如何对合成数据进行微调,使其更真实。
因此只有当合成数据是基于连续数字时,才能对其进行微调。如果是基于离散的数字,就没有办法做微小的改变。
例如,如果输出像素值为 1.0 的图像,则下一步可以将该像素值更改为 1.0001。
但如果输出单词‘penguin’,不能在下一步直接将其更改为‘penguin+.001’,因为没有‘penguin+.001’这样的单词。你必须从‘penguin’直接转变到‘ostrich’。
由于所有的 NLP 都是基于离散的值,如单词、字符或字节,所以目前还没有人知道该如何将 GAN 应用于 NLP。”
但是现在,GAN 已经可用于生成各种内容,包括图像、视频、音频和文本。这些输出的合成数据既可以用于训练其他的模型,也可以用于创建一些有趣的项目。
GAN 原理
GAN 由两个神经网络组成,一个是合成新样本的生成器,另一个是对比训练样本与生成样本的判别器。判别器的目标是区分“真实”和“虚假”的输入(对样本来自模型分布还是真实分布进行分类)。这些样本可以是图像、视频、音频片段和文本。

为了合成这些新的样本,生成器的输入为随机噪声,然后尝试从训练数据中学习到的分布中生成真实的图像。
判别器网络(卷积神经网络)输出相对于合成数据的梯度,其中包含着如何改变合成数据以使其更具真实感的信息。最终生成器收敛,它可以生成符合真实数据分布的样本,而判别器无法区分生成数据和真实数据。
ok,接下来我们就来实现一下
准备阶段
下载数据集
数据集,笔者这里已经为大家提供了,链接如下:
链接: https://pan.baidu.com/s/15wFZAANvr8gajiVY_1mI0A
提取码: c9vy
解压数据集
将下载好的数据集解压,放在工程目录下

加载数据集
加载数据集的代码,笔者这里直接提供给大家了,下面只是展示部分代码,文末会提供完整项目的代码链接
import multiprocessing
import tensorflow as tf
def make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1):
@tf.function
def _map_fn(img):
img = tf.image.resize(img, [resize, resize])
img = tf.clip_by_value(img, 0, 255)
img = img / 127.5 - 1
return img
dataset = disk_image_batch_dataset(img_paths,
batch_size,
drop_remainder=drop_remainder,
map_fn=_map_fn,
shuffle=shuffle,
repeat=repeat)
img_shape = (resize, resize, 3)
len_dataset = len(img_paths) // batch_size
return dataset, img_shape, len_dataset
def batch_dataset(dataset,
batch_size,
drop_remainder=True,
n_prefetch_batch=1,
filter_fn=None,
map_fn=None,
n_map_threads=None,
filter_after_map=False,
shuffle=True,
shuffle_buffer_size=None,
repeat=None):
构建网络
搭建Generator,Generator包含两个部分,init部分和前向传播的call部分,代码如下
class Generator(keras.Model):
def __init__(self):
super(Generator, self).__init__()
# z:[b,100]-->[b,3*3*512]-->[b,3,3,512]-->[b,64,64,3]
self.fc=keras.layers.Dense(3*3*512)
self.conv1=keras.layers.Conv2DTranspose(256,3,3,'valid') # 反卷积
self.bn1=keras.layers.BatchNormalization()
self.conv2=keras.layers.Conv2DTranspose(128,5,2,'valid')
self.bn2=keras.layers.BatchNormalization()
self.conv3=keras.layers.Conv2DTranspose(3,4,3,'valid')
def call(self, inputs, training=None, mask=None):
# [z,100]-->[z,3*3*512]
x=self.fc(inputs)
x=tf.reshape(x,[-1,3,3,512])
x=tf.nn.leaky_relu(x)
x=tf.nn.leaky_relu(self.bn1(self.conv1(x),training=training))
x=tf.nn.leaky_relu(self.bn2(self.conv2(x),training=training))
x=self.conv3(x)
x=tf.tanh(x)
return x
搭建Discriminator,同上
class Discriminator(keras.Model):
def __init__(self):
super(Discriminator, self).__init__()
# [b,64,64,3]-->[b,1]
self.conv1=keras.layers.Conv2D(64,5,3,'valid')
self.conv2=keras.layers.Conv2D(128,5,3,'valid')
self.bn2=keras.layers.BatchNormalization()
self.conv3=keras.layers.Conv2D(256,5,3,'valid')
self.bn3=keras.layers.BatchNormalization()
# [b,h,w,c]-->[b,-1]
self.flatten=keras.layers.Flatten()
# [b,-1]-->[b,1]
self.fc=keras.layers.Dense(1)
def call(self, inputs, training=None, mask=None):
x=tf.nn.leaky_relu(self.conv1(inputs))
x=tf.nn.leaky_relu(self.bn2(self.conv2(x),training=training))
x=tf.nn.leaky_relu(self.bn3(self.conv3(x),training=training))
x=self.flatten(x)
logits=self.fc(x)
return logits
训练GAN
定义相关数据,包括epoch,lr等等
这些数据可以自定义,笔者这里就不改动了
z_dim = 100
epochs = 50000
batch_size = 512
learning_rate = 0.0002
is_training = True
加载数据
img_path=glob.glob(r'E:\python_pro\TF2.0\GAN\faces\*.jpg')
dataset, img_shape, _ = make_anime_dataset(img_path, batch_size)
可以打印查看数据集信息:
(512, 64, 64, 3), (64, 64, 3)
(512, 64, 64, 3) ,1.0, -1.0
定义优化器,注意我们在开始训练时,需要新建训练GAN图片的文件,为查看数据提供持久化依据
for epoch in range(epochs):
batch_z = tf.random.uniform([batch_size, z_dim], minval=-1., maxval=1.)
batch_x = next(db_iter)
# train D
with tf.GradientTape() as tape:
d_loss = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
grads = tape.gradient(d_loss, discriminator.trainable_variables)
d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
with tf.GradientTape() as tape:
g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
grads = tape.gradient(g_loss, generator.trainable_variables)
g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
if epoch % 100 == 0:
print(epoch, 'd-loss:',float(d_loss), 'g-loss:', float(g_loss))
z = tf.random.uniform([100, z_dim])
fake_image = generator(z, training=False)
img_path = os.path.join('GAN_IMAGE', 'gan%d.png'%epoch)
save_result(fake_image.numpy(), 10, img_path, color_mode='P')
训练结果
接下来我们来看看,训练的效果图,注意,GAN的训练过程是非常非常非常慢的,大概训练十几个小时,才能有个比较好的效果,有的数据集甚至会训练几天之久,这个随数据集的大小和对最终效果的要求来定的。笔者这个数据集比较的简单,只是给大家做演示,好了,废话就不过多的说了,上图




上述分别是训练了100epoch、500、1500、4000的效果图,可以看到随着训练的次数增加,效果因为越来越好了
总结
大家在训练GAN时,还是需要一个好一些的GPU显卡才行,这样可以体验GPU给我们带来的加速效果。这样会使得训练的速度大大加快。
笔者水平有限,如有表述不准确的地方还请谅解,有错误的地方欢迎大家批评指正。
最后还是希望大家动手实践实践,共同进步。
最终的代码链接:https://github.com/huzixuan1/TF_2.0/tree/master/GAN
Tensorflow2.0实战之GAN的更多相关文章
- Google老师亲授 TensorFlow2.0实战: 入门到进阶
Google老师亲授 TensorFlow2.0 入门到进阶 课程以Tensorflow2.0框架为主体,以图像分类.房价预测.文本分类等项目为依托,讲解Tensorflow框架的使用方法,同时学习到 ...
- Google工程师亲授 Tensorflow2.0-入门到进阶
第1章 Tensorfow简介与环境搭建 本门课程的入门章节,简要介绍了tensorflow是什么,详细介绍了Tensorflow历史版本变迁以及tensorflow的架构和强大特性.并在Tensor ...
- 『TensorFlow2.0正式版教程』极简安装TF2.0正式版(CPU&GPU)教程
0 前言 TensorFlow 2.0,今天凌晨,正式放出了2.0版本. 不少网友表示,TensorFlow 2.0比PyTorch更好用,已经准备全面转向这个新升级的深度学习框架了. 本篇文章就 ...
- 一文上手Tensorflow2.0(四)
系列文章目录: Tensorflow2.0 介绍 Tensorflow 常见基本概念 从1.x 到2.0 的变化 Tensorflow2.0 的架构 Tensorflow2.0 的安装(CPU和GPU ...
- 一文上手TensorFlow2.0(一)
目录: Tensorflow2.0 介绍 Tensorflow 常见基本概念 从1.x 到2.0 的变化 Tensorflow2.0 的架构 Tensorflow2.0 的安装(CPU和GPU) Te ...
- 【转】WF4.0实战系列索引
转自:http://www.cnblogs.com/zhuqil/archive/2010/07/05/wf4-in-action-index.html 此系列的工作流文件案例比较多点,实用性好. W ...
- 基于tensorflow2.0 使用tf.keras实现Fashion MNIST
本次使用的是2.0测试版,正式版估计会很快就上线了 tf2好像更新了蛮多东西 虽然教程不多 还是找了个试试 的确简单不少,但是还是比较喜欢现在这种写法 老样子先导入库 import tensorflo ...
- vue.js2.0实战(1):搭建开发环境及构建项目
Vue.js学习系列: vue.js2.0实战(1):搭建开发环境及构建项目 https://my.oschina.net/brillantzhao/blog/1541638 vue.js2.0实战( ...
- ETL工具--DataX3.0实战
DataX是一个在异构的数据库/文件系统之间高速交换数据的工具,实现了在任意的数据处理系统(RDBMS/Hdfs/Local filesystem)之间的数据交换,由淘宝数据平台部门完成. DataX ...
- TensorFlow2.0(1):基本数据结构—张量
1 引言 TensorFlow2.0版本已经发布,虽然不是正式版,但预览版都发布了,正式版还会远吗?相比于1.X,2.0版的TensorFlow修改的不是一点半点,这些修改极大的弥补了1.X版本的反人 ...
随机推荐
- 通过商品API接口获取到数据后的分析和应用
一.如果你想要分析商品API接口获取到的数据,可以按照如下的步骤进行: 了解API接口返回值的格式,如JSON格式.XML格式.CSV格式等,选择适合你的数据分析方式. 使用API请求工具(如Post ...
- java类序列化和反序列化
参考:https://zhuanlan.zhihu.com/p/144535172?utm_id=0 https://blog.csdn.net/qq_42617455/article/details ...
- 编译python为可执行文件遇到的问题:使用python-oracledb连接oracle数据库时出现错误:DPY-3010
错误原文: DPY-3010: connections to this database server version are not supported by python-oracledb in ...
- 从零用VitePress搭建博客教程(2) –VitePress默认首页和头部导航、左侧导航配置
2. 从零用VitePress搭建博客教程(2) –VitePress默认首页和头部导航.左侧导航配置 接上一节: 从零用VitePress搭建博客教程(1) – VitePress的安装和运行 四. ...
- mysqli操作
1.使用mysqli_connect()函数,语法如下: mysqli 对象名=mysqli_connect(数据库服务名,用户名,密码,数据库名) 例:$conn=mysqli_connect('l ...
- python加解密小工具
1.地址 https://github.com/Doneone/RSA_GUI 2.用法 python3 rsa_gui.py 创建密钥对 3.思考 为什么要写rsa小工具呐,因为单纯想实现一个简单的 ...
- 『STAOI』G - Round 1 半个游记
很刺激. 挂个链接
- 实战攻防演练-利用Everything搜索软件进行内网后渗透利用
前言 Everything是一款很出名的文件搜索工具,基于文件.文件夹名称的快速搜索的轻量级的软件,而早在几年前就有很多apt组织利用everything来进行文件查找等,前几年在T00ls上也有人发 ...
- openwrt使用tailscale实现内网穿透
问题 之前一直有电信公网ip,最近发现电信公网ip被撤下来了,打电话再去要发现给的是10开头的ip,电信客服还跟我说10开头就是公网ip,= =,根本就不是,无奈使用zerotier进行打洞,把zer ...
- Filter入门实例
一.介绍 Filter:Filter是Servlet的"加强版",它主要用于对用户请求进行预处理,也可对HttpServletResponse进行后处理,是个典型的"处理 ...