[NN] Guided Backpropgation 可视化
Pytorch Guided Backpropgation
Intro
guided backpropgation通过修改RELU的梯度反传,使得小于0的部分不反传,只传播大于0的部分,这样到第一个conv层的时候得到的梯度就是对后面relu激活起作用的梯度,这时候我们对这些梯度进行可视化,得到的就是对网络起作用的区域。(实际上可视化的是梯度)。
简单记一下。用到hook的神经网络可视化方法。
code
import torch
import torch.nn as nn
from torchvision import transforms,models
import re
from models.densenet import densenet121
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
class Guided_Prop():
def __init__(self,model):
self.model = model
self.model.eval()
self.out_img = None
self.activation_maps = []
def register_hooks(self):
def register_first_layer_hook(module,grad_in,grad_out):
self.out_img = grad_in[0] #(b,c,h,w) -> (c,h,w)
def forward_hook_fn(module,input_feature,output_feature):
self.activation_maps.append(output_feature)
def backward_hook_fn(module,grad_in,grad_out):
grad = self.activation_maps.pop()
grad[grad > 0] = 1
g_positive = torch.clamp(grad_out[0],min = 0.)
result_grad = grad * g_positive
return (result_grad,)
modules = list(self.model.features.named_children())
for name,module in modules:
if isinstance(module,nn.ReLU):
module.register_forward_hook(forward_hook_fn)
module.register_backward_hook(backward_hook_fn)
first_layer = modules[0][1]
first_layer.register_backward_hook(register_first_layer_hook)
def visualize(self,input_image):
softmax = nn.Softmax(dim = 1)
idx_tensor = torch.tensor([float(i) for i in range(61)])
self.register_hooks()
self.model.zero_grad()
out = self.model(input_image) # [[b,n],[b,n],[b,n]]
yaw = softmax(out[0])
yaw = torch.sum(yaw * idx_tensor,dim = 1) * 3 - 90.
pitch = softmax(out[1])
pitch = torch.sum(pitch * idx_tensor,dim = 1) * 3 - 90.
roll = softmax(out[2])
roll = torch.sum(roll * idx_tensor,dim = 1) * 3 - 90.
#print(yaw)
out = yaw + pitch + roll
out.backward()
result = self.out_img.data[0].permute(1,2,0) # chw -> hwc(opencv)
return result.numpy()
def normalize(I):
norm = (I-I.mean())/I.std()
norm = norm * 0.1
norm = norm + 0.5
norm = norm.clip(0, 1)
return norm
if __name__ == "__main__":
input_size = 224
model = densenet121(pretrained = False,num_classes = 61)
model.load_state_dict(torch.load("./ckpt/DenseNet/model_2692_.pkl"))
img = Image.open("/media/xueaoru/其他/ML/head_pose_work/brick/head_and_heads/test/BIWI00009409_-17_+1_+17.png")
transform = transforms.Compose([
transforms.Resize(input_size),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
tensor = transform(img).unsqueeze(0).requires_grad_()
viz = Guided_Prop(model)
result = viz.visualize(tensor)
result = normalize(result)
plt.imshow(result)
plt.show()
由于是多任务问题,所以直接拿结果反传,对于一般的分类问题,可以给定target来用gt用one-hot反传。
head pose estimation 的梯度可视化。
[NN] Guided Backpropgation 可视化的更多相关文章
- 吴裕雄 python神经网络 水果图片识别(3)
import osimport kerasimport timeimport numpy as npimport tensorflow as tffrom random import shufflef ...
- TF之NN:matplotlib动态演示深度学习之tensorflow将神经网络系统自动学习并优化修正并且将输出结果可视化—Jason niu
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt def add_layer(inputs, in_ ...
- MySQL 慢查询日志分析及可视化结果
MySQL 慢查询日志分析及可视化结果 MySQL 慢查询日志分析 pt-query-digest分析慢查询日志 pt-query-digest --report slow.log 报告最近半个小时的 ...
- mininet之miniedit可视化操作
Mininet 2.2.0之后的版本内置了一个mininet可视化工具miniedit,使用Mininet可视化界面方便了用户自定义拓扑创建,为不熟悉python脚本的使用者创造了更简单的环境,界面直 ...
- python之gui-tkinter可视化编辑界面 自动生成代码
首先提供资源链接 http://pan.baidu.com/s/1kVLOrIn#list/path=%2F
- 学习TensorFlow,TensorBoard可视化网络结构和参数
在学习深度网络框架的过程中,我们发现一个问题,就是如何输出各层网络参数,用于更好地理解,调试和优化网络?针对这个问题,TensorFlow开发了一个特别有用的可视化工具包:TensorBoard,既可 ...
- 数据分析之---Python可视化工具
1. 数据分析基本流程 作为非专业的数据分析人员,在平时的工作中也会遇到一些任务:需要对大量进行分析,然后得出结果,解决问题. 所以了解基本的数据分析流程,数据分析手段对于提高工作效率还是非常有帮助的 ...
- AI - TensorFlow - 可视化工具TensorBoard
TensorBoard TensorFlow自带的可视化工具,能够以直观的流程图的方式,清楚展示出整个神经网络的结构和框架,便于理解模型和发现问题. 可视化学习:https://www.tensorf ...
- 【TensorFlow篇】--Tensorflow框架可视化之Tensorboard
一.前述 TensorBoard是tensorFlow中的可视化界面,可以清楚的看到数据的流向以及各种参数的变化,本文基于一个案例讲解TensorBoard的用法. 二.代码 设计一个MLP多层神经网 ...
随机推荐
- mybatis一对多关联关系映射
mybatis一对多关联关系映射 一对多关联关系只需要在多的一方引入少的一方的主键作为外键即可.在实体类中就是反过来,在少的一方添加多的一方,声明一个List 属性名 作为少的一方的属性. 用户和订单 ...
- Vue2 & ElementUI实现管理后台之input获得焦点
Vue.directive('focus', function (el, option) { var defClass = 'el-input', defTag = 'input'; var valu ...
- Vue+ElementUI学习总结(转载)
Vue框架简介 Vue是一套构建用户界面的框架, 开发只需要关注视图层, 它不仅易于上手,还便于与第三方库或既有项目的整合.是基于MVVM(Model-View-ViewModel)设计思想.提供MV ...
- php通过反射方法调用私有方法
PHP 5 具有完整的反射 API,添加了对类.接口.函数.方法和扩展进行反向工程的能力. 下面我们演示一下如何通过反射,来调用执行一个类中的私有方法: <?php //MyClass这个类中包 ...
- INSERT - 在表中创建新行
SYNOPSIS INSERT INTO table [ ( column [, ...] ) ] { DEFAULT VALUES | VALUES ( { expression | DEFAULT ...
- juniper 命令
show chassis hardware 查看系统硬件配置,fpc表示板卡,pic表示板卡中的槽位,xcvr表示板卡中的槽位的端口位置 show chassis envirmonent 查看系统运行 ...
- Codeforces 902 树同型构造 多项式长除法构造(辗转相除法)
A #include <bits/stdc++.h> #define PI acos(-1.0) #define mem(a,b) memset((a),b,sizeof(a)) #def ...
- Linux中关闭SSH的DNS解析
在操作中,我们都会用SSH协议来远程控制虚拟机,但是在输入用户名时候,会有一段时间的卡顿,此时正在进行SSH协议的DNS解析,我们为了快速的连接到虚拟机上,就要关闭这个解析过程,如下是具体配置: 1. ...
- Tronado【第2篇】:tronado自定义Form组件
Tronado自定义Form组件 一.获取类里面的静态属性以及动态属性的方法 方式一: # ===========方式一================ class Foo(object): user ...
- javascript的基础知识点
一:鼠标提示框 需求描述:鼠标移入都input上,div显示,移出,div消失 分析:控制display=block/none 鼠标移入,鼠标移出事件 <input type="bu ...