gan_loss族的函数有:

1.gan_loss:

函数原型:

def gan_loss(
# GANModel.
model,
# Loss functions.
generator_loss_fn=tuple_losses.wasserstein_generator_loss,
discriminator_loss_fn=tuple_losses.wasserstein_discriminator_loss,
# Auxiliary losses.
gradient_penalty_weight=None,
gradient_penalty_epsilon=1e-10,
gradient_penalty_target=1.0,
gradient_penalty_one_sided=False,
mutual_information_penalty_weight=None,
aux_cond_generator_weight=None,
aux_cond_discriminator_weight=None,
tensor_pool_fn=None,
# Options.
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=True)

参数:

model:gan_model族函数的返回值

generator_loss_fn:生成器使用的损失函数,可用函数见其他说明。

discriminator_loss_fn:判别器使用的损失函数,可用函数见其他说明。

gradient_penalty_weight:如果不是None,则必须提供一个非负数或Tensor,意义为梯度惩罚的权值。

gradient_penalty_epsilon:如果提供了上一个参数,那么这个参数应该提供一个用于在梯度罚函数中维持数值稳定性的较小的正值。 请注意,某些应用程序需要增加此值以避免NaN。

gradient_penalty_target:如果上上个参数不是None,那么这个参数就指明了梯度规范的目标值。应该是一个数值类型或Tensor。

gradient_penalty_one_sided:(暂不明白什么意思)。

mutual_information_penalty_weight:交叉信息惩罚权值。如果不是None,必须提供一个非负数或Tensor。

aux_cond_generator_weight:如果不是None,则添加生成器分类损失。

aux_cond_discriminator_weight:如果不是None,则添加判别器分类损失。

tensor_pool_fn:tensor pool函数。此函数传入tuple类型:(generated_data, generator_inputs),函数将它们放在内部pool中,并且返回上一个pool中的值。如,可以传入tfgan.features.tensor_pool。

reduction:传入tf.losses.Reduction类的函数。

add_summaries:是否添加总结到Tensorboard日志。

返回值:

返回“GANLoss 命名元组”。

函数内部实现:

# Create standard losses with optional kwargs, if the loss functions accept
# them.
def _optional_kwargs(fn, possible_kwargs):
"""Returns a kwargs dictionary of valid kwargs for a given function."""
if inspect.getargspec(fn).keywords is not None:
return possible_kwargs
actual_args = inspect.getargspec(fn).args
actual_kwargs = {}
for k, v in possible_kwargs.items():
if k in actual_args:
actual_kwargs[k] = v
return actual_kwargs
possible_kwargs = {'reduction': reduction, 'add_summaries': add_summaries}
gen_loss = generator_loss_fn(
model, **_optional_kwargs(generator_loss_fn, possible_kwargs))
dis_loss = discriminator_loss_fn(
pooled_model, **_optional_kwargs(discriminator_loss_fn, possible_kwargs))

其他说明:

  • tfgan内置损失函数:
__all__ = [
'acgan_discriminator_loss',
'acgan_generator_loss',
'least_squares_discriminator_loss',
'least_squares_generator_loss',
'modified_discriminator_loss',
'modified_generator_loss',
'minimax_discriminator_loss',
'minimax_generator_loss',
'wasserstein_discriminator_loss',
'wasserstein_hinge_discriminator_loss',
'wasserstein_hinge_generator_loss',
'wasserstein_generator_loss',
'wasserstein_gradient_penalty',
'mutual_information_penalty',
'combine_adversarial_loss',
'cycle_consistency_loss',
'stargan_generator_loss_wrapper',
'stargan_discriminator_loss_wrapper',
'stargan_gradient_penalty_wrapper'
]

2.cyclegan_loss:

函数原型:

def cyclegan_loss(
model,
# Loss functions.
generator_loss_fn=tuple_losses.least_squares_generator_loss,
discriminator_loss_fn=tuple_losses.least_squares_discriminator_loss,
# Auxiliary losses.
cycle_consistency_loss_fn=tuple_losses.cycle_consistency_loss,
cycle_consistency_loss_weight=10.0,
# Options
**kwargs)

参数:

model:gan_model族函数的返回值

generator_loss_fn:生成器使用的损失函数。

discriminator_loss_fn:判别器使用的损失函数。

cycle_consistency_loss_fn:循环一致性损失函数。

cycle_consistency_loss_weight:循环一致性损失的权值。

**kwargs:这里的参数将直接传递给cyclegan_loss函数内部调用的gan_loss函数。

返回值:

返回“CycleGANLoss 命名元组”。

函数内部实现:

循环一致性损失函数与权值的定义:

  # Defines cycle consistency loss.
cycle_consistency_loss = cycle_consistency_loss_fn(
model, add_summaries=kwargs.get('add_summaries', True))
cycle_consistency_loss_weight = _validate_aux_loss_weight(
cycle_consistency_loss_weight, 'cycle_consistency_loss_weight')
aux_loss = cycle_consistency_loss_weight * cycle_consistency_loss

**kwargs的实现:

  # Defines losses for each partial model.
def _partial_loss(partial_model):
partial_loss = gan_loss(
partial_model,
generator_loss_fn=generator_loss_fn,
discriminator_loss_fn=discriminator_loss_fn,
**kwargs)
return partial_loss._replace(generator_loss=partial_loss.generator_loss +
aux_loss) with tf.compat.v1.name_scope('cyclegan_loss_x2y'):
loss_x2y = _partial_loss(model.model_x2y)
with tf.compat.v1.name_scope('cyclegan_loss_y2x'):
loss_y2x = _partial_loss(model.model_y2x)

其他说明:

  • cycle-gan实际上是由两个普通gan组合而成的,其loss是普通gan的loss加上循环一致性损失。
  • 循环一致性损失权值越大,则X->Y->X循环的相似性方面学习的越快。

3.stargan_loss:

函数原型:

def stargan_loss(
model,
generator_loss_fn=tuple_losses.stargan_generator_loss_wrapper(
losses_wargs.wasserstein_generator_loss),
discriminator_loss_fn=tuple_losses.stargan_discriminator_loss_wrapper(
losses_wargs.wasserstein_discriminator_loss),
gradient_penalty_weight=10.0,
gradient_penalty_epsilon=1e-10,
gradient_penalty_target=1.0,
gradient_penalty_one_sided=False,
reconstruction_loss_fn=tf.compat.v1.losses.absolute_difference,
reconstruction_loss_weight=10.0,
classification_loss_fn=tf.compat.v1.losses.softmax_cross_entropy,
classification_loss_weight=1.0,
classification_one_hot=True,
add_summaries=True)

参数:

model:gan_model族函数的返回值

generator_loss_fn:生成器使用的损失函数。

discriminator_loss_fn:判别器使用的损失函数。

gradient_penalty_weight:如果不是None,则必须提供一个非负数或Tensor,意义为梯度惩罚的权值。

gradient_penalty_epsilon:如果提供了上一个参数,那么这个参数应该提供一个用于在梯度罚函数中维持数值稳定性的较小的正值。 请注意,某些应用程序需要增加此值以避免NaN。

gradient_penalty_target:如果上上个参数不是None,那么这个参数就指明了梯度规范的目标值。应该是一个数值类型或Tensor。

gradient_penalty_one_sided:(暂不明白什么意思)。

reconstruction_loss_fn:重建损失函数。

reconstruction_loss_weight:重建损失的权重。

classification_loss_fn:分类损失函数。

classification_loss_weight:分类损失的权重。

classification_one_hot:分类的one_hot_label。

add_summaries:是否向tensorboard添加总结。

返回值:

返回“StarGANLoss 命名元组”。

函数内部实现:

梯度惩罚函数与权值的定义:

  # Gradient Penalty.
if _use_aux_loss(gradient_penalty_weight):
gradient_penalty_fn = tuple_losses.stargan_gradient_penalty_wrapper(
losses_wargs.wasserstein_gradient_penalty)
discriminator_loss += gradient_penalty_fn(
model,
epsilon=gradient_penalty_epsilon,
target=gradient_penalty_target,
one_sided=gradient_penalty_one_sided,
add_summaries=add_summaries) * gradient_penalty_weight

重建损失函数与权值的定义:

  # Reconstruction Loss.
reconstruction_loss = reconstruction_loss_fn(model.input_data,
model.reconstructed_data)
generator_loss += reconstruction_loss * reconstruction_loss_weight
if add_summaries:
tf.compat.v1.summary.scalar('reconstruction_loss', reconstruction_loss)

分类损失函数与权值定义:

  # Classification Loss.
generator_loss += _classification_loss_helper(
true_labels=model.generated_data_domain_target,
predict_logits=model.discriminator_generated_data_domain_predication,
scope_name='generator_classification_loss') * classification_loss_weight
discriminator_loss += _classification_loss_helper(
true_labels=model.input_data_domain_label,
predict_logits=model.discriminator_input_data_domain_predication,
scope_name='discriminator_classification_loss'
) * classification_loss_weight

其他说明:

tfgan折腾笔记(三):核心函数详述——gan_loss族的更多相关文章

  1. tfgan折腾笔记(二):核心函数详述——gan_model族

    定义model的函数有: 1.gan_model 函数原型: def gan_model( # Lambdas defining models. generator_fn, discriminator ...

  2. Typescript 学习笔记三:函数

    中文网:https://www.tslang.cn/ 官网:http://www.typescriptlang.org/ 目录: Typescript 学习笔记一:介绍.安装.编译 Typescrip ...

  3. ES6学习笔记<三> 生成器函数与yield

    为什么要把这个内容拿出来单独做一篇学习笔记? 生成器函数比较重要,相对不是很容易理解,单独做一篇笔记详细聊一聊生成器函数. 标题为什么是生成器函数与yield? 生成器函数类似其他服务器端语音中的接口 ...

  4. tfgan折腾笔记(一):核心功能简要概述

    tfgan是什么? tfgan是tensorflow团队开发出的一个专门用于训练各种GAN的轻量级库,它是基于tensorflow开发的,所以兼容于tensorflow.在tensorflow1.x版 ...

  5. python学习笔记三:函数及变量作用域

    一.定义 def functionName([arg1,arg2,...]): code 二.示例 #!/usr/bin/python #coding:utf8 #coding=utf8 #encod ...

  6. python 学习笔记三 (函数)

    1.把函数视为对象 def factorial(n): '''return n!''' return 1 if n < 2 else n*factorial(n-1) print(factori ...

  7. MySql学习笔记(三) —— 聚集函数的使用

    1.AVG() 求平均数 select avg(prod_price) as avg_price from products; --返回商品价格的平均值 ; --返回生产商id为1003的商品价格平均 ...

  8. wr720n v4 折腾笔记(三):网络配置与扩充USB

    0x01 前言 网络配置比较简单,但是USB拓展就麻烦许多了,这里由于overlay的内存分配问题导致软件安装失败,这里找到了一种方法就是直接从uboot刷入南浦月大神的wr720n的openwrt固 ...

  9. Python 学习笔记三

    笔记三:函数 笔记二已取消置顶链接地址:http://www.cnblogs.com/dzzy/p/5289186.html 函数的作用: 給代码段命名,就像变量給数字命名一样 可以接收参数,像arg ...

随机推荐

  1. Eclipse 热部署方式

    1.tomcat 热部署 1.1方法一:更改 server.xml,更改为 <Context docBase="dreamlive" path="/ROOT&quo ...

  2. 使用命令安装laravel 项目

    cp .env.example .env   拷贝.env 文件 php artisan key:generate  生成秘钥 php artisan migrate   生成数据表 composer ...

  3. chkconfig原理

    ll  /etc/rc.d    里面有运行级别对应的脚本 chkconfig --list  sshd ll /etc/rc.d/rc3.d/   | grep sshd     (查看3启动 里面 ...

  4. MOOC(11)- 获取cookie后存到json中

    获取cookie后转成字典格式 把字典格式cookie存到json数据中 需要在表格中写好关键字,判断是否写cookie.是否读cookie 在需要用cookie的时候根据键去json中取值 # 1. ...

  5. scala编程(四)——类和对象

     类,字段和方法 在scala里定义一个典型的类,代码如下: class ChecksumAccumulator { private var sum = 0 def add(b: Byte): Uni ...

  6. leetcode第38题:报数

    这是一道简单题,但是我做了很久,主要难度在读题和理解题上. 思路:给定一个数字,返回这个数字报数数列.我们可以通过从1开始,不断扩展到n的数列.数列的值为前一个数列的count+num,所以我们不断叠 ...

  7. 【数据结构】B树与B+树

    定义 B 树可以看作是对2-3查找树的一种扩展,即他允许每个节点有M-1个子节点. 根节点至少有两个子节点 每个节点有M-1个key,并且以升序排列 位于M-1和M key的子节点的值位于M-1 和M ...

  8. 树的DFS

    Depth-first search (DFS) is an algorithm for traversing or searching tree or graph data structures. ...

  9. python django ORM

    1.在models.py中创创建类 # -*- coding: utf-8 -*- from __future__ import unicode_literals from django.db imp ...

  10. 编程原理—如何用javascript代码解决一些问题

    关于编程,我最喜欢的就是解决问题.我不相信有谁天生具有解决问题的能力.这是一种通过反复锻炼而建立并维持的能力.像任何练习一样,有一套指导方针可以帮助你更有效地提高解决问题的能力.我将介绍5个最重要的软 ...