ResNet网络的训练和预测

简介 Introduction

图像分类与CNN

图像分类 是指将图像信息中所反映的不同特征,把不同类别的目标区分开来的图像处理方法,是计算机视觉中其他任务,比如目标检测、语义分割、人脸识别等高层视觉任务的基础。

ImageNet 大规模视觉识别挑战赛(ILSVRC),常称为 ImageNet 竞赛,包括图像分类、物体定位,以及物体检测等任务,推动计算机视觉领域发展最重要的比赛之一。

在2012年的 ImageNet 竞赛中,深度卷积网络 AlexNet 横空出世。以超出第二名10%以上的top-5准确率,勇夺 ImageNet2012 比赛的冠军。从此,以 CNN(卷积神经网络) 为代表的深度学习方法开始在计算机视觉领域的应用开始大放异彩,更多的更深的CNN网络被提出,比如 ImageNet2014 比赛的冠军 VGGNet, ImageNet2015 比赛的冠军 ResNet。

ResNet

ResNet 是2015年ImageNet竞赛的冠军。目前,ResNet 相对对于传统的机器学习分类算法而言,效果已经相当的出色,之后大量的检测,分割,识别等任务也都在 ResNet 基础上完成。

OneFlow-Benchmark 仓库中,提供 ResNet50 v1.5 的 OneFlow 实现。在 ImageNet-2012 数据集上训练90轮后,验证集上的准确率能够达到:77.318%(top1),93.622%(top5)。

更详细的网络参数对齐工作,见 OneFlow-Benchmark的cnns 部分

关于 ResNet50 v1.5 的说明:

ResNet50 v1.5 是原始 ResNet50 v1 的一个改进版本,相对于原始的模型,精度稍有提升 (~0.5% top1) 。

本文就以上面的 ResNet50 为例,一步步展现如何使用 OneFlow 进行 ResNet50 网络的训练和预测。

主要内容包括:

  • 准备工作
  • 项目安装和准备工作
  • 快速开始
  • 预测/推理
  • 训练和验证
  • 评估
  • 更详细的说明
  • 分布式训练
  • 混合精度训练与预测
  • 进阶
  • 参数对齐
  • 数据集制作(ImageNet2012)
  • OneFlow 模型转 ONNX 模型

准备工作 Requirements

别担心,使用 OneFlow 非常容易,只要准备好下面三步,即可开始 OneFlow 的图像识别之旅。

git clone git@github.com:Oneflow-Inc/OneFlow-Benchmark.git

cd OneFlow-Benchmark/Classification/cnns

  • 准备数据集(可选)
  • 直接使用 synthetic 虚拟合成数据集
  • 下载制作的 Imagenet(2012) 迷你数据集 解压放入data目录
  • 或者:制作完整 OFRecord 格式的 ImageNet 数据集(见下文进阶部分)

提供了通用脚本:train.sh 和 inference.sh,它们适用于此仓库下所有cnn网络模型的训练、验证、推理。可以通过设置参数使用不同的模型、数据集来训练/推理。

关于模型的说明:

默认情况下,使用resnet50,也可以通过改动脚本中的--model参数指定其他模型,如:--model="resnet50",--model="vgg" 等。

关于数据集的说明:

1)为了使读者快速上手,提供了 synthetic 虚拟合成数据,“合成数据”是指不通过磁盘加载数据,而是直接在内存中生成一些随机数据,作为神经网络的数据输入源。

2)同时,提供了一个小的迷你示例数据集。直接下载解压至 cnn 项目的 data 目录,即可快速开始训练。读者可以在熟悉了流程后,参考数据集制作部分,制作完整的 Imagenet2012 数据集。

3)使用 OFRcord 格式的数据集可以提高数据加载效率(但这非必须,参考数据输入,OneFlow 支持直接加载 numpy 数据)。

快速开始 Quick Start

开始 OneFlow 的图像识别之旅吧!

首先,切换到目录:

cd OneFlow-Benchmark/Classification/cnns

预训练模型

resnet50

resnet50_v1.5_model (validation accuracy: 77.318% top1,93.622% top5 )

预测/推理

下载好预训练模型后,解压后放入当前目录,然后执行:

sh inference.sh

此脚本将调用模型对这张金鱼图片进行分类:

若输出下面的内容,则表示预测成功:

data/fish.jpg

0.87059885 goldfish, Carassius auratus

可见,模型判断这张图片有87.05%的概率是金鱼 goldfish。

训练和验证(Train & Validation)

  • 训练同样很简单,只需执行:

sh train.sh

即可开始模型的训练,将看到如下输出:

Loading synthetic data.

Loading synthetic data.

Saving model to ./output/snapshots/model_save-20200723124215/snapshot_initial_model.

Init model on demand.

train: epoch 0, iter 10, loss: 7.197278, top_1: 0.000000, top_k: 0.000000, samples/s: 61.569

train: epoch 0, iter 20, loss: 6.177684, top_1: 0.000000, top_k: 0.000000, samples/s: 122.555

Saving model to ./output/snapshots/model_save-20200723124215/snapshot_epoch_0.

train: epoch 0, iter 30, loss: 3.988656, top_1: 0.525000, top_k: 0.812500, samples/s: 120.337

train: epoch 1, iter 10, loss: 1.185733, top_1: 1.000000, top_k: 1.000000, samples/s: 80.705

train: epoch 1, iter 20, loss: 1.042017, top_1: 1.000000, top_k: 1.000000, samples/s: 118.478

Saving model to ./output/snapshots/model_save-20200723124215/snapshot_epoch_1.

...

为了方便运行演示,默认使用synthetic虚拟合成数据集,使可以快速看到模型运行的效果

同样,你也可以使用迷你示例数据集,下载解压后放入 cnn 项目的 data 目录即可,然后修改训练脚本如下:

rm -rf core.*

rm -rf ./output/snapshots/*

DATA_ROOT=data/imagenet/ofrecord

python3 of_cnn_train_val.py \

--train_data_dir=$DATA_ROOT/train \

--num_examples=50 \

--train_data_part_num=1 \

--val_data_dir=$DATA_ROOT/validation \

--num_val_examples=50 \

--val_data_part_num=1 \

--num_nodes=1 \

--gpu_num_per_node=1 \

--model_update="momentum" \

--learning_rate=0.001 \

--loss_print_every_n_iter=1 \

--batch_size_per_device=16 \

--val_batch_size_per_device=10 \

--num_epoch=10 \

--model="resnet50"

运行此脚本,将在仅有50张金鱼图片的迷你 ImageNet 数据集上,训练出一个分类模型,可以对金鱼图片进行分类。

不要着急,如果需要在完整的 ImageNet2012 数据集上进行训练,请参考:OneFlow-Benchmark仓库。

评估(Evaluate)

你可以使用自己训练好的模型,或者提供的 resnet50_v1.5_model (解压后放入当前目录),对resnet50模型的精度进行评估。

只需运行:

sh evaluate.sh

即可获得训练好的模型在50000张验证集上的准确率:

Time stamp: 2020-07-27-09:28:28

Restoring model from resnet_v15_of_best_model_val_top1_77318.

I0727 09:28:28.773988162    8411 ev_epoll_linux.c:82]        Use of signals is disabled. Epoll engine will not be used

Loading data from /dataset/ImageNet/ofrecord/validation

validation: epoch 0, iter 195, top_1: 0.773277, top_k: 0.936058, samples/s: 1578.325

validation: epoch 0, iter 195, top_1: 0.773237, top_k: 0.936078, samples/s: 1692.303

validation: epoch 0, iter 195, top_1: 0.773297, top_k: 0.936018, samples/s: 1686.896

执行 sh evaluate.sh 前,确保准备了 ImageNet(2012) 的验证集,验证集制作方法请参考:OneFlow-Benchmark仓库。

从3轮的评估结果来看,模型在 ImageNet(2012) 上已经达到了77.32+%的 top1 精度。

最后,恭喜你!完成了 Resnet 模型在 ImageNet 上完整的训练/验证、推理和评估!

更详细的说明 Details

分布式训练

简单而易用的分布式,是 OneFlow 的主打特色之一。

OneFlow 框架从底层设计上,就原生支持高效的分布式训练。尤其对于分布式的数据并行,用户完全不用操心算法从单机单卡扩展到多机多卡时,数据如何划分以及同步的问题。也就是说,使用 OneFlow,用户以单机单卡的视角写好的代码,自动具备多机多卡分布式数据并行的能力。

如何配置并运行分布式训练?

还是以上面"快速开始"部分演示的代码为例,在 train.sh 中,只要用 --num_nodes 指定节点(机器)个数,同时用 --node_ips 指定节点的 IP 地址,然后用 --gpu_num_per_node 指定每个节点上使用的卡数,就轻松地完成了分布式的配置。

例如,想要在2机8卡上进行分布式训练,像下面这样配置:

# train.sh

python3 of_cnn_train_val.py \

--num_nodes=2 \

--node_ips="192.168.1.1, 192.168.1.2"

--gpu_num_per_node=4 \

...

--model="resnet50"

然后分别在两台机器上,同时执行:

./train.sh

程序启动后,通过 watch -n 0.1 nvidia-smi 命令可以看到,两台机器的 GPU 都开始了工作。一段时间后,会在 --node_ips 设置中的第一台机器的屏幕上,打印输出。

混合精度训练与预测

目前,OneFlow 已经原生支持 float16/float32 的混合精度训练。训练时,模型参数(权重)使用 float16 进行训练,同时保留 float32 用作梯度更新和计算过程。由于参数的存储减半,会带来训练速度的提升。

在 OneFlow 中开启 float16/float32 的混合精度训练模式,ResNet50 的训练速度理论上能达到1.7倍的加速。

如何开启 float16 / float32 混合精度训练?

只需要在 train.sh 脚本中添加参数 --use_fp16=True 即可。

混合精度模型

为提供了一个在 ImageNet2012 完整训练了90个 epoch 的混合精度模型,Top_1:77.33%

可以直接下载使用:resnet50_v15_fp16

进阶 Advanced

参数对齐

OneFlow 的 ResNet50 实现,为了保证和英伟达的 Mxnet 版实现对齐,从 learning rate 学习率,优化器 Optimizer 的选择,数据增强的图像参数设定,到更细的每一层网络的形态,bias,weight 初始化等都做了细致且几乎完全一致的对齐工作。具体的参数对齐工作,请参考:OneFlow-Benchmark 仓库

数据集制作

用于图像分类数据集简介

用于图像分类的公开数据集有CIFAR,ImageNet 等等,这些数据集中,以 jpeg 的格式提供原始的图片。

  • CIFAR 是由Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。包括CIFAR-10和CIFAR-100。
  • ImageNet ImageNet 数据集,一般是指2010-2017年间大规模视觉识别竞赛 (ILSVRC) 的所使用的数据集的统称。ImageNet 数据从2010年来稍有变化,常用 ImageNet-2012 数据集包含1000个类别,其中训练集包含1,281,167张图片,每个类别数据732至1300张不等,验证集包含50,000张图片,平均每个类别50张图片。

完整的 ImageNet(2012)制作过程,请参考 tools 目录下的README说明

OneFlow 模型转 ONNX 模型

简介

ONNX (Open Neural Network Exchange) 是一种较为广泛使用的神经网络中间格式,通过 ONNX 格式,OneFlow 模型可以被许多部署框架(如 OpenVINO、ONNX Runtime 和移动端的 ncnn、tnn、TEngine 等)所使用。这一节介绍如何将训练好的 ResNet50 v1.5 模型转换为 ONNX 模型并验证正确性。

快速上手

提供了完整代码:resnet_to_onnx.py 帮你轻松完成模型的转换和测试的工作

步骤一: 下载预训练模型:resnet50_v1.5_model ,解压后放入当前目录

步骤二: 执行:python3 resnet_to_onnx.py

此代码将完成 OneFlow 模型 -> ONNX 模型的转化,然后使用 ONNX Runtime 加载转换后的模型对单张图片进行测试。测试图片如下:

​ 图片来源:https://en.wikipedia.org/wiki/Tiger

输出:

Convert to onnx success! >>  onnx/model/resnet_v15_of_best_model_val_top1_77318.onnx

data/tiger.jpg

Are the results equal? Yes

Class: tiger, Panthera tigris; score: 0.8112028241157532

如何生成 ONNX 模型

上面的示例代码,介绍了如何转换 OneFlow 的 ResNet 模型至 ONNX 模型,并给出了一个利用 onnx runtime 进行预测的例子,同样,可以利用下面的步骤来完成自己训练的 ResNet 或其他模型的转换。

步骤一:将模型权重保存到本地

首先指定待转换的 OneFlow 模型路径,然后指定转换后的 ONNX 模型存放路径,例如示例中:

#set up your model path

flow_weights_path = 'resnet_v15_of_best_model_val_top1_77318'

onnx_model_dir = 'onnx/model'

步骤二:新建一个用于推理的 job function

然后新建一个用于推理的 job function,它只包含网络结构本身,不包含读取 OFRecord 的算子,并且直接接受 numpy 数组形式的输入。可参考 resnet\_to\_onnx.py 中的 InferenceNet。

步骤三:调用 flow.onnx.export方法

接下来代码中会调用 oneflow_to_onnx() 方法,此方法包含了核心的模型转换方法: flow.onnx.export()

flow.onnx.export 将从 OneFlow 网络得到 ONNX 模型,它的第一个参数是上文所说的专用于推理的 job function,第二个参数是 OneFlow 模型路径,第三个参数是(转换后)ONNX 模型的存放路径

onnx_model = oneflow_to_onnx(InferenceNet, flow_weights_path, onnx_model_dir, external_data=False)

验证 ONNX 模型的正确性

生成 ONNX 模型之后可以使用 ONNX Runtime 运行 ONNX 模型,验证 OneFlow 模型和 ONNX 模型能够在相同的输入下产生相同的结果。相应的代码在 resnet_to_onnx.py 的 check_equality。

ResNet网络的训练和预测的更多相关文章

  1. 学习笔记-ResNet网络

    ResNet网络 ResNet原理和实现 总结 一.ResNet原理和实现 神经网络第一次出现在1998年,当时用5层的全连接网络LetNet实现了手写数字识别,现在这个模型已经是神经网络界的“hel ...

  2. ResNet网络再剖析

    随着2018年秋季的到来,提前批和内推大军已经开始了,自己也成功得当了几次炮灰,不过在总结的过程中,越是了解到自己的不足,还是需要加油. 最近重新复习了resnet网络,又能发现一些新的理念,感觉很f ...

  3. tensorflow数据加载、模型训练及预测

    数据集 DNN 依赖于大量的数据.可以收集或生成数据,也可以使用可用的标准数据集.TensorFlow 支持三种主要的读取数据的方法,可以在不同的数据集中使用:本教程中用来训练建立模型的一些数据集介绍 ...

  4. 深度学习之ResNet网络

    介绍 Resnet分类网络是当前应用最为广泛的CNN特征提取网络. 我们的一般印象当中,深度学习愈是深(复杂,参数多)愈是有着更强的表达能力.凭着这一基本准则CNN分类网络自Alexnet的7层发展到 ...

  5. ResNet网络的Pytorch实现

    1.文章原文地址 Deep Residual Learning for  Image Recognition 2.文章摘要 神经网络的层次越深越难训练.我们提出了一个残差学习框架来简化网络的训练,这些 ...

  6. tensorflow 训练最后预测结果为一个定值,可能的原因

    训练一个分类网络,没想到预测结果为一个定值. 找了很久发现,是因为tensor的维度的原因.  注意:我说的是我的label数据的维度. 我的输入是: y_= tf.placeholder(tf.in ...

  7. PyTorch对ResNet网络的实现解析

    PyTorch对ResNet网络的实现解析 1.首先导入需要使用的包 import torch.nn as nn import torch.utils.model_zoo as model_zoo # ...

  8. 0609-搭建ResNet网络

    0609-搭建ResNet网络 目录 一.ResNet 网络概述 二.利用 torch 实现 ResNet34 网络 三.torchvision 中的 resnet34网络调用 四.第六章总结 pyt ...

  9. 机器学习使用sklearn进行模型训练、预测和评价

    cross_val_score(model_name, x_samples, y_labels, cv=k) 作用:验证某个模型在某个训练集上的稳定性,输出k个预测精度. K折交叉验证(k-fold) ...

随机推荐

  1. flex 的 三个参数 flex:1 0 auto

    flex :flex-group  flex-shirk  flex-basis ①.flex-group 剩余空间索取 默认值为0,不索取 eg:父元素400,子元素A为100px,B为200px. ...

  2. dedecms发布文章排序按发布时间,不是更新时间

    织梦DEDECMS5.7这个版本存在一个问题,修改文章的同时也修改了文章的发布时间,这种情况下,如果我们调用最新文章时使用按"发布日期排序",就会打乱原来我们网站上的文章的顺序. ...

  3. 从苏宁电器到卡巴斯基第12篇:我在苏宁电器当营业员 IV

    卖iPhone首先是需要接受培训的 像iPhone这样的重点产品,并不是只要选好了人(营业员),说卖就能卖的,在正式销售之前需要接受厂家的培训.如果说人事关系或者产品源隶属于苹果,那么是由苹果中国公司 ...

  4. JSONP跨域资源共享的安全问题

    目录 关于 JSONP 一.JSON 劫持 二.Callback 可定义导致的安全问题 三.其他文件格式( Content-Type )与 JSON 四.防御 摘自:http://blog.known ...

  5. CVE-2017-11882:Microsoft office 公式编辑器 font name 字段栈溢出通杀漏洞调试分析

    \x01 漏洞简介 在 2017 年 11 月微软的例行系统补丁发布中,修复了一个 Office 远程代码执行漏洞(缓冲区溢出),编号为 CVE-2017-11882,又称为 "噩梦公式&q ...

  6. [CTF]盲文对照表

    [CTF]盲文对照表 摘自:https://wenku.baidu.com/view/28b04fd380eb6294dd886ca7.html 学点盲文 盲文又称点字,国际通用的点字由6个凸起的圆点 ...

  7. 【vue-08】vuex

    vuex的作用 简单理解,就是将多个组件共享的变量统一放到一个地方去管理,比如用户登录时的数据token. 快速上手 安装:npm install vuex 首先,我们在src文件夹下创建一个文件夹: ...

  8. thymeleaf-extras-springsecurity在Spring或SpringBoot中html页面命名空间

    xmlns:sec="http://www.thymeleaf.org/extras/spring-security"

  9. 简单使用高德地图开放平台API

    需求说明 输入经纬度,得到城市名 挑选API 使用高德逆地理编码API,点击查看文档 demo <?php /** * 根据输入的经纬度返回城市名称 * @param $longitude 终点 ...

  10. Markdown编辑器怎么用

    Markdown编辑器怎么用 1.代码块 快速创建一个代码块 // 语法: // ```+语言名称,如```java,```c++ 2.标题 语法:#+空格+标题名字,一个#表示一级标题,两个#表示二 ...