1. 模型定义与导出代码

import torch
import torch.nn as nn
import torch.onnx
import onnxsim
import onnx class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(num_features=16)
self.act1 = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels=16, out_channels=64, kernel_size=5, padding=2)
self.bn2 = nn.BatchNorm2d(num_features=64)
self.act2 = nn.ReLU()
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(in_features=64, out_features=10) def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.act1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.act2(x) # flatten: B×C×H×W → B×C×L(L=H×W)
x = torch.flatten(x, 2, 3) # 平均池化:B×C×L → B×C×1
x = self.avgpool(x) # 再次 flatten:B×C×1 → B×C
x = torch.flatten(x, 1) # 全连接层分类:B×C → B×10
x = self.head(x)
return x

2. 导出为 ONNX 并简化

def export_norm_onnx():
input = torch.rand(1, 3, 64, 64) # 输入:B×3×64×64
model = Model()
file = "./sample-reshape.onnx" # 导出 ONNX 模型
torch.onnx.export(
model = model,
args = (input,),
f = file,
input_names = ["input0"],
output_names = ["output0"],
opset_version = 15
)
print("Finished normal onnx export") # 检查模型结构合法性
model_onnx = onnx.load(file)
onnx.checker.check_model(model_onnx) # 使用 onnx-simplifier 进行图结构简化
print(f"Simplifying with onnx-simplifier {onnxsim.__version__}...")
model_onnx, check = onnxsim.simplify(model_onnx)
assert check, "assert check failed"
onnx.save(model_onnx, file)

小提示:为什么 flatten 会生成多个节点?

x = torch.flatten(x, 2, 3)

ONNX 中不支持 flatten(x, start_dim=2) 这样的高维展开直接表示,因此 PyTorch 导出时会转换为:

  • Shape:获取张量形状
  • Slice:提取要 flatten 的维度
  • Concat:拼接新 shape
  • Reshape:完成 flatten 动作

使用 onnxsim 简化后,这些操作通常会被合并为一个简单的 FlattenReshape


3. 主函数执行导出流程

if __name__ == "__main__":
export_norm_onnx()

代码结构整体说明:

模型结构(Model 类):

x -> conv1 -> bn1 -> relu1
-> conv2 -> bn2 -> relu2
-> flatten -> avgpool -> flatten -> linear -> output

其中重点在于:

第一段 flatten:

x = torch.flatten(x, 2, 3)  # B, C, H, W -> B, C, L

这个操作会导致导出的 ONNX 图中生成:

  • Shape
  • Slice
  • Concat
  • Reshape

等一系列辅助节点。为什么?


为什么 flatten(x, 2, 3) 会变成这么多 ONNX 节点?

PyTorch 的 torch.flatten(x, 2, 3) 表示:

  • x 从第 2 维(H)到第 3 维(W)展平为一维
  • 举个例子:输入 x[B, C, H, W],flatten 后变成 [B, C, H*W]

但是在 ONNX 中:

  • ONNX 不支持 “动态切片 + flatten” 作为单一原始操作
  • 所以需要分解为多个步骤来实现:

1. Shape:先获取 x 的形状

2. Slice:抽取你需要的维度值(这里是 HW

3. Concat:拼接出新 shape,例如 [B, C, H*W]

4. Reshape:应用这个新 shape

这就是你看到的:

Shape -> Slice -> Slice -> Mul -> Concat -> Reshape

的由来。


为什么导出前后图不一样?

你有两个版本:

原始导出图:

  • 有上述所有细化节点(Slice/Shape/Reshape 等)
  • 这对于 动态输入尺寸 很重要,但会让图复杂

简化后的 ONNX(使用 onnxsim.simplify):

  • 会自动识别这部分是一个 flatten 动作
  • 用更简洁的方式重新表达(甚至直接用一个 Flatten 节点)

这是为什么你写了:

# onnx中其实会有一些constant value,以及不需要计算图跟踪的节点
# 大家可以一起从netron中看看这些节点都在干什么

平铺流程:flatten + avgpool + flatten + fc

你原始的网络有这几步转换:

步骤 输入维度 输出维度 说明
flatten(x, 2, 3) [B,C,H,W] [B,C,L] H × W 展平为 L
AdaptiveAvgPool1d(1) [B,C,L] [B,C,1] 类似全局平均池化
flatten(x, 1) [B,C,1] [B,C] 去掉最后一维
Linear [B,C] [B,10] 最终全连接层分类

建议你动手做以下实验理解更深:

  1. 注释掉 onnxsim.simplify(),用 Netron 打开 .onnx 文件,看看 flatten 变成了哪些低层操作?
  2. 然后再运行一次 simplify,看看有没有把它们合并成一个 Flatten 或更简洁的结构?
  3. torch.flatten(x, 2, 3) 换成 .view(b, c, -1).reshape(...),看看导出的结构是否更简洁?

总结重点

内容
flatten 操作为什么变复杂? 因为 ONNX 中 flatten 只支持从第 1 个维度开始,如果你指定的是 2~3,会生成 shape/slice/reshape
onnxsim.simplify 作用? 自动识别复杂逻辑并简化(合并 slice、reshape 等)
推荐做法? 导出前先理解动态维度怎么计算,导出后建议简化以减小模型体积、提升兼容性
哪些操作最容易生成冗余图? flatten、transpose、reshape、permute、expand 等涉及动态 shape 的操作

随机推荐

  1. Java HashMap和ConcurrentHashMap知识点梳理

    jdk 8 HashMap 扩容之后旧元素存放位置是? java 在扩容的时候会创建一个新的 Node<K,V>[],用于存放扩容之后的值,并将旧的Node数组(其大小记作n)置空:至于旧 ...

  2. 大数据存储计算平台EasyMR:多集群统一管理助力企业高效运维

    随着全球企业进入数字化转型的快车道,数据已成为企业运营.决策和增长的核心驱动力.为了处理海量数据,同时应对数据处理的复杂性和确保系统的高可用性,企业往往选择部署多个Hadoop集群,这样的策略可以将生 ...

  3. oracle 存储过程 for loop 定时任务

    记录. 是这么个事,要实现一个需求,当人员表里的数据有更新后需要告知其他系统更新他们自己的人员数据. 我想了一下,表里是有时间戳字段的,那我只要监听这个时间就行,拿到数据后用存储过程把数据插入到中间表 ...

  4. 2023 syzx 春季训练 1

    得找个时间把 zr 题补补.. A 考虑 \(f_{i}\) 只能拆为 \(f_{i-1}+f_{i-2}\),考虑拆 \(f_{i-1}=f_{i-2}+f_{i-3}\) 时,这条 \(f_{i- ...

  5. AWTK项目编译问题整理(1)

    三方库组织 公司的项目初步三方库路径组织是这样,awtk-widget开头的是awtk的自定义控件,无源码的二进制库放在sourceless这个文件夹: ./3rd   ├── awtk-widget ...

  6. Winform 工具栏 ToolStripMenuItem下拉选择项选中对勾不居中

    问题描述 工具栏ToolStrip --> ToolStripDropDownButton --> ToolStripMenuItem Checked = true 选中后,前面的对勾不居 ...

  7. Asp.Net Core MVC 记住密码

    https://www.cnblogs.com/Hmd528/p/10695156.html if (lm.RememberMe)                    {               ...

  8. 卓岚物联小程序通过4G串口服务器采集供热站流量计数据

    1.概述 远程监控供热站流量计数据,可以随时随地查看流量计瞬时流量.累积流量,平均流量,运行时间等相关参数,可以远程修改流量计相关参数. 图一 4G串口服务器与流量计链接示意图 ZLAN8305作为4 ...

  9. MCP 核心架构解析

    引言 Model Context Protocol (MCP) 是一种为连接大型语言模型(LLM)应用而设计的通信协议,它建立在灵活.可扩展的架构基础上,旨在实现LLM应用程序与各类集成之间的无缝交互 ...

  10. macOS Monterey系统安装 CocoaPods详细教程

    更新ruby 系统默认的应该是老旧的 v2.6.10,我们要更新到3.x以上,不然可能会和其它较新插件(如3.x的gem)冲突. 安装rbenv brew install rbenv # 对于 mac ...