C# TorchSharp 图像分类实战:VGG大规模图像识别的超深度卷积网络
教程名称:使用 C# 入门深度学习
作者:痴者工良
教程地址:
电子书仓库:https://github.com/whuanle/cs_pytorch
Maomi.Torch 项目仓库:https://github.com/whuanle/Maomi.Torch
图像分类 | VGG大规模图像识别的超深度卷积网络
本文主要讲解用于大规模图像识别的超深度卷积网络 VGG,通过 VGG 实现自有数据集进行图像分类训练模型和识别,VGG 有 vgg11、vgg11_bn、vgg13、vgg13_bn、vgg16、vgg16_bn、vgg19、vgg19_bn 等变种,VGG 架构的实现可参考论文:https://arxiv.org/abs/1409.1556
论文中文版地址:
数据集
本文主要使用经典图像分类数据集 CIFAR-10 进行训练,CIFAR-10 数据集中有 10 个分类,每个类别均有 60000 张图像,50000 张训练图像和 10000 张测试图像,每个图像都经过了预处理,生成 32x32 彩色图像。
CIFAR-10 的 10 个分类分别是:
airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck
下面给出几种数据集的本地化导入方式。
直接下载
由于 CIFAR-10 是经典数据集,因此 TorchSharp 默认支持下载该数据集,但是由于网络问题,国内下载数据库需要开飞机,数据集自动下载和导入:
// 加载训练和验证数据
var train_dataset = datasets.CIFAR10(root: "E:/datasets/CIFAR-10", train: true, download: true, target_transform: transform);
var val_dataset = datasets.CIFAR10(root: "E:/datasets/CIFAR-10", train: false, download: true, target_transform: transform);
opendatalab 数据集社区
opendatalab 是一个开源数据集社区仓库,里面有大量免费下载的数据集,借此机会给读者讲解一下如何从 opendatalab 下载数据集,这对读者学习非常有帮助。
CIFAR-10 数据集仓库地址:
https://opendatalab.com/OpenDataLab/CIFAR-10/cli/main
打开 https://opendatalab.com 注册账号,然后在个人信息中心添加密钥。
然后下载 openxlab 提供的 cli 工具:
pip install openxlab #安装
安装 openxlab 后,会要求添加路径到环境变量,环境变量地址是 Scripts 地址,示例:
C:\Users\%USER%\AppData\Roaming\Python\Python312\Scripts
接着进行登录,输入命令后按照提示输入 key 和 secret:
openxlab login # 进行登录,输入对应的AK/SK,可在个人中心查看AK/SK
然后打开空目录下载数据集,数据集仓库会被下载到 OpenDataLab___CIFAR-10
目录中:
openxlab dataset info --dataset-repo OpenDataLab/CIFAR-10 # 数据集信息及文件列表查看
openxlab dataset get --dataset-repo OpenDataLab/CIFAR-10 #数据集下载
数据集信息及文件列表查看
openxlab dataset info --dataset-repo OpenDataLab/CIFAR-10
下载的文件比较多,但是我们只需要用到 cifar-10-binary.tar.gz
,直接解压 cifar-10-binary.tar.gz
到目录中(也可以不解压)。
然后导入数据:
// 加载训练和验证数据
var train_dataset = datasets.CIFAR10(root: "E:/datasets/OpenDataLab___CIFAR-10", train: true, download: false, target_transform: transform);
var val_dataset = datasets.CIFAR10(root: "E:/datasets/OpenDataLab___CIFAR-10", train: false, download: false, target_transform: transform);
自定义数据集
Maomi.Torch 提供了自定义数据集导入方式,降低了开发者制作数据集的难度。自定义数据集也要区分训练数据集和测试数据集,训练数据集用于特征识别和训练,而测试数据集用于验证模型训练的准确率和损失值。
测试数据集和训练数据集可以放到不同的目录中,具体名称没有要求,然后每个分类单独一个目录,目录名称就是分类名称,按照目录名称的排序从 0 生成标签值。
├─test
│ ├─airplane
│ ├─automobile
│ ├─bird
│ ├─cat
│ ├─deer
│ ├─dog
│ ├─frog
│ ├─horse
│ ├─ship
│ └─truck
└─train
│ ├─airplane
│ ├─automobile
│ ├─bird
│ ├─cat
│ ├─deer
│ ├─dog
│ ├─frog
│ ├─horse
│ ├─ship
│ └─truck
读者可以参考 exportdataset
项目,将 CIFAR-10 数据集生成导出到目录中。
通过自定义目录导入数据集的代码为:
var train_dataset = MM.Datasets.ImageFolder(root: "E:/datasets/t1/train", target_transform: transform);
var val_dataset = MM.Datasets.ImageFolder(root: "E:/datasets/t1/test", target_transform: transform);
模型训练
定义图像预处理转换代码,代码如下所示:
Device defaultDevice = MM.GetOpTimalDevice();
torch.set_default_device(defaultDevice);
Console.WriteLine("当前正在使用 {defaultDevice}");
// 数据预处理
var transform = transforms.Compose([
transforms.Resize(32, 32),
transforms.ConvertImageDtype( ScalarType.Float32),
MM.transforms.ReshapeTransform(new long[]{ 1,3,32,32}),
transforms.Normalize(means: new double[] { 0.485, 0.456, 0.406 }, stdevs: new double[] { 0.229, 0.224, 0.225 }),
MM.transforms.ReshapeTransform(new long[]{ 3,32,32})
]);
因为 TorchSharp 对图像维度处理的兼容性不好,没有 Pytorch 的自动处理,因此导入的图片维度和批处理维度、transforms 处理的维度兼容性不好,容易报错,因此这里需要使用 Maomi.Torch 的转换函数,以便在导入图片和进行图像批处理的时候,保障 shape 符合要求。
分批加载数据集:
// 加载训练和验证数据
var train_dataset = datasets.CIFAR10(root: "E:/datasets/CIFAR-10", train: true, download: true, target_transform: transform);
var val_dataset = datasets.CIFAR10(root: "E:/datasets/CIFAR-10", train: false, download: true, target_transform: transform);
var train_loader = new DataLoader(train_dataset, batchSize: 1024, shuffle: true, device: defaultDevice, num_worker: 10);
var val_loader = new DataLoader(val_dataset, batchSize: 1024, shuffle: false, device: defaultDevice, num_worker: 10);
初始化 vgg16 网络:
var model = torchvision.models.vgg16(num_classes: 10);
model.to(device: defaultDevice);
设置损失函数和优化器:
var criterion = nn.CrossEntropyLoss();
var optimizer = optim.SGD(model.parameters(), learningRate: 0.001, momentum: 0.9);
训练模型并保存:
int num_epochs = 150;
for (int epoch = 0; epoch < num_epochs; epoch++)
{
model.train();
double running_loss = 0.0;
int i = 0;
foreach (var item in train_loader)
{
var (inputs, labels) = (item["data"], item["label"]);
var inputs_device = inputs.to(defaultDevice);
var labels_device = labels.to(defaultDevice);
optimizer.zero_grad();
var outputs = model.call(inputs_device);
var loss = criterion.call(outputs, labels_device);
loss.backward();
optimizer.step();
running_loss += loss.item<float>() * inputs.size(0);
Console.WriteLine($"[{epoch}/{num_epochs}][{i % train_loader.Count}/{train_loader.Count}]");
i++;
}
double epoch_loss = running_loss / train_dataset.Count;
Console.WriteLine($"Train Loss: {epoch_loss:F4}");
model.eval();
long correct = 0;
int total = 0;
using (torch.no_grad())
{
foreach (var item in val_loader)
{
var (inputs, labels) = (item["data"], item["label"]);
var inputs_device = inputs.to(defaultDevice);
var labels_device = labels.to(defaultDevice);
var outputs = model.call(inputs_device);
var predicted = outputs.argmax(1);
total += (int)labels.size(0);
correct += (predicted == labels_device).sum().item<long>();
}
}
double val_accuracy = 100.0 * correct / total;
Console.WriteLine($"Validation Accuracy: {val_accuracy:F2}%");
}
model.save("model.dat");
启动项目后可以直接执行训练,训练一百多轮后,准确率在 70% 左右,损失值在 0.0010
左右,继续训练已经提高不了准确率了。
导出的模型还是比较大的:
513M model.dat
下面来编写图像识别测试,在示例项目 vggdemo
中自带了三张图片,读者可以直接导入使用。
model.load("model.dat");
model.to(device: defaultDevice);
model.eval();
var classes = new string[] {
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck"
};
List<Tensor> imgs = new();
imgs.Add(transform.call(MM.LoadImage("airplane.jpg").to(defaultDevice)).view(1, 3, 32, 32));
imgs.Add(transform.call(MM.LoadImage("cat.jpg").to(defaultDevice)).view(1, 3, 32, 32));
imgs.Add(transform.call(MM.LoadImage("dog.jpg").to(defaultDevice)).view(1, 3, 32, 32));
using (torch.no_grad())
{
foreach (var data in imgs)
{
var outputs = model.call(data);
var index = outputs[0].argmax(0).ToInt32();
// 转换为归一化的概率
// outputs.shape = [1,10],所以取 [dim:1]
var array = torch.nn.functional.softmax(outputs, dim: 1);
var max = array[0].ToFloat32Array();
var predicted1 = classes[index];
Console.WriteLine($"识别结果 {predicted1},准确率:{max[index] * 100}%");
}
}
识别结果:
当前正在使用 cuda:0
识别结果 airplane,准确率:99.99983%
识别结果 cat,准确率:99.83113%
识别结果 dog,准确率:100%
用到的三张图片均从网络上搜索而来:
C# TorchSharp 图像分类实战:VGG大规模图像识别的超深度卷积网络的更多相关文章
- 全卷积网络Fully Convolutional Networks (FCN)实战
全卷积网络Fully Convolutional Networks (FCN)实战 使用图像中的每个像素进行类别预测的语义分割.全卷积网络(FCN)使用卷积神经网络将图像像素转换为像素类别.与之前介绍 ...
- 全卷积网络(FCN)实战:使用FCN实现语义分割
摘要:FCN对图像进行像素级的分类,从而解决了语义级别的图像分割问题. 本文分享自华为云社区<全卷积网络(FCN)实战:使用FCN实现语义分割>,作者: AI浩. FCN对图像进行像素级的 ...
- 【Android自己定义View实战】之自己定义超简单SearchView搜索框
[Android自己定义View实战]之自己定义超简单SearchView搜索框 这篇文章是对之前文章的翻新,至于为什么我要又一次改动这篇文章?原因例如以下 1.有人举报我抄袭,原文链接:http:/ ...
- GAN实战笔记——第四章深度卷积生成对抗网络(DCGAN)
深度卷积生成对抗网络(DCGAN) 我们在第3章实现了一个GAN,其生成器和判别器是具有单个隐藏层的简单前馈神经网络.尽管很简单,但GAN的生成器充分训练后得到的手写数字图像的真实性有些还是很具说服力 ...
- [caffe]深度学习之图像分类模型VGG解读
一.简单介绍 vgg和googlenet是2014年imagenet竞赛的双雄,这两类模型结构有一个共同特点是go deeper.跟googlenet不同的是.vgg继承了lenet以及alexnet ...
- 深度学习原理与框架-卷积网络细节-图像分类与图像位置回归任务 1.模型加载 2.串接新的全连接层 3.使用SGD梯度对参数更新 4.模型结果测试 5.各个模型效果对比
对于图像的目标检测任务:通常分为目标的类别检测和目标的位置检测 目标的类别检测使用的指标:准确率, 预测的结果是类别值,即cat 目标的位置检测使用的指标:欧式距离,预测的结果是(x, y, w, h ...
- 【Android实战】----从Retrofit源代码分析到Java网络编程以及HTTP权威指南想到的
一.简单介绍 接上一篇[Android实战]----基于Retrofit实现多图片/文件.图文上传中曾说非常想搞明确为什么Retrofit那么屌. 近期也看了一些其源代码分析的文章以及亲自查看了源代码 ...
- 经典卷积网络VGG,GoodLeNet,Inception
目录 ImageNet LeNet-5 LeNet-5 Demo AlexNet VGG 1*1 Convolution GoogLeNet Stack more layers? ImageNet L ...
- 深度学习原理与框架-卷积网络细节-经典网络架构 1.AlexNet 2.VGG
1.AlexNet是2012年最早的第一代神经网络,整个神经网络的构架是8层的网络结构.网络刚开始使用11*11获得较大的感受野,随后使用5*5和3*3做特征的提取,最后使用3个全连接层做得分值得运算 ...
- CV3——学习笔记-实战项目(上):如何搭建和训练一个深度学习网络
http://www.mooc.ai/course/353/learn?lessonid=2289&groupId=0#lesson/2289 1.AlexNet, VGGNet, Googl ...
随机推荐
- C#/.NET/.NET Core技术前沿周刊 | 第 15 期(2024年11.25-11.30)
前言 C#/.NET/.NET Core技术前沿周刊,你的每周技术指南针!记录.追踪C#/.NET/.NET Core领域.生态的每周最新.最实用.最有价值的技术文章.社区动态.优质项目和学习资源等. ...
- H5C3时钟实例(rem适配)
1.原理分析和效果图 先上效果图: 屏幕适配上使用rem适配,假设用户的手机屏幕最下宽度是375px,而谷歌浏览器最小的字体大小为12px,所以我以375px为标准尺寸进行rem适配,即375px的屏 ...
- HTML5 进度条
1. <progress>标签 进度条 value属性:规定进程的当前值.默认为0 max属性:规定需要完成的值. PS:这里没有最小值设置,或者说最小值一律为0 <progress ...
- EasyExcel => EasyExcel-Plus => FastExcel
目录 什么是 FastExcel 主要特性 适用场景 结论 导航 快速开始 EasyExcel 与 FastExcel 的区别 EasyExcel 如何升级到 FastExcel 1. 修改依赖 2. ...
- 构建你的.NET Aspire解决方案
.NET Aspire 是一组功能强大的工具.模板和包,用于构建可观察的生产就绪应用程序..NET Aspire 通过处理特定云原生问题的 NuGet 包集合提供.云原生应用程序通常由小型互连部分或微 ...
- ThreeJs-06详解灯光与阴影
一.gsap动画库 1.1 基本使用和原理 首先直接npm安装然后导入 比如让一个物体,x轴时间为5s 旋转同理 动画的速度曲线,可以在官网的文档找到 1.2 控制动画属性与方法 当然这里面也有一些方 ...
- 用 16G 内存存放 30亿数据(Java Map)转载
在讨论怎么去重,提出用 direct buffer 建 btree,想到应该有现成方案,于是找到一个好东西: MapDB - MapDB : http://www.mapdb.org/ 以下来自:ko ...
- 哪里有 class 告诉我?
说明 本文中的 JVM 参数和代码在 JDK 8 版本生效. 哪里有用户类? 用户类是由开发者和第三方定义的类,它是由应用程序类加载器加载的. Java 程序可以通过CLASSPATH 环境变量,JV ...
- [转]MySQL和MySQL驱动mysql-connector-java升级到8.0.X版本
原文链接:MySQL和MySQL驱动mysql-connector-java升级到8.0.X版本
- 基于开源IM即时通讯框架MobileIMSDK:RainbowChat-iOS端v6.1版已发布
关于MobileIMSDK MobileIMSDK 是一套专门为移动端开发的开源IM即时通讯框架,超轻量级.高度提炼,一套API优雅支持UDP .TCP .WebSocket 三种协议,支持iOS.A ...