【语义分割】Stacked Hourglass Networks 以及 PyTorch 实现
Stacked Hourglass Networks(级联漏斗网络)
姿态估计(Pose Estimation)是 CV 领域一个非常重要的方向,而级联漏斗网络的提出就是为了提升姿态估计的效果,但是其中的经典思想可以扩展到其他方向,比如目标识别方向,代表网络是 CornerNet(预测目标的左上角和右下角点,再进行组合画框)。
CNN 之所以有效,是因为它能自动提取出对分类、检测和识别等任务有帮助的特征,并且随着网络层数的增加,所提取的特征逐渐变得抽象。以人脸识别为例,低层卷积网络能够提取出一些简单的特征,如轮廓;中间卷积网络能够提取出抽象一些的特征,如眼睛鼻子;较高层的卷积网络则能提取出更加抽象的特征,比如完整的人脸。这些将有助于我们理解级联漏斗模型(Stacked Hourglass Model,简称SHM)为什么有效。
做姿态估计,需要预测身体不同的关节点,手臂这种线条简单的结构,可能在中间卷积网络更容易被识别;而面部这种线条复杂的结构,可能在高层卷积网络才更容易被识别。因此,如果我们只使用最后一层的 feature map,就会造成一些信息的丢失。SHN 的主要贡献——利用多尺度特征来识别姿态。
Single Hourglass Network

上图是单个漏斗网络的结构。该结构与全卷积网络和其它设计(以多尺度方式处理空间信息,并进行密集预测)紧密相连。然而漏斗网络与其它设计有什么不同呢?由图可以看出,其自底向上(从高分辨率到低分辨率)处理和自顶向下(从低分辨率到高分辨率)处理之间的容量分布(这里实在不知道怎么翻译。。。)更加对称。另外还有一点需要注意,在自顶向下处理过程中,使用的不是 unpooing(一种常见的上采样操作)或者 deconv layers(可称为去卷积层),而是采用nearest neighbor upsampling(最近邻上采样)和 skip connections。这些操作需要在源码中理解。
Stacked Hourglass Networks

上图是单个漏斗网络后面的一些设计以及两个漏斗网络的连接细节。
块1 是上面介绍的单个沙漏网络,在它后面是一个 1$\times\(1 的全卷积网络,即块2;块2 后面分离出上下两个分支(块3 和块4):上分支(块3)依然是一个 1\)\times$1 的全卷积网络,下分支(块4)为 Heat map(下面重点介绍)。块5 是对块4 进行 channal 上的扩增,以方便块3、块5 和 上个漏斗网络的输出进行合并,一起作为当前漏斗网络的输出,同时是下一个漏斗网络的输入。
这里对 Heat map 进行解释:大部分姿态检测的最后一步是对 feature map 上的每个像素做概率预测,计算该像素是某个关节点的概率,而这里的 feature map 就是上面输出的 Heat map。使用它与真值进行误差计算。应用中,如果多个 Hourglass Module 组合在一起进行梯度下降,输出层的误差经过多层反向传播会大幅减小,也就是发生了梯度消失。因此,在整个网络中每个Hourglass Module 后面都会输出 Heat map 来计算损失。这种方法称为 中间监督(Intermediate Supervision),可以保证底层参数正常更新。
之所以使用多个 Stack Hourglass,是为了重复自下而上和自上而下的推理机制,允许重新评估整个图像的初始估计和特征,实现这一过程的核心就是预测中间的 Heat map,并让中间 Heat map 参与 loss 计算。
PyTorch 实现 Model
首先定义残差网络的基本模块:

import torch.nn as nn class HgResBlock(nn.Module): def __init__(self, inplanes, outplanes, stride=1):
super(HgResBlock, self).__init__() self.inplanes = inplanes
self.outplanes = outplanes
midplanes = outplanes // 2 self.bn_1 = nn.BatchNorm2d(inplanes)
self.conv_1 = nn.Conv2d(inplanes, midplanes, kernel_size=1, stride=stride)
self.bn_2 = nn.BatchNorm2d(midplanes)
self.conv_2 = nn.Conv2d(midplanes, midplanes, kernel_size=3, stride=1, padding=1)
self.bn_3 = nn.BatchNorm2d(midplanes)
self.conv_3 = nn.Conv2d(midplanes, outplanes, kernel_size=1, stride=1)
self.relu = nn.ReLU(inplanes=True)
if inplanes != outplanes:
self.conv_skip = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1) # Bottle neck
def forward(self, x):
residual = x out = self.bn_1(x)
out = self.conv_1(out)
out = self.relu(out) out = self.bn_2(out)
out = self.conv_2(out)
out = self.relu(out) out = self.bn_3(out)
out = self.conv_3(out)
out = self.relu(out) if self.inplanes != self.outplanes:
residual = self.conv_skip(residual)
out += residual return out
定义单个的 Hourglass Module(注意这里用到了递归):

import torch.nn as nn class Hourglass(nn.Module): def __init__(self, depth, nFeat, nModules, resBlocks):
super(Hourglass, self).__init__() self.depth = depth
self.nFeat = nFeat
self.nModules = nModules
self.resBlocks = resBlocks self.hg = self._make_hourglass()
self.downsample = nn.MaxPool2d(kernel_size=2, stride=2)
self.upsample = nn.Upsample(scale_factor=2, mode='nearest') def _make_residual(self, n):
return nn.Sequential(*[self.resBlocks(self.nFeat, self.nFeat) for _ in range(n)]) def _make_hourglass(self):
hg = [] for i in range(self.depth):
res = [self._make_residual(self.nModules) for _ in range(3)]
if i == (self.depth - 1):
res.append(self._make_residual(self.nModules)) # extra one for the middle
hg.append(nn.ModuleList(res)) return nn.ModuleList(hg) def _hourglass_forward(self, depth_id, x):
up_1 = self.hg[depth_id][0](x)
low_1 = self.downsample(x)
low_1 = self.hg[depth_id][1](low_1) if depth_id == (self.depth - 1):
low_2 = self.hg[depth_id][3](low_1)
else:
low_2 = self._hourglass_forward(depth_id+1, low_1) low_3 = self.hg[depth_id][2](low_2)
up_2 = self.upsample(low_3) return up_1 + up_2 def forward(self, x):
return self._hourglass_forward(0, x)
定义 Stacked Hourglass Network:

import torch.nn as nn from Model.HgResBlock import HgResBlock
from Model.SingleHourglass import Hourglass class HourglassNet(nn.Module): def __init__(self, nStacks, nModules, nFeat, nClasses, resBlock=HgResBlock, inplanes=3):
super(HourglassNet, self).__init__() self.nStacks = nStacks
self.nModules = nModules
self.nFeat = nFeat
self.nClasses = nClasses
self.resBlock = resBlock
self.inplanes = inplanes hg, res, fc, score, fc_, score_ = [], [], [], [], [], [] for i in range(nStacks):
hg.append(Hourglass(depth=4, nFeat=nFeat, nModules=nModules, resBlocks=resBlock))
res.append(self._make_residual(nModules))
fc.append(self._make_fc(nFeat, nFeat))
score.append(nn.Conv2d(nFeat, nClasses, kernel_size=1))
if i < (nStacks - 1):
fc_.append(nn.Conv2d(nFeat, nFeat, kernel_size=1))
score_.append(nn.Conv2d(nClasses, nFeat, kernel_size=1)) self.hg = nn.ModuleList(hg)
self.res = nn.ModuleList(res)
self.fc = nn.ModuleList(fc)
self.score = nn.ModuleList(score)
self.fc_ = nn.ModuleList(fc_)
self.score_ = nn.ModuleList(score_) def _make_head(self):
self.conv_1 = nn.Conv2d(self.inplanes, 64, kernel_size=7, stride=2, padding=3)
self.bn_1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True) self.res_1 = self.resBlock(64, 128)
self.pool = nn.MaxPool2d(2, 2)
self.res_2 = self.resBlock(128, 128)
self.res_3 = self.resBlock(128, self.nFeat) def _make_residual(self, n):
return nn.Sequential(*[self.resBlock(self.nFeat, self.nFeat) for _ in range(n)]) def _make_fc(self, inplanes, outplanes):
return nn.Sequential(
nn.Conv2d(inplanes, outplanes, kernel_size=1),
nn.BatchNorm2d(outplanes),
nn.ReLU(True)) def forward(self, x):
# head
x = self.conv_1(x)
x = self.bn_1(x)
x = self.relu(x) x = self.res_1(x)
x = self.pool(x)
x = self.res_2(x)
x = self.res_3(x) out = [] for i in range(self.nStacks):
y = self.hg[i](x)
y = self.res[i](y)
y = self.fc[i](y)
score = self.score[i](y)
out.append(score)
if i < (self.nStacks - 1):
fc_ = self.fc_[i](y)
score_ = self.score_[i](score)
x = x + fc_ + score_ return out
References:
[1] Stacked Hourglass Networks for Human Pose Estimation
[2] [hourglass pytorch 实现]
(https://blog.csdn.net/github_36923418/article/details/81030883)
【语义分割】Stacked Hourglass Networks 以及 PyTorch 实现的更多相关文章
- 论文阅读理解 - Stacked Hourglass Networks for Human Pose Estimation
http://blog.csdn.net/zziahgf/article/details/72732220 keywords 人体姿态估计 Human Pose Estimation 给定单张RGB图 ...
- 论文笔记 Stacked Hourglass Networks for Human Pose Estimation
Stacked Hourglass Networks for Human Pose Estimation key words:人体姿态估计 Human Pose Estimation 给定单张RGB ...
- PyTorch中的MIT ADE20K数据集的语义分割
PyTorch中的MIT ADE20K数据集的语义分割 代码地址:https://github.com/CSAILVision/semantic-segmentation-pytorch Semant ...
- 使用LabVIEW实现基于pytorch的DeepLabv3图像语义分割
前言 今天我们一起来看一下如何使用LabVIEW实现语义分割. 一.什么是语义分割 图像语义分割(semantic segmentation),从字面意思上理解就是让计算机根据图像的语义来进行分割,例 ...
- 【Semantic segmentation Overview】一文概览主要语义分割网络(转)
文章来源:https://www.tinymind.cn/articles/410 本文来自 CSDN 网站,译者蓝三金 图像的语义分割是将输入图像中的每个像素分配一个语义类别,以得到像素化的密集分类 ...
- 语义分割丨PSPNet源码解析「训练阶段」
引言 之前一段时间在参与语义分割的项目,最近有时间了,正好把这段时间的所学总结一下. 在代码上,语义分割的框架会比目标检测简单很多,但其中也涉及了很多细节.在这篇文章中,我以PSPNet为例,解读一下 ...
- caffe初步实践---------使用训练好的模型完成语义分割任务
caffe刚刚安装配置结束,乘热打铁! (一)环境准备 前面我有两篇文章写到caffe的搭建,第一篇cpu only ,第二篇是在服务器上搭建的,其中第二篇因为硬件环境更佳我们的步骤稍显复杂.其实,第 ...
- 笔记︱图像语义分割(FCN、CRF、MRF)、论文延伸(Pixel Objectness、)
图像语义分割的意思就是机器自动分割并识别出图像中的内容,我的理解是抠图- 之前在Faster R-CNN中借用了RPN(region proposal network)选择候选框,但是仅仅是候选框,那 ...
- 笔记:基于DCNN的图像语义分割综述
写在前面:一篇魏云超博士的综述论文,完整题目为<基于DCNN的图像语义分割综述>,在这里选择性摘抄和理解,以加深自己印象,同时达到对近年来图像语义分割历史学习和了解的目的,博古才能通今!感 ...
随机推荐
- 第K个语法符号
在第一行我们写上一个 0.接下来的每一行,将前一行中的0替换为01,1替换为10. 给定行数 N 和序数 K,返回第 N 行中第 K个字符.(K从1开始) 例子: 输入: N = 1, K = 1输出 ...
- 剑指Offer-37.二叉树的深度(C++/Java)
题目: 输入一棵二叉树,求该树的深度.从根结点到叶结点依次经过的结点(含根.叶结点)形成树的一条路径,最长路径的长度为树的深度. 分析: 递归求解左右子树的最大值即可,每遍历到一个结点,深度加1,最后 ...
- Python 变量与运算符
变量 基本概念: 1. 变量,名字,数据的唯一标识2.变量命名: 字母.数字.下划线: 不能以数字开头: 区分大小写: 不能使用保留字和关键字: 命名要有意义:(多个单词时,推荐使用下划线连接) 3. ...
- 最强Linux shell工具Oh My Zsh 指南
引言 笔者已经使用zsh一年多了,发现这个东东的功能太强大了.接下来,给大家推荐一下. 以下是oh-my-zsh部分功能 命令验证 在所有正在运行的shell中共享命令历史记录 拼写纠正 主题提示(A ...
- IT兄弟连 HTML5教程 多媒体应用 小结及习题
小结 在互联网上,图像和链接则是通过URL唯一确定信息资源的位置.URL分为绝对URL和相对URL.通过使用<img />标记在浏览器中显示一张图像.超文本具有的链接能力,可层层链接相关文 ...
- easyui treegrid数据重复加载问题
在使用easyui的时候,出现了数据重复加载的问题.如下图 关于这个问题有两种说法,第一种说法是 easyui-datagrid 类在html和js中重复定义,数据渲染时会加载两次.另一种是$(&qu ...
- Linux中,Tomcat 怎么承载高并发(深入Tcp参数 backlog)
一.前言 这两天看tomcat,查阅 tomcat 怎么承载高并发时,看到了backlog参数.我们知道,服务器端一般使用mq来减轻高并发下的洪峰冲击,将暂时不能处理的请求放入队列,后续再慢慢处理.其 ...
- (六十一)c#Winform自定义控件-信号灯(工业)-HZHControls
官网 http://www.hzhcontrols.com 前提 入行已经7,8年了,一直想做一套漂亮点的自定义控件,于是就有了本系列文章. GitHub:https://github.com/kww ...
- (转)深入解析TensorFlow中滑动平均模型与代码实现
本文链接:https://blog.csdn.net/m0_38106113/article/details/81542863 指数加权平均算法的原理 TensorFlow中的滑动平均模型使用的是滑动 ...
- 并发容器之ConcurrentLinkedQueue
本人免费整理了Java高级资料,涵盖了Java.Redis.MongoDB.MySQL.Zookeeper.Spring Cloud.Dubbo高并发分布式等教程,一共30G,需要自己领取.传送门:h ...