前言

深度学习作为人工智能的重要手段,迎来了爆发,在NLP、CV、物联网、无人机等多个领域都发挥了非常重要的作用。最近几年,各种深度学习算法层出不穷, Generative Adverarial Network(GAN)自2014年提出以来,引起广泛关注,身为深度学习三巨头之一的Yan Lecun对GAN的评价颇高,认为GAN是近年来在深度学习上最大的突破,是近十年来机器学习上最有意思的工作。围绕GAN的论文数量也迅速增多,各种版本的GAN出现,主要在CV领域带来了一些贡献,如下图所示。

我们可以利用GAN生成一些我们需要的图像或者文本,比如二次元头像。

GAN简介

GAN主要的应用是自动生成一些东西,包括图像和文本等,比如随机给一个向量作为输入,通过GAN的Generator生成一张图片,或者生成一串语句。Conditional GAN的应用更多一些,比如数据集是一段文字和图像的数据对,通过训练,GAN可以通过给定一段文字生成对应的图像。

GAN主要可以分为Generator(生成器)和Discriminator(判别器)两个部分,其中Generator其实就是一个神经网络,输入一个向量,可以输出一张图像(即一个高维的向量表示),如下图示。

​Discriminator也是一个神经网络,输入为一张图像,输出为一个数值,输出的数值用于判断输入的图像是否是真的,数值越大,说明图像是真的,数值越小,说明图像为假的,如下图示。

​Generator负责生成图像,Discriminator负责对Generator生成的图像和真实图像去进行对比,区别出真假,Generator需要不断优化来欺骗Discriminator,以假乱真;而Discriminator也不断优化,来提高识别能力,能够识别出Generator的把戏。二者的这种关系可以形象地通过下图展示。

Generator和Discriminator连接起来,形成一个比较大的深层网络,即为GAN网络。

场景描述

深度学习的各种算法在PAI上可以通过PAI-DSW进行实现,在PAI-DSW上进行训练数据,利用GAN自动生成二次元头像。

数据准备

首先需要准备真实的二次元头像作为数据集,这里从网上找到一些共享的资源,存储在了钉钉钉盘中,钉盘地址 ,提取密码: c2pz,数据集如下图示,约5万多张:

算法实践

利用PAI-DSW进行GAN算法实践,首先需要安装准备好环境。

首先进入到Notebook建模,创建新实例,之后打开实例,进入Terminal,在Terminal下用户可以像在自己本地一样安装相应的依赖包,进行操作。

准备好环境之后,我们可以通过如下图示方法,将基于Tensorflow的DCGAN代码和数据集上传上去。 ​

用于训练的DCGAN代码地址:https://github.com/carpedm20/DCGAN-tensorflow,关于DCGAN的网络框架图如下,详细介绍可以参考论文:https://arxiv.org/abs/1511.06434,这里我们不做详述。

数据集和代码上传成功,如下图示。

其中,data目录下的faces即为数据集,该文件夹下为对应的5万多张真实二次元头像。DCGAN-tensorflow为整个代码路径,其中最主要的两个代码文件是main.py和model.py,其中最主要的核心代码如下。

def main(_):
pp.pprint(flags.FLAGS.__flags) if FLAGS.input_width is None:
FLAGS.input_width = FLAGS.input_height
if FLAGS.output_width is None:
FLAGS.output_width = FLAGS.output_height if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir) #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
run_config = tf.ConfigProto()
run_config.gpu_options.allow_growth=True with tf.Session(config=run_config) as sess:
if FLAGS.dataset == 'mnist':
dcgan = DCGAN(
sess,
input_width=FLAGS.input_width,
input_height=FLAGS.input_height,
output_width=FLAGS.output_width,
output_height=FLAGS.output_height,
batch_size=FLAGS.batch_size,
sample_num=FLAGS.batch_size,
y_dim=10,
z_dim=FLAGS.generate_test_images,
dataset_name=FLAGS.dataset,
input_fname_pattern=FLAGS.input_fname_pattern,
crop=FLAGS.crop,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir,
data_dir=FLAGS.data_dir)
else:
dcgan = DCGAN(
sess,
input_width=FLAGS.input_width,
input_height=FLAGS.input_height,
output_width=FLAGS.output_width,
output_height=FLAGS.output_height,
batch_size=FLAGS.batch_size,
sample_num=FLAGS.batch_size,
z_dim=FLAGS.generate_test_images,
dataset_name=FLAGS.dataset,
input_fname_pattern=FLAGS.input_fname_pattern,
crop=FLAGS.crop,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir,
data_dir=FLAGS.data_dir) show_all_variables() if FLAGS.train:
dcgan.train(FLAGS)

        else:
# Update D network
_, summary_str = self.sess.run([d_optim, self.d_sum],
feed_dict={ self.inputs: batch_images, self.z: batch_z })
self.writer.add_summary(summary_str, counter) # Update G network
_, summary_str = self.sess.run([g_optim, self.g_sum],
feed_dict={ self.z: batch_z })
self.writer.add_summary(summary_str, counter) # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
_, summary_str = self.sess.run([g_optim, self.g_sum],
feed_dict={ self.z: batch_z })
self.writer.add_summary(summary_str, counter) errD_fake = self.d_loss_fake.eval({ self.z: batch_z })
errD_real = self.d_loss_real.eval({ self.inputs: batch_images })
errG = self.g_loss.eval({self.z: batch_z})

一切就绪之后,我们执行命令进行训练,调用命令如下:

​python main.py --input_height 96 --input_width 96 --output_height 48 --output_width 48 --dataset faces --crop --train --epoch 300 --input_fname_pattern "*.jpg"

其中,参数dateset指定数据集的目录,epoch指定循环迭代的次数,input_height、input_width用于指定输入文件的大小,输出文件的大小同样也需要参数设定,代码执行过程如下图示:​

我们来看下执行结果,分别看一下epoch为1,30,100的时候生成的二次元头像效果图。

epoch=1

epoch=30

epoch=100​

我们发现,随着不断迭代,生成的二次元头像也越来越逼真。

总结

通过上面的实践,我们领略到了GAN的魅力,GAN的变种有很多,除此之外我们还可以利用GAN做非常多的有意思的事情,比如通过文字生成图像,通过简单文字生成宣传海报等。PAI-DSW像是一个练武场,为我们准备好了深度学习所需要的环境和条件,让我们可以尽情享受大数据和深度学习的乐趣,除了GAN,像比较火热的Bert等模型,我们也都可以试一试。


本文作者:不等_赵振才

原文链接

本文为云栖社区原创内容,未经允许不得转载。

【机器学习PAI实战】—— 玩转人工智能之利用GAN自动生成二次元头像的更多相关文章

  1. 【机器学习PAI实战】—— 玩转人工智能之综述

    摘要: 基于人工智能火热的大背景下,通过阿里云的机器学习平台PAI在真实场景中的应用,详细阐述相关算法及使用方法,力求能够让读者读后能够马上动手利用PAI搭建属于自己的机器学习实用方案,真正利用PAI ...

  2. 【机器学习PAI实战】—— 玩转人工智能之你最喜欢哪个男生?

    摘要: 分类问题是生活中最常遇到的问题之一.普通人在做出选择之前,可能会犹豫不决,但对机器而言,则是唯一必选的问题.我们可以通过算法生成模型去帮助我们快速的做出选择,而且保证误差最小.充足的样本,合适 ...

  3. 【机器学习PAI实战】—— 玩转人工智能之商品价格预测

    摘要: 我们经常思考机器学习,深度学习,以至于人工智能给我们带来什么?在数据相对充足,足够真实的情况下,好的学习模型可以发现事件本身的内在规则,内在联系.我们去除冗余的信息,可以通过最少的特征构建最简 ...

  4. 一分钟带你学会利用mybatis-generator自动生成代码!

    目录 一.MyBatis Generator简介 二.使用方式 三.实战 之前的文章<SpringBoot系列-整合Mybatis(XML配置方式)>介绍了XML配置方式整合的过程,本文介 ...

  5. [06] 利用mybatis-generator自动生成代码

    1.mybatis-generator 概述 MyBatis官方提供了逆向工程 mybatis-generator,可以针对数据库表自动生成MyBatis执行所需要的代码(如Mapper.java.M ...

  6. 利用python自动生成verilog模块例化模板

    一.前言 初入职场,一直忙着熟悉工作,就没什么时间更新博客.今天受“利奇马”的影响,只好宅在家中,写写技术文章.芯片设计规模日益庞大,编写脚本成了芯片开发人员必要的软技能.模块端口动不动就几十上百个, ...

  7. 利用mybatis-generator自动生成代码(转)

    利用mybatis-generator自动生成代码 mybatis-generator有三种用法:命令行.eclipse插件.maven插件.个人觉得maven插件最方便,可以在eclipse/int ...

  8. springboot整合mybatis,利用mybatis-genetor自动生成文件

    springboot整合mybatis,利用mybatis-genetor自动生成文件 项目结构: xx 实现思路: 1.添加依赖 <?xml version="1.0" e ...

  9. Asp.Net Core 轻松学-利用 Swagger 自动生成接口文档

    前言     目前市场上主流的开发模式,几乎清一色的前后端分离方式,作为服务端开发人员,我们有义务提供给各个客户端良好的开发文档,以方便对接,减少沟通时间,提高开发效率:对于开发人员来说,编写接口文档 ...

随机推荐

  1. 牛客网暑期ACM多校训练营(第一场)菜鸟补题QAQ

    签到题 J Different Integers(树状数组) 题目大意:给一个长为n的数组,每一个询问给两个数字i, j ,询问1~i, j~n这两个区间中有多少不同的数字,真的像是莫队裸题,但是两个 ...

  2. [转]C# 委托、事件,lamda表达式

    1. 委托Delegate C#中的Delegate对应于C中的指针,但是又有所不同C中的指针既可以指向方法,又可以指向变量,并且可以进行类型转换, C中的指针实际上就是内存地址变量,他是可以直接操作 ...

  3. CF集萃3

    CF1118F2 - Tree Cutting 题意:给你一棵树,每个点被染成了k种颜色之一或者没有颜色.你要切断恰k - 1条边使得不存在两个异色点在同一连通块内.求方案数. 解:对每颜色构建最小斯 ...

  4. PHP+Ajax点击加载更多内容 -这个效果好,速度快,只能点击更多加载,不能滚动自动加载

    这个效果好,速度快,只能点击更多加载,不能滚动自动加载 一.HTML部分 <div id="more"> <div class="single_item ...

  5. PAT甲级——A1025 PAT Ranking

    Programming Ability Test (PAT) is organized by the College of Computer Science and Technology of Zhe ...

  6. IO流13 --- 转换流实现文件复制 --- 技术搬运工(尚硅谷)

    InputStreamReader 将字节输入流转换为字符输入流 OutputStreamWriter 将字符输出流转换为字节输出流 @Test public void test2() { //转换流 ...

  7. C# 全局Hook在xp上不回调

    最近做了个捕捉全局鼠标,获取目标窗体内的控件文本信息,点击的按钮信息.用的全局钩子.在win10上运行正常,部署到xp系统上就没有反应.查了些资料,解决了此问题. 原本安装钩子的写法如下: Nativ ...

  8. numpy.flatnonzero():

    numpy.flatnonzero(): 该函数输入一个矩阵,返回扁平化后矩阵中非零元素的位置(index) 这是官方文档给出的用法,非常正规,输入一个矩阵,返回了其中非零元素的位置. 1 >& ...

  9. [转]IE userData

    IE浏览器实现了它专属的客户端存储机制——“userData”.userData可以实现一定量的字符串数据存储,可以将其用做是Web存储的替代方案.本文将详细介绍IE userData 概述 在IE5 ...

  10. light oj 1427(ac自动机)

    #include <bits/stdc++.h> using namespace std; *; ; map<string,int>Map; struct Trie { int ...