Gumble_Softmax 可以解决的问题

场景:对于一个分类任务,通常会使用softmax函数来将模型的输出值转换为概率的形式,并通过argmax函数取最大的概率值标签作为模型的预测标签。在分类任务中,argmax可以不参与反向传播过程(即直接通过softmax值和true_label计算loss),而在其他任务中,例如利用GAN生成文本则需要利用argmax的结果去生成loss,此时argmax是位于反向传播中的。

  • 解决argmax输出值不能反应模型输出的概率分布的问题

    出现问题的原因: 当训练过程中,softmax的输出值为 p = [0.7, 0.2, 0.1], 其表示 p([1, 0, 0]) = 0.7; p([0, 1, 0]) = 0.2; p([0, 0, 1]) = 0.1; 若采用argmax(p), 其结果为 [1, 0, 0] , 且若 p 的值一直维持这样一种概率分布,那么argmax的值将一直保持 [1, 0, 0], 且概率为 1, 这不符合输出的概率分布[0.7, 0.2, 0.1],即100次取值中, [1, 0, 0] 出现70次, [0, 1, 0] 出现20次, [0, 0, 1] 出现10次。

    解决: 在argmax函数中引入随机性,即在argmax函数在引入一个符合gumble分布的噪声值G, G=-log(-log(ξ)), ξ是从[0,1]均匀分布中随机取样的,最后的 argmax函数变为 y = argmax(log§ + G), 最后加上onehot的操作获得输出的采样值 y = onehot(argmax[log§ + G]); § 是模型的逻辑输出值(非softmax后的概率分布), 这样最终获得的值是符合模型逻辑值§经过softmax后的概率分布的,也就是说,100次采样中, [1, 0, 0] 出现70次, [0, 1, 0] 出现20次, [0, 0, 1] 出现10次。

    证明最后的结果是符合模型输出的概率分布的

    理论证明:

    https://blog.csdn.net/u011345885/article/details/122610352

    实例证明:

    参考: https://zhuanlan.zhihu.com/p/495806343

    logits = torch.Tensor([0.7, 0.2, 0.1])
    q1 = {0:0,1:0,2:0}
    for i in range(10000):
    t = torch.nn.functional.gumbel_softmax(logits, tau=0.2).argmax().item()
    q1[t] += 1
    q2 = torch.softmax(logits, -1)
    print("q1", q1)
    print("q2", q2) >> q1 {0: 4670, 1: 2757, 2: 2573}
    >> q2 tensor([0.464, 0.281, 0.255])
  • 解决argmax不可导,无法进行反向传播的问题

    出现的原因: argmax(x,y)不可导的根本原因是其向量空间不是光滑的,有尖锐的点和面;而是某些任务中,argmax会被插入到反向传播的计算图中。

    解决: 在解决上个问题的基础上,我们可以获得one_hot形式的符合模型输出概率分布的采样值 y = onehot(argmax[log§ + G]), 但是其中的 one(argmax()) 还是不可导的操作,所以可以使用softmax 来近似 one(argmax()), 并增加一个温度函数 tau 来控制最后的结果和 真实onehot的近似程度。为什么softmax操作是可导的,其实softmax 就是 one(argmax()) 的光滑化。



    当 tau 足够小时,采样出来的向量十分接近 onehot 形式(类onehot但是不是真实的onehot), 而 tau 比较大时,采样的值接近于均匀分布。一般在训练初期,设置较大的tau,保证模型的充足的探索性;而在训练后期,一般设置较小的tau,生成比较稳定的 类似onehot向量。

    下图是原论文[https://arxiv.org/pdf/1611.01144.pdf] 中对于 tau 参数大小的实验结果。

    可以看出随着温度参数的增大采样值的分布逐渐由类onehot分布转换为均匀分布。

    在 pytorch的 gumbel_softmax 的源码中可以对于其实现原理有一个清晰的认识。

    其中有一个 hard 参数,当hard = False,函数直接返回采样值,当 hard = True, 函数是对采样值进行了一个 max 的操作,最后再和采样值组合在一起。这样的操作使得,在 forward 阶段, 传播的是 onehot值 y_hard; 而在 backpropagation 阶段,传播的是 y_soft 的梯度信息,因为 detach() 函数截断了其余的梯度传播。

    def gumbel_softmax(logits: Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> Tensor:
    #########
    gumbels = (
    -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
    ) # ~Gumbel(0,1)
    gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
    y_soft = gumbels.softmax(dim) if hard:
    # Straight through.
    index = y_soft.max(dim, keepdim=True)[1]
    y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
    ret = y_hard - y_soft.detach() + y_soft
    else:
    # Reparametrization trick.
    ret = y_soft
    return ret

Gumbel_Softmax 概要的更多相关文章

  1. .Net 分布式云平台基础服务建设说明概要

    1)  背景 建设云平台的基础框架,用于支持各类云服务的业务的构建及发展. 2)  基础服务 根据目前对业务的理解和发展方向,总结抽象出以下几个基础服务,如图所示 3)  概要说明 基础服务的发展会根 ...

  2. 前端MVC学习总结(一)——MVC概要与angular概要、模板与数据绑定

    一.前端MVC概要 1.1.库与框架的区别 框架是一个软件的半成品,在全局范围内给了大的约束.库是工具,在单点上给我们提供功能.框架是依赖库的.AngularJS是框架而jQuery则是库. 1.2. ...

  3. HTML5 学习总结(一)——HTML5概要与新增标签

    一.HTML5概要 1.1.为什么需要HTML5 HTML4陈旧不能满足日益发展的互联网需要,特别是移动互联网.为了增强浏览器功能Flash被广泛使用,但安全与稳定堪忧,不适合在移动端使用(耗电.触摸 ...

  4. CSS3与页面布局学习总结(一)——概要、选择器、特殊性与刻度单位

    web前端开发者最最注的内容是三个:HTML.CSS与JavaScript,他们分别在不同方面发挥自己的作用,HTML实现页面结构,CSS完成页面的表现与风格,JavaScript实现一些客户端的功能 ...

  5. 更改WAS Profiles的概要文件的server1的SDK版本

    WebSphere只能使用IBM JDK 哦,不能使用sun的JDK哦.不过如果只是改jdk的版本的话可以参考如下步骤:(以集群为例,假设具有管理节点Dmgr01,应用概要AppSrv01) 1. 确 ...

  6. HTML5 学习笔记(一)——HTML5概要与新增标签

    目录 一.HTML5概要 1.1.为什么需要HTML5 1.2.什么是HTML5 1.3.HTML5现状及浏览器支持 1.4.HTML5特性 1.5.HTML5优点与缺点 1.5.1.优点 1.5.2 ...

  7. [ASP.NET MVC 大牛之路]03 - C#高级知识点概要(2) - 线程和并发

    本人博客已转移至:http://www.exblr.com/liam  我也想过跳过C#高级知识点概要直接讲MVC,但经过前思后想,还是觉得有必要讲的.我希望通过自己的经验给大家一些指引,带着大家一起 ...

  8. Ajax概要:

    Ajax概要: Ajax不是个全新的技术,它是多种技术合并在一起产生的,包括XHTML,CSS,JavaScript,XmlHttpRequest,XML,JSON,DOM等 优点:(这也解释了为何我 ...

  9. Css概要与选择器,刻度单位

    目录 一.CSS3概要 1.1.特点 1.2.效果演示 1.3.帮助文档与学习 二.选择器 1.1.基础的选择器 1.2.组合选择器 1.3.属性选择器 1.4.伪类 1.5.伪元素 三.特殊性(优先 ...

随机推荐

  1. SpringCloud基础概念学习笔记(Eureka、Ribbon、Feign、Zuul)

    SpringCloud基础概念学习笔记(Eureka.Ribbon.Feign.Zuul) SpringCloud入门 参考: https://springcloud.cc/spring-cloud- ...

  2. 安装Squid到CentOS(YUM)

    运行环境 系统版本:CentOS Linux release 7.3.1611 (Core) 软件版本:无 硬件要求:无 安装过程 1.关闭防火墙和SeLinux [root@localhost ~] ...

  3. 第06组 Beta冲刺 (5/5)

    目录 1.1 基本情况 1.2 冲刺概况汇报 1.郝雷明 2. 方梓涵 3.曾丽莉 4.黄少丹 5. 董翔云 6.鲍凌函 7.杜筱 8.詹鑫冰 9.曹兰英 10.吴沅静 1.3 冲刺成果展示 1.1 ...

  4. 看Spring源码不得不会的@Enable模块驱动实现原理讲解

    这篇文章我想和你聊一聊 spring的@Enable模块驱动的实现原理. 在我们平时使用spring的过程中,如果想要加个定时任务的功能,那么就需要加注解@EnableScheduling,如果想使用 ...

  5. C# 类继承中的私有字段都去了哪里?

    最近在看 C++ 类继承中的字段内存布局,我就很好奇 C# 中的继承链那些 private 字段都哪里去了? 在内存中是如何布局的,毕竟在子类中是无法访问的. 一:举例说明 为了方便讲述,先上一个例子 ...

  6. CSS(九):background(背景属性)

    通过CSS背景属性,可以给页面元素添加背景样式. 背景属性可以设置背景颜色.背景图片.背景平铺.背景图像固定等. background-color(背景颜色) background-color属性定义 ...

  7. Jmeter(五十三) - 从入门到精通高级篇 - 懒人教你在Linux系统中安装Jmeter(详解教程)

    1.简介 我们绝大多数使用的都是Windows操作系统,因此在Windows系统上安装JMeter已经成了家常便饭,而且安装也相对简单,但是服务器为了安全.灵活小巧,特别是前几年的勒索病毒,现在绝大多 ...

  8. 皓远的第二次博客作业(最新pta集,链表练习及期中考试总结)

    前言: 知识点运用:正则表达式,有关图形设计计算的表达式和算法,链表的相关知识,Java类的基础运用,继承.容器与多态. 题量:相较于上次作业,这几周在java方面的练习花了更多的精力和时间,所要完成 ...

  9. NHibernte 4.0.3版本中,使用Queryover().Where().OrderBy().Skip().Take()方法分页获取数据失败

    问题代码如下: var result=repository.QueryOver<modal>() .Where(p=>p.Code==Code) .OrderBy(p=>p.I ...

  10. grafana整合zabbix

    1. 安装grafana wget https://dl.grafana.com/oss/release/grafana-7.5.7-1.x86_64.rpm rpm -i --nodeps graf ...