[开发技巧]·AdaptivePooling与Max/AvgPooling相互转换

个人网站--> http://www.yansongsong.cn/

1.问题描述

自适应池化Adaptive Pooling是PyTorch的一种池化层,根据1D,2D,3D以及Max与Avg可分为六种形式。

自适应池化Adaptive Pooling与标准的Max/AvgPooling区别在于,自适应池化Adaptive Pooling会根据输入的参数来控制输出output_size,而标准的Max/AvgPooling是通过kernel_size,stride与padding来计算output_size:

                        output_size = ceil ( (input_size+2∗padding−kernel_size)/stride)+1

Adaptive Pooling仅存在与PyTorch,如果需要将包含Adaptive Pooling的代码移植到Keras或者TensorFlow就会遇到问题。

本文将提供一个公式,可以简便的将AdaptivePooling准换为Max/AvgPooling,便于大家移植使用。

2.原理讲解

我们已经知道了普通Max/AvgPooling计算公式为:output_size = ceil ( (input_size+2∗padding−kernel_size)/stride)+1 

当我们使用Adaptive Pooling时,这个问题就变成了由已知量input_size,output_size求解kernel_size与stride

为了简化问题,我们将padding设为0(后面我们可以发现源码里也是这样操作的c++源码部分

stride = floor ( (input_size / (output_size) )

kernel_size = input_size − (output_size−1) * stride

3.实战演示

下面我们通过一个实战来操作一下,验证公式的正确性

import torch as t
import math
import numpy as np alist = t.randn(2,6,7) inputsz = np.array(alist.shape[1:])
outputsz = np.array([2,3]) stridesz = np.floor(inputsz/outputsz).astype(np.int32) kernelsz = inputsz-(outputsz-1)*stridesz adp = t.nn.AdaptiveAvgPool2d(list(outputsz))
avg = t.nn.AvgPool2d(kernel_size=list(kernelsz),stride=list(stridesz))
adplist = adp(alist)
avglist = avg(alist) print(alist)
print(adplist)
print(avglist)

输出结果

tensor([[[ 0.9095,  0.8043,  0.4052,  0.3410,  1.8831,  0.8703, -0.0839],
[ 0.3300, -1.2951, -1.8148, -1.1118, -1.1091, 1.5657, 0.7093],
[-0.6788, -1.2790, -0.6456, 1.9085, 0.8627, 1.1711, 0.5614],
[-0.0129, -0.6447, -0.6685, -1.2087, 0.8535, -1.4802, 0.5274],
[ 0.7347, 0.0374, -1.7286, -0.7225, -0.4257, -0.0819, -0.9878],
[-1.2553, -1.0774, -0.1936, -1.4741, -0.9028, -0.1584, -0.6612]], [[-0.3473, 1.0599, -1.5744, -0.2023, -0.5336, 0.5512, -0.3200],
[-0.2518, 0.1714, 0.6862, 0.3334, -1.2693, -1.3348, -0.0878],
[ 1.0515, 0.1385, 0.4050, 0.8554, 1.0170, -2.6985, 0.3586],
[-0.1977, 0.8298, 1.6110, -0.9102, 0.7129, 0.2088, 0.9553],
[-0.2218, -0.7234, -0.4407, 1.0369, -0.8884, 0.3684, 1.2134],
[ 0.5812, 1.1974, -0.1584, -0.0903, -0.0628, 3.3684, 2.0330]]]) tensor([[[-0.3627, 0.0799, 0.7145],
[-0.5343, -0.7190, -0.3686]], [[ 0.1488, -0.0314, -0.4797],
[ 0.2753, 0.0900, 0.8788]]]) tensor([[[-0.3627, 0.0799, 0.7145],
[-0.5343, -0.7190, -0.3686]], [[ 0.1488, -0.0314, -0.4797],
[ 0.2753, 0.0900, 0.8788]]])

可以发现adp = t.nn.AdaptiveAvgPool2d(list(outputsz))与avg = t.nn.AvgPool2d(kernel_size=list(kernelsz),stride=list(stridesz))结果一致

为了防止这是偶然现象,修改参数,使用AdaptiveAvgPool1d进行试验

import torch as t
import math
import numpy as np alist = t.randn(2,3,9) inputsz = np.array(alist.shape[2:])
outputsz = np.array([4]) stridesz = np.floor(inputsz/outputsz).astype(np.int32) kernelsz = inputsz-(outputsz-1)*stridesz adp = t.nn.AdaptiveAvgPool1d(list(outputsz))
avg = t.nn.AvgPool1d(kernel_size=list(kernelsz),stride=list(stridesz))
adplist = adp(alist)
avglist = avg(alist) print(alist)
print(adplist)
print(avglist)

  

输出结果

tensor([[[ 1.3405,  0.3509, -1.5119, -0.1730,  0.6971,  0.3399, -0.0874,
-1.2417, 0.6564],
[ 2.0482, 0.3528, 0.0703, 1.2012, -0.8829, -0.3156, 1.0603,
-0.7722, -0.6086],
[ 1.0470, -0.9374, 0.3594, -0.8068, 0.5126, 1.4135, 0.3538,
-1.0973, 0.3046]], [[-0.1688, 0.7300, -0.3457, 0.5645, -1.2507, -1.9724, 0.4469,
-0.3362, 0.7910],
[ 0.5676, -0.0614, -0.0243, 0.1529, 0.8276, 0.2452, -0.1783,
0.7460, 0.2577],
[-0.1433, -0.7047, -0.4883, 1.2414, -1.4316, 0.9704, -1.7088,
-0.0094, -0.3739]]]) tensor([[[ 0.0598, -0.3293, 0.3165, -0.2242],
[ 0.8237, 0.1295, -0.0461, -0.1069],
[ 0.1563, 0.0217, 0.7600, -0.1463]], [[ 0.0718, -0.3440, -0.9254, 0.3006],
[ 0.1606, 0.3187, 0.2982, 0.2751],
[-0.4454, -0.2262, -0.7233, -0.6973]]]) tensor([[[ 0.0598, -0.3293, 0.3165, -0.2242],
[ 0.8237, 0.1295, -0.0461, -0.1069],
[ 0.1563, 0.0217, 0.7600, -0.1463]], [[ 0.0718, -0.3440, -0.9254, 0.3006],
[ 0.1606, 0.3187, 0.2982, 0.2751],
[-0.4454, -0.2262, -0.7233, -0.6973]]])

可以发现adp = t.nn.AdaptiveAvgPool1d(list(outputsz))与avg = t.nn.AvgPool1d(kernel_size=list(kernelsz),stride=list(stridesz))结果也是相同的。

4.总结分析

在以后遇到别人代码使用Adaptive Pooling,可以通过这两个公式转换为标准的Max/AvgPooling,从而应用到不同的学习框架中

stride = floor ( (input_size / (output_size) )

kernel_size = input_size − (output_size−1) * stride

只需要知道输入的input_size ,就可以推导出stride 与kernel_size ,从而替换为标准的Max/AvgPooling

Hope this helps

[开发技巧]·AdaptivePooling与Max/AvgPooling相互转换的更多相关文章

  1. iOS开发技巧系列---详解KVC(我告诉你KVC的一切)

    KVC(Key-value coding)键值编码,单看这个名字可能不太好理解.其实翻译一下就很简单了,就是指iOS的开发中,可以允许开发者通过Key名直接访问对象的属性,或者给对象的属性赋值.而不需 ...

  2. Unity 游戏开发技巧集锦之制作一个望远镜与查看器摄像机

    Unity 游戏开发技巧集锦之制作一个望远镜与查看器摄像机 Unity中制作一个望远镜 本节制作的望远镜,在鼠标左键按下时,看到的视图会变大:当不再按下的时候,会慢慢缩小成原来的视图.游戏中时常出现的 ...

  3. [开发技巧]·Python极简实现滑动平均滤波(基于Numpy.convolve)

    [开发技巧]·Python极简实现滑动平均滤波(基于Numpy.convolve) ​ 1.滑动平均概念 滑动平均滤波法(又称递推平均滤波法),时把连续取N个采样值看成一个队列 ,队列的长度固定为N ...

  4. Mysql - 开发技巧(二)

    本文中的涉及到的表在https://github.com/YangBaohust/my_sql中 本文衔接Mysql - 巧用join来优化sql(https://www.cnblogs.com/dd ...

  5. javascript的10个开发技巧

    总结10个提高开发效率的JavaScript开发技巧. 1.生成随机的uid. const genUid = () => { var length = 20; var soupLength = ...

  6. SQL开发技巧(二)

    本系列文章旨在收集在开发过程中遇到的一些常用的SQL语句,然后整理归档,本系列文章基于SQLServer系列,且版本为SQLServer2005及以上-- 文章系列目录 SQL开发技巧(一) SQL开 ...

  7. DelphiXE2 DataSnap开发技巧收集

    DelphiXE2 DataSnap开发技巧收集 作者:  2012-08-07 09:12:52     分类:Delphi     标签: 作为DelphiXE2 DataSnap开发的私家锦囊, ...

  8. delphi XE5下安卓开发技巧

    delphi XE5下安卓开发技巧 一.手机快捷方式显示中文名称 project->options->Version Info-label(改成需要显示的中文名即可),但是需要安装到安卓手 ...

  9. 经典收藏 50个jQuery Mobile开发技巧集萃

    http://www.cnblogs.com/chu888chu888/archive/2011/11/10/2244181.html 1.Backbone移动实例 这是在Safari中运行的一款Ba ...

随机推荐

  1. Oracle的nvl

    在Oracle中nvl(字段名,value)函数用于对没有值的字段做处理在MySql中ifnull(字段名,value)是一样的功能

  2. 种树 BZOJ2151 模拟费用流

    分析: 我们如果选择点i,那么我们不能选择i-1和i+1,如果没有这个限制,直接贪心就可行,而加上这个限制,我们考虑同样贪心,每次选择i后,将点i-1,i+1从双向链表中删除,并且将-a[i]+a[i ...

  3. 读取txt内文件内容

    命令如下: f = open("c:\\1.txt","r")  lines = f.readlines()#读取全部内容  for line in lines ...

  4. Linux安装任意版本的dotnet环境

    下载地址 https://www.microsoft.com/net/download/dotnet-core/2.1 安装符合服务器CPU架构的二进制包. 如果架构不对,会出现一下错误: -bash ...

  5. Spark学习之数据读取与保存总结(一)

    一.动机 我们已经学了很多在 Spark 中对已分发的数据执行的操作.到目前为止,所展示的示例都是从本地集合或者普通文件中进行数据读取和保存的.但有时候,数据量可能大到无法放在一台机器中,这时就需要探 ...

  6. 你可能不知道的jvm的类加载机制

    引言:在java代码中,类型的加载.连接与初始化过程都是在程序运行期间完成的. 加载:查找并加载类的二进制数据(class文件加载到内存中) 连接:a 验证:确保被加载类的正确性. b准备:为类的静态 ...

  7. ab性能测试工具的使用

    一.什么是ab ab,即Apache Benchmark,是一种用于测试Apache超文本传输协议(HTTP)服务器的工具. ab命令会创建很多的并发访问线程,模拟多个访问者同时对某一URL地址进行访 ...

  8. SpringBoot进阶教程(二十九)整合Redis 发布订阅

    SUBSCRIBE, UNSUBSCRIBE 和 PUBLISH 实现了 发布/订阅消息范例,发送者 (publishers) 不用编程就可以向特定的接受者发送消息 (subscribers). Ra ...

  9. 阿里云重磅发布DMS数据库实验室 免费体验数据库引擎

    2月27日,阿里云数据管理DMS发布年度巨献——数据库实验室,用户可在该实验室环境下免费体验数据库引擎.以及DMS各项产品功能.数据库实验室是DMS所提供的体验空间,免费赠送数据库引擎资源. 用户只需 ...

  10. 在编写Arcgis Engine 过程中对于接口引用和实现过程过产生的感悟

    Engine10.2版本 在vs里面新建类GeoMaoAO,并定义接口,在class中定义并实现,如下代码 以平时练习为例,我定义了一个接口,在里面定义了许多的控件,并在类中想要实现这一接口.如果在v ...