tfgan折腾笔记(三):核心函数详述——gan_loss族
gan_loss族的函数有:
1.gan_loss:
函数原型:
def gan_loss(
# GANModel.
model,
# Loss functions.
generator_loss_fn=tuple_losses.wasserstein_generator_loss,
discriminator_loss_fn=tuple_losses.wasserstein_discriminator_loss,
# Auxiliary losses.
gradient_penalty_weight=None,
gradient_penalty_epsilon=1e-10,
gradient_penalty_target=1.0,
gradient_penalty_one_sided=False,
mutual_information_penalty_weight=None,
aux_cond_generator_weight=None,
aux_cond_discriminator_weight=None,
tensor_pool_fn=None,
# Options.
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=True)
参数:
model:gan_model族函数的返回值
generator_loss_fn:生成器使用的损失函数,可用函数见其他说明。
discriminator_loss_fn:判别器使用的损失函数,可用函数见其他说明。
gradient_penalty_weight:如果不是None,则必须提供一个非负数或Tensor,意义为梯度惩罚的权值。
gradient_penalty_epsilon:如果提供了上一个参数,那么这个参数应该提供一个用于在梯度罚函数中维持数值稳定性的较小的正值。 请注意,某些应用程序需要增加此值以避免NaN。
gradient_penalty_target:如果上上个参数不是None,那么这个参数就指明了梯度规范的目标值。应该是一个数值类型或Tensor。
gradient_penalty_one_sided:(暂不明白什么意思)。
mutual_information_penalty_weight:交叉信息惩罚权值。如果不是None,必须提供一个非负数或Tensor。
aux_cond_generator_weight:如果不是None,则添加生成器分类损失。
aux_cond_discriminator_weight:如果不是None,则添加判别器分类损失。
tensor_pool_fn:tensor pool函数。此函数传入tuple类型:(generated_data, generator_inputs),函数将它们放在内部pool中,并且返回上一个pool中的值。如,可以传入tfgan.features.tensor_pool。
reduction:传入tf.losses.Reduction类的函数。
add_summaries:是否添加总结到Tensorboard日志。
返回值:
返回“GANLoss 命名元组”。
函数内部实现:
# Create standard losses with optional kwargs, if the loss functions accept
# them.
def _optional_kwargs(fn, possible_kwargs):
"""Returns a kwargs dictionary of valid kwargs for a given function."""
if inspect.getargspec(fn).keywords is not None:
return possible_kwargs
actual_args = inspect.getargspec(fn).args
actual_kwargs = {}
for k, v in possible_kwargs.items():
if k in actual_args:
actual_kwargs[k] = v
return actual_kwargs
possible_kwargs = {'reduction': reduction, 'add_summaries': add_summaries}
gen_loss = generator_loss_fn(
model, **_optional_kwargs(generator_loss_fn, possible_kwargs))
dis_loss = discriminator_loss_fn(
pooled_model, **_optional_kwargs(discriminator_loss_fn, possible_kwargs))
其他说明:
- tfgan内置损失函数:
__all__ = [
'acgan_discriminator_loss',
'acgan_generator_loss',
'least_squares_discriminator_loss',
'least_squares_generator_loss',
'modified_discriminator_loss',
'modified_generator_loss',
'minimax_discriminator_loss',
'minimax_generator_loss',
'wasserstein_discriminator_loss',
'wasserstein_hinge_discriminator_loss',
'wasserstein_hinge_generator_loss',
'wasserstein_generator_loss',
'wasserstein_gradient_penalty',
'mutual_information_penalty',
'combine_adversarial_loss',
'cycle_consistency_loss',
'stargan_generator_loss_wrapper',
'stargan_discriminator_loss_wrapper',
'stargan_gradient_penalty_wrapper'
]
2.cyclegan_loss:
函数原型:
def cyclegan_loss(
model,
# Loss functions.
generator_loss_fn=tuple_losses.least_squares_generator_loss,
discriminator_loss_fn=tuple_losses.least_squares_discriminator_loss,
# Auxiliary losses.
cycle_consistency_loss_fn=tuple_losses.cycle_consistency_loss,
cycle_consistency_loss_weight=10.0,
# Options
**kwargs)
参数:
model:gan_model族函数的返回值
generator_loss_fn:生成器使用的损失函数。
discriminator_loss_fn:判别器使用的损失函数。
cycle_consistency_loss_fn:循环一致性损失函数。
cycle_consistency_loss_weight:循环一致性损失的权值。
**kwargs:这里的参数将直接传递给cyclegan_loss函数内部调用的gan_loss函数。
返回值:
返回“CycleGANLoss 命名元组”。
函数内部实现:
循环一致性损失函数与权值的定义:
# Defines cycle consistency loss.
cycle_consistency_loss = cycle_consistency_loss_fn(
model, add_summaries=kwargs.get('add_summaries', True))
cycle_consistency_loss_weight = _validate_aux_loss_weight(
cycle_consistency_loss_weight, 'cycle_consistency_loss_weight')
aux_loss = cycle_consistency_loss_weight * cycle_consistency_loss
**kwargs的实现:
# Defines losses for each partial model.
def _partial_loss(partial_model):
partial_loss = gan_loss(
partial_model,
generator_loss_fn=generator_loss_fn,
discriminator_loss_fn=discriminator_loss_fn,
**kwargs)
return partial_loss._replace(generator_loss=partial_loss.generator_loss +
aux_loss) with tf.compat.v1.name_scope('cyclegan_loss_x2y'):
loss_x2y = _partial_loss(model.model_x2y)
with tf.compat.v1.name_scope('cyclegan_loss_y2x'):
loss_y2x = _partial_loss(model.model_y2x)
其他说明:
- cycle-gan实际上是由两个普通gan组合而成的,其loss是普通gan的loss加上循环一致性损失。
- 循环一致性损失权值越大,则X->Y->X循环的相似性方面学习的越快。
3.stargan_loss:
函数原型:
def stargan_loss(
model,
generator_loss_fn=tuple_losses.stargan_generator_loss_wrapper(
losses_wargs.wasserstein_generator_loss),
discriminator_loss_fn=tuple_losses.stargan_discriminator_loss_wrapper(
losses_wargs.wasserstein_discriminator_loss),
gradient_penalty_weight=10.0,
gradient_penalty_epsilon=1e-10,
gradient_penalty_target=1.0,
gradient_penalty_one_sided=False,
reconstruction_loss_fn=tf.compat.v1.losses.absolute_difference,
reconstruction_loss_weight=10.0,
classification_loss_fn=tf.compat.v1.losses.softmax_cross_entropy,
classification_loss_weight=1.0,
classification_one_hot=True,
add_summaries=True)
参数:
model:gan_model族函数的返回值
generator_loss_fn:生成器使用的损失函数。
discriminator_loss_fn:判别器使用的损失函数。
gradient_penalty_weight:如果不是None,则必须提供一个非负数或Tensor,意义为梯度惩罚的权值。
gradient_penalty_epsilon:如果提供了上一个参数,那么这个参数应该提供一个用于在梯度罚函数中维持数值稳定性的较小的正值。 请注意,某些应用程序需要增加此值以避免NaN。
gradient_penalty_target:如果上上个参数不是None,那么这个参数就指明了梯度规范的目标值。应该是一个数值类型或Tensor。
gradient_penalty_one_sided:(暂不明白什么意思)。
reconstruction_loss_fn:重建损失函数。
reconstruction_loss_weight:重建损失的权重。
classification_loss_fn:分类损失函数。
classification_loss_weight:分类损失的权重。
classification_one_hot:分类的one_hot_label。
add_summaries:是否向tensorboard添加总结。
返回值:
返回“StarGANLoss 命名元组”。
函数内部实现:
梯度惩罚函数与权值的定义:
# Gradient Penalty.
if _use_aux_loss(gradient_penalty_weight):
gradient_penalty_fn = tuple_losses.stargan_gradient_penalty_wrapper(
losses_wargs.wasserstein_gradient_penalty)
discriminator_loss += gradient_penalty_fn(
model,
epsilon=gradient_penalty_epsilon,
target=gradient_penalty_target,
one_sided=gradient_penalty_one_sided,
add_summaries=add_summaries) * gradient_penalty_weight
重建损失函数与权值的定义:
# Reconstruction Loss.
reconstruction_loss = reconstruction_loss_fn(model.input_data,
model.reconstructed_data)
generator_loss += reconstruction_loss * reconstruction_loss_weight
if add_summaries:
tf.compat.v1.summary.scalar('reconstruction_loss', reconstruction_loss)
分类损失函数与权值定义:
# Classification Loss.
generator_loss += _classification_loss_helper(
true_labels=model.generated_data_domain_target,
predict_logits=model.discriminator_generated_data_domain_predication,
scope_name='generator_classification_loss') * classification_loss_weight
discriminator_loss += _classification_loss_helper(
true_labels=model.input_data_domain_label,
predict_logits=model.discriminator_input_data_domain_predication,
scope_name='discriminator_classification_loss'
) * classification_loss_weight
其他说明:
无
tfgan折腾笔记(三):核心函数详述——gan_loss族的更多相关文章
- tfgan折腾笔记(二):核心函数详述——gan_model族
定义model的函数有: 1.gan_model 函数原型: def gan_model( # Lambdas defining models. generator_fn, discriminator ...
- Typescript 学习笔记三:函数
中文网:https://www.tslang.cn/ 官网:http://www.typescriptlang.org/ 目录: Typescript 学习笔记一:介绍.安装.编译 Typescrip ...
- ES6学习笔记<三> 生成器函数与yield
为什么要把这个内容拿出来单独做一篇学习笔记? 生成器函数比较重要,相对不是很容易理解,单独做一篇笔记详细聊一聊生成器函数. 标题为什么是生成器函数与yield? 生成器函数类似其他服务器端语音中的接口 ...
- tfgan折腾笔记(一):核心功能简要概述
tfgan是什么? tfgan是tensorflow团队开发出的一个专门用于训练各种GAN的轻量级库,它是基于tensorflow开发的,所以兼容于tensorflow.在tensorflow1.x版 ...
- python学习笔记三:函数及变量作用域
一.定义 def functionName([arg1,arg2,...]): code 二.示例 #!/usr/bin/python #coding:utf8 #coding=utf8 #encod ...
- python 学习笔记三 (函数)
1.把函数视为对象 def factorial(n): '''return n!''' return 1 if n < 2 else n*factorial(n-1) print(factori ...
- MySql学习笔记(三) —— 聚集函数的使用
1.AVG() 求平均数 select avg(prod_price) as avg_price from products; --返回商品价格的平均值 ; --返回生产商id为1003的商品价格平均 ...
- wr720n v4 折腾笔记(三):网络配置与扩充USB
0x01 前言 网络配置比较简单,但是USB拓展就麻烦许多了,这里由于overlay的内存分配问题导致软件安装失败,这里找到了一种方法就是直接从uboot刷入南浦月大神的wr720n的openwrt固 ...
- Python 学习笔记三
笔记三:函数 笔记二已取消置顶链接地址:http://www.cnblogs.com/dzzy/p/5289186.html 函数的作用: 給代码段命名,就像变量給数字命名一样 可以接收参数,像arg ...
随机推荐
- “全隐藏式3D摄像头”亮相,FindX如何将设计与体验融为一体
北京时间6月20日,OPPO在卢浮宫发布暌违四年之久的Find旗舰系列新手机--Find X.在Find X背后,我认为其设计值得深思.尤其是Find X为突破传统设计束缚,首创双轨潜望结构有着重要启 ...
- Spring @Column的注解详解
就像@Table注解用来标识实体类与数据表的对应关系类似,@Column注解来标识实体类中属性与数据表中字段的对应关系. 该注解的定义如下: @Target({METHOD, FIELD}) @Ret ...
- 66)PHP,会话技术
其实刷新(F5)就是一个新的请求. 会话技术的实现:1.Cookie 2.Session(其实cookie能做的,session也能做.session能做的,cookie也能做.就是cookie ...
- [Linux] Ubuntu 配置nfs
安装NFS Server: 1. 执行命令 "$ sudo apt-get install nfs-kernel-server",安装nfs server 端 2. 创建需要用来分 ...
- java5的静态导入import static
在Java 5中,import语句得到了增强,以便提供甚至更加强大的减少击键次数功能,虽然一些人争议说这是以可读性为代价的.这种新的特性成为静态导入. 1.静态导入的与普通import的区别: imp ...
- python--mysql的CURD操作
from pymysql import * def main(): # 创建Connextion连接 conn = connect(host='localhost', port=3306, user= ...
- Java实现Luhm算法--银行卡号合法性校验
银行卡是由"发卡行标识代码 + 自定义 + 校验码 "等部分组成的. 银联标准卡与以往发行的银行卡最直接的区别就是其卡号前6位数字的不同. 银行卡卡号的前6位是用来表示发卡银行 ...
- MyBatisUtil
package com.it.util; import java.io.IOException; import java.io.Reader; import org.apache.ibatis.io. ...
- python 简单主机批量管理工具
需求: 主机分组 主机信息配置文件用configparser解析 可批量执行命令.发送文件,结果实时返回,执行格式如下 batch_run -h h1,h2,h3 -g web_cluster ...
- 将Hexo网站托管到Coding.net
只需要注册coding.net,然后建立一个名为用户名+coding.me的仓库即可,需要注意的是 coding.net的pages仓库只能有一个master分支 开始使用 Coding Pages官 ...