[Pytorch框架] 4.2.3 可视化理解卷积神经网络
文章目录
%load_ext autoreload
%autoreload 2
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from torchvision import models,datasets
import matplotlib.pyplot as plt
torch.__version__
'1.0.0'
4.2.3 可视化理解卷积神经网络
在上一节中我们已经通过一个预训练的VGG16模型对一张图片进行了分类,下面我们粘贴上一节的代码
cat_img=Image.open('./1280px-Felis_silvestris_catus_lying_on_rice_straw.jpg')
transform_224= transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop((224,224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
cat_img_224=transform_224(cat_img)
上面的代码是我们读取了一张图片,并对图片进行了一些预处理,下面我们来创建vgg16的预训练好网络模型
net = models.vgg16(pretrained=True)# 修改这里可以更换其他与训练的模型
inputs=cat_img_224[np.newaxis] #这两个方法都可以cat_img_224[None,::]
进行一次前向的传播,看看得到了什么结果
out = net(inputs)
_, preds = torch.max(out.data, 1)
preds
label=preds.numpy()[0]
label
287
我们看到了,这里返回的是285,代码几乎一样,但是返回的结果与上一节的样例有差别,这是什么原因呢?
首先我们先看一下这个数字的含义,我们使用的是通过imagenet来作为预训练的模型,imagenet里面有1000个分类,我们如何去找这个含义呢?
有好心人已经给我们准备好了 这个连接
我们找一下 285: ‘Egyptian cat’, 说明识别出了是一只猫,种类还是埃及猫,应该还是比较准确的,但是这张图片是我特意寻找的,里面包含了很多隐藏的细节,这里就不多介绍了,大家如果有兴趣,可以换一个模型,或者修改下transforms方法,看看模型都会识别出来是什么类别。
注:不同的预训练权重也会出现不同的结果,我测试出现过277,282,287等结果
下面我们开始进入正题,卷积神经网络的可视化
背景
CNN模型虽然在图像处理上表现出非常良好的性能和准确性,但一直以来都被认为是一个黑盒模型,人们无法了解里面的工作机制。
针对这个问题,研究人员除了从理论层面去寻找解释外,也提出了一些可视化的方法直观地理解CNN的内部机理,毕竟眼见为实,看到了大家就相信了。
这里介绍两类方法,一种是基于Deconvolution, 另一种则是基于反向传播的方法。我们主要使用代码实现基于反向传播的方法的可视化。
基于Deconvolution的方法
Visualizing and Understanding Convolutional Networks
主要是将激活函数的特征映射回像素空间,来揭示什么样的输入模式能够产生特定的输出,因为网络是有层级关系的,所以越靠近输出的层级学到的特征越抽象,与实际任务越相关,这里就不多介绍了,这里有一个使用 keras的实现,有兴趣的可以看看
基于Backpropagation的方法
另外一类的实现就是基于Backpropagation的方法,这里我们主要进行介绍,在介绍之前,我们首先要引用一下别人写的代码
pytorch-cnn-visualizations,将这个代码的src目录放到与这个notebook同级别目录下,我们后面会直接调用他的代码进行演示操作。
首先,我们做一些准备工作
import sys
sys.path.insert(0, './src/')
def rgb2gray(rgb):
return np.dot(rgb[...,:3], [0.299, 0.587, 0.114])
def rescale_grads(map,gradtype="all"):
if(gradtype=="pos"):
map = (np.maximum(0, map) / map.max())
elif gradtype=="neg":
map = (np.maximum(0, -map) / -map.min())
else:
map = map - map.min()
map /= map.max()
return map
Guided-Backpropagation
这个方法来自于ICLR-2015 的文章《Striving for Simplicity: The All Convolutional Net》,文中提出了使用stride convolution 替代pooling 操作,这样整个结构都只有卷积操作。作者为了研究这种结构的有效性,提出了guided-backpropagation的方法。
大致的方法为:选择某一种输出模式,然后通过反向传播计算输出对输入的梯度。这种方式与上一种deconvnet的方式的唯一区别在于对ReLU梯度的处理。
ReLU在反向传播的计算采用的前向传播的特征作为门阀,而deconvnet采用的是梯度值,guided-backpropagation则将两者组合在一起使用,这样有助于得到的重构都是正数。
这段话可能有点绕,具体细节还是看论文吧,我们这里只关注如何实现
inputs.requires_grad=True # 这句话必须要有,否则会报错
from guided_backprop import GuidedBackprop #这里直接引用写好的方法,在src,目录找想对应的文件
GB=GuidedBackprop(net)
gp_grads=GB.generate_gradients(inputs, label)
gp_grads=np.moveaxis(gp_grads,0,-1)
#我们分别计算三类的gp
ag=rescale_grads(gp_grads,gradtype="all")
pg=rescale_grads(gp_grads,gradtype="pos")
ng=rescale_grads(gp_grads,gradtype="neg")
下面我们使用matplotlib看看结果
plt.imshow(cat_img)
<matplotlib.image.AxesImage at 0x23d840392e8>

plt.imshow(ag)
<matplotlib.image.AxesImage at 0x23d8441c7f0>

plt.imshow(ng)
<matplotlib.image.AxesImage at 0x23d84487080>

plt.imshow(ag)
<matplotlib.image.AxesImage at 0x23d854b44e0>

上面三张图是rbg三个通道的展示结果,下面我们合并成一个通道再看一下
gag=rgb2gray(ag)
plt.imshow(gag)
<matplotlib.image.AxesImage at 0x23d8550fe80>

gpg=rgb2gray(pg)
plt.imshow(gpg)
<matplotlib.image.AxesImage at 0x23d85576710>

gng=rgb2gray(ng)
plt.imshow(gng)
<matplotlib.image.AxesImage at 0x23d855d4fd0>

CAM(Class Activation Map)
这个方法严格来说不是基于梯度的,但是后面我们会将反向传播与CAM整合,所以简单的对CAM做个说明。
CAM 来自CVPR 2016 《Learning Deep Features for Discriminative Localization》,作者在研究global average pooling(GAP)时,发现GAP不止作为一种正则,减轻过拟合,在稍加改进后,可以使得CNN具有定位的能力,CAM(class activation map)是指输入中的什么区域能够指示CNN进行正确的识别。
通常特征图上每个位置的值在存在其感知野里面某种模式时被激活,最后的class activation map是这些模式的线性组合,我们可以通过上采样,将class activation map 还原到与原图一样的大小,通过叠加,我们就可以知道哪些区域是与最后分类结果息息相关的部分。
这里就不介绍了
Grad-CAM
Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization
顾名思义 Grad-CAM的加权系数是通过反向传播得到的,而CAM的特征加权系数是分类器的权值。
Grad-CAM 与 CAM相比,它的优点是适用的范围更广,Grad-CAM对各类结构,各种任务都可以使用。这两种方法也可以应用于进行弱监督下的目标检测,后续也有相关工作基于它们进行改进来做弱监督目标检测。
import math
from gradcam import GradCam
from guided_gradcam import guided_grad_cam
from guided_backprop import GuidedBackprop
nlayers=len(net.features._modules.items())-1
print(nlayers) # 打印一下一共有多少层
cam_list=[]
#下面我们循环每一层
for layer in range(nlayers):
#GradCam
grad_cam = GradCam(net,target_layer=layer)
cam = grad_cam.generate_cam(inputs, label)
#GuidedBackprop
GBP = GuidedBackprop(net)
guided_grads = GBP.generate_gradients(inputs, label)
# Guided Grad cam
cam_gb = guided_grad_cam(cam, guided_grads)
cam_list.append(rgb2gray(np.moveaxis(cam_gb,0,-1)))
30
我们选个图,看看效果
plt.imshow(cam_list[0])
<matplotlib.image.AxesImage at 0x23d858b7588>

在 Visualizing and Understanding Convolutional Networks 中作者还给出了其他不同的方法,这里就不详细说明了
需要注意的是,在使用 Visualizing and Understanding Convolutional Networks的时候,对网络模型是有要求的,要求网络将模型包含名为features的组合层,这部分是代码中写死的,所以在pytorch的内置模型中,vgg、alexnet、densenet、squeezenet是可以直接使用的,inception(googlenet)和resnet没有名为features的组合层,如果要使用的话是需要对代码进行修改的。
[Pytorch框架] 4.2.3 可视化理解卷积神经网络的更多相关文章
- 用反卷积(Deconvnet)可视化理解卷积神经网络还有使用tensorboard
『cs231n』卷积神经网络的可视化与进一步理解 深度学习小白——卷积神经网络可视化(二) TensorBoard--TensorFlow可视化 原文地址:http://blog.csdn.net/h ...
- PyTorch框架+Python 3面向对象编程学习笔记
一.CNN情感分类中的面向对象部分 sparse.py super(Embedding, self).__init__() 表示需要父类初始化,即要运行父类的_init_(),如果没有这个,则要自定义 ...
- 手写数字识别 卷积神经网络 Pytorch框架实现
MNIST 手写数字识别 卷积神经网络 Pytorch框架 谨此纪念刚入门的我在卷积神经网络上面的摸爬滚打 说明 下面代码是使用pytorch来实现的LeNet,可以正常运行测试,自己添加了一些注释, ...
- 小白学习之pytorch框架(1)-torch.nn.Module+squeeze(unsqueeze)
我学习pytorch框架不是从框架开始,从代码中看不懂的pytorch代码开始的 可能由于是小白的原因,个人不喜欢一些一下子粘贴老多行代码的博主或者一些弄了一堆概念,导致我更迷惑还增加了畏惧的情绪(个 ...
- 全面解析Pytorch框架下模型存储,加载以及冻结
最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...
- 详解卷积神经网络(CNN)
详解卷积神经网络(CNN) 详解卷积神经网络CNN 概揽 Layers used to build ConvNets 卷积层Convolutional layer 池化层Pooling Layer 全 ...
- 小白学习之pytorch框架(4)-softmax回归(torch.gather()、torch.argmax()、torch.nn.CrossEntropyLoss())
学习pytorch路程之动手学深度学习-3.4-3.7 置信度.置信区间参考:https://cloud.tencent.com/developer/news/452418 本人感觉还是挺好理解的 交 ...
- 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())
在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...
- 深度学习框架PyTorch一书的学习-第四章-神经网络工具箱nn
参考https://github.com/chenyuntc/pytorch-book/tree/v1.0 希望大家直接到上面的网址去查看代码,下面是本人的笔记 本章介绍的nn模块是构建与autogr ...
- 【chainer框架】【pytorch框架】
教程: https://bennix.github.io/ https://bennix.github.io/blog/2017/12/14/chain_basic/ https://bennix.g ...
随机推荐
- 什么是序列化?实体类为什么要实现序列化接口?实体类为什么要指定SerialversionUID?
首先我们说答案:实体类对象在保存在内存中的,而对于web应用程序而言,很多客户端会对服务器后台提交数据请求,如得到某种类型的商品,此时后台程序会从数据库中读取符合条件的记录,并它们打包成对象的集合,再 ...
- 2020/03/25 CSS相关知识点
2020-03-25 16:35:03 又是一个风和日丽的下午!今天的内容比较多 真是令人头大 ,手速又慢所以缺的可能比较多,而且这东西还是多靠实践为好. 文件下载地址: https://share. ...
- ARM-linux的Windows交叉编译环境搭建
交叉编译Arm Linux平台的QT5库 1.准备交叉编译环境 环境说明:Windows10 64位 此过程需要: (1)Qt库开源代码,我使用的是5.13.0版本: (2)Perl语言环境5.12版 ...
- 初次使用Sqoop报错,sqoop命令不能正常使用:hcatalog does not exist!accumulo does not exist!
1.问题描述: (1)问题示例: [hadoop@master Tmp]$ sqoop helpWarning: /home/grid/Sqoop/sqoop-1.4.7/../hcatalog d ...
- C++ condition_variable
一.使用场景 在主线程中创建一个子线程去计数,计数累计100次后认为成功,并告诉主线程:主线程收到计数100次完成的信息后继续往下执行 二.条件变量的成员函数 wait:当前线程调用 wait() 后 ...
- JMeter压测脚本实例:单接口
新建测试计划 添加线程组 添加HTTP请求 配置该请求相关参数 1.请求头部信息 ①HTTP请求同级线程组下添加HTTP信息头部管理器 ②填充该请求所需的头部信息 2.请求体 选中之前增加的HTTP请 ...
- 集训第二周计划:把cf近期的div2除了最后一题给切完
太菜了太菜了,弄个训练计划. 晚上没事干的时候我想把博客园皮肤改一下,搜着搜着不知道怎么回事点进去一些竞赛选手的博客,比如这个 https://www.cnblogs.com/soda-ma/p/13 ...
- Windows 系统下怎么获取 UDP 本机地址
Windows 系统下怎么获取 UDP 本机地址 我们知道 UDP 获取远端地址非常简单,通常接口 recvfrom 就可以直接获取到远端的地址和端口:如果获取 UDP 的本机地址就需要点特殊处理了, ...
- MySQL 中索引是如何实现的,有哪些类型的索引,如何进行优化索引
MySQL 中的索引 前言 索引的实现 哈希索引 全文索引 B+ 树索引 索引的分类 聚簇索引(clustered index) 非聚簇索引(non-clustered index) 联合索引 覆盖索 ...
- 全网最详细中英文ChatGPT-GPT-4示例文档-事实性回答应用从0到1快速入门——官网推荐的48种最佳应用场景(附python/node.js/curl命令源代码,小白也能学)
目录 Introduce 简介 setting 设置 Prompt 提示 Sample response 回复样本 API request 接口请求 python接口请求示例 node.js接口请求示 ...