摘要:CycleGAN图像翻译模型,由两个生成网络和两个判别网络组成,通过非成对的图片将某一类图片转换成另外一类图片,可用于风格迁移

本文分享自华为云社区《基于MindSpore的CycleGAN介绍和实现》,作者: Tianyi_Li 。

前言

我们这次介绍下著名的CycleGAN,同时提供了基于MindSpore的代码,方便大家运行验证。

CycleGAN的介绍

CycleGAN图像翻译模型,由两个生成网络和两个判别网络组成,通过非成对的图片将某一类图片转换成另外一类图片,可用于风格迁移,效果演示如下图所示:

CycleGAN是GAN的一种,那什么是GAN呢?

生成对抗网络(Generative Adversarial Network, 简称GAN) 是一种非监督学习的方式,通过让两个神经网络相互博弈的方法进行学习,该方法由lan Goodfellow等人在2014年提出。生成对抗网络由一个生成网络和一个判别网络组成,生成网络从潜在的空间(latent space)中随机采样作为输入,其输出结果需要尽量模仿训练集中的真实样本。判别网络的输入为真实样本或生成网络的输出,其目的是将生成网络的输出从真实样本中尽可能的分辨出来。而生成网络则尽可能的欺骗判别网络,两个网络相互对抗,不断调整参数。 生成对抗网络常用于生成以假乱真的图片。此外,该方法还被用于生成影片,三维物体模型等。

好了,我们已经对GAN有了大体的了解,下面说回CycleGAN。

CycleGAN由两个生成网络和两个判别网络组成,生成网络A是输入A类风格的图片输出B类风格的图片,生成网络B是输入B类风格的图片输出A类风格的图片。生成网络中编码部分的网络结构都是采用convolution-norm-ReLU作为基础结构,解码部分的网络结构由transpose convolution-norm-ReLU组成,判别网络基本是由convolution-norm-leaky_ReLU作为基础结构,详细的网络结构可以查看network/CycleGAN_network.py文件。生成网络提供两种可选的网络结构:Unet网络结构和普通的encoder-decoder网络结构。生成网络损失函数由LSGAN的损失函数,重构损失和自身损失组成,判别网络的损失函数由LSGAN的损失函数组成。

CycleGAN最经典的地方是设计和提出了循环一致性损失。以黑白图片上色为例,循环一致性就是:黑白图(真实)—>网络—>彩色图—>网络—>黑白图(造假)。为了保证上色后的彩色图片中具有原始黑白图片的所有内容信息,文章中将生成的彩色图像还原回去,生成造假的黑白图,通过损失函数来约束真实白图和造假黑白图一致,达到图像上色的目的。除此之外,CycleGAN不像Pix2Pix一样,需要使用配对数据进行训练,CycleGAN直接使用两个域图像进行训练,而不用建立每个样本和对方域之间的配对关系,这就厉害了,一下子让风格迁移任务变得简单很多。

看一下CycleGAN的网络结构图:

如果想了解更多详情,可以阅读CycleGAN的原论文,推荐读一读,会有更深刻和更清楚的理解,下面给出链接:

https://arxiv.org/abs/1703.10593

CycleGAN的实现

代码和数据集

这里我提供了一个包含代码和数据集的仓库链接:https://git.openi.org.cn/tjulitianyi/CycleGAN_MindSpore,但是更建议使用最新版本代码,见下方特别说明。

特别说明:我们将在华为云ModelArts的NoteBook,基于MindSpore-GPU 1.8.1 运行CycleGAN的代码,因为云环境的更新不确定性,所以运行可能会报错,这时可以参考如下最新代码:https://gitee.com/mindspore/models/tree/master/research/cv/CycleGAN。

需要提醒大家的是,必须需要使用MindSpore 1.8.0以及以上的版本,之前版本会报错,因为某些API不支持。而最新的1.8.1版本有时也会报错,报错信息如下,怀疑可能是代码的设置有些问题:

目前ModelArts最高支持到MindSpore 1.7,我们需要自行安装最新的MindSpore 1.8.1版本。

先来看看我使用的NoteBook环境:

这里特别提醒大家,NoteBook是要花钱的,我选择的单卡Tesla V100大约每小时28元,也有更便宜的,大概每小时8元的单卡Tesla P100,请大家根据自身情况选择,千万注意使用情况,别欠费了。

准备环境

下面进入NoteBook,打开一个Terminal:

先来看看我们的显卡信息和CUDA Version:

我们看到CUDA Version是10.2,下面到MindSpore官网看看安装教程,我们需要安装MindSpore 1.8.1,但是没有CUDA 10.2对应的版本,这里就选择就近的CUDA 10.1版本了。

在Terminal执行如下命令:

pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/1.8.1/MindSpore/gpu/x86_64/cuda-10.1/mindspore_gpu-1.8.1-cp37-cp37m-linux_x86_64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple

下载速度很快,安装速度也是非常快:

最后运行显示如下信息,表示安装成功了:

获取代码

接下来下载代码,执行如下命令(由于要下载整个仓库,时间有点长):

git clone https://gitee.com/mindspore/models.git

命令运行截图:

下面我们将感兴趣的CycleGAN代码拷贝到当前目录下,执行如下命令:

cp -r models/research/cv/CycleGAN/ ./

准备数据集

下面进入CycleGAN目录:

cd CycleGAN

我们这里使用的是monet2photo数据集,由于直接在ModelArts的NoteBook下载速度很慢,所以建议大家下载到本地,再上传到NoteBook的CycleGAN/data目录下,下载链接为:https://s3.openi.org.cn/opendata/attachment/7/b/7beb4534-6e79-463e-a7c6-032510bab215?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=1fa9e58b6899afd26dd3%2F20220814%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20220814T085624Z&X-Amz-Expires=604800&X-Amz-SignedHeaders=host&response-content-disposition=attachment%3B filename%3D"monet2photo.zip"&X-Amz-Signature=20fbfd9c798701efcbf21d811f3dfdd6b8d5744f388c799bc38715f7fe78c783

上传完成后,解压数据集即可。我的运行截图如下图所示:

启动训练

注意,请在CycleGAN的目录下启动训练,如下图所示:

我是在GPU下的单卡训练,所以启动训练的命令为:

python train.py --platform GPU --device_id 0 --model ResNet --max_epoch 200 --dataroot ./data/monet2photo/ --outputs_dir ./outputs

运行截图为:

可以看到已经成功启动训练,打印出loss,此时我是用的Tesla V100显卡大约占了4GB显存,利用率接近100%,此时来看不适合用Tesla V100来跑,未能发挥其大显存的优势,而其计算能力其实一般。CycleGAN模型训练比较费时间,请注意花费,预计完成全部200epoch的训练需要72小时以上。

评估模型

python eval.py --platform GPU --device_id 0 --model ResNet --G_A_ckpt ./outputs/ckpt/G_A_200.ckpt --G_B_ckpt ./outputs/ckpt/G_B_200.ckpt

注意,这里的.ckpt模型名称,请根据实际训练生成的具体轮数的模型名称太难写,比如目前只保存了20epoch的模型,那上述命令的200就应该改成20。

更多命令或适配其他硬件平台和多卡情况,可参考scripts文件夹下脚本。

结语

我们简单介绍了著名的CycleGAN,给出了基于MindSpor的完整代码,并带着大家跑了一遍,目前有些问题,后续会更新。作为经典的GAN的一种,CycleGAN有很多值得我们学习的地方,还需要深入分析挖掘,以鉴今事。

关于代码运行的问题,可以到官仓提交issue求助,下为链接:https://gitee.com/mindspore/models/issues

点击关注,第一时间了解华为云新鲜技术~

带你徒手完成基于MindSpore的CycleGAN实现的更多相关文章

  1. 带你手写基于 Spring 的可插拔式 RPC 框架(一)介绍

    概述 首先这篇文章是要带大家来实现一个框架,听到框架大家可能会觉得非常高大上,其实这和我们平时写业务员代码没什么区别,但是框架是要给别人使用的,所以我们要换位思考,怎么才能让别人用着舒服,怎么样才能让 ...

  2. 技术干货 | 基于MindSpore更好的理解Focal Loss

    [本期推荐专题]物联网从业人员必读:华为云专家为你详细解读LiteOS各模块开发及其实现原理. 摘要:Focal Loss的两个性质算是核心,其实就是用一个合适的函数去度量难分类和易分类样本对总的损失 ...

  3. 基于MIndSpore框架的道路场景语义分割方法研究

    基于MIndSpore框架的道路场景语义分割方法研究 概述 本文以华为最新国产深度学习框架Mindspore为基础,将城市道路下的实况图片解析作为任务背景,以复杂城市道路进行高精度的语义分割为任务目标 ...

  4. Vivado设计二:zynq的PS访问PL中的自带IP核(基于zybo)

    1.建立工程 首先和Vivado设计一中一样,先建立工程(这部分就忽略了) 2.create block design 同样,Add IP 同样,也添加配置文件,这些都和设计一是一样的,没什么区别. ...

  5. 如何基于MindSpore实现万亿级参数模型算法?

    摘要:近来,增大模型规模成为了提升模型性能的主要手段.特别是NLP领域的自监督预训练语言模型,规模越来越大,从GPT3的1750亿参数,到Switch Transformer的16000亿参数,又是一 ...

  6. 徒手打造基于Spark的数据工厂(Data Factory):从设计到实现

    在大数据处理和人工智能时代,数据工厂(Data Factory)无疑是一个非常重要的大数据处理平台.市面上也有成熟的相关产品,比如Azure Data Factory,不仅功能强大,而且依托微软的云计 ...

  7. 带你手写基于 Spring 的可插拔式 RPC 框架(三)通信协议模块

    在写代码之前我们先要想清楚几个问题. 我们的框架到底要实现什么功能? 我们要实现一个远程调用的 RPC 协议. 最终实现效果是什么样的? 我们能像调用本地服务一样调用远程的服务. 怎样实现上面的效果? ...

  8. 带你手写基于 Spring 的可插拔式 RPC 框架(二)整体结构

    前言 上一篇文章中我们已经知道了什么是 RPC 框架和为什么要做一个 RPC 框架了,这一章我们来从宏观上分析,怎么来实现一个 RPC 框架,这个框架都有那些模块以及这些模块的作用. 总体设计 在我们 ...

  9. MindInsight:一款基于MindSpore框架的训练可视化插件

    技术背景 在深度学习或者其他参数优化领域中,对于结果的可视化以及中间网络结构的可视化,也是一个非常重要的工作.一个好的可视化工具,可以更加直观的展示计算结果,可以帮助人们更快的发掘大量的数据中最有用的 ...

  10. django drf json格式化日期时间带T的问题 基于python的解决方法

    # models.py update_time = models.DateTimeField(verbose_name=u'更新时间', default=timezone.now) 问题:天 与 小时 ...

随机推荐

  1. SQL笔记(1)索引/触发器

    --创建聚集索引 create clustered index ix_tbl_test_DocDate on tbl_test(DocDate) GO --创建非聚集索引 create nonclus ...

  2. 给Mac下的iTerm2增加配色

    iterm2就不说了,Mac下非常好用的终端,这里就先谈谈如何给其增加配色,效果如下图 可以来这下载theme : http://iterm2colorschemes.com/ 1.先编辑你的prof ...

  3. POJ 3264 Balanced Lineup(RMQ)

    点我看题目 题意 :N头奶牛,Q次询问,然后给你每一头奶牛的身高,每一次询问都给你两个数,x y,代表着从x位置上的奶牛到y位置上的奶牛身高最高的和最矮的相差多少. 思路 : 刚好符合RMQ的那个求区 ...

  4. [Python 3.x 官方文档翻译]Whetting Your Appetite 欢迎您的使用

    If you do much work on computers, eventually you find that there’s some task you’d like to automate. ...

  5. Mac CLion下OpenGL环境配置

    1. 配置glew和glfw 终端下运行下面两句,安装完后在/usr/local/Cellar/下可以找到对应的目录. brew install glew brew install glfw3 效果如 ...

  6. BZOJ4012[HNOI2015]开店——树链剖分+可持久化线段树/动态点分治+vector

    题目描述 风见幽香有一个好朋友叫八云紫,她们经常一起看星星看月亮从诗词歌赋谈到 人生哲学.最近她们灵机一动,打算在幻想乡开一家小店来做生意赚点钱.这样的 想法当然非常好啦,但是她们也发现她们面临着一个 ...

  7. vue-router学习

    JS push goTo(){ , postId: ' }}) } router.js // 动态路径参数 以冒号开头 { path: '/user/:id', component: User } { ...

  8. 使用jquery.qrcode生成二维码及常见问题解决方案

    转载文章  使用jquery.qrcode生成二维码及常见问题解决方案 一.jquery.qrcode.js介 jquery.qrcode.js 是一个纯浏览器 生成 QRcode 的 jQuery ...

  9. PHP 数据加密

    <?php /** * * 加密 * */ function lock_url($txt, $key = "aiteng") { $chars = "ABCDEFG ...

  10. asp.net正则表达式删除指定的HTML标签的代码

    抓取某网页的数据后(比如描述),如果照原样显示的话,可能会因为它里面包含没有闭合的HTML标签而打乱了格式,也可能它里面用了比较让人 费解 的HTML标签,把预订的格式搅乱. 如果全盘删除里面的 HT ...