Dynamic Routing Between Capsules
概
虽然11年就提出了capsule的概念, 但是走入人们视线的应该还是这篇文章吧. 虽然现阶段, capsule没有体现出什么优势. 不过, capsule相较于传统的CNN融入了很多先验知识, 更能够拟合人类的视觉系统(我不知), 或许有一天它会大放异彩.
主要内容

直接从这个结构图讲起吧.
- Input: 1 x 28 x 28 的图片 经过 9 x 9的卷积核(stride=1, padding=0, out_channels=256)作用;
- 256 x 20 x 20的特征图, 经过primarycaps作用(9 x 9 的卷积核(strde=2, padding=0, out_channels=256);
- (32 x 8) x 6 x 6的特征图, 理解为32 x 6 x 6 x 8 = 1152 x 8, 即1152个胶囊, 每个胶囊由一个8D的向量表示\(u_{i}\); (这个地方要不要squash, 大部分实现都是要的.)
- 接下来digitcaps中有10个caps(对应10个类别), 1152caps和10个caps一一对应, 分别用\(i, j\)表示, 前一层的caps为后一层提供输入, 输入为
\]
可见, 应当有1152 x 10个\(W_{ij}\in \mathbb{R}^{16\times 8}\), 其中16是输出胶囊的维度. 最后10个caps的输出为
\]
其中\(c_{ij}\)是通过一个路由算法决定的, \(v_j\), 即最后的输入如此定义是出于一种直觉, 即保持原始输出(\(s\))的方向, 同时让\(v\)的长度表示一个概率(这一步称为squash).
首先初始化\(b_{ij}=0\) (这里在程序实现的时候有一个考量, 是每一次都要初始化吗, 我看大部分的实现都是如此的).

上面的Eq.3就是
c_{ij}=\frac{\exp(b_{ij})}{\sum_{k}\exp(b_{ik})}.
\]
另外\(\hat{\mu}_{j|i} \cdot v_j=\hat{\mu}_{j|i}^Tv_j\)是一种cos相似度度量.
损失函数
损失函数采用的是margin loss:
L_k = T_k \max(0, m^+ - \|v_k\|)^2 + \lambda (1 - T_k) \max(0, \|v_k\|-m^-)^2.
\]
\(m^+, m^-\)通常取0.9和0.1, \(\lambda\)通常取0.5.
代码
我的代码, 在sgd下可以训练(但是准确率只有98), 在adam下就死翘翘了, 所以代码肯定是有问题, 但是我实在是找不出来了, 这里有很多实现的汇总.
"""
Sabour S., Frosst N., Hinton G. Dynamic Routing Between Capsules.
Neural Information Processing Systems, pp. 3856-3866, 2017.
https://arxiv.org/pdf/1710.09829.pdf
The implement below refers to https://github.com/adambielski/CapsNet-pytorch.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
def squash(s):
temp = s.norm(dim=-1, keepdim=True)
return (temp / (1. + temp ** 2)) * s
class PrimaryCaps(nn.Module):
def __init__(
self, in_channel, out_entities,
out_dims, kernel_size, stride, padding
):
super(PrimaryCaps, self).__init__()
self.conv = nn.Conv2d(in_channel, out_entities * out_dims,
kernel_size, stride, padding)
self.out_entities = out_entities
self.out_dims = out_dims
def forward(self, inputs):
conv_outs = self.conv(inputs).permute(0, 2, 3, 1).contiguous()
outs = conv_outs.view(conv_outs.size(0), -1, self.out_dims)
return squash(outs)
class AgreeRouting(nn.Module):
def __init__(self, in_caps, out_caps, out_dims, iterations=3):
super(AgreeRouting, self).__init__()
self.in_caps = in_caps
self.out_caps = out_caps
self.out_dims = out_dims
self.iterations = iterations
@staticmethod
def softmax(inputs, dim=-1):
return F.softmax(inputs, dim=dim)
def forward(self, inputs):
# inputs N x in_caps x out_caps x out_dims
b = torch.zeros(inputs.size(0), self.in_caps, self.out_caps).to(inputs.device)
for r in range(self.iterations):
c = self.softmax(b) # N x in_caps x out_caps !!!!!!!!!
s = (c.unsqueeze(-1) * inputs).sum(dim=1) # N x out_caps x out_dims
v = squash(s) # N x out_caps x out_dims
b = b + (v.unsqueeze(dim=1) * inputs).sum(dim=-1)
return v
class CapsLayer(nn.Module):
def __init__(self, in_caps, in_dims, out_caps, out_dims, routing):
super(CapsLayer, self).__init__()
self.in_caps = in_caps
self.in_dims = in_dims
self.routing = routing
self.weights = nn.Parameter(torch.rand(in_caps, out_caps, in_dims, out_dims))
nn.init.kaiming_uniform_(self.weights)
def forward(self, inputs):
# inputs: N x in_caps x in_dims
inputs = inputs.view(inputs.size(0), self.in_caps, 1, 1, self.in_dims)
u_pres = (inputs @ self.weights).squeeze() # N x in_caps x out_caps x out_dims
outs = self.routing(u_pres) # N x out_caps x out_dims
return outs
class CapsNet(nn.Module):
def __init__(self):
super(CapsNet, self).__init__()
# N x 1 x 28 x 28
self.conv = nn.Conv2d(1, 256, 9, 1, padding=0) # N x (32 * 8) x 20 x 20
self.primarycaps = PrimaryCaps(256, 32, 8, 9, 2, 0) # N x (6 x 6 x 32) x 8
routing = AgreeRouting(32 * 6 * 6, 10, 8, 3)
self.digitlayer = CapsLayer(32 * 6 * 6, 8, 10, 16, routing)
def forward(self, inputs):
conv_outs = F.relu(self.conv(inputs))
pri_outs = self.primarycaps(conv_outs)
outs = self.digitlayer(pri_outs)
probs = outs.norm(dim=-1)
return probs
if __name__ == "__main__":
x = torch.randn(4, 1, 28 ,28)
capsnet = CapsNet()
print(capsnet(x))
def margin_loss(logits, labels, m=0.9, leverage=0.5, adverage=True):
# outs: N x num_classes x dim
# labels: N
temp1 = F.relu(m - logits) ** 2
temp2 = F.relu(logits + m - 1) ** 2
T = F.one_hot(labels.long(), logits.size(-1))
loss = (temp1 * T + leverage * temp2 * (1 - T)).sum()
if adverage:
loss = loss / logits.size(0)
# Another implement is using scatter_
# T = torch.zero(logits.size()).long()
# T.scatter_(dim=1, index=labels.view(-1, 1), 1.).cuda() if cuda()
return loss
Dynamic Routing Between Capsules的更多相关文章
- Hinton's paper Dynamic Routing Between Capsules 的 Tensorflow , Keras ,Pytorch实现
Tensorflow 实现 A Tensorflow implementation of CapsNet(Capsules Net) in Hinton's paper Dynamic Routing ...
- 【论文笔记】Dynamic Routing Between Capsules
Dynamic Routing Between Capsules 2018-09-16 20:18:30 Paper:https://arxiv.org/pdf/1710.09829.pdf%20 P ...
- Dynamic Routing Based On Redis
Dynamic Routing Based On Redis Ngnix技术研究系列2-基于Redis实现动态路由 上篇博文我们写了个引子: Ngnix技术研究系列1-通过应用场景看Nginx的反 ...
- dynamic routing between captual
对于人脑 决策树形式 对于CNN 层级与层级间的传递 人在识别物体的时候会进行坐标框架的设置 CNN无法识别,只能通过大量训练 胶囊 :一个神经元集合,有一个活动的向量,来表示物体的各类信息,向量的长 ...
- Paper | SkipNet: Learning Dynamic Routing in Convolutional Networks
目录 1. 概括 2. 相关工作 3. 方法细节 门限模块的结构 训练方法 4. 总结 作者对residual network进行了改进:加入了gating network,基于上一层的激活值,得到一 ...
- 总结近期CNN模型的发展(一)---- ResNet [1, 2] Wide ResNet [3] ResNeXt [4] DenseNet [5] DPNet [9] NASNet [10] SENet [11] Capsules [12]
总结近期CNN模型的发展(一) from:https://zhuanlan.zhihu.com/p/30746099 余俊 计算机视觉及深度学习 1.前言 好久没有更新专栏了,最近因为项目的原因接 ...
- 百年老图难倒谷歌AI,兔还是鸭?这是个问题
上面这张图,画的是鸭子还是兔子? 自从1892年首次出现在一本德国杂志上之后,这张图就一直持续引发争议.有些人只能看到一只兔子,有些人只能看到一只鸭子,有些人两个都能看出来. 心理学家用这张图证明了一 ...
- 浅析 Hinton 最近提出的 Capsule 计划
[原文] 浅析 Hinton 最近提出的 Capsule 计划 关于最新的 Hinton 的论文 Dynamic Routing Between Capsules,参见 https:// ...
- Hinton“深度学习之父”和“神经网络先驱”,新论文Capsule将推翻自己积累了30年的学术成果时
Hinton“深度学习之父”和“神经网络先驱”,新论文Capsule将推翻自己积累了30年的学术成果时 在论文中,Capsule被Hinton大神定义为这样一组神经元:其活动向量所表示的是特定实体类型 ...
随机推荐
- absorb
absorb 物理的absorb比较直观.被书本/知识absorb也好理解.涉及到money/time时有点抽象,但汉语也有"吸金"的说法,consume, use up.可以吸收 ...
- day13 iptables防火墙
day13 iptables防火墙 一.防火墙的概述 1.什么是防火墙 防止恶意流量访问的软件就叫做防火墙. 2.防火墙的种类 软件防火墙:firewalld.iptables 硬件防火墙:F5 fi ...
- 商业爬虫学习笔记day5
一. 发送post请求 import requests url = "" # 发送post请求 data = { } response = requests.post(url, d ...
- 【bfs】洛谷 P1443 马的遍历
题目:P1443 马的遍历 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 记录一下第一道ac的bfs,原理是利用队列queue记录下一层的所有点,然后一层一层遍历: 其中: 1.p ...
- 【swift】Xcode未响应(卡死、卡住、CPU满载、忙碌、转圈圈)
在尝试了网上的方法,依然没能解决问题,尝试如下: 1.去自己项目的路径,找到<你的项目名.xcodeproj>,点击[显示包内容],删除xcuserdata文件夹 2.去Library,把 ...
- Spring同一个类中的注解方法调用AOP失效问题总结
public interface XxxService { // a -> b void a(); void b(); } @Slf4j public class XxxServiceImpl ...
- mybatis-plus条件构造用is开头的Boolean类型字段时遇到的问题
is打头的Boolean字段导致的代码生成器与lambda构造器的冲突 https://gitee.com/baomidou/mybatis-plus/issues/I171DD?_from=gite ...
- 【Elasticsearch-Java】Java客户端搭建
Elasticsearch Java高级客户端 1. 概述 Java REST Client 有两种风格: Java Low Level REST Client :用于Elasticsearch ...
- SprignBoot整合Spring Data Elasticsearch
一.原生java整合elasticsearch的API地址 https://www.elastic.co/guide/en/elasticsearch/client/java-api/6.2/java ...
- 【Java 与数据库】How to Timeout JDBC Queries
How to Timeout JDBC Queries JDBC queries by default do not have any timeout, which means that a quer ...