神经网络可视化《Grad-CAM:Visual Explanations from Deep Networks via Gradient-based Localization》
神经网络已经在很多场景下表现出了很好的识别能力,但是缺乏解释性一直所为人诟病。《Grad-CAM:Visual Explanations from Deep Networks via Gradient-based Localization》这篇论文基于梯度为其可解释性做了一些工作,它可以显著描述哪块图片区域对识别起了至关重要的作用,以热度图的方式可视化神经网络的注意力。本博客主要是基于pytorch的简单工程复现。原文见这里,本代码基于这里。
1 import torch
2 import torchvision
3 from torchvision import models
4 from torchvision import transforms
5 from PIL import Image
6 import pylab as plt
7 import numpy as np
8 import cv2
9
10
11 class Extractor():
12 """
13 pytorch在设计时,中间层的梯度完成回传后就释放了
14 这里用hook工具在保存中间参数的梯度
15 """
16 def __init__(self, model, target_layer):
17 self.model = model
18 self.target_layer = target_layer
19 self.gradient = None
20
21 def save_gradient(self, grad):
22 self.gradient=grad
23
24 def __call__(self, x):
25 outputs = []
26 self.gradients = []
27 for name,module in self.model.features._modules.items():
28 x = module(x)
29 if name == self.target_layer:
30 x.register_hook(self.save_gradient)
31 target_activation=x
32 x=x.view(1,-1)
33 for name,module in self.model.classifier._modules.items():
34 x = module(x)
35 # 维度为(1,c, h, w) , (1,class_num)
36 return target_activation, x
37
38
39 def preprocess_image(path):
40 means=[0.485, 0.456, 0.406]
41 stds=[0.229, 0.224, 0.225]
42 m_transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(means,stds)])
43 img=Image.open(path)
44 return m_transform(img).reshape(1,3,224,224)
45
46
47 class GradCam():
48 def __init__(self, model, target_layer_name, use_cuda):
49 self.model = model
50 self.model.eval()
51 self.cuda = use_cuda
52 if self.cuda:
53 self.model = model.cuda()
54
55 self.extractor = Extractor(self.model, target_layer_name)
56
57
58 def __call__(self, input, index = None):
59 if self.cuda:
60 target_activation, output = self.extractor(input.cuda())
61 else:
62 target_activation, output = self.extractor(input)
63
64 # index是想要查看的类别,未指定时选择网络做出的预测类
65 if index == None:
66 index = np.argmax(output.cpu().data.numpy())
67
68 # batch维为1(我们默认输入的是单张图)
69 one_hot = np.zeros((1, output.size()[-1]), dtype = np.float32)
70 one_hot[0][index] = 1.0
71 one_hot = torch.tensor(one_hot)
72 if self.cuda:
73 one_hot = torch.sum(one_hot.cuda() * output)
74 else:
75 one_hot = torch.sum(one_hot * output)
76
77 self.model.zero_grad()
78 one_hot.backward(retain_graph=True)
79
80 grads_val = self.extractor.gradient.cpu().data.numpy()
81 # 维度为(c, h, w)
82 target = target_activation.cpu().data.numpy()[0]
83 # 维度为(c,)
84 weights = np.mean(grads_val, axis = (2, 3))[0, :]
85 # cam要与target一样大
86 cam = np.zeros(target.shape[1 : ], dtype = np.float32)
87 for i, w in enumerate(weights):
88 cam += w * target[i, :, :]
89
90 # 每个位置选择c个通道上最大的最为输出
91 cam = np.maximum(cam, 0)
92 cam = cv2.resize(cam, (224, 224))
93 cam = cam - np.min(cam)
94 cam = cam / np.max(cam)
95 return cam
96
97
98 def show_cam_on_image(img, mask):
99 heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
100 heatmap = np.float32(heatmap) / 255
101 cam = heatmap + np.float32(img)
102 cam = cam / np.max(cam)
103 cv2.imwrite("cam2.jpg", np.uint8(255 * cam))
104
105
106 #target_layer 越靠近分类层效果越好
107 grad_cam = GradCam(model = models.vgg19(pretrained=True), target_layer_name = "35", use_cuda=True)
108 input = preprocess_image("both.png")
109 mask = grad_cam(input, None)
110 img = cv2.imread("both.png", 1)
111 #热度图是直接resize加到输入图上的
112 img = np.float32(cv2.resize(img, (224, 224))) / 255
113 show_cam_on_image(img, mask)
原图:

可视化图:

神经网络可视化《Grad-CAM:Visual Explanations from Deep Networks via Gradient-based Localization》的更多相关文章
- Grad-CAM:Visual Explanations from Deep Networks via Gradient-based Localization
目录 Grad-CAM:Visual Explanations from Deep Networks via Gradient-based Localization 1.Abstract 2.Intr ...
- 【论文简读】 Deep web data extraction based on visual
<Deep web data extraction based on visual information processing>作者 J Liu 上海海事大学 2017 AIHC会议登载 ...
- 深度卷积神经网络用于图像缩放Image Scaling using Deep Convolutional Neural Networks
This past summer I interned at Flipboard in Palo Alto, California. I worked on machine learning base ...
- 论文笔记:SiamRPN++: Evolution of Siamese Visual Tracking with Very Deep Networks
SiamRPN++: Evolution of Siamese Visual Tracking with Very Deep Networks 2019-04-02 12:44:36 Paper:ht ...
- 论文笔记之:Action-Decision Networks for Visual Tracking with Deep Reinforcement Learning
论文笔记之:Action-Decision Networks for Visual Tracking with Deep Reinforcement Learning 2017-06-06 21: ...
- Distill详述「可微图像参数化」:神经网络可视化和风格迁移利器!
近日,期刊平台 Distill 发布了谷歌研究人员的一篇文章,介绍一个适用于神经网络可视化和风格迁移的强大工具:可微图像参数化.这篇文章从多个方面介绍了该工具. 图像分类神经网络拥有卓越的图像生成能力 ...
- WPF中的可视化对象(Visual)
原文:WPF中的可视化对象(Visual) 这是MSDN对Visual的解释:Visual class:Provides rendering support in WPF, which include ...
- TensorSpace:超酷炫3D神经网络可视化框架
TensorSpace:超酷炫3D神经网络可视化框架 TensorSpace - 一款 3D 模型可视化框架,支持多种模型,帮助你可视化层间输出,更直观地展示模型的输入输出,帮助理解模型结构和输出方法 ...
- Deep Learning 8_深度学习UFLDL教程:Stacked Autocoders and Implement deep networks for digit classification_Exercise(斯坦福大学深度学习教程)
前言 1.理论知识:UFLDL教程.Deep learning:十六(deep networks) 2.实验环境:win7, matlab2015b,16G内存,2T硬盘 3.实验内容:Exercis ...
随机推荐
- Python工程打包
Python项目打包 我是自己写了一个项目,然后需要打包成问一个exe文件,这样直接打开这个文件就可以运行,而不需要在pycharm中打开相应文件才能运行,也可以将打包好的文件发给其他人,不需要pyc ...
- jmeter工具初探
jmeter工具初探 一.jmeter工具介绍 1.一种免费的java开源工具,可以进行二次开发 2.运行环境:java运行环境,需要安装JDK,配置JAVAHOME 环境变量 3.下载jmeter: ...
- insert语句生成的存储过程
问题: 1.如何配置数据库数据: 方式一:图形界面点击输入数据,导出成sql. 缺点:表多,数据多的时候非常繁琐,字段含义需要另外开窗口对照. 方式二:徒手写或者修改已有语句:insert table ...
- 干货|SQL语句大全,所有的SQL都在这里了(建议收藏)
一个执着于技术的公众号 一.基础 1.登录数据库 mysql -uroot -p123123 2.创建数据库 create database <数据库名> 3.删除数据库 drop dat ...
- linux项目部署(非前后端分离crm)
参考博客 参考博客2---部署过程 导论:看参考博客1 WSGI是Web服务器网关接口.它是一个规范,描述了Web服务器如何与Web应用程序通信,以及Web应用程序如何链接在一起以处理一个请求,(接收 ...
- 代码审计VauditDemo程序到exp编写
要对一个程序做系统的审计工作,很多人都认为代码审计工作是在我们将CMS安装好之后才开始的,其实不然,在安装的时候审计就已经开始了! 一般安装文件为install.php或install/或includ ...
- 记录一次用宝塔部署微信小程序Node.js后端接口代码的详细过程
一直忙着写毕设,上一次写博客还是元旦,大半年过去了.... 后面会不断分享各种新项目的源码与技术.欢迎关注一起学习哈! 记录一次部署微信小程序Node.js后端接口代码的详细过程,使用宝塔来部署. 我 ...
- Java 16 新特性:record类
以前我们定义类都是用class关键词,但从Java 16开始,我们将多一个关键词record,它也可以用来定义类.record关键词的引入,主要是为了提供一种更为简洁.紧凑的final类的定义方式. ...
- 基于dhtmlxGantt的Blazor甘特图组件
基于dhtmlxGantt实现的甘特图组件,目前仅做到了数据展现,方法及插槽暂未实现,若需可按照dhtmlxGantt的文档及微软的Balzor文档,自行扩展. 数据发生变化后甘特图会立即发生变化. ...
- 442. Find All Duplicates in an Array - LeetCode
Question 442. Find All Duplicates in an Array Solution 题目大意:在数据中找重复两次的数 思路:数组排序,前一个与后一个相同的即为要找的数 Jav ...