在CNN上增加一层CAM告诉你CNN到底关注什么
Cam(Class Activation Mapping)是一个很有意思的算法,他能够将神经网络到底在关注什么可视化的表现出来。但同时它的实现却又如此简介,相比NIN,googLenet这些使用GAP(Global Average Pooling)用来代替全连接层,他却将其输出的权重和featuremap相乘,累加,将其用图像表示出来。
其网络架构如下

Class Activation Mapping具体论文
当然Cam的目的并不仅仅是将其表示出来,神经网络所关注的地方,通常就是物体所在的地方,因此它可以辅助训练检测网络。
因此就有了PlacesNet。
网络上基本都是基于AlexNet等网络,其实任何网络,只要加一层全局池化层就可以帮助我们将CNN关注什么表示出来,因此我对Tensorflow官方Mnist的CNN网络进行少量的修改,实现了CAM。只是将最后的全连接层,改为了全局池化层。
CAM的核心公式很简单
\]
将全局平均池化层输出的权重乘上feature map累加
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
/home/lyn/anaconda3/lib/python3.6/importlib/_bootstrap.py:205: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6
return f(*args, **kwds)
Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
在读入数据后,设定基本的学习参数
# Training Parameters
learning_rate = 0.001
num_steps = 10000
batch_size = 128
display_step = 10
# Network Parameters
num_input = 784 # MNIST data input (img shape: 28*28)
num_classes = 10 # MNIST total classes (0-9 digits)
# tf Graph input
X = tf.placeholder(tf.float32, [None, num_input])
Y = tf.placeholder(tf.int32, [None, num_classes])
# Create some wrappers for simplicity
def conv2d(x, W, b, strides=1):
# Conv2D wrapper, with bias and relu activation
x = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], padding='SAME')
x = tf.nn.bias_add(x, b)
return tf.nn.relu(x)
def maxpool2d(x, k=2):
# MaxPool2D wrapper
return tf.nn.max_pool(x, ksize=[1, k, k, 1], strides=[1, k, k, 1],
padding='SAME')
在实际使用中我们需要获得得feature map,与全局池化层相乘并累加
def conv_layers(x,weights,biases):
x = tf.reshape(x, shape=[-1, 28, 28, 1])
# Convolution Layer
conv1 = conv2d(x, weights['wc1'], biases['bc1'])
# Max Pooling (down-sampling)
conv1 = maxpool2d(conv1, k=2)
# Convolution Layer
conv2 = conv2d(conv1, weights['wc2'], biases['bc2'])
# Max Pooling (down-sampling)
conv2 = maxpool2d(conv2, k=2)
return conv2
def out_layer(conv2,weights,biases):
gap = tf.nn.avg_pool(conv2,ksize=[1,7,7,1],strides=[1,7,7,1],padding="SAME")
gap = tf.reshape(gap,[-1,128])
out = tf.add(tf.matmul(gap, weights['out']), biases['out'])
return out
def generate_heatmap(conv2,label,weights):
conv2_resized = tf.image.resize_images(conv2,[28,28])
label_w = tf.gather(tf.transpose(weights['out']),label)
label_w = tf.reshape(label_w,[-1,128,1])
conv2_resized = tf.reshape(conv2_resized,[-1,28*28,128])
classmap = tf.matmul( conv2_resized, label_w )
classmap = tf.reshape( classmap, [-1, 28,28] )
return classmap
# Store layers weight & bias
weights = {
# 5x5 conv, 1 input, 32 outputs
'wc1': tf.Variable(tf.random_normal([5, 5, 1, 32])),
# 5x5 conv, 32 inputs, 64 outputs
'wc2': tf.Variable(tf.random_normal([5, 5, 32, 128])),
'out': tf.Variable(tf.random_normal([128, num_classes]))
}
biases = {
'bc1': tf.Variable(tf.random_normal([32])),
'bc2': tf.Variable(tf.random_normal([128])),
'bd1': tf.Variable(tf.random_normal([1024])),
'out': tf.Variable(tf.random_normal([num_classes]))
}
# Construct model
conv2 = conv_layers(X, weights, biases)
logits = out_layer(conv2,weights, biases)
prediction = tf.nn.softmax(logits)
classmap = generate_heatmap(conv2,tf.argmax(prediction,1),weights)
loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=Y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(loss_op)
# Evaluate model
correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
# Initialize the variables (i.e. assign their default value)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
for step in range(1, num_steps+1):
batch_x, batch_y = mnist.train.next_batch(batch_size)
# Run optimization op (backprop)
sess.run(train_op, feed_dict={X: batch_x, Y: batch_y})
if step % display_step == 0 or step == 1:
# Calculate batch loss and accuracy
loss, acc = sess.run([loss_op, accuracy], feed_dict={X: batch_x,
Y: batch_y})
print("Step " + str(step) + ", Minibatch Loss= " + \
"{:.4f}".format(loss) + ", Training Accuracy= " + \
"{:.3f}".format(acc))
print("Optimization Finished!")
# Calculate accuracy for 256 MNIST test images
print("Testing Accuracy:", \
sess.run(accuracy, feed_dict={X: mnist.test.images[:256],
Y: mnist.test.labels[:256],}
))
sess.run(conv2,feed_dict={X:mnist.test.images[:10],Y:mnist.test.labels[:10]})
classmaps = sess.run(classmap,feed_dict={X:mnist.test.images[:10],Y:mnist.test.labels[:10]})
Testing Accuracy: 0.976562
for i in range(9):
plt.subplot(33*10+i+1)
plt.axis("off")
plt.imshow(classmaps[i],cmap="gray")
plt.title("label is"+str(np.argmax(mnist.test.labels[i])))

黑色为正权重点,白色为负权重点。
当然由于网络太浅,使用了全局平均池化以后,训练时间大大增长,准确率也不如之前。
在CNN上增加一层CAM告诉你CNN到底关注什么的更多相关文章
- Ext JS treegrid 发生的在tree上增加itemclick 与在其它列上增加actioncolumn 发生事件冲突(event conflict)的解决办法
Ext JS treegrid 发生的在tree上增加itemclick 与在其它列上增加actioncolumn 发生事件冲突(event conflict)的解决办法 最近在适用Ext JS4开发 ...
- AIX上增加逻辑卷时报错误0516-787 extendlv: Maximum allocation for logical volume
AIX上增加逻辑卷时报错误0516-787 extendlv: Maximum allocation for logical volume jdelv02 is 512. 在往aix使用chfs -a ...
- 跟我一起学extjs5(05--主界面上增加顶部和底部区域)
跟我一起学extjs5(05--主界面上增加顶部和底部区域) 这一节为主界面加一个顶部区域和底部区域. 一个管理系统的界面能够粗分为顶部标题部分.中间数据展示和处理的部分.底部备注和状 ...
- SQL某个字段在原内容上增加固定内容或replace查找替换内容
今天正好遇到一个SQL小问题,特做备注 在原有的表中数据如pic 在不动原内容的基础上增加../路径,但不能修改原数据值 原数据 SQL: pic字段 需要增加'../'的内容 update Bmps ...
- 在iOS上增加手势锁屏、解锁功能
在iOS上增加手势锁屏.解锁功能 在一些涉及个人隐私的场景下,尤其是当移动设备包含太多私密信息时,为用户的安全考虑是有必要的. 桌面版的QQ在很多年前就考虑到用户离开电脑后隐私泄露的危险,提供了“离开 ...
- sharepoint 2010 在自定义列表的字段上增加功能菜单
sharepoint 2010 在自定义列表的字段上增加功能菜单方法 打开sharepoint designer 2010,找到需要修改的视图页面,例如allitem.aspx,编辑这个页面,点击高级 ...
- 在Linux服务器上增加硬盘没那么简单【转】
运维案例:HP服务器,LINUX系统在保障数据的前提下扩展/home分区 部门需求:研发部门提出需要在现有的服务器上扩容磁盘空间,以满足开发环境的磁盘需求.现有空间1.6T需要增加到2T. 需求调查分 ...
- ASP.NET Web API实践系列06, 在ASP.NET MVC 4 基础上增加使用ASP.NET WEB API
本篇尝试在现有的ASP.NET MVC 4 项目上增加使用ASP.NET Web API. 新建项目,选择"ASP.NET MVC 4 Web应用程序". 选择"基本&q ...
- Mirror--如何在主库上增加文件
由于各种原因,如磁盘不空不足,需要对主库增加数据库文件到其他磁盘上,而镜像服务器上没有对应盘符,很多人会选择删除镜像,重新完备还原来搭建镜像,这种方式耗时耗力. 在做此类操作时,需要对主服务器和镜像服 ...
随机推荐
- 【leetcode】813. Largest Sum of Averages
题目如下: 解题思路:求最值的题目优先考虑是否可以用动态规划.记dp[i][j]表示在数组A的第j个元素后面加上第i+1 (i从0开始计数)个分隔符后可以得到的最大平均值,那么可以得到递归关系式: d ...
- XML 语法
XML 语法规则 本节的目的是想让你了解 XML 中的语法所依据的规则,避免在编写 XML 文档的时候遇到错误. XML 的语法规则很简单,且很有逻辑.这些规则很容易学习,也很容易使用. 所有的 XM ...
- 调整ceph的pg数(pg_num, pgp_num)
https://www.jianshu.com/p/ae96ee24ef6c 调整ceph的pg数 PG全称是placement groups,它是ceph的逻辑存储单元.在数据存储到cesh时,先打 ...
- [14th CSMO Day 1 <平面几何>]
关于LowBee苦思冥想的结果(仅供参考):
- [USACO17FEB]Why Did the Cow Cross the Road III G (树状数组,排序)
题目链接 Solution 二维偏序问题. 现将所有点按照左端点排序,如此以来从左至右便满足了 \(a_i<a_j\) . 接下来对于任意一个点 \(j\) ,其之前的所有节点都满足 \(a_i ...
- iOS7上leftBarButtonItem无法实现滑动返回的完美解决方案
今天遇到了在iOS7上使用leftBarButtonItem却无法响应滑动返回事件的问题,一番谷歌,最后终于解决了,在这里把解决方案分享给大家. 在iOS7之前的系统,如果要自定义返回按钮,直接设置b ...
- 前端每日实战:124# 视频演示如何用纯 CSS 创作一只纸鹤
效果预览 按下右侧的"点击预览"按钮可以在当前页面预览,点击链接可以全屏预览. https://codepen.io/comehope/pen/xagoYb 可交互视频 此视频是可 ...
- 洛谷P1441 砝码称重(搜索,dfs+dp)
洛谷P1441 砝码称重 \(n\) 的范围为 \(n \le 20\) ,\(m\) 的范围为 \(m \le 4\) . 暴力遍历每一种砝码去除情况,共有 \(n^m\) 种情况. 对于剩余砝码求 ...
- leetcode 215. 数组中的第K个最大元素(python)
在未排序的数组中找到第 k 个最大的元素.请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素. 示例 1: 输入: [3,2,1,5,6,4] 和 k = 2输出: 5示 ...
- 007-TreeMap、Map和Bean互转、BeanUtils.copyProperties(A,B)拷贝、URL编码解码、字符串补齐,随机字母数字串
一.转换 1.1.TreeMap 有序Map 无序有序转换 使用默认构造方法: public TreeMap(Map<? extends K, ? extends V> m) 1.2.Ma ...