ONNXRuntime学习笔记(二)
继上一篇计划的实践项目,这篇记录我训练模型相关的工作。
- 首先要确定总体目标:训练一个pytorch模型,CIFAR-100数据集测试集acc达到90%;部署后推理效率达到50ms/张, 部署平台为window10+3050Ti+RX5800h.
- 训练模型的话,最好是有一套完备的代码,像谷歌的models,FB的detectron2,商汤的mm系列等等框架,这些是建立在深度学习框架tf或pth基础上的进一步封装,提供一些更高级的写好的模块可以调用,如Resnet、FPN、、proposal、NMS等等。但凡事都有两面,封装度越高意味着稳定性更好但修改的灵活性越差。只调用API对我们理解底层实现是不利的。之前我写过一个基于Pytorch的图像分类训练推理代码,现在又可以拿出来用一用了,地址:https://github.com/lee-zq/CNN-Backbone ,我在之前训练CIFAR-10的基础上又添加了CIFAR-100数据集的Dataloader创建代码。
- 首先,我尝试了CIFAR10+DenseNet,最后测试效果Acc=85%;然后尝试了CIFAR10+ResNet18,收敛较慢,但最终Acc=91.02%;基于此,。我尝试了CIFAR100+ResNet18,收敛很慢,大概到73Epoch稳定下来,但最终训练集Acc能达到90.62%,但测试集Acc为65.67%。大概率原因是模型拟合能力够用但是训练集多样性太差。模型结构如下:
ResNet(
(conv1): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(layer1): Sequential(
(0): ResidualBlock(
(left): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential()
)
(1): ResidualBlock(
(left): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential()
)
)
(layer2): Sequential(
(0): ResidualBlock(
(left): Sequential(
(0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): ResidualBlock(
(left): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential()
)
)
(layer3): Sequential(
(0): ResidualBlock(
(left): Sequential(
(0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): ResidualBlock(
(left): Sequential(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential()
)
)
(layer4): Sequential(
(0): ResidualBlock(
(left): Sequential(
(0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): ResidualBlock(
(left): Sequential(
(0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential()
)
)
(fc): Linear(in_features=512, out_features=10, bias=True)
)
Total number of parameters: 11173962
总参数量约11M,既然CIFAR-100效果太差,那就暂且还是用CIFAR-10做后面的训练测试吧,我又在之前的数据增强基础上加了RandomGrayscale和RandomAffine,最终的数据增强如下:
self.mean = [0.4914, 0.4822, 0.4465]
self.std = [0.2023, 0.1994, 0.2010]
self.num_workers= num_workers
self.transform_train = transforms.Compose([# 数据增强
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomGrayscale(0.15),
transforms.RandomAffine((-30,30)),
transforms.RandomRotation(20),
transforms.ToTensor(),
transforms.Normalize(self.mean, self.std),
transforms.RandomErasing(),
])
- 然后微调继续训练,测试集Acc进一步提升到92.28%,可见数据多样性的重要性。进一步的,torchvision提供了AutoAugment数据增强方法的接口,可以直接调用,最终数据增强代码如下:
self.mean = [0.4914, 0.4822, 0.4465]
self.std = [0.2023, 0.1994, 0.2010]
self.num_workers= num_workers
self.transform_train = transforms.Compose([# 数据增强
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.autoaugment.AutoAugment(policy=transforms.autoaugment.AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize(self.mean, self.std),
transforms.RandomErasing()
])
- 训练epoch数为80,优化器Adam,初始学习率0.01,每20epoch衰减,衰减因子gamma为0.1,目前还在训练ing,要花两个小时。完整重头训练估计要花4个小时,在之前的基础上微调会快很多,最终测试集Acc达到94.83%,达到预期。下一篇记录利用onnxruntime推理进行测试的过程。
ONNXRuntime学习笔记(二)的更多相关文章
- WPF的Binding学习笔记(二)
原文: http://www.cnblogs.com/pasoraku/archive/2012/10/25/2738428.htmlWPF的Binding学习笔记(二) 上次学了点点Binding的 ...
- AJax 学习笔记二(onreadystatechange的作用)
AJax 学习笔记二(onreadystatechange的作用) 当发送一个请求后,客户端无法确定什么时候会完成这个请求,所以需要用事件机制来捕获请求的状态XMLHttpRequest对象提供了on ...
- [Firefly引擎][学习笔记二][已完结]卡牌游戏开发模型的设计
源地址:http://bbs.9miao.com/thread-44603-1-1.html 在此补充一下Socket的验证机制:socket登陆验证.会采用session会话超时的机制做心跳接口验证 ...
- JMX学习笔记(二)-Notification
Notification通知,也可理解为消息,有通知,必然有发送通知的广播,JMX这里采用了一种订阅的方式,类似于观察者模式,注册一个观察者到广播里,当有通知时,广播通过调用观察者,逐一通知. 这里写 ...
- java之jvm学习笔记二(类装载器的体系结构)
java的class只在需要的时候才内转载入内存,并由java虚拟机的执行引擎来执行,而执行引擎从总的来说主要的执行方式分为四种, 第一种,一次性解释代码,也就是当字节码转载到内存后,每次需要都会重新 ...
- Java IO学习笔记二
Java IO学习笔记二 流的概念 在程序中所有的数据都是以流的方式进行传输或保存的,程序需要数据的时候要使用输入流读取数据,而当程序需要将一些数据保存起来的时候,就要使用输出流完成. 程序中的输入输 ...
- 《SQL必知必会》学习笔记二)
<SQL必知必会>学习笔记(二) 咱们接着上一篇的内容继续.这一篇主要回顾子查询,联合查询,复制表这三类内容. 上一部分基本上都是简单的Select查询,即从单个数据库表中检索数据的单条语 ...
- NumPy学习笔记 二
NumPy学习笔记 二 <NumPy学习笔记>系列将记录学习NumPy过程中的动手笔记,前期的参考书是<Python数据分析基础教程 NumPy学习指南>第二版.<数学分 ...
- Learning ROS for Robotics Programming Second Edition学习笔记(二) indigo tools
中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS for Robotics Pr ...
随机推荐
- SpringBoot和SpringCloud?
SpringBoot是Spring推出用于解决传统框架配置文件冗余,装配组件繁杂的基于Maven的解决方案,旨在快速搭建单个微服务而SpringCloud专注于解决各个微服务之间的协调与配置,服务之间 ...
- 解释内存中的栈(stack)、堆(heap)和方法区(method area) 的用法?
通常我们定义一个基本数据类型的变量,一个对象的引用,还有就是函数调用的 现场保存都使用 JVM 中的栈空间:而通过 new 关键字和构造器创建的对象则放在 堆空间,堆是垃圾收集器管理的主要区域,由于现 ...
- 1.0缓存:Login.aspx?
所有的manifest资源配置文件以CACHE MANIFEST声明开头. #(哈希标签)有助于提供缓存文件的版本. CACHE命令指定哪些文件需要被缓存. manifest资源配置文件的内容类型应是 ...
- idea推送项目到github
参考: https://blog.csdn.net/SoWhatWorld/article/details/103817786?depth_1-utm_source=distribute.pc_rel ...
- 二十三、原理图和PCB交互式布局
上图 在原理图里面直接选择在PCB里面就可以移动了,大功告成
- ZEGO音视频服务的高可用架构设计与运营
前言: ZEGO 即构科技作为一家实时音视频的提供商,系统稳定性直接影响用户的主观体验,如何保障服务高可用且用户体验最优是行业面临的挑战,本文结合实际业务场景进行思考,介绍 ZEGO 即构在高可用架构 ...
- 现在做 Web 全景合适吗?
Web 全景在以前带宽有限的条件下常常用来作为街景和 360° 全景图片的查看.它可以给用户一种 self-immersive 的体验,通过简单的操作,自由的查看周围的物体.随着一些运营商推出大王卡等 ...
- 我试试这个昵称好使不队项目NABCD指路
我试试这个昵称好使不队项目NABCD指路:https://www.cnblogs.com/team-development/p/14617203.html
- 微信jssdk分享(附代码)
老规矩---demo图: (注释:微信分享只支持右上角三个点触发) ======> 老规矩上菜: 1. 这边我封装了 share.js function allSharefun(param) ...
- python的字典及相关操作
一.什么是字典 字典是Python中最强大的数据类型之一,也是Python语言中唯一的映射类型.映射类型对象里哈希值(键,key)和指向的对象(值,value)是一对多的的关系,通常被认为是可变的哈希 ...