[NN] Guided Backpropgation 可视化
Pytorch Guided Backpropgation
Intro
guided backpropgation通过修改RELU的梯度反传,使得小于0的部分不反传,只传播大于0的部分,这样到第一个conv层的时候得到的梯度就是对后面relu激活起作用的梯度,这时候我们对这些梯度进行可视化,得到的就是对网络起作用的区域。(实际上可视化的是梯度)。
简单记一下。用到hook的神经网络可视化方法。
code
import torch
import torch.nn as nn
from torchvision import transforms,models
import re
from models.densenet import densenet121
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
class Guided_Prop():
def __init__(self,model):
self.model = model
self.model.eval()
self.out_img = None
self.activation_maps = []
def register_hooks(self):
def register_first_layer_hook(module,grad_in,grad_out):
self.out_img = grad_in[0] #(b,c,h,w) -> (c,h,w)
def forward_hook_fn(module,input_feature,output_feature):
self.activation_maps.append(output_feature)
def backward_hook_fn(module,grad_in,grad_out):
grad = self.activation_maps.pop()
grad[grad > 0] = 1
g_positive = torch.clamp(grad_out[0],min = 0.)
result_grad = grad * g_positive
return (result_grad,)
modules = list(self.model.features.named_children())
for name,module in modules:
if isinstance(module,nn.ReLU):
module.register_forward_hook(forward_hook_fn)
module.register_backward_hook(backward_hook_fn)
first_layer = modules[0][1]
first_layer.register_backward_hook(register_first_layer_hook)
def visualize(self,input_image):
softmax = nn.Softmax(dim = 1)
idx_tensor = torch.tensor([float(i) for i in range(61)])
self.register_hooks()
self.model.zero_grad()
out = self.model(input_image) # [[b,n],[b,n],[b,n]]
yaw = softmax(out[0])
yaw = torch.sum(yaw * idx_tensor,dim = 1) * 3 - 90.
pitch = softmax(out[1])
pitch = torch.sum(pitch * idx_tensor,dim = 1) * 3 - 90.
roll = softmax(out[2])
roll = torch.sum(roll * idx_tensor,dim = 1) * 3 - 90.
#print(yaw)
out = yaw + pitch + roll
out.backward()
result = self.out_img.data[0].permute(1,2,0) # chw -> hwc(opencv)
return result.numpy()
def normalize(I):
norm = (I-I.mean())/I.std()
norm = norm * 0.1
norm = norm + 0.5
norm = norm.clip(0, 1)
return norm
if __name__ == "__main__":
input_size = 224
model = densenet121(pretrained = False,num_classes = 61)
model.load_state_dict(torch.load("./ckpt/DenseNet/model_2692_.pkl"))
img = Image.open("/media/xueaoru/其他/ML/head_pose_work/brick/head_and_heads/test/BIWI00009409_-17_+1_+17.png")
transform = transforms.Compose([
transforms.Resize(input_size),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
tensor = transform(img).unsqueeze(0).requires_grad_()
viz = Guided_Prop(model)
result = viz.visualize(tensor)
result = normalize(result)
plt.imshow(result)
plt.show()
由于是多任务问题,所以直接拿结果反传,对于一般的分类问题,可以给定target来用gt用one-hot反传。
head pose estimation 的梯度可视化。


[NN] Guided Backpropgation 可视化的更多相关文章
- 吴裕雄 python神经网络 水果图片识别(3)
import osimport kerasimport timeimport numpy as npimport tensorflow as tffrom random import shufflef ...
- TF之NN:matplotlib动态演示深度学习之tensorflow将神经网络系统自动学习并优化修正并且将输出结果可视化—Jason niu
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt def add_layer(inputs, in_ ...
- MySQL 慢查询日志分析及可视化结果
MySQL 慢查询日志分析及可视化结果 MySQL 慢查询日志分析 pt-query-digest分析慢查询日志 pt-query-digest --report slow.log 报告最近半个小时的 ...
- mininet之miniedit可视化操作
Mininet 2.2.0之后的版本内置了一个mininet可视化工具miniedit,使用Mininet可视化界面方便了用户自定义拓扑创建,为不熟悉python脚本的使用者创造了更简单的环境,界面直 ...
- python之gui-tkinter可视化编辑界面 自动生成代码
首先提供资源链接 http://pan.baidu.com/s/1kVLOrIn#list/path=%2F
- 学习TensorFlow,TensorBoard可视化网络结构和参数
在学习深度网络框架的过程中,我们发现一个问题,就是如何输出各层网络参数,用于更好地理解,调试和优化网络?针对这个问题,TensorFlow开发了一个特别有用的可视化工具包:TensorBoard,既可 ...
- 数据分析之---Python可视化工具
1. 数据分析基本流程 作为非专业的数据分析人员,在平时的工作中也会遇到一些任务:需要对大量进行分析,然后得出结果,解决问题. 所以了解基本的数据分析流程,数据分析手段对于提高工作效率还是非常有帮助的 ...
- AI - TensorFlow - 可视化工具TensorBoard
TensorBoard TensorFlow自带的可视化工具,能够以直观的流程图的方式,清楚展示出整个神经网络的结构和框架,便于理解模型和发现问题. 可视化学习:https://www.tensorf ...
- 【TensorFlow篇】--Tensorflow框架可视化之Tensorboard
一.前述 TensorBoard是tensorFlow中的可视化界面,可以清楚的看到数据的流向以及各种参数的变化,本文基于一个案例讲解TensorBoard的用法. 二.代码 设计一个MLP多层神经网 ...
随机推荐
- 082、数据收集利器 cAdvisor (2019-04-30 周二)
参考https://www.cnblogs.com/CloudMan6/p/7683190.html cAdvisor 是google 开发的容器监控工具,下面我们开始安装和体验 cAdvisor ...
- check cve
今天想检查一下 Gitlab 11.9.0 产品受哪些 cve 的影响.其实网上已经有很多网站可以查询产品的相关 cve,但就是粒度比较粗.我想在 cve 列表中筛选出特定的版本,已经特定的版本,比如 ...
- 雪花算法生成ID
前言我们的数据库在设计时一般有两个ID,自增的id为主键,还有一个业务ID使用UUID生成.自增id在需要分表的情况下做为业务主键不太理想,所以我们增加了uuid作为业务ID,有了业务id仍然还存在自 ...
- 使用CXF开发WebService程序的总结(三):创建webservice客户端
1.创建一个maven子工程 ws_client,继承父工程 1.1 修改父工程pom配置 <modules> <module>ws_server</module> ...
- Linux常用命令及Shell的简单介绍
一.linux命令 1.查看指令的参数搭配: man 指令名称 2.基础指令 ls 列出当前目录下的所有文档的名称(文档指的是文件和文件夹) 常用参数搭配: ls -l 列出文档详细信息 l ...
- Codeforces 990 调和级数路灯贪心暴力 DFS生成树两子树差调水 GCD树连通块暴力
A 水题 /*Huyyt*/ #include<bits/stdc++.h> #define mem(a,b) memset(a,b,sizeof(a)) using namespace ...
- Hash基础
BKDR Hash: 选取恰当的进制,可以把字符串中的字符看成一个大数字中的每一位数字,不过比较字符串和比较大数字的复杂度并没有什么区别 首先不要把任意字符对应到数字0,比如假如把a对应到数字0,那么 ...
- Python socket服务
套接字(socket)是一个抽象层,应用程序可以通过它发送或接收数据,可对其进行像对文件一样的打开.读写和关闭等操作. 1. 实现客户端发送字符,服务器返回大写的字符: 服务器: import soc ...
- SpringMVC 向前台页面传值-ModelAndView
ModelAndView 该对象中包含了一个model属性和一个view属性 model:其实是一个ModelMap类型.其实ModelMap是一个LinkedHashMap的子类 view:包含了一 ...
- iOS中延迟执行和取消的几种方式
公用延迟执行的方法: - (void)delayMethod { NSLog(@"delayMethodEnd"); } 方法一.performSelector 方法 1.延迟执行 ...