基于minsit数据集的图像分类任务|CNN简单应用项目
Github地址
README
摘要
本次实验报告用两种方式完成了基于minst数据集完成了图像的分类任务
第一种方式采用课件所讲述的差值法对训练集里的每一张图片进行了预测,并最后得出总体的测试acc,由于只是简单采用差值法对图片进行预测,没有作其他的操作,因此acc只达到了16.8%
第二种方式采用了深度学习,2d卷积神经网络的方式进行图像分类。acc达到了0.98267
实验内容及目的
实验内容为通过差值法和2dCNN的方法,对每一类1000张,共10类的minsit数据集作分类任务。实验目的是为:掌握Matlab图片导入、分析和操作的方式。
实验相关原理描述
差值法的原理:
本质上是矩阵之间的相似性,相似度最大的即为预测类别,其中公式如所示
该模型不需要进行训练,在测试的时候,每一张图片所对应的矩阵与哪一个带有标签的矩阵相似度最大即可,则该图片的预测结果即位对应的标签值。
2dCNN法的原理:
通过2d卷积神经网络,让学习器学习minst数据集图像特征,最后根据学习到的特征进行分类的预测。
CNN网络模型图如图所示:

实验过程
差值法
在数据处理方面,在使用差值法进行图像分类时,我利用了Python语言对所有图片进行打标签工作,以便于Matlab程序读取每一张图片及其相对应的标签,代码如下:
import pandas as pd
import os
import numpy as np
csv_path = r'/Users/demac/我的文件/SYSU/4. 2022第二学期/图像处理实验/实验一/图像实验一/number_recognize/Data.csv'
df = pd.read_csv(csv_path)
arr = np.array(df)
arr = arr.tolist()
f = open(r'/Users/demac/我的文件/SYSU/4. 2022第二学期/图像处理实验/实验一/图像实验一/number_recognize/Data.txt', 'w')
for cur_label in range(0, 10):
for root, dirs, files in os.walk(fr"/Users/demac/我的文件/SYSU/4. 2022第二学期/图像处理实验/实验一/图像实验一/number_recognize/train_dataset/{cur_label}"):
for file in files:
# 获取文件路径
path = os.path.join(root, file)
# f.write(path + '\n')
new_row = [path, cur_label]
arr.append(new_row)
arr = np.array(arr)
df = pd.DataFrame(arr)
df.to_csv(csv_path)
print()
分类代码如下:
clear all;
opts = delimitedTextImportOptions("NumVariables", 3);
% 指定范围和分隔符
opts.DataLines = [2, Inf];
opts.Delimiter = ",";
% 指定列名称和类型
opts.VariableNames = ["VarName1", "VarName2", "VarName3"];
opts.VariableTypes = ["double", "string", "double"];
% 指定文件级属性
opts.ExtraColumnsRule = "ignore";
opts.EmptyLineRule = "read";
% 指定变量属性
opts = setvaropts(opts, "VarName2", "WhitespaceRule", "preserve");
opts = setvaropts(opts, "VarName2", "EmptyFieldRule", "auto");
opts = setvaropts(opts, ["VarName1", "VarName3"], "ThousandsSeparator", ",");
% 导入数据
Data = readtable("Data.csv", opts)
% 数据预处理
clear opts
array = table2array(Data);
% 此时array里面存的就是路径和标签
% 在所有数据中取出9张作为label比对图片
img_model = {};
idx = 1;
for i = 1:9
img_model{i} = imread(array(idx,2));
idx = idx+1000;
end
% img_model已经处理好了
开始test
correct = 0;
for i = 1:length(array)
true = array(i,3);
img = imread(array(i,2));
pred_idx = -1;
min_err = 100000;
for j = 1:length(img_model)
error = count_err(img_model{j},img);
if(error < min_err)
min_err = error;
pred_idx = j - 1;
end
end
if pred_idx == -1
disp("error");
end
if pred_idx == str2num(true)
% 代表预测正确
correct = correct + 1;
end
end
disp(correct);
disp("最终的准确率为: " + num2str(correct/length(array)));
function error = count_err(img1,img2)
error = norm(double(img1) - double(img2));
end
2dCNN
clear all;
DatasetPath = fullfile(['/Users/demac/我的文件/SYSU/' ...
'4. 2022第二学期/图像处理实验/实验一/图像实验一/number_recognize/train_dataset/']);
imds = imageDatastore(DatasetPath, ...
'IncludeSubfolders',true, ...
'LabelSource','foldernames');
每个类别有1000张图片,取700张进行train,300张进行test
train_size = 700;
[imdsTrain,imdsValidation] = splitEachLabel(imds,train_size,'randomized');
% 定义神经网络的forward
inplane = [28,28,1]; % 图像输入大小
numClasses = 10; %10分类任务
layers = [
imageInputLayer(inplane)
convolution2dLayer(5,20) % 卷积层
batchNormalizationLayer % 归一层
reluLayer % 激活函数
fullyConnectedLayer(numClasses) % 全链接层
softmaxLayer
classificationLayer];
% train
options = trainingOptions("sgdm", ...
"MaxEpochs",5, ...
"ValidationData",imdsValidation, ...
"ValidationFrequency",30, ...
"Verbose",false, ...
"Plots",'training-progress'); % 最后输出训练过程的趋势
net = trainNetwork(imdsTrain,layers,options); % 构建网络
% test
Pred = classify(net,imdsValidation);
YValidation = imdsValidation.Labels;
acc = mean(Pred == YValidation);
disp("acc: " + num2str(acc));
实验结果
差值法
差值法分类最后的acc为:0.1676
2dCNN法
2dCNN法最后得到的acc为:0.98267
训练收敛过程如下图所示:
总结
通过两个处理方法的实验,我们发现,差值法并不能很好的完成minist数据集的10分类任务。与此同时,卷积神经网络是一种很好的分类方法,对于98.27%的准确率,我们还可以通过调整网络前向传播,如增加注意力机制等模块等方式继续提高分类的准确度。
附件
main.mlx 差值法分类任务代码源文件
main2.mlx 卷积神经网络分类任务代码源文件
main.pdf和main2.pdf 实时脚本输出pdf文件
label.py 打标签Python源文件
final.jpg 神经网络模型结构图
train.png 神经网络训练过程图
基于minsit数据集的图像分类任务|CNN简单应用项目的更多相关文章
- 基于MNIST数据的卷积神经网络CNN
基于tensorflow使用CNN识别MNIST 参数数量:第一个卷积层5x5x1x32=800个参数,第二个卷积层5x5x32x64=51200个参数,第三个全连接层7x7x64x1024=3211 ...
- MNIST数据集上卷积神经网络的简单实现(使用PyTorch)
设计的CNN模型包括一个输入层,输入的是MNIST数据集中28*28*1的灰度图 两个卷积层, 第一层卷积层使用6个3*3的kernel进行filter,步长为1,填充1.这样得到的尺寸是(28+1* ...
- 【实践】如何利用tensorflow的object_detection api开源框架训练基于自己数据集的模型(Windows10系统)
如何利用tensorflow的object_detection api开源框架训练基于自己数据集的模型(Windows10系统) 一.环境配置 1. Python3.7.x(注:我用的是3.7.3.安 ...
- 机器学习算法(二): 基于鸢尾花数据集的朴素贝叶斯(Naive Bayes)预测分类
机器学习算法(二): 基于鸢尾花数据集的朴素贝叶斯(Naive Bayes)预测分类 项目链接参考:https://www.heywhale.com/home/column/64141d6b1c8c8 ...
- 器学习算法(六)基于天气数据集的XGBoost分类预测
1.机器学习算法(六)基于天气数据集的XGBoost分类预测 1.1 XGBoost的介绍与应用 XGBoost是2016年由华盛顿大学陈天奇老师带领开发的一个可扩展机器学习系统.严格意义上讲XGBo ...
- 基于MNIST数据集使用TensorFlow训练一个没有隐含层的浅层神经网络
基础 在参考①中我们详细介绍了没有隐含层的神经网络结构,该神经网络只有输入层和输出层,并且输入层和输出层是通过全连接方式进行连接的.具体结构如下: 我们用此网络结构基于MNIST数据集(参考②)进行训 ...
- 基于Xilinx Zynq Z7045 SoC的CNN的视觉识别应用
基于Xilinx Zynq Z7045 SoC的CNN的视觉识别应用 由 judyzhong 于 星期三, 08/16/2017 - 14:56 发表 作者:stark 近些年来随着科学技术的不断进步 ...
- SpringBoot整合Shiro实现基于角色的权限访问控制(RBAC)系统简单设计从零搭建
SpringBoot整合Shiro实现基于角色的权限访问控制(RBAC)系统简单设计从零搭建 技术栈 : SpringBoot + shiro + jpa + freemark ,因为篇幅原因,这里只 ...
- 一个基于 .NET Core 2.0 开发的简单易用的快速开发框架 - LinFx
LinFx 一个基于 .NET Core 2.0 开发的简单易用的快速开发框架,遵循领域驱动设计(DDD)规范约束,提供实现事件驱动.事件回溯.响应式等特性的基础设施.让开发者享受到正真意义的面向对象 ...
- 基于C++11的100行实现简单线程池
基于C++11的100行实现简单线程池 1 线程池原理 线程池是一种多线程处理形式,处理过程中将任务添加到队列,然后在创建线程后自动启动这些任务.线程池线程都是后台线程.每个线程都使用默认的堆栈大小, ...
随机推荐
- Codeforces Round #739 (Div. 3) 个人题解(A~F2)
比赛链接:Here 1560A. Dislike of Threes Description 找出第 $k$ 大的不可被 $3$ 整除以及非 $3$ 结尾的整数 直接枚举出前 1000 个符合条件的数 ...
- Codeforces Round #727 (Div. 2) A~D题题解
比赛链接:Here 1539A. Contest Start 让我们找出哪些参与者会干扰参与者i.这些是数字在 \(i+1\) 和 \(i+min(t/x,n)\)之间的参与者.所以第一个最大值 \( ...
- 编写Java代码时应该避免的6个坑
通常情况下,我们都希望我们的代码是高效和兼容的,但是实际情况下代码中常常含有一些隐藏的坑,只有等出现异常时我们才会去解决它.本文是一篇比较简短的文章,列出了开发人员在编写 Java 程序时常犯的错误, ...
- Liunx常用操作(11)-VI编辑器-末行模式命令
vI编辑器三种模式 分别为命令模式.输入模式.末行模式.
- Go 疑难杂症汇总
1. revision v0.0.0: unknown revision v0.0.0 go get -u github.com/uudashr/gopkgs/cmd/gopkgs 报错: [root ...
- The project description file (.project) for XXX is missing
在STS中切换项目分支的时候,出现一个项目打不开了,提示:The project description file (.project) for XXX is missing 试了下网上的方法都没有解 ...
- 文心一言 VS 讯飞星火 VS chatgpt (185)-- 算法导论14.1 2题
二.用go语言,对于图 14-1中的红黑树 T 和关键字 x.key 为35的结点x,说明执行 OS-RANK(T,x) 的过程. 文心一言: 在红黑树中,OS-RANK(T, x) 是一个操作,用于 ...
- Go-单元测试-Test
单元测试 文件名以 _test.go 结尾 函数名以 Test 开头 函数参数固定 t *testing.T 运行单元测试 go test Demo 源文件 package unit import & ...
- [转帖]聊聊redis的slowlog与latency monitor
https://www.jianshu.com/p/95a9ce63ddb2 序 本文主要研究一下redis的slowlog与latency monitor slowlog redis在2.2.12版 ...
- [转帖]Linux内存之Cache
一. Linux内存之Cache 1.1.Cache 1.1.1.什么是Cache? Cache存储器,是位于CPU和主存储器DRAM之间的一块高速缓冲存储器,规模较小,但是速度很快,通常由SRAM( ...