前言

几周前,我在AnitaB.org组织的Hopperx1 London上发表了演讲作为伦敦科技周的一部分。

在演讲结束后,我收到了热烈的反馈,所以我决定写一个稍微长一点的演讲版本来介绍FlashTorch

该软件包可通过pip安装。Github仓库链接(https://github.com/MisaOgura/flashtorch)。你也可以在Google Colab上运行(https://colab.research.google.com/github/MisaOgura/flashtorch/blob/master/examples/visualise_saliency_with_backprop_colab.ipynb),从而无需安装任何东西!

特征可视化简介

特征可视化是一个活跃的研究领域,旨在探索我们观看"神经网络看到的图像"的方法来了解神经网络如何感知图像。它的出现和发展是为了响应人们越来越希望神经网络能够被人类解读。

最早的工作包括分析输入图像中神经网络正在关注的内容。例如,图像特定类显著图(image-specific class saliency maps)(https://arxiv.org/abs/1312.6034)通过反向传播计算相对于输入图像的类的梯度,可视化输入图像中对对应类别贡献最大的区域(在后面有更多关于显著图的内容) 。

特征可视化中的另一个技术是激活最大化(Activation maximization)。这允许我们迭代地更新输入图像(最初由一些随机噪声产生)以生成最大程度地激活目标神经元的图像。它提供了一些关于个体神经元如何响应输入图像的直觉。这就是谷歌推广的所谓的Deep Dream背后的技术。

这是一个巨大的进步,但是有一些缺点,因为它没有提供观察整个网络如何运作的能力,因为神经元不是孤立运作。这导致了对神经元之间相互作用的可视化研究。Olah等人通过在两个神经元之间进行添加或插值来演示激活空间的算术性质。

然后Olah通过分析在给定特定输入时每个神经元在隐藏层内的贡献,进一步定义了更有意义的可视化单元。观察一组同时被强烈激活的神经元,发现似乎有一组神经元负责捕捉诸如耷拉的耳朵、毛茸茸的腿和草之类的概念。

该领域的最新发展之一是Activation Atlas(Carter等,2019)(https://distill.pub/2019/activation-atlas/)。在这项研究中,作者指出了可视化过滤器激活的一个主要缺点,因为它只给出了一个有限的网络如何响应单个输入的视图。为了看到一个大的网络如何感知大量的对象和这些对象之间的联系,他们设计了一种方法,通过显示神经元的常见组合,来创建一个通过神经网络可以看到的全局图。

FlashTorch实现的动机

当我发现特征可视化时,我立即被吸引这项技术使神经网络更易于解释的潜力。然后我很快意识到没有工具可以轻松地将这些技术应用到我在PyTorch中构建的神经网络。

所以我决定建立一个FlashTorch,现在它可以通过pip进行安装!我实现的第一个特征可视化技术是显著图

我们将在下面详细介绍哪些显著图,以及如何使用FlashTorch它们与神经网络一起实现它们。

显著图

人类视觉感知中的显着性是一种主观能力,使得视野中的某些事物脱颖而出并引起我们的注意。计算机视觉中的显著图可以指示图像中最显着的区域。

从卷积神经网络(CNN)创建显著图的方法最初于2013年在Deep Inside Convolutional Networks:Visualizing Image Classification Models and Saliency Maps(https://arxiv.org/abs/1312.6034)中引入。作者报告说,通过计算目标类相对于输入图像的梯度,我们可以可视化输入图像中的区域,这些区域对该类的预测值有影响。

使用FlashTorch的显着性

不用多说,让我们自己使用FlashTorch可视化显著图!

FlashTorch附带了一些utils功能,使数据处理更容易。我们将以灰色的猫头鹰这个图像为例。

然后我们将对图像应用一些变换,使其形状,类型和值适合作为CNN的输入。

我将使用被ImageNet分类数据集进行预训练的AlexNet来进行可视化。事实上,FlashTorch支持所有随torchvision推出的模型,所以我鼓励您也尝试其他模型!

Backprop类的核心就是创建显著图。

在实例化时,它接收模型Backprop(model)并将自定义钩子注册到网络中感兴趣的层,以便我们可以从计算图中获取中间梯度以进行可视化。由于PyTorch的设计方式,这些中间梯度并不是立即可用的。FlashTorch会帮你整理。

现在,在计算梯度之前我们需要的最后一件事,目标类索引。

回顾一下,我们对目标类相对于输入图像的梯度感兴趣。然而,该模型使用ImageNet数据集进行预训练,因此其预测实际上是1000个类别的概率。我们希望从这1000个值中找出目标类的值(在我们的例子中是灰色的猫头鹰),以避免不必要的计算,并且只关注输入图像和目标类之间的关系。

为此,我还实现了一个名为ImageNetIndex的类。如果你不想下载整个数据集,只想根据类名找出类索引,这是一个方便的工具。如果给它一个类名,它将找到相应的类索引target_class = imagenet['great grey owl']。如果你确实要下载数据集,请使用最新版本中提供的ImageNet类torchvision==0.3.0。

现在,我们有输入图像和目标类索引(24),所以我们准备计算梯度!

这两行是关键:

gradients = backprop.calculate_gradients(input_, target_class)

max_gradients = backprop.calculate_gradients(input_, target_class, take_max=True)

默认情况下,每个颜色通道都会计算梯度,所以它的s形与我们的输入图像(3,224,224)相同。有时候,如果我们在不同的颜色通道上设置最大的梯度,就可以更容易地看到梯度。我们可以通过将take_max=True参数传递给方法调用来实现这一点。梯度的形状将是(1,224,224)

最后,让我们看看我们得到了什么!

我们可以看到,动物所在区域的像素对预测值的影响最大。

但是这是一种噪声信号,它不能告诉我们很多关于神经网络对猫头鹰的感知。

有什么方法可以改善这一点吗

通过引导反向传播来改善

答案是肯定的!

在"Striving for Simplicity: The All Convolutional Net"(https://arxiv.org/abs/1412.6806)一文中,作者介绍了一种降低梯度计算噪声的方法。

实质上,在引导反向传播(Guided backproagation)中,对目标类的预测值没有影响或有负面影响的神经元被屏蔽并忽略。通过这样做,我们可以防止梯度通过这样的神经元,从而减少噪音。

你可以FlashTorch通过传递guided=True参数来调用方法calculate_gradients来使用带引导的反向传播,如下所示:

让我们可视化进行引导后的梯度。

差异是惊人的!

现在我们可以清楚地看到网络正在关注凹陷的眼睛和猫头鹰的圆头。这些是说服该神经网络将对象分类为的特征灰色的猫头鹰!

但它并不总是专注于眼睛或头部……

正如你所看到的,网络已经学会专注于特征,这些特征与我们认为这些鸟类最有特色的东西大致相符。

特征可视化的应用

通过特征可视化,我们不仅可以更好地了解神经网络对物体的了解,而且我们还可以更好地:

  • 诊断网络出错的原因

  • 找出并纠正算法中的偏差

  • 从只关注神经网络的精确度向前迈进

  • 了解网络行为的原因

  • 阐明神经网络如何学习的机制

立即使用FlashTorch!

如果你有在PyTorch中使用CNN的项目,FlashTorch可以帮助你使你的项目更具解释性和可解释性。

欢迎关注磐创博客资源汇总站:

http://docs.panchuang.net/

欢迎关注PyTorch官方中文教程站:

http://pytorch.panchuang.net/

最便捷的神经网络可视化工具之一--Flashtorch的更多相关文章

  1. 高效使用 Python 可视化工具 Matplotlib

    Matplotlib是Python中最常用的可视化工具之一,可以非常方便地创建海量类型的2D图表和一些基本的3D图表.本文主要介绍了在学习Matplotlib时面临的一些挑战,为什么要使用Matplo ...

  2. Python 可视化工具 Matplotlib

    英文出处:Chris Moffitt. Matplotlib是Python中最常用的可视化工具之一,可以非常方便地创建海量类型的2D图表和一些基本的3D图表.本文主要介绍了在学习Matplotlib时 ...

  3. AI - TensorFlow - 可视化工具TensorBoard

    TensorBoard TensorFlow自带的可视化工具,能够以直观的流程图的方式,清楚展示出整个神经网络的结构和框架,便于理解模型和发现问题. 可视化学习:https://www.tensorf ...

  4. Distill详述「可微图像参数化」:神经网络可视化和风格迁移利器!

    近日,期刊平台 Distill 发布了谷歌研究人员的一篇文章,介绍一个适用于神经网络可视化和风格迁移的强大工具:可微图像参数化.这篇文章从多个方面介绍了该工具. 图像分类神经网络拥有卓越的图像生成能力 ...

  5. 教你如何选择BI数据可视化工具

    本文来自网易云社区. 关于如何选择BI数据可视化工具,总体而言,主流BI产品在选择的时候要除了需要考虑从数据到展现.从公司内到公司外等各种场景,结合前面朋友的回答,还需要考虑以下几点:1:以后的数据处 ...

  6. 可能这是Redis可视化工具最全的横向评测

    1 命令行 不知道大家在日常操作redis时用什么可视化工具呢? 以前总觉得没有什么太好的可视化工具,于是问了一个业内朋友.对方回:你还用可视化工具?直接命令行呀,redis提供了这么多命令,操作起来 ...

  7. CNN可视化技术总结(四)--可视化工具与项目

    CNN可视化技术总结(一)-特征图可视化 CNN可视化技术总结(二)--卷积核可视化 CNN可视化技术总结(三)--类可视化 导言: 前面介绍了可视化的三种方法--特征图可视化,卷积核可视化,类可视化 ...

  8. 99%的人都搞错了的java方法区存储内容,通过可视化工具HSDB和代码示例一次就弄明白了

    https://zhuanlan.zhihu.com/p/269134063  番茄番茄我是西瓜 那是我日夜思念深深爱着的人啊~ 已关注   6 人赞同了该文章 前言 本篇是java内存区域管理系列教 ...

  9. MongoDB 安装和可视化工具

    MongoDB 是一款非常热门的NoSQL,面向文档的数据库管理系统,官方下载地址是:MongoDB,博主选择的是 Enterprise Server (MongoDB 3.2.9)版本,安装在Win ...

随机推荐

  1. PyGame学习笔记之壹

    新建窗口 代码 '''PyGame学习笔记之壹''' import pygame # 引入 PyGame 库 pygame.init() # PyGame 库初始化 screen = pygame.d ...

  2. On Fixed-Point Implementation of Log-MPA for SCMA Signals

    目录 论文来源 摘要 基本概念 1.SCMA 2.SCMA编码器 研究内容 1.基于Log-MPA的SCMA解码器实现过程 论文创新点 借鉴之处 论文来源 本论文来自于IEEE WIRELESS CO ...

  3. 新大陆NB-IoT模块烧写详细过程

    NB-IOT 模块板设置 1. NB-IOT 模块板如下 2.将模块上红色开关 1. 2 向下拨, 3. 4 开关向上拨,如下 3.将黑色开关向左侧拨至 M3 芯片处,如下 4.将模块上启动/下载开关 ...

  4. 关于localStorage面试的那点事

    最近面试的时候关于html5API总会被问到localStorage的问题, 对于一般的问题很简单,无非就是 localStorage.sessionStorage和cookie这三个客户端缓存的区别 ...

  5. 一些大厂的css reset 代码

    不同的浏览器对标签的默认值不同,为了避免页面出现浏览器差异,所以要初始化样式表属性.使用通配符*并不可取,因为会遍历到每一个标签,大型网页会加载过慢,影响性能. 雅虎工程师提供的CSS初始化示例代码: ...

  6. 【colab pytorch】使用tensorboardcolab可视化

    import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from ...

  7. Matplotlib数据可视化(7):图片展示与保存

    In [1]: import os import matplotlib.image as mpimg from PIL import Image import matplotlib.pyplot as ...

  8. OpenMP Programming

    一.OpenMP概述 1.OpenMP应用编程接口API是在共享存储体系结构上的一个编程模型 2.包含 编译制导(compiler directive).运行库例程(runtime library). ...

  9. Spark入门(四)--Spark的map、flatMap、mapToPair

    spark的RDD操作 在上一节Spark经典的单词统计中,了解了几个RDD操作,包括flatMap,map,reduceByKey,以及后面简化的方案,countByValue.那么这一节将介绍更多 ...

  10. D2T1服务器需求——毒?瘤题(并不是

    这题我第一眼居然差点错了\(OTZ\) 然后写了线段树,还写挂了-- 写好了\(query\)操作,发现似乎不需要区间查询,然后又删掉-- 看着这熟悉的操作,似乎在哪里见过-- 然后我莫名其妙把一个\ ...