tfgan是什么?

tfgan是tensorflow团队开发出的一个专门用于训练各种GAN的轻量级库,它是基于tensorflow开发的,所以兼容于tensorflow。在tensorflow1.x版本中,tfgan存在于tensorflow.contrib中,作为一个小模块供使用者调用。在更新到tensorflow2.0版本后,tfgan成为一个独立的库。可使用:

pip install tensorflow-gan

进行下载安装,并在python中使用以下语句导入这个包:

import tensorflow_gan as tfgan

可以使用tfgan对目前流行的GAN模型进行训练。并且,tfgan维护团队也会不断更新tfgan,使得其可以对论文中最新提出的GAN模型进行训练。

tfgan项目托管在github中,点击这里可以查看tfgan在github中托管的源代码及其官方教程与示例。

tfgan核心功能

tfgan的中函数的功能主要集中在基于tensorflow的LOSS函数、优化器、训练迭代的封装,以及对GAN模型的评估。其它的如数据集的输入、生成器和判别器模型的结构以及推断过程则需要通过调用tensorflow函数自己编写。即使这样,tfgan也极大的简化了GAN的训练与实现。接下来就针对tfgan中的几个核心功能对应的函数进行一个预览,以便对tfgan有一个初步印象。具体的用法将在后续文章中详细说明。注意:以下的代码中的函数均为调用,而不是函数原型

tfgan核心函数示例

·初始化模型

以Original-GAN为例进行说明,其它的例如C-GAN, info-GAN, Cycle-GAN等的情况与此处略有不同,在后续文章中会有具体说明。

gan_model = tfgan.gan_model(
generator_fn=generator,
discriminator_fn=discriminator,
real_data=images,
generator_inputs=tf.random.normal(
[batch_size, noise_dims]
)
)

在tfgan中,调用gan_model函数以创建Original-GAN网络模型,其主要参数包含4个,以下详细说明:

generator_fn:需要先自定义一个生成器函数,函数中定义判别器网络模型,并将函数名称作为参数传入。定义的生成器函数的接口应当符合如下格式:

def generator(noise, weight_decay=2.5e-5, is_training=True):
'''GAN Generator. Args:
noise: A 2D Tensor of shape [batch size, noise dim].
weight_decay: The value of the l2 weight decay.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics. Returns:
A generated image.
'''

discriminator_fn:同样,需要首先自定义一个判别器函数,函数中定义判别器网络模型。并将函数名称作为参数传入。定义的判别器函数的接口应当符合如下格式:

def discriminator(img, unused_conditioning, weight_decay=2.5e-5):
'''GAN discriminator. Args:
img: Real or generated MNIST digits. Should be in the range [-1, 1].
unuseed_conditioning: The TFGAN API can help with conditional GANs, which
would require extra `condition` information to both the generator and the
discriminator. Since this example is not conditional, we do not use this
argument.
weight_decay: The L2 weight decay. Returns:
Logits for the probability that the image is real.
'''

real_data:真实图像。一个batch的Tensor格式。

generator_inputs:输入GAN的随机噪声,一般通过tf.random.normal()函数获得。

·指定损失函数

使用gan_loss函数指定训练GAN时所需要的损失函数,若调用形式如下所示,使用默认的损失函数:

gan_loss = tfgan.gan_loss(gan_model, add_summaries=True)

gan_model:上一步初始化模型时的返回值。

add_summaries:是否添加损失的总结。tfgan在训练时,会自动生成tensorboard的日志信息(日志的位置将在最后一步“gan_train”函数中指定,tensorboard是一个适配于tensorflow的训练过程可视化工具),若为True,将添加loss的信息到日志中。

或者使用tfgan中内置的其它loss函数,下面的函数调用时就使用了带权重惩罚的W距离。或者可以自己自定义loss函数,此处不再详述。

gan_loss = tfgan.gan_loss(
gan_model,
generator_loss_fn=tfgan.losses.modified_generator_loss,
discriminator_loss_fn=tfgan.losses.modified_discriminator_loss,
mutual_information_penalty_weight=1.0,
add_summaries=True

·指定优化器

train_ops = tfgan.gan_train_ops(
gan_model,
gan_loss,
generator_optimizer=tf.compat.v1.train.AdamOptimizer(3e-3, 0.5),
discriminator_optimizer=tf.compat.v1.train.AdamOptimizer(3e-4, 0.5),
summarize_gradients=True
)

优化器一般需要传递4个参数:

gan_model:第一步调用tfgan.gan_model的返回值

gan_loss:第二步调用tfgan.gan_loss的返回值

generator_optimizer:指定生成器的优化器

discriminator_optimizer:指定判别器的优化器

summarize_gradients:添加梯度的总结

·开始训练

tfgan.gan_train(
train_ops,
hooks=[
tf.estimator.StopAtStepHook(num_steps=max_number_of_steps),
tf.estimator.LoggingTensorHook([status_message], every_n_iter=20)
],
logdir=train_log_dir,
get_hooks_fn=tfgan.get_joint_train_hooks(),
save_checkpoint_secs=60
)

参数解释:

train_ops:上一步函数的返回值

hooks:tf.train.SessionRunHook类型的回调函数,用列表形式封装。此处的函数将在每次训练迭代时调用

logdir:tfgan自动将建立好的网络模型以及训练过程的参数变化存储下来,此参数即为存储的位置

get_hooks_fn:G和D的训练方式,get_joint_train_hooks()意为进行一次G+D的参数更新,然后再单独进行一次D的参数更新。以此为一个迭代周期。

save_checkpoint_secs:训练过程中参数存储周期,此处设置为60s存储一次网络参数。

调用gan_train函数后,训练开始进行。

使用tfgan进行GAN网络训练步骤:

1.定义Generator与Discriminator网络模型;

2.加载训练集数据为batch形式;

3.调用gan_model函数以初始化网络模型;

4.调用gan_loss函数以指定损失函数;

5.调用gan_train_ops函数以指定优化器;

6.调用gan_train函数开始训练;

7.训练完毕后,tfgan自动将网络模型及参数以及训练过程的总结(summarise)存储在硬盘中。

使用tfgan进行推断的步骤:

1.从tfgan保存的日志中加载网络模型及参数;

2.加载测试数据;

3.将数据传入(feed)网络,得到结果。

tfgan折腾笔记(一):核心功能简要概述的更多相关文章

  1. tfgan折腾笔记(三):核心函数详述——gan_loss族

    gan_loss族的函数有: 1.gan_loss: 函数原型: def gan_loss( # GANModel. model, # Loss functions. generator_loss_f ...

  2. tfgan折腾笔记(二):核心函数详述——gan_model族

    定义model的函数有: 1.gan_model 函数原型: def gan_model( # Lambdas defining models. generator_fn, discriminator ...

  3. 简要分析武汉一起好P2P平台的核心功能

    写作背景 加入武汉一起好,正式工作40天了,对公司的核心业务有了更多的了解,想梳理下自己对于P2P平台的认识. 武汉一起好,自己运营的yiqihao.com,是用PHP实现的,同时也帮助若干P2P平台 ...

  4. APPCAN学习笔记004---AppCan与Hybrid,appcan概述

    APPCAN学习笔记004---AppCan与Hybrid,appcan概述 技术qq交流群:JavaDream:251572072 本节讲了appcan的开发流程,和开发工具 笔记不做具体介绍了,以 ...

  5. VSTO学习笔记(一)VSTO概述

    原文:VSTO学习笔记(一)VSTO概述 接触VSTO纯属偶然,前段时间因为忙于一个项目,在客户端Excel中制作一个插件,从远程服务器端(SharePoint Excel Services)上下载E ...

  6. Chrome扩展开发之四——核心功能的实现思路

    目录: 0.Chrome扩展开发(Gmail附件管理助手)系列之〇——概述 1.Chrome扩展开发之一——Chrome扩展的文件结构 2.Chrome扩展开发之二——Chrome扩展中脚本的运行机制 ...

  7. [编程笔记]第一章 C语言概述

    //C语言学习笔记 第一讲 C语言概述 第二讲 基本编程知识 第三讲 运算符和表达式 第四讲 流程控制 第五讲 函数 第六讲 数组 第七讲 指针 第八讲 变量的作用域和存储方式 第九讲 拓展类型 第十 ...

  8. 笔记-scrapy-辅助功能

    笔记-scrapy-辅助功能 1.      scrapy爬虫管理 爬虫主体写完了,要部署运行,还有一些工程性问题: 限频 爬取深度限制 按条件停止,例如爬取次数,错误次数: 资源使用限制,例如内存限 ...

  9. 自己实现spring核心功能 三

    前言 前两篇已经基本实现了spring的核心功能,下面讲到的参数绑定是属于springMvc的范畴了.本篇主要将请求到servlet后怎么去做映射和处理.首先来看一看dispatherServlet的 ...

随机推荐

  1. [原]调试实战——使用windbg调试崩溃在ComFriendlyWaitMtaThreadProc

    原调试debugwindbgcrash崩溃COM 前言 这是几年前在项目中遇到的一个崩溃问题,崩溃在了ComFriendlyWaitMtaThreadProc()里,没有源码.耗费了我很大精力,最终通 ...

  2. Java任务调度框架之分布式调度框架XXL-Job介绍

    ​ Java任务调度框架之分布式调度框架XXL-Job介绍及快速入门 调度器使用场景: Java开发中经常会使用到定时任务:比如每月1号凌晨生成上个月的账单.比如每天凌晨1点对上一天的数据进行对账操作 ...

  3. Kubernetes系列三:二进制安装Kubernetes环境

    安装环境: # 三个节点信息 192.168.31.11 主机名:env11 角色:部署Master节点/Node节点/ETCD节点 192.168.31.12 主机名:env12 角色:部署Node ...

  4. Java——Collection集合、迭代器、泛型

    集合 ——集合就是java提供的一种容器,可以用来存储多个数据. 集合和数组的区别 数组的长度是固定的.集合的长度是可变的. 数组中存储的是同一类型的元素,可以存储基本数据类型值. 集合存储的都是对象 ...

  5. 【ccf- csp201509-4】高速公路

    #include<iostream> using namespace std; void DFS(int**mat, int *mark,int *sp, int n, int p) { ...

  6. 天大IPv6使用指南(老校区)

    天津大学是CERNET地区网络中心和地区主结点之一,提供良好的IPv6服务,在老校区最大接入宽带达到100Mbps,下载资源非常方便. 但是,在天大使用IPv6时,同学们是不是经常出现时断时续的现象呢 ...

  7. fatal: remote origin already exists.

    解决方法: 先删除, 再添加 1. git remote rm origin 2. git remote add origin https://github.com/zjulanjian/eshop. ...

  8. Python基础——类new方法与单例模式

    介绍: new方法是类中魔术方法之一,他的作用是给类实例化开辟一个内存地址,并返回一个实例化,再由__init__对这个实例进行初始化,故它的执行肯定就是在初始化方法__init__之前了.new方法 ...

  9. nginx应用geoip模块,实现不同地区访问不同页面的需求(实践版)

    https://www.52os.net/articles/configure-nginx-using-geoip-allow-whitelist.html       搞了几天没有搞定,这篇文章一下 ...

  10. python-django电商项目-需求分析架构设计数据库设计_20191115

    python-django电商项目需求分析 1.用户模块 1)注册页 注册时校验用户名是否已被注册. 完成用户信息的注册. 给用户的注册邮箱发送邮件,用户点击邮件中的激活链接完成用户账户的激活. 2) ...