[论文理解] Spatial Transformer Networks
Spatial Transformer Networks
简介
本文提出了能够学习feature仿射变换的一种结构,并且该结构不需要给其他额外的监督信息,网络自己就能学习到对预测结果有用的仿射变换。因为CNN的平移不变性等空间特征一定程度上被pooling等操作破坏了,所以,想要网络能够应对平移的object或者其他仿射变换后的object有更好的表示,就需要设计一种结构来学习这种变换,使得作用了这种变换后的feature能够能好的表示任务。
网络结构

上图中U表示输入feature map,通过spatial transformer 分支学习到transform,然后通过差值或其他sampler映射到输出feature,这样输出的feature会有一种更加健壮的表示。
spatial transform的结构由三个部分组成,下面会详细介绍。
仿射变换
仿射变换分为平移、缩放、翻转、旋转和裁剪这几种变换,其中二维的变换可以用矩阵来表示:
x' \\
y'
\end{matrix}\right) =
\left[\begin{matrix}
\theta_1 & \theta_2 & \theta_3 \\
\theta_4 & \theta_5 & \theta_6 \\
\end{matrix}\right]
\left(\begin{matrix}
x \\
y\\
1
\end{matrix}\right)
\]
其中theta对应取不同的值会对应不同的变换。所以网络同学学习到这种变换,帮助feature得到一种更加有效的表示。
Localisation Network
该部分对应与上图中的localisation net部分,目的是为了学习到上面公式中的theta参数,也就是说,这一部分的结构可以直接全连接6个theta或者使用conv结构,只要能映射到6个theta就可以了。这一部分比较简单。
Parameterised Sampling Grid
这一部分对应于上图的Grid Generator部分,这一部分的作用是建立输入图像位置到输出图像位置的映射,也就是对应于我们上面提到的仿射变换,我们在这一结构下可以通过上面学习到的参数theta来通过矩阵形式对输入进行放放射变换,注意变换的时候每个channel的变换应该是一致的。公式表示为:

我们可以通过限定theta的取值来限定网络只学习某种变换,也就是只学习一部分theta参数。
Differentiable Image Sampling
上面放射变换只是定义了变换前到变换后的位置映射,这个映射其实并不完整,这就意味着有些点是没有值的,如果要给值,就要使用插值的方法了。论文中提到了最邻近插值和双线性插值两种插值方法。
对于最邻近插值给出了这样的定义:

这样对于输出feature的第i个值,其对应的输入feature的位置取决于m和n,由krnoecker delta函数定义知,当且仅当自变量为0时输出为1.所以上式只有在m取得x方向上距离对应点最近的整数点以及n取得y方向上距离最近的整数点时有值,其值就为对应两个方向都最近的点的值。
对于双线性插值给出了这样的定义:

由上式可以知道,只有当m和n取值为对应点xy方向上距离为1以内的整数时才有值,而距离对应点最近的整数点是有四个的,比如(0.5,0.5)距离其最近的四个点分别为(0,0),(0,1),(1,1),(1,0),后面两个取值就成了距离权重,前面U取值为四个点之一的整数点的值,所以这个式子可以解释为以距离作为权重,取最近的四个点的值的加权求和。
反向传播
定义了上面的对应函数,作者证明了输出到输入是可以进行反向传播的,以双线性插值为例:

import torch
import torch.nn as nn
from torchvision.models import vgg16
import torch.nn.functional as F
from torchsummary import summary
class STN(nn.Module):
def __init__(self):
super(STN,self).__init__()
self.feature_extractor = vgg16(pretrained = False).features
self.conv = nn.Conv2d(512,256,7)
self.fc = nn.Sequential(
nn.Linear(256,512),
nn.ReLU(),
nn.Linear(512,6)
)
def forward(self,x):
features = self.feature_extractor(x) # (b,c,h,w) h = w = 7 c = 512
theta = self.conv(features).view(-1,256) # b,256
theta = self.fc(theta).view(-1,2,3) # b,2
transformed = F.affine_grid(theta,x.size()) # theta (n,2,3) size (n,c,h,w) ,这一步是得到仿射变换的映射
x = F.grid_sample(x,transformed) # 这一步就是根据映射关系,去做插值,得到变换后的图像
return x
if __name__ == "__main__":
net = STN()
summary(net,(3,224,224),device = "cpu")
[Running] python -u "/media/xueaoru/DATA/ubuntu/six/STN.py"
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 224, 224] 1,792
ReLU-2 [-1, 64, 224, 224] 0
Conv2d-3 [-1, 64, 224, 224] 36,928
ReLU-4 [-1, 64, 224, 224] 0
MaxPool2d-5 [-1, 64, 112, 112] 0
Conv2d-6 [-1, 128, 112, 112] 73,856
ReLU-7 [-1, 128, 112, 112] 0
Conv2d-8 [-1, 128, 112, 112] 147,584
ReLU-9 [-1, 128, 112, 112] 0
MaxPool2d-10 [-1, 128, 56, 56] 0
Conv2d-11 [-1, 256, 56, 56] 295,168
ReLU-12 [-1, 256, 56, 56] 0
Conv2d-13 [-1, 256, 56, 56] 590,080
ReLU-14 [-1, 256, 56, 56] 0
Conv2d-15 [-1, 256, 56, 56] 590,080
ReLU-16 [-1, 256, 56, 56] 0
MaxPool2d-17 [-1, 256, 28, 28] 0
Conv2d-18 [-1, 512, 28, 28] 1,180,160
ReLU-19 [-1, 512, 28, 28] 0
Conv2d-20 [-1, 512, 28, 28] 2,359,808
ReLU-21 [-1, 512, 28, 28] 0
Conv2d-22 [-1, 512, 28, 28] 2,359,808
ReLU-23 [-1, 512, 28, 28] 0
MaxPool2d-24 [-1, 512, 14, 14] 0
Conv2d-25 [-1, 512, 14, 14] 2,359,808
ReLU-26 [-1, 512, 14, 14] 0
Conv2d-27 [-1, 512, 14, 14] 2,359,808
ReLU-28 [-1, 512, 14, 14] 0
Conv2d-29 [-1, 512, 14, 14] 2,359,808
ReLU-30 [-1, 512, 14, 14] 0
MaxPool2d-31 [-1, 512, 7, 7] 0
Conv2d-32 [-1, 256, 1, 1] 6,422,784
Linear-33 [-1, 512] 131,584
ReLU-34 [-1, 512] 0
Linear-35 [-1, 6] 3,078
================================================================
Total params: 21,272,134
Trainable params: 21,272,134
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 218.40
Params size (MB): 81.15
Estimated Total Size (MB): 300.13
----------------------------------------------------------------
[Done] exited with code=0 in 2.511 seconds
论文原文:https://arxiv.org/pdf/1506.02025.pdf
[论文理解] Spatial Transformer Networks的更多相关文章
- (转载)理解Spatial Transformer Networks
理解Spatial Transformer Networks 转载于:知乎-SIGAI 书的购买链接 书的勘误,优化,源代码资源 获取全文PDF请查看:理解Spatial Transformer Ne ...
- spatial transformer networks 这篇论文
大致看了看这个paper, 很novel. 我的观点: 在traditional convolutional neural netwoks 中,我们通常会depend 于 extracting fea ...
- 论文笔记:空间变换网络(Spatial Transformer Networks)
2015, NIPS Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu Google DeepMind 为什么提出( ...
- Spatial Transformer Networks(空间变换神经网络)
Reference:Spatial Transformer Networks [Google.DeepMind]Reference:[Theano源码,基于Lasagne] 闲扯:大数据不如小数据 这 ...
- 深度学习方法(十二):卷积神经网络结构变化——Spatial Transformer Networks
欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.机器学习技术感兴趣的同学加入. 今天具体介绍一个Google ...
- STN(Spatial Transformer Networks)
url: https://arxiv.org/abs/1506.02025 year:2015 blog: https://kevinzakka.github.io/2017/01/10/stn-pa ...
- [论文理解]Region-Based Convolutional Networks for Accurate Object Detection and Segmentation
Region-Based Convolutional Networks for Accurate Object Detection and Segmentation 概括 这是一篇2016年的目标检测 ...
- Spatial Transformer Networks
参考:http://blog.csdn.net/xbinworld/article/details/69049680 论文的关键在理解这句话: 先通过V中坐标(xtarget,ytarget)以此找到 ...
- Graph Transformer Networks 论文分享
论文地址:https://arxiv.org/abs/1911.06455 实现代码地址:https://github.com/ seongjunyun/Graph_Transformer_Netwo ...
随机推荐
- react 基础语法使用
刚开始不久react,在菜鸟上及其他前辈网站上学习,下面开始我的自学笔记. 包括: 渲染元素 组件(函数方法定义.es6 class定义) 事件处理 条件渲染 列表 下面代码部分将不会再写html部分 ...
- python学习-输入输出
Python的输入和输出非常方便,下面详细记录一下 任何计算机程序都是为了执行一个特定的任务,有了输入,用户才能告诉计算机程序所需的信息,有了输出,程序运行后才能告诉用户任务的结果.输入是Input, ...
- Could not determine which “make” command to run. Check the “make” step in the build configuration
环境: QT5.10 VisualStudio2015 错误1: Could not determine which “make” command to run. Check the “make” s ...
- PAT Basic 1014 福尔摩斯的约会 (20 分)
大侦探福尔摩斯接到一张奇怪的字条:我们约会吧! 3485djDkxh4hhGE 2984akDfkkkkggEdsb s&hgsfdk d&Hyscvnm.大侦探很快就明白了,字条上奇 ...
- Python内部执行过程
一.编译过程概述 当我们执行Python代码的时候,在Python解释器用四个过程“拆解”我们的代码,最终被CPU执行返回给用户. 首先当用户键入代码交给Python处理的时候会先进行词法分析,例如用 ...
- Windows10永久激活的工具
最近发现一个很好用的Windows10 永久激活的工具,比KMS什么的管用,而且无毒无公害.几乎支持所有的win10版本.感兴趣的朋友可以试试.之前win10没洗白的同学,也试试吧,说不定就洗白了呢. ...
- office visio
画 流程图软件 UML 是否要用做类图.时序图?????
- Centos 7搭建Gitlab服务器以及操作(创建项目,创建群组,创建用户,添加密钥)
一. 安装并配置依赖包 在CentOS系统上安装所需的依赖:ssh,防火墙,postfix(用于邮件通知) ,wget,以下这些命令也会打开 系统防火墙中的HTTP和SSH端口访问 安装前准备 命令: ...
- SELinux 对nginx访问目录的影响
centos新装的系统,用yum 安装的nginx . 因用yum 安装的nginx 默认目录在/usr下面. 当SELinux开启时,将会禁止访问设置在其他路径下的地址.比如我设置server 中 ...
- nginx反向代理原理及配置详解
nginx概述nginx是一款自由的.开源的.高性能的HTTP服务器和反向代理服务器:同时也是一个IMAP.POP3.SMTP代理服务器:nginx可以作为一个HTTP服务器进行网站的发布处理,另外n ...