继上一篇计划的实践项目,这篇记录我训练模型相关的工作。

  • 首先要确定总体目标:训练一个pytorch模型,CIFAR-100数据集测试集acc达到90%;部署后推理效率达到50ms/张, 部署平台为window10+3050Ti+RX5800h.
  • 训练模型的话,最好是有一套完备的代码,像谷歌的models,FB的detectron2,商汤的mm系列等等框架,这些是建立在深度学习框架tf或pth基础上的进一步封装,提供一些更高级的写好的模块可以调用,如Resnet、FPN、、proposal、NMS等等。但凡事都有两面,封装度越高意味着稳定性更好但修改的灵活性越差。只调用API对我们理解底层实现是不利的。之前我写过一个基于Pytorch的图像分类训练推理代码,现在又可以拿出来用一用了,地址:https://github.com/lee-zq/CNN-Backbone ,我在之前训练CIFAR-10的基础上又添加了CIFAR-100数据集的Dataloader创建代码。
  1. 首先,我尝试了CIFAR10+DenseNet,最后测试效果Acc=85%;然后尝试了CIFAR10+ResNet18,收敛较慢,但最终Acc=91.02%;基于此,。我尝试了CIFAR100+ResNet18,收敛很慢,大概到73Epoch稳定下来,但最终训练集Acc能达到90.62%,但测试集Acc为65.67%。大概率原因是模型拟合能力够用但是训练集多样性太差。模型结构如下:
ResNet(
(conv1): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(layer1): Sequential(
(0): ResidualBlock(
(left): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential()
)
(1): ResidualBlock(
(left): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential()
)
)
(layer2): Sequential(
(0): ResidualBlock(
(left): Sequential(
(0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): ResidualBlock(
(left): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential()
)
)
(layer3): Sequential(
(0): ResidualBlock(
(left): Sequential(
(0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): ResidualBlock(
(left): Sequential(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential()
)
)
(layer4): Sequential(
(0): ResidualBlock(
(left): Sequential(
(0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): ResidualBlock(
(left): Sequential(
(0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential()
)
)
(fc): Linear(in_features=512, out_features=10, bias=True)
)
Total number of parameters: 11173962

总参数量约11M,既然CIFAR-100效果太差,那就暂且还是用CIFAR-10做后面的训练测试吧,我又在之前的数据增强基础上加了RandomGrayscale和RandomAffine,最终的数据增强如下:

        self.mean = [0.4914, 0.4822, 0.4465]
self.std = [0.2023, 0.1994, 0.2010]
self.num_workers= num_workers
self.transform_train = transforms.Compose([# 数据增强
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomGrayscale(0.15),
transforms.RandomAffine((-30,30)),
transforms.RandomRotation(20),
transforms.ToTensor(),
transforms.Normalize(self.mean, self.std),
transforms.RandomErasing(),
])
  1. 然后微调继续训练,测试集Acc进一步提升到92.28%,可见数据多样性的重要性。进一步的,torchvision提供了AutoAugment数据增强方法的接口,可以直接调用,最终数据增强代码如下:
        self.mean = [0.4914, 0.4822, 0.4465]
self.std = [0.2023, 0.1994, 0.2010]
self.num_workers= num_workers
self.transform_train = transforms.Compose([# 数据增强
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.autoaugment.AutoAugment(policy=transforms.autoaugment.AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize(self.mean, self.std),
transforms.RandomErasing()
])
  1. 训练epoch数为80,优化器Adam,初始学习率0.01,每20epoch衰减,衰减因子gamma为0.1,目前还在训练ing,要花两个小时。完整重头训练估计要花4个小时,在之前的基础上微调会快很多,最终测试集Acc达到94.83%,达到预期。下一篇记录利用onnxruntime推理进行测试的过程。

ONNXRuntime学习笔记(二)的更多相关文章

  1. WPF的Binding学习笔记(二)

    原文: http://www.cnblogs.com/pasoraku/archive/2012/10/25/2738428.htmlWPF的Binding学习笔记(二) 上次学了点点Binding的 ...

  2. AJax 学习笔记二(onreadystatechange的作用)

    AJax 学习笔记二(onreadystatechange的作用) 当发送一个请求后,客户端无法确定什么时候会完成这个请求,所以需要用事件机制来捕获请求的状态XMLHttpRequest对象提供了on ...

  3. [Firefly引擎][学习笔记二][已完结]卡牌游戏开发模型的设计

    源地址:http://bbs.9miao.com/thread-44603-1-1.html 在此补充一下Socket的验证机制:socket登陆验证.会采用session会话超时的机制做心跳接口验证 ...

  4. JMX学习笔记(二)-Notification

    Notification通知,也可理解为消息,有通知,必然有发送通知的广播,JMX这里采用了一种订阅的方式,类似于观察者模式,注册一个观察者到广播里,当有通知时,广播通过调用观察者,逐一通知. 这里写 ...

  5. java之jvm学习笔记二(类装载器的体系结构)

    java的class只在需要的时候才内转载入内存,并由java虚拟机的执行引擎来执行,而执行引擎从总的来说主要的执行方式分为四种, 第一种,一次性解释代码,也就是当字节码转载到内存后,每次需要都会重新 ...

  6. Java IO学习笔记二

    Java IO学习笔记二 流的概念 在程序中所有的数据都是以流的方式进行传输或保存的,程序需要数据的时候要使用输入流读取数据,而当程序需要将一些数据保存起来的时候,就要使用输出流完成. 程序中的输入输 ...

  7. 《SQL必知必会》学习笔记二)

    <SQL必知必会>学习笔记(二) 咱们接着上一篇的内容继续.这一篇主要回顾子查询,联合查询,复制表这三类内容. 上一部分基本上都是简单的Select查询,即从单个数据库表中检索数据的单条语 ...

  8. NumPy学习笔记 二

    NumPy学习笔记 二 <NumPy学习笔记>系列将记录学习NumPy过程中的动手笔记,前期的参考书是<Python数据分析基础教程 NumPy学习指南>第二版.<数学分 ...

  9. Learning ROS for Robotics Programming Second Edition学习笔记(二) indigo tools

    中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS for Robotics Pr ...

随机推荐

  1. 为什么要用 Dubbo?

    随着服务化的进一步发展,服务越来越多,服务之间的调用和依赖关系也越来越 复杂,诞生了面向服务的架构体系(SOA), 也因此衍生出了一系列相应的技术,如对服务提供.服务调用.连接处理.通信 协议.序列化 ...

  2. spring-boot-learning-REST风格网站

    什么是REST风格: Representational State Transfer :表现层状态转换,实际上是一种风格.标准,约定 首先需要有资源才能表现, 所以第一个名词是" 资源&qu ...

  3. JPA、JTA、XA相关索引

    JPA和分布式事务简介:https://www.cnblogs.com/onlywujun/p/4784233.html JPA.JTA与JMS区别:https://www.cnblogs.com/y ...

  4. JVM原理与深度调优

    什么是jvm jvm是java虚拟机 运行在用户态.通过应用程序实现java代码跨平台.与平台无关.实际上是"一次编译,到处执行" 1.从微观来说编译出来的是字节码!去到哪个平台都 ...

  5. (HTML基础系列1)网页基本结构

    元素名 描述 header 标记头部区域内容 footer 标记脚部区域内容 section Web页面中独立的一块区域 article 独立的文章内容 aside 相关文章应用或应用(常用于侧边栏) ...

  6. 面试BAT,你凭什么说你掌握了CSS

    介绍 项目已经开源:https://github.com/nanhupatar... 欢迎PR 推荐 关注我们的公众号 display: none; 与 visibility: hidden; 的区别 ...

  7. html5 canvas基础10点

    本文主要讲解下一些canvas的基础 1.<canvas id="canvas">若此浏览器不支持canvas会显示该文字</canvas> //创建个ht ...

  8. node-webkit文档翻译#package.json

    title: node-webkit文档翻译#package.json date: 2013-12-07 21:38:25 tags: node-webkit 基本示例 { "main&qu ...

  9. python-统计字符个数

    输入一个字符串,统计其中数字字符及小写字符的个数 输入格式: 输入一行字符串 输出格式: 共有?个数字,?个小写字符 输入样例: helo134ss12 输出样例: 共有5个数字,6个小写字符 代码: ...

  10. CentOS7 DHCP自动获取IP地址

    # 设置auto自动获取IP[root@localhost ~]# nmcli c modify eth0 ipv4.method auto //etho0 网卡名称     # 重新启动网卡并重新加 ...