(原)堆叠hourglass网络
转载请注明出处:
https://www.cnblogs.com/darkknightzh/p/11486185.html
论文:
https://arxiv.org/abs/1603.06937
官方torch代码(没具体看):
https://github.com/princeton-vl/pose-hg-demo
第三方pytorch代码(位于models/StackedHourGlass.py):
https://github.com/Naman-ntc/Pytorch-Human-Pose-Estimation
1. 简介
该论文利用多尺度特征来识别姿态,如下图所示,每个子网络称为hourglass Network,是一个沙漏型的结构,多个这种结构堆叠起来,称作stacked hourglass。堆叠的方式,方便每个模块在整个图像上重新估计姿态和特征。如下图所示,输入图像通过全卷积网络fcn后,得到特征,而后通过多个堆叠的hourglass,得到最终的热图。

Hourglass如下图所示。其中每个方块均为下下图的残差模块。


Hourglass采用了中间监督(Intermediate Supervision)。每个hourglass均会有热图(蓝色)。训练阶段,将这些热图和真实热图计算损失MSE,并求和,得到损失;推断阶段,使用的是最后一个hourglass的热图。

2. stacked hourglass
堆叠hourglass结构如下图所示(nChannels=256,nStack=2,nModules=2,numReductions=4, nJoints=17):

代码如下:
class StackedHourGlass(nn.Module):
"""docstring for StackedHourGlass"""
def __init__(self, nChannels, nStack, nModules, numReductions, nJoints):
super(StackedHourGlass, self).__init__()
self.nChannels = nChannels
self.nStack = nStack
self.nModules = nModules
self.numReductions = numReductions
self.nJoints = nJoints self.start = M.BnReluConv(3, 64, kernelSize = 7, stride = 2, padding = 3) # BN+ReLU+conv self.res1 = M.Residual(64, 128) # 输入和输出不等,输入通过1*1conv结果和3*(BN+ReLU+conv)求和
self.mp = nn.MaxPool2d(2, 2)
self.res2 = M.Residual(128, 128) # 输入和输出相等,为x+3*(BN+ReLU+conv)
self.res3 = M.Residual(128, self.nChannels) # 输入和输出相等,为x+3*(BN+ReLU+conv);否则输入通过1*1conv结果和3*(BN+ReLU+conv)求和。 _hourglass, _Residual, _lin1, _chantojoints, _lin2, _jointstochan = [],[],[],[],[],[] for _ in range(self.nStack): # 堆叠个数
_hourglass.append(Hourglass(self.nChannels, self.numReductions, self.nModules))
_ResidualModules = []
for _ in range(self.nModules):
_ResidualModules.append(M.Residual(self.nChannels, self.nChannels)) # 输入和输出相等,为x+3*(BN+ReLU+conv)
_ResidualModules = nn.Sequential(*_ResidualModules)
_Residual.append(_ResidualModules) # self.nModules 个 3*(BN+ReLU+conv)
_lin1.append(M.BnReluConv(self.nChannels, self.nChannels)) # BN+ReLU+conv
_chantojoints.append(nn.Conv2d(self.nChannels, self.nJoints,1)) # 1*1 conv,维度变换
_lin2.append(nn.Conv2d(self.nChannels, self.nChannels,1)) # 1*1 conv,维度不变
_jointstochan.append(nn.Conv2d(self.nJoints,self.nChannels,1)) # 1*1 conv,维度变换 self.hourglass = nn.ModuleList(_hourglass)
self.Residual = nn.ModuleList(_Residual)
self.lin1 = nn.ModuleList(_lin1)
self.chantojoints = nn.ModuleList(_chantojoints)
self.lin2 = nn.ModuleList(_lin2)
self.jointstochan = nn.ModuleList(_jointstochan) def forward(self, x):
x = self.start(x)
x = self.res1(x)
x = self.mp(x)
x = self.res2(x)
x = self.res3(x)
out = [] for i in range(self.nStack):
x1 = self.hourglass[i](x)
x1 = self.Residual[i](x1)
x1 = self.lin1[i](x1)
out.append(self.chantojoints[i](x1))
x1 = self.lin2[i](x1)
x = x + x1 + self.jointstochan[i](out[i]) # 特征求和 return (out)
3. hourglass
hourglass在numReductions>1时,递归调用自己,结构如下:

代码如下:
class Hourglass(nn.Module):
"""docstring for Hourglass"""
def __init__(self, nChannels = 256, numReductions = 4, nModules = 2, poolKernel = (2,2), poolStride = (2,2), upSampleKernel = 2):
super(Hourglass, self).__init__()
self.numReductions = numReductions
self.nModules = nModules
self.nChannels = nChannels
self.poolKernel = poolKernel
self.poolStride = poolStride
self.upSampleKernel = upSampleKernel """For the skip connection, a residual module (or sequence of residuaql modules) """
_skip = []
for _ in range(self.nModules):
_skip.append(M.Residual(self.nChannels, self.nChannels)) # 输入和输出相等,为x+3*(BN+ReLU+conv)
self.skip = nn.Sequential(*_skip) """First pooling to go to smaller dimension then pass input through
Residual Module or sequence of Modules then and subsequent cases:
either pass through Hourglass of numReductions-1 or pass through M.Residual Module or sequence of Modules """
self.mp = nn.MaxPool2d(self.poolKernel, self.poolStride) _afterpool = []
for _ in range(self.nModules):
_afterpool.append(M.Residual(self.nChannels, self.nChannels)) # 输入和输出相等,为x+3*(BN+ReLU+conv)
self.afterpool = nn.Sequential(*_afterpool) if (numReductions > 1):
self.hg = Hourglass(self.nChannels, self.numReductions-1, self.nModules, self.poolKernel, self.poolStride) # 嵌套调用本身
else:
_num1res = []
for _ in range(self.nModules):
_num1res.append(M.Residual(self.nChannels,self.nChannels)) # 输入和输出相等,为x+3*(BN+ReLU+conv)
self.num1res = nn.Sequential(*_num1res) # doesnt seem that important ? """ Now another M.Residual Module or sequence of M.Residual Modules """
_lowres = []
for _ in range(self.nModules):
_lowres.append(M.Residual(self.nChannels,self.nChannels)) # 输入和输出相等,为x+3*(BN+ReLU+conv)
self.lowres = nn.Sequential(*_lowres) """ Upsampling Layer (Can we change this??????) As per Newell's paper upsamping recommended """
self.up = myUpsample()#nn.Upsample(scale_factor = self.upSampleKernel) # 将高和宽扩充为原来2倍,实现上采样 def forward(self, x):
out1 = x
out1 = self.skip(out1) # 输入和输出相等,为x+3*(BN+ReLU+conv)
out2 = x
out2 = self.mp(out2) # 降维
out2 = self.afterpool(out2) # 输入和输出相等,为x+3*(BN+ReLU+conv)
if self.numReductions>1:
out2 = self.hg(out2) # 嵌套调用本身
else:
out2 = self.num1res(out2) # 输入和输出相等,为x+3*(BN+ReLU+conv)
out2 = self.lowres(out2) # 输入和输出相等,为x+3*(BN+ReLU+conv)
out2 = self.up(out2) # 升维 return out2 + out1 # 求和
4. 上采样myUpsample
上采样代码如下:
class myUpsample(nn.Module):
def __init__(self):
super(myUpsample, self).__init__()
pass
def forward(self, x): # 将高和宽扩充为原来2倍,实现上采样
return x[:, :, :, None, :, None].expand(-1, -1, -1, 2, -1, 2).reshape(x.size(0), x.size(1), x.size(2)*2, x.size(3)*2)
其中x为(N)(C)(H)(W)的矩阵,x[:, :, :, None, :, None]为(N)(C)(H)(1)(W)(1)的矩阵,expand之后变成(N)(C)(H)(2)(W)(2)的矩阵,最终reshape之后变成(N)(C)(2H) (2W)的矩阵,实现了将1个像素水平和垂直方向各扩充2倍,变成4个像素(4个像素值相同),完成了上采样。
5. 残差模块
残差模块结构如下:

代码如下:
class Residual(nn.Module):
"""docstring for Residual""" # 输入和输出相等,为x+3*(BN+ReLU+conv);否则输入通过1*1conv结果和3*(BN+ReLU+conv)求和
def __init__(self, inChannels, outChannels):
super(Residual, self).__init__()
self.inChannels = inChannels
self.outChannels = outChannels
self.cb = ConvBlock(inChannels, outChannels) # 3 * (BN+ReLU+conv) 其中第一组降维,第二组不变,第三组升维
self.skip = SkipLayer(inChannels, outChannels) # 输入和输出通道相等,则输出=输入,否则为1*1 conv def forward(self, x):
out = 0
out = out + self.cb(x)
out = out + self.skip(x)
return out
其中skiplayer代码如下:
class SkipLayer(nn.Module):
"""docstring for SkipLayer""" # 输入和输出通道相等,则输出=输入,否则为1*1 conv
def __init__(self, inChannels, outChannels):
super(SkipLayer, self).__init__()
self.inChannels = inChannels
self.outChannels = outChannels
if (self.inChannels == self.outChannels):
self.conv = None
else:
self.conv = nn.Conv2d(self.inChannels, self.outChannels, 1) def forward(self, x):
if self.conv is not None:
x = self.conv(x)
return x
6. conv
class BnReluConv(nn.Module):
"""docstring for BnReluConv""" # BN+ReLU+conv
def __init__(self, inChannels, outChannels, kernelSize = 1, stride = 1, padding = 0):
super(BnReluConv, self).__init__()
self.inChannels = inChannels
self.outChannels = outChannels
self.kernelSize = kernelSize
self.stride = stride
self.padding = padding self.bn = nn.BatchNorm2d(self.inChannels)
self.conv = nn.Conv2d(self.inChannels, self.outChannels, self.kernelSize, self.stride, self.padding)
self.relu = nn.ReLU() def forward(self, x):
x = self.bn(x)
x = self.relu(x)
x = self.conv(x)
return x
7. ConvBlock
class ConvBlock(nn.Module):
"""docstring for ConvBlock""" # 3 * (BN+ReLU+conv) 其中第一组降维,第二组不变,第三组升维
def __init__(self, inChannels, outChannels):
super(ConvBlock, self).__init__()
self.inChannels = inChannels
self.outChannels = outChannels
self.outChannelsby2 = outChannels//2 self.cbr1 = BnReluConv(self.inChannels, self.outChannelsby2, 1, 1, 0) # BN+ReLU+conv
self.cbr2 = BnReluConv(self.outChannelsby2, self.outChannelsby2, 3, 1, 1) # BN+ReLU+conv
self.cbr3 = BnReluConv(self.outChannelsby2, self.outChannels, 1, 1, 0) # BN+ReLU+conv def forward(self, x):
x = self.cbr1(x)
x = self.cbr2(x)
x = self.cbr3(x)
return x
(原)堆叠hourglass网络的更多相关文章
- [原] KVM虚拟机网络闪断分析
背景 公司云平台的机器时常会发生网络闪断,通常在10s-100s之间. 异常情况 VM出现问题时,表现出来的情况是外部监控系统无法访问,猜测可能是由于系统假死,OVS链路问题等等.但是在出现网络问题的 ...
- [原]NYOJ-无线网络覆盖-199
大学生程序代写 /*无线网络覆盖 时间限制:3000 ms | 内存限制:65535 KB 难度:3 描述 我们的乐乐同学对于网络可算得上是情有独钟,他有一个计划,那就是用无线网覆盖郑州大学. 现 ...
- 论文笔记 Stacked Hourglass Networks for Human Pose Estimation
Stacked Hourglass Networks for Human Pose Estimation key words:人体姿态估计 Human Pose Estimation 给定单张RGB ...
- Linux资源监控命令/工具(网络)
1.手动/自动设定与启动/关闭IP参数:ifconfig,ifup,ifdown 这三个指令的用途都是在启动网络接口,不过,ifup与ifdown仅能就/etc/sysconfig/netw ...
- 5、Docker容器网络
使用Linux进行IP层网络管理的指 http://linux-ip.net/html/ # yum install iproute http://linux-ip.net/html/tool ...
- 基于 Keras 用 LSTM 网络做时间序列预测
目录 基于 Keras 用 LSTM 网络做时间序列预测 问题描述 长短记忆网络 LSTM 网络回归 LSTM 网络回归结合窗口法 基于时间步的 LSTM 网络回归 在批量训练之间保持 LSTM 的记 ...
- ACM-ICPC 2018 徐州赛区网络预赛 I. query 树状数组
I. query 题目链接: Problem Description Given a permutation \(p\) of length \(n\), you are asked to answe ...
- 深度残差网络——ResNet学习笔记
深度残差网络—ResNet总结 写于:2019.03.15—大连理工大学 论文名称:Deep Residual Learning for Image Recognition 作者:微软亚洲研究院的何凯 ...
- NASNet : Google Brain经典作,改造搜索空间,性能全面超越人工网络,继续领跑NAS领域 | CVPR 2018
论文将搜索空间从整体网络转化为卷积单元(cell),再按照设定堆叠成新的网络家族NASNet.不仅降低了搜索的复杂度,从原来的28天缩小到4天,而且搜索出来的结构具有扩展性,在小模型和大模型场景下都能 ...
随机推荐
- JDOJ 1929: 求最长不下降序列长度
JDOJ 1929: 求最长不下降序列长度 JDOJ传送门 Description 设有一个正整数的序列:b1,b2,-,bn,对于下标i1<i2<-<im,若有bi1≤bi2≤-≤ ...
- shell编程(1)
shell编程(1) 杨乾成 2017301500302 一.尝试ping 题目第一项要求是检验自己主机所在网段有多少主机连通.于是我写的shell程序如下: #!/bin/bash i=; coun ...
- loj2305 NOI2017 游戏
题目链接 思路 既然\(x\)的数量那么小,我们就可以先把每个\(x\)搜索一遍. 枚举x的时候不需要把\(a,b,c\)全枚举一遍,只要枚举其中的两个就可以枚举到当前位置选任何车的情况. 然后就变成 ...
- js 立即调用函数 IIFE(Immediately Invoked Function Expression) 【转】
原文链接:https://www.cnblogs.com/ming-os9/p/8891300.html JS中 (function(){...})()立即执行函数 1 (function(){. ...
- React Hooks 深入系列
本文基于近段时间对 hooks 碎片化的理解作一次简单梳理, 个人博客.同时欢迎关注基于 hooks 构建的 UI 组件库 -- snake-design. 在 class 已经融入 React 生态 ...
- 安装Office 2016 出现 Office 16 Click-to-Run Extensibility Component
无法安装 64 位版本的 Office,因为在您的 PC 上找到了以下 32 位程序: Office 16 Click-to-Run Extensibility Component 请卸载所有 32 ...
- 【LOJ#6485】LJJ 学二项式定理(单位根反演)
[LOJ#6485]LJJ 学二项式定理(单位根反演) 题面 LOJ 题解 显然对于\(a0,a1,a2,a3\)分开算答案. 这里以\(a0\)为例 \[\begin{aligned} Ans&am ...
- 【03】Jenkins:SonarQube
写在前面的话 SonarQube 这个服务有些人熟悉,有些人陌生.对于我们这样的运维人员,我们需要了解的是,SonarQube 是一个代码质量管理平台,懂得怎么安装配置,这其实就差不多足够了.我们在 ...
- select2插件placeholder不显示的问题
如果设置了select2的templateSelection,没做特殊处理的话placeholder会不显示,需要做特殊处理 templateSelection: function(repo){ if ...
- angularJS 在edge浏览器上传文件,无法主动触发ng-click
今天发现的问题 在谷歌浏览器一直运行良好的功能,在edge浏览器不能使用. 代码参考我的另一篇博客:WebAPI Angularjs 上传文件 不能运行的原因 下图红框中的代码在edge浏览器中无法执 ...