ResNet-50模型图像分类示例

概述

计算机视觉是当前深度学习研究最广泛、落地最成熟的技术领域,在手机拍照、智能安防、自动驾驶等场景有广泛应用。从2012年AlexNet在ImageNet比赛夺冠以来,深度学习深刻推动了计算机视觉领域的发展,当前最先进的计算机视觉算法几乎都是深度学习相关的。深度神经网络可以逐层提取图像特征,并保持局部不变性,被广泛应用于分类、检测、分割、检索、识别、提升、重建等视觉任务中。

本文结合图像分类任务,介绍MindSpore如何应用于计算机视觉场景。

图像分类

图像分类是最基础的计算机视觉应用,属于有监督学习类别。给定一张数字图像,判断图像所属的类别,如猫、狗、飞机、汽车等等。用函数来表示这个过程如下:

def classify(image):
   label = model(image)
   return label

选择合适的model是关键。这里的model一般指的是深度卷积神经网络,如AlexNet、VGG、GoogLeNet、ResNet等等。

MindSpore实现了典型的卷积神经网络,开发者可以参考model_zoo

MindSpore当前支持的图像分类网络包括:典型网络LeNet、AlexNet、ResNet。

任务描述及准备

图1:CIFAR-10数据集[1]

如图1所示,CIFAR-10数据集共包含10类、共60000张图片。其中,每类图片6000张,50000张是训练集,10000张是测试集。每张图片大小为32*32。

图像分类的训练指标通常是精度(Accuracy),即正确预测的样本数占总预测样本数的比值。

接下来介绍利用MindSpore解决图片分类任务,整体流程如下:

  1. 下载CIFAR-10数据集
  2. 数据加载和预处理
  3. 定义卷积神经网络,本例采用ResNet-50网络
  4. 定义损失函数和优化器
  5. 调用Model高阶API进行训练和保存模型文件
  6. 加载保存的模型进行推理

本例面向Ascend 910 AI处理器硬件平台,你可以在这里下载完整的样例代码:https://gitee.com/mindspore/docs/tree/r1.1/tutorials/tutorial_code/resnet

下面对任务流程中各个环节及代码关键片段进行解释说明。

下载CIFAR-10数据集

先从CIFAR-10数据集官网上下载CIFAR-10数据集。本例中采用binary格式的数据,Linux环境可以通过下面的命令下载:

wget https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz

接下来需要解压数据集,解压命令如下:

tar -zvxf cifar-10-binary.tar.gz

数据预加载和预处理

  1. 加载数据集

数据加载可以通过内置数据集格式Cifar10Dataset接口完成。

Cifar10Dataset,读取类型为随机读取,内置CIFAR-10数据集,包含图像和标签,图像格式默认为uint8,标签数据格式默认为uint32。更多说明请查看API中Cifar10Dataset接口说明。

数据加载代码如下,其中data_home为数据存储位置:

cifar_ds = ds.Cifar10Dataset(data_home)
  1. 数据增强

数据增强主要是对数据进行归一化和丰富数据样本数量。常见的数据增强方式包括裁剪、翻转、色彩变化等等。MindSpore通过调用map方法在图片上执行增强操作:

resize_height = 224
resize_width = 224
rescale = 1.0 / 255.0
shift = 0.0
 
# define map operations
random_crop_op = C.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
random_horizontal_op = C.RandomHorizontalFlip()
resize_op = C.Resize((resize_height, resize_width)) # interpolation default BILINEAR
rescale_op = C.Rescale(rescale, shift)
normalize_op = C.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
changeswap_op = C.HWC2CHW()
type_cast_op = C2.TypeCast(mstype.int32)
 
c_trans = []
if training:
    c_trans = [random_crop_op, random_horizontal_op]
c_trans += [resize_op, rescale_op, normalize_op, changeswap_op]
 
# apply map operations on images
cifar_ds = cifar_ds.map(operations=type_cast_op, input_columns="label")
cifar_ds = cifar_ds.map(operations=c_trans, input_columns="image")
  1. 数据混洗和批处理

最后通过数据混洗(shuffle)随机打乱数据的顺序,并按batch读取数据,进行模型训练:

# apply shuffle operations
cifar_ds = cifar_ds.shuffle(buffer_size=10)
 
# apply batch operations
cifar_ds = cifar_ds.batch(batch_size=args_opt.batch_size, drop_remainder=True)
 
# apply repeat operations
cifar_ds = cifar_ds.repeat(repeat_num)

定义卷积神经网络

卷积神经网络已经是图像分类任务的标准算法了。卷积神经网络采用分层的结构对图片进行特征提取,由一系列的网络层堆叠而成,比如卷积层、池化层、激活层等等。

ResNet通常是较好的选择。首先,它足够深,常见的有34层,50层,101层。通常层次越深,表征能力越强,分类准确率越高。其次,可学习,采用了残差结构,通过shortcut连接把低层直接跟高层相连,解决了反向传播过程中因为网络太深造成的梯度消失问题。此外,ResNet网络的性能很好,既表现为识别的准确率,也包括它本身模型的大小和参数量。

MindSpore Model Zoo中已经实现了ResNet模型,可以采用ResNet-50。调用方法如下:

network = resnet50(class_num=10)

定义损失函数和优化器

接下来需要定义损失函数(Loss)和优化器(Optimizer)。损失函数是深度学习的训练目标,也叫目标函数,可以理解为神经网络的输出(Logits)和标签(Labels)之间的距离,是一个标量数据。

常见的损失函数包括均方误差、L2损失、Hinge损失、交叉熵等等。图像分类应用通常采用交叉熵损失(CrossEntropy)。

优化器用于神经网络求解(训练)。由于神经网络参数规模庞大,无法直接求解,因而深度学习中采用随机梯度下降算法(SGD)及其改进算法进行求解。MindSpore封装了常见的优化器,如SGD、ADAM、Momemtum等等。本例采用Momentum优化器,通常需要设定两个参数,动量(moment)和权重衰减项(weight decay)。

MindSpore中定义损失函数和优化器的代码样例如下:

# loss function definition
ls = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
 
# optimization definition
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)

调用Model高阶API进行训练和保存模型文件

完成数据预处理、网络定义、损失函数和优化器定义之后,就可以进行模型训练了。模型训练包含两层迭代,数据集的多轮迭代(epoch)和一轮数据集内按分组(batch)大小进行的单步迭代。其中,单步迭代指的是按分组从数据集中抽取数据,输入到网络中计算得到损失函数,然后通过优化器计算和更新训练参数的梯度。

为了简化训练过程,MindSpore封装了Model高阶接口。用户输入网络、损失函数和优化器完成Model的初始化,然后调用train接口进行训练,train接口参数包括迭代次数(epoch)和数据集(dataset)。

模型保存是对训练参数进行持久化的过程。Model类中通过回调函数(callback)的方式进行模型保存,如下面代码所示。用户通过CheckpointConfig设置回调函数的参数,其中,save_checkpoint_steps指每经过固定的单步迭代次数保存一次模型,keep_checkpoint_max指最多保存的模型个数。

'''
network, loss, optimizer are defined before.
batch_num, epoch_size are training parameters.
'''
model = Model(net, loss_fn=ls, optimizer=opt, metrics={'acc'})
 
# CheckPoint CallBack definition
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=35)
ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10", directory="./", config=config_ck)
 
# LossMonitor is used to print loss value on screen
loss_cb = LossMonitor()
model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])

加载保存的模型,并进行验证

训练得到的模型文件(如resnet.ckpt)可以用来预测新图像的类别。首先通过load_checkpoint加载模型文件。然后调用Model的eval接口预测新图像类别。

param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
eval_dataset = create_dataset(training=False)
res = model.eval(eval_dataset)
print("result: ", res)

参考文献

[1] https://www.cs.toronto.edu/~kriz/cifar.html

ResNet-50模型图像分类示例的更多相关文章

  1. Windows Phone 8初学者开发—第12部分:改进视图模型和示例数据

    原文 Windows Phone 8初学者开发—第12部分:改进视图模型和示例数据 第12部分:改进视图模型和示例数据 原文地址:http://channel9.msdn.com/Series/Win ...

  2. SharePoint 2010 -- Silverlight托管客户端模型简单示例

    Silverlight托管客户端模型,是SharePoint2010推出的三种客户端模型".NET托管"."ECMAScript"."Sliverli ...

  3. SharePoint 2010 -- .Net托管客户端模型简单示例

    .Net托管客户端模型,是SharePoint2010推出的三种客户端模型".NET托管"."ECMAScript"."Sliverlight&quo ...

  4. UML和模式应用4:初始阶段(4)--需求制品之用例模型模板示例

    1. 前言 UP开发包括四个阶段:初始阶段.细化阶段.构建阶段.移交阶段: UP每个阶段包括 业务建模.需求.设计等科目: 其中需求科目对应的需求制品包括:设想.业务规则.用例模型.补充性规格说明.词 ...

  5. UDP通讯模型简单示例

    1. UDP通讯模型 2. 服务器端 ① 创建一个socket,用函数socket() ② 绑定IP地址.端口等信息到socket上,用函数bind() ③ 循环接收数据,用函数recvfrom() ...

  6. TCP通讯模型简单示例

    1. TCP通讯模型 2. 服务器端 ① 创建socket,用函数socket() ② 绑定IP地址.端口号等信息到socket上,用函数bind() ③ 设置允许的最大连接数,用函数listen() ...

  7. MapReduce 编程模型 & WordCount 示例

    学习大数据接触到的第一个编程思想 MapReduce.   前言 之前在学习大数据的时候,很多东西很零散的做了一些笔记,但是都没有好好去整理它们,这篇文章也是对之前的笔记的整理,或者叫输出吧.一来是加 ...

  8. SharePoint2010 -- ECMAScript客户端模型简单示例

    ECMAScript客户端模型,是SharePoint2010推出的三种客户端模型".NET托管"."ECMAScript"."Sliverlight ...

  9. c++ 网络编程(十) LINUX/windows 异步通知I/O模型与重叠I/O模型 附带示例代码

    原文作者:aircraft 原文链接:https://www.cnblogs.com/DOMLX/p/9662931.html 一.异步IO模型(asynchronous IO) (1)什么是异步I/ ...

随机推荐

  1. 11- APP性能测试GT工具的使用

    对性能测试来说有服务端的性能与客户端(APP)的性能. GT简介 1.GT(随身调)是APP的随身调测平台,它是直接运行在手机上的"集成调试环境"(IDTE) 2.利用GT,仅凭一 ...

  2. WordPress 函数do_action()详解和应用举例

      do_action()函数: 我们经常能看到在一些WordPress函数中调用了do_action()函数,例如get_header(), get_footer()等调用模板的函数中经常调用do_ ...

  3. SpringBoot + Dubbo + Zookper 整合

    经过2个小时的调试终于弄完了,过程如下, 环境: JDK1.8 .Springboot2.2.6. Windows10系统 如果不看Dubbo 管理页面的话就不用下载 Dubbo-domain了,这个 ...

  4. POJ2406简单KMP

    题意:      给一个字符串,求最大的前缀循环周期,就是最小的循环节对应的最大的那个周期. 思路:      KMP的简单应用,求完next数组后有这样的应用:next[i] :是最大循环节的第几位 ...

  5. FCKeditor编辑器漏洞

    目录 FCKeditor asp网页 aspx网页 php网页 jsp网页 FCKeditor FCKeditor是一个功能强大支持所见即所得功能的文本编辑器,可以为用户提供微软office软件一样的 ...

  6. 重新封装了一下NODE-MONGO 使其成为一个独立的服务.可以直接通过get/post来操作

    # 重新封装了一下NODE-MONGO 使其成为一个独立的服务.可以直接通过get/post来操作 # consts.js 配置用的数据,用于全局参数配置 # log.js 自己写的一个简单的存储本地 ...

  7. Day003 变量、常量、作用域

    变量 变量:就是可以变化的量 Java是一种强类型语言,每个变量都必须声明其类型. Java变量是程序中最基本的存储单元,其要素包括变量名,变量类型和作用域 变量的定义 数据类型 变量名 = 值:可以 ...

  8. VMware 15 虚拟机黑屏问题

    方法一:关闭加速3D图形 点击虚拟机,右键设置,取消勾选后,再进行重启 方法二:用管理员运行cmd 输入如下命令,要使用管理员运行,然后重启电脑 netsh winsock reset 方法三:换成V ...

  9. Spring Boot集成sharding-jdbc实现分库分表

    一.水平分割 1.水平分库 1).概念:以字段为依据,按照一定策略,将一个库中的数据拆分到多个库中.2).结果每个库的结构都一样:数据都不一样:所有库的并集是全量数据: 2.水平分表 1).概念以字段 ...

  10. [源码解析] 并行分布式框架 Celery 之 容错机制

    [源码解析] 并行分布式框架 Celery 之 容错机制 目录 [源码解析] 并行分布式框架 Celery 之 容错机制 0x00 摘要 0x01 概述 1.1 错误种类 1.2 失败维度 1.3 应 ...