这一篇博客以代码为主,主要是来介绍如果使用keras构建一个DCGAN,然后基于DCGAN,做一个自动生成动漫头像。训练过程如下(50轮的训练过程)“

关于DCGAN或者GAN的相关知识,可以参考GAN网络入门教程。建议先了解相关知识,再来看这一篇博客。

项目地址:GitHub

使用前准备

首先的首先,我们肯定是需要数据集的,这里使用的数据集来自kaggle——Anime Faces。里面有21551张动漫头像的图片。大家可以到kaggle上面去下载数据集,或者说到我的github上去下载数据集(求个 不过分吧)。部分数据如下:

如果自己电脑计算机资源不是很强的话,比如我,一个mx250小水管(玩玩lol还是可以的,训练这个模型可能要等到下辈子),推荐大家去注册一个kaggle或者colab账号去白嫖GPU资源(1080,2080的玩家请随意)。不过个人更加的推荐kaggle,因为感觉它的资源分配是可见的,且可以后台运行。

数据集

数据集是动漫图片,我们可以将图片的像素点的值变成\([-1,1]\)之间,具体代码如下:

# 数据集的位置
avatar_img_path = "./data" import imageio
import os
import numpy as np
def load_data():
"""
加载数据集
:return: 返回numpy数组
"""
all_images = []
for image_name in os.listdir(avatar_img_path):
# 加载图片
image = imageio.imread(os.path.join(avatar_img_path,image_name))
all_images.append(image)
all_images = np.array(all_images)
# 将图片数值变成[-1,1]
all_images = (all_images - 127.5) / 127.5
# 将数据随机排序
np.random.shuffle(all_images)
return all_images
img_dataset = load_data()

然后定义展示图片的方法:


import matplotlib.pyplot as plt
def show_images(images,index = -1):
"""
展示并保存图片
:param images: 需要show的图片
:param index: 图片名
:return:
"""
plt.figure()
for i, image in enumerate(images):
ax = plt.subplot(5, 5, i+1)
plt.axis('off')
plt.imshow(image)
plt.savefig("data_%d.png"%index)
plt.show()
  • 展示数据集中的部分图片:
show_images(img_dataset[0: 25])

定义参数

这里我们只定义两个参数,图片的shape代表生成的图片是\(64 \times 64\)的RGB图片,以及noise的大小是100:

# noise的维度
noise_dim = 100
# 图片的shape
image_shape = (64,64,3)

构建网络

首先导入tensorflow中的keras库,如下:

from tensorflow.keras.models import Sequential,Model
from tensorflow.keras.layers import UpSampling2D, Conv2D, Dense, BatchNormalization, LeakyReLU, Input,Reshape, MaxPooling2D, Flatten, AveragePooling2D, Conv2DTranspose
from tensorflow.keras.optimizers import Adam

下图中的网络结构参照了kaggle中的Anime face generation with DCGAN (beginner)

构建G网络

生成器网络,我们按照如下的结构进行构建:

原理是我们通过全连接层将nosise的向量放大,然后在再使用反卷积等操作将其逐渐变成shape为\((64,64,3)\)的图片。

def build_G():
"""
构建生成器
:return:
"""
model = Sequential()
# 全连接层 100 -> 2048
model.add(Dense(2048,input_dim = noise_dim))
# 激活函数
model.add(LeakyReLU(0.2))
# 全连接层 2048 -> 8 * 8 * 256
model.add(Dense(8 * 8 * 256))
# DN层
model.add(BatchNormalization())
model.add(LeakyReLU(0.2))
# 8 * 8 * 256 -> (8,8,256)
model.add(Reshape((8, 8, 256)))
# 卷积层 (8,8,256) -> (8,8,128)
model.add(Conv2D(128, kernel_size=5, padding='same'))
model.add(BatchNormalization())
model.add(LeakyReLU(0.2))
# 反卷积层 (8,8,128) -> (16,16,128)
model.add(Conv2DTranspose(128, kernel_size=5, strides=2, padding='same'))
model.add(LeakyReLU(0.2))
# 反卷积层 (16,16,128) -> (32,32,64)
model.add(Conv2DTranspose(64, kernel_size=5, strides=2, padding='same'))
model.add(LeakyReLU(0.2))
# 反卷积层 (32,32,64) -> (64,64,3) = 图片
model.add(Conv2DTranspose(3, kernel_size=5, strides=2, padding='same', activation='tanh'))
return model
G = build_G()

可以发现,\(G\)网络并没有compile这一步,这是因为\(G\)网络的权重优化并不是直接优化的,而是通过GAN网络进行间接优化的。

构建D网络

D网络的结构示意图如下:

判别器网络就是一个寻常的CNN网络:


def build_D():
"""
构建判别器
:return:
"""
model = Sequential()
# 卷积层
model.add(Conv2D(64, kernel_size=5, padding='valid',input_shape = image_shape))
# BN层
model.add(BatchNormalization())
# 激活层
model.add(LeakyReLU(0.2))
# 平均池化层
model.add(AveragePooling2D(pool_size=2))
# 卷积层
model.add(Conv2D(128, kernel_size=3, padding='valid'))
model.add(BatchNormalization())
model.add(LeakyReLU(0.2))
model.add(AveragePooling2D(pool_size=2))
model.add(Conv2D(256, kernel_size=3, padding='valid'))
model.add(BatchNormalization())
model.add(LeakyReLU(0.2))
model.add(AveragePooling2D(pool_size=2))
# 将输入展平
model.add(Flatten())
# 全连接层
model.add(Dense(1024))
model.add(BatchNormalization())
model.add(LeakyReLU(0.2))
# 最终输出1(true img) 0(fake img)的概率大小
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy',
optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
return model
D = build_D()

构建GAN网络

由前面的博客,我们知道,GAN网络由G网络和D网络组成,GAN网络的input为nosie,输出为图片真假的概率。因此它的网络结构示意图如下所示:


def build_gan():
"""
构建GAN网络
:return:
"""
# 冷冻判别器,也就是在训练的时候只优化G的网络权重,而对D保持不变
D.trainable = False
# GAN网络的输入
gan_input = Input(shape=(noise_dim,))
# GAN网络的输出
gan_out = D(G(gan_input))
# 构建网络
gan = Model(gan_input,gan_out)
# 编译GAN网络,使用Adam优化器,以及加上交叉熵损失函数(一般用于二分类)
gan.compile(loss='binary_crossentropy',optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
return gan
GAN = build_gan()

关于GAN的小trick

我们会将真实的图片的lable标记为1,fake图片的lable标记为0,但是我们训练的时候可以使lable的值在一定的范围内浮动。关于更多的trick,可以参考这篇 GANs training tricks


def sample_noise(batch_size):
"""
随机产生正态分布(0,1)的noise
:param batch_size:
:return: 返回的shape为(batch_size,noise)
"""
return np.random.normal(size=(batch_size, noise_dim)) def smooth_pos_labels(y):
"""
使得true label的值的范围为[0.7,1.2]
:param y:
:return:
"""
return y - 0.3 + (np.random.random(y.shape) * 0.5) def smooth_neg_labels(y):
"""
使得fake label的值的范围为[0.0,0.3]
:param y:
:return:
"""
return y + np.random.random(y.shape) * 0.3

训练

开始训练之前,我们还介绍一个函数,load_batch,因为我们训练图片不可能说一次将图片全部进行训练而是分批次进行训练(full batch需要大量的内存空间),而load_batch函数就行按批次加载图片。

def load_batch(data, batch_size,index):
"""
按批次加载图片
:param data: 图片数据集
:param batch_size: 批次大小
:param index: 批次序号
:return:
"""
return data[index*batch_size: (index+1)*batch_size]

然后我们就需要定义\(train\)函数了:


def train(epochs=100, batch_size=64):
"""
训练函数
:param epochs: 训练的次数
:param batch_size: 批尺寸
:return:
"""
# 判别器损失
discriminator_loss = 0
# 生成器损失
generator_loss = 0
# img_dataset.shape[0] / batch_size 代表这个数据可以分为几个批次进行训练
n_batches = int(img_dataset.shape[0] / batch_size) for i in range(epochs):
for index in range(n_batches):
# 按批次加载数据
x = load_batch(img_dataset, batch_size,index)
# 产生noise
noise = sample_noise(batch_size)
# G网络产生图片
generated_images = G.predict(noise)
# 产生为1的标签
y_real = np.ones(batch_size)
# 将1标签的范围变成[0.7 , 1.2]
y_real = smooth_pos_labels(y_real)
# 产生为0的标签
y_fake = np.zeros(batch_size)
# 将0标签的范围变成[0.0 , 0.3]
y_fake = smooth_neg_labels(y_fake)
# 训练真图片loss
d_loss_real = D.train_on_batch(x, y_real)
# 训练假图片loss
d_loss_fake = D.train_on_batch(generated_images, y_fake) discriminator_loss = d_loss_real + d_loss_fake
# 产生为1的标签
y_real = np.ones(batch_size)
# 训练GAN网络,input = fake_img ,label = 1
generator_loss = GAN.train_on_batch(noise, y_real) print('[Epoch {0}]. Discriminator loss : {1}. Generator_loss: {2}.'.format(i, discriminator_loss, generator_loss))
# 随机产生(25,100)的noise
test_noise = sample_noise(25)
# 使用G网络生成25张图偏
test_images = G.predict(test_noise)
# show 预测 img
show_images(test_images,i)

开始训练:

train(epochs=500, batch_size=32)

最后就进入到了漫长的等待结果的时间了。

总结

项目地址:GitHub

参考

GAN网络之入门教程(四)之基于DCGAN动漫头像生成的更多相关文章

  1. GAN网络之入门教程(五)之基于条件cGAN动漫头像生成

    目录 Prepare 在上篇博客(AN网络之入门教程(四)之基于DCGAN动漫头像生成)中,介绍了基于DCGAN的动漫头像生成,时隔几月,序属三秋,在这篇博客中,将介绍如何使用条件GAN网络(cond ...

  2. GAN网络从入门教程(一)之GAN网络介绍

    GAN网络从入门教程(一)之GAN网络介绍 稍微的开一个新坑,同样也是入门教程(因此教程的内容不会是从入门到精通,而是从入门到入土).主要是为了完成数据挖掘的课程设计,然后就把挖掘榔头挖到了GAN网络 ...

  3. GAN网络从入门教程(二)之GAN原理

    在一篇博客GAN网络从入门教程(一)之GAN网络介绍中,简单的对GAN网络进行了一些介绍,介绍了其是什么,然后大概的流程是什么. 在这篇博客中,主要是介绍其数学公式,以及其算法流程.当然数学公式只是简 ...

  4. GAN网络从入门教程(三)之DCGAN原理

    目录 DCGAN简介 DCGAN的特点 几个重要概念 下采样(subsampled) 上采样(upsampling) 反卷积(Deconvolution) 批标准化(Batch Normalizati ...

  5. 【Zigbee技术入门教程-号外】基于Z-Stack协议栈的抢答系统

    [Zigbee技术入门教程-号外]基于Z-Stack协议栈的抢答系统 广东职业技术学院  欧浩源 一.引言    2017年全国职业院校技能大赛"物联网技术应用"赛项中任务三题2的 ...

  6. 无废话ExtJs 入门教程四[表单:FormPanel]

    无废话ExtJs 入门教程四[表单:FormPanel] extjs技术交流,欢迎加群(201926085) 继上一节内容,我们在窗体里加了个表单.如下所示代码区的第28行位置,items:form. ...

  7. PySide——Python图形化界面入门教程(四)

    PySide——Python图形化界面入门教程(四) ——创建自己的信号槽 ——Creating Your Own Signals and Slots 翻译自:http://pythoncentral ...

  8. Elasticsearch入门教程(四):Elasticsearch文档CURD

    原文:Elasticsearch入门教程(四):Elasticsearch文档CURD 版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明. 本文链接: ...

  9. RabbitMQ入门教程(四):工作队列(Work Queues)

    原文:RabbitMQ入门教程(四):工作队列(Work Queues) 版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明. 本文链接:https:/ ...

随机推荐

  1. Photon PUN 一 介绍

    有句话说的好 , 官网永远是最好的学习地方 . 虽然国内的资料不多 , 但是官网的资料还是很充足 , 这就带着英汉词典就着作阅读理解的劲头去官网学习吧 https://doc.photonengine ...

  2. War3模型导出

    Wiki https://en.wikipedia.org/wiki/.MDX 教程 http://www.ou99.com/article-25996-1.html http://blog.csdn ...

  3. 万级K8s集群背后etcd稳定性及性能优化实践

    背景与挑战 随着腾讯自研上云及公有云用户的迅速增长,一方面,腾讯云容器服务TKE服务数量和核数大幅增长, 另一方面我们提供的容器服务类型(TKE托管及独立集群.EKS弹性集群.edge边缘计算集群.m ...

  4. istio部署

    Istio的部署介绍 目录 Istio的部署介绍 部署模型 集群模式 单集群 多集群 网络模型 单网络 多网络 控制面模型 身份和信任模型 网格中的信任 网格之间的信任 网格模型 单网格 多网格 租户 ...

  5. nginx的gzip压缩

    随着nginx的发展,越来越多的网站使用nginx,因此nginx的优化变得越来越重要,今天我们来看看nginx的gzip压缩到底是怎么压缩的呢? gzip(GNU-ZIP)是一种压缩技术.经过gzi ...

  6. 05_进程间通信 IPC

    1.进程间的通信方式 1.磁盘交互: 速度慢,不安全 2.socket套接字 3.管道通信(Pipe) 4.消息队列(Queue, Manager().Queue, JoinableQueue) 5. ...

  7. 10_Python的函数function

    1.函数的概述 1.函数是可以重复执行的语句块且可以重复调用,函数封装了可重复执行的语句提高了语句的可重复性     2.函数的参数和返回值的作用流程图: https://www.processon. ...

  8. Zabbix Agent报“listener failed: zbx_tcp_listen() fatal error: unable to serve on any address”

    一台服务器的Zabbix Agent升级后,在Zabbix Server发现Zabbix Agent无法访问.检查Zabbix Agent发现服务停止了,启动Zabbix Agent后,发现服务马上又 ...

  9. Thymeleaf中model设一个值 页面显示此值 JS取此值

    model设值: m.addAttribute("pageNo", pageNo); 页面显示值: 当前为第:<span th:text="${pageNo}&qu ...

  10. IFile、File与实体转换

    /** * 根据物理实体文件在开发工程中创建实体文件 */ @Override public void getEntityFilesByErFile(IFile erfile, IFolder ent ...