你的模型到底有多少参数,每秒的浮点运算到底有多少,这些你都知道吗?近日,GitHub 开源了一个小工具,它可以统计 PyTorch 模型的参数量与每秒浮点运算数(FLOPs)。有了这两种信息,模型大小控制也就更合理了。

其实模型的参数量好算,但浮点运算数并不好确定,我们一般也就根据参数量直接估计计算量了。但是像卷积之类的运算,它的参数量比较小,但是运算量非常大,它是一种计算密集型的操作。反观全连接结构,它的参数量非常多,但运算量并没有显得那么大。

此外,机器学习还有很多结构没有参数但存在计算,例如和  等。因此,PyTorch-OpCounter 这种能直接统计 FLOPs 的工具还是非常有吸引力的。

  • PyTorch-OpCounter GitHub 地址:https://github.com/Lyken17/pytorch-OpCounter

OpCouter

PyTorch-OpCounter 的安装和使用都非常简单,并且还能定制化统计规则,因此那些特殊的运算也能自定义地统计进去。

我们可以使用 pip 简单地完成安装:pip install thop。不过 GitHub 上的代码总是最新的,因此也可以从 GitHub 上的脚本安装。

对于 torchvision 中自带的模型,Flops 统计通过以下几行代码就能完成:

from torchvision.models import resnet50
from thop import profile model = resnet50()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))

我们测试了一下 DenseNet-121,用 OpCouter 统计了参数量与运算量。API 的输出如下所示,它会告诉我们具体统计了哪些结构,它们的配置又是什么样的。

最后输出的浮点运算数和参数量分别为如下所示,换算一下就能知道 DenseNet-121 的参数量约有 798 万,计算量约有 2.91 GFLOPs。

flops: 2914598912.0
parameters: 7978856.0

OpCouter 是怎么算的

我们可能会疑惑,OpCouter 到底是怎么统计的浮点运算数。其实它的统计代码在项目中也非常可读,从代码上看,目前该工具主要统计了视觉方面的运算,包括各种卷积、激活函数、池化、批归一化等。例如最常见的二维卷积运算,它的统计代码如下所示:

def count_conv2d(m, x, y):
x = x[0] cin = m.in_channels
cout = m.out_channels
kh, kw = m.kernel_size
batch_size = x.size()[0] out_h = y.size(2)
out_w = y.size(3) # ops per output element
# kernel_mul = kh * kw * cin
# kernel_add = kh * kw * cin - 1
kernel_ops = multiply_adds * kh * kw
bias_ops = 1 if m.bias is not None else 0
ops_per_element = kernel_ops + bias_ops # total ops
# num_out_elements = y.numel()
output_elements = batch_size * out_w * out_h * cout
total_ops = output_elements * ops_per_element * cin // m.groups     m.total_ops = torch.Tensor([int(total_ops)])

总体而言,模型会计算每一个卷积核发生的乘加运算数,再推广到整个卷积层级的总乘加运算数。

定制你的运算统计

有一些运算统计还没有加进去,如果我们知道该怎样算,那么就可以写个自定义函数。

class YourModule(nn.Module):
# your definition
def count_your_model(model, x, y):
# your rule here input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ),
                        custom_ops={YourModule: count_your_model})

最后,作者利用这个工具统计了各种流行视觉模型的参数量与 FLOPs 量:

欢迎关注磐创博客资源汇总站:http://docs.panchuang.net/

欢迎关注PyTorch官方中文教程站:http://pytorch.panchuang.net/

两行代码统计模型参数量与FLOPs,这个PyTorch小工具值得一试的更多相关文章

  1. pytorch统计模型参数量

    用resnet50 来举例子 print("resnet50 have {} paramerters in total".format(sum(x.numel() for x in ...

  2. gRaphael——JavaScript 矢量图表库:两行代码实现精美图表

    gRaphael 是一个致力于帮助开发人员在网页中绘制各种精美图表的 Javascript 库,基于强大的 Raphael 矢量图形库.你只需要编写几行简单的代码就能创建出精美的条形图.饼图.点图和曲 ...

  3. 两行代码玩转SUMO!

    两行代码玩转SUMO! 这篇博客很简单,但是内容很丰富 如何生成如下所示的研究型路网结构? 只需要打开ubuntu终端输入如下代码即可,grid.number代表路口数量,grid.length代表路 ...

  4. 深度学习之depthwise separable convolution,计算量及参数量

    目录: 1.什么是depthwise separable convolution? 2.分析计算量.flops 3.参数量 4.与传统卷积比较 5.reference

  5. 深度学习之group convolution,计算量及参数量

    目录: 1.什么是group convolution? 和普通的卷积有什么区别? 2.分析计算量.flops 3.分析参数量 4.相比于传统普通卷积有什么优势以及缺点,有什么改进方法? 5.refer ...

  6. 通过两行代码即可调整苹果电脑 Launchpad 图标大小

    之前用 13 寸 Mac 的时候我还没觉得,后来换了 16 寸就发现有点不对劲了.因为 Mac 的高分辨率,当你进入 Launchpad 界面,应用图标的大小可能会让你怀疑:这特么是苹果的设计吗?有点 ...

  7. iOS 两行代码解决数据持久化

    在实际的iOS开发中,有些时候涉及到将程序的状态保存下来,以便下一次恢复,或者是记录用户的一些喜好和用户的登录信息等等. 这就需要涉及到数据的持久化了,所谓数据持久化就是数据的本地保存,将数据从内存中 ...

  8. 【转】Delphi+Halcon实战一:两行代码识别QR二维码

    Delphi+Halcon实战一:两行代码识别QR二维码 感谢网友:绝代双椒( QQ号应原作者要求隐藏了:xxxx6348)的支持 本文是绝代双椒的作品,因为最近在忙zw量化培训,和ziwang.co ...

  9. paip. 解决php 以及 python 连接access无效的参数量。参数不足,期待是 1”的错误

    paip. 解决php 以及 python 连接access无效的参数量.参数不足,期待是 1"的错误 作者Attilax  艾龙,  EMAIL:1466519819@qq.com  来源 ...

随机推荐

  1. FPGA小白学习之路(4)PLL中的locked信号解析(转)

    ALTPLL中的areset,locked的使用 转自:http://www.360doc.com/content/13/0509/20/9072830_284220258.shtml 今天对PLL中 ...

  2. USB小白学习之路(1) Cypress固件架构解析

    Cypress固件架构彻底解析及USB枚举 1. RAM的区别 56pin或者100pin的cy7c68013A,只有内部RAM,不支持外部RAM 128pin的cy7c68013A在pin脚EA=0 ...

  3. 使用QT绘制一个多边形

    目录 1. 概述 2. 实现 2.1. 代码 2.2. 解析 3. 结果 1. 概述 可以通过QT的重绘事件和鼠标事件来绘制多边形,最简单的办法就是在继承QWidget的窗体中重写paintEvent ...

  4. 处理 Vue 单页面 SEO 的另一种思路

    vue-meta-info 官方地址: https://github.com/monkeyWang... (设置vue 单页面meta info信息,如果需要单页面SEO,可以和 prerender- ...

  5. C++读入输出优化

    读入输出优化虽然对于小数据没有半点作用,但是对于大数据来说,可以优化几十ms. 有时就是那么几十ms,可以被卡掉大数据的点 读入优化 int read() { int x=0,sig=1; char ...

  6. PC端如何下载B站里面的视频?

    此随笔只是记录一下:   PC端下载B站的视频,在blibli前面加上一个i 然后在视频上鼠标右键,视频另存为+路径即可 PS:网上其他的方法,比如在blibli前面加上kan,后面加上jj等,这些方 ...

  7. python之嵌套 闭包 装饰器 global、nonlocal关键字

    嵌套: 在函数的内部定义函数闭包: 符合开放封闭原则:在不修改源代码与调用方式的情况下为函数添加新功能  # global 将局部变量变成全局变量 num = 100 def fn1(): globa ...

  8. @常见的远程服务器连接工具:Xshell与secureCRT的比较!!!(对于刚接触的测试小白很有帮助哦)

    现在比较受欢迎的终端模拟器软件当属xshell和securecrt了. XShell绝对首选,免费版也没什么限制,随便改字体随便改颜色随便改大小随便改字符集,多窗口,也比较小巧,而SecureCRT界 ...

  9. [译]ABP框架v2.3.0已经发布!

    在新冠病毒的日子里,我们发布了ABP框架v2.3, 这篇文章将说明本次发布新增内容和过去的两周我们做了什么. 关于新冠病毒和我们的团队 关于冠状病毒的状况我们很难过.在Volosoft的团队,我们有不 ...

  10. 当AI遇上K8S:使用Rancher安装机器学习必备工具JupyterHub

    Jupyter Notebook是用于科学数据分析的利器,JupyterHub可以在服务器环境下为多个用户托管Jupyter运行环境.本文将详细介绍如何使用Rancher安装JupyterHub来为数 ...