基于minsit数据集的图像分类任务|CNN简单应用项目
Github地址
README
摘要
本次实验报告用两种方式完成了基于minst数据集完成了图像的分类任务
第一种方式采用课件所讲述的差值法对训练集里的每一张图片进行了预测,并最后得出总体的测试acc,由于只是简单采用差值法对图片进行预测,没有作其他的操作,因此acc只达到了16.8%
第二种方式采用了深度学习,2d卷积神经网络的方式进行图像分类。acc达到了0.98267
实验内容及目的
实验内容为通过差值法和2dCNN的方法,对每一类1000张,共10类的minsit数据集作分类任务。实验目的是为:掌握Matlab图片导入、分析和操作的方式。
实验相关原理描述
差值法的原理:
本质上是矩阵之间的相似性,相似度最大的即为预测类别,其中公式如所示
该模型不需要进行训练,在测试的时候,每一张图片所对应的矩阵与哪一个带有标签的矩阵相似度最大即可,则该图片的预测结果即位对应的标签值。
2dCNN法的原理:
通过2d卷积神经网络,让学习器学习minst数据集图像特征,最后根据学习到的特征进行分类的预测。
CNN网络模型图如图所示:

实验过程
差值法
在数据处理方面,在使用差值法进行图像分类时,我利用了Python语言对所有图片进行打标签工作,以便于Matlab程序读取每一张图片及其相对应的标签,代码如下:
import pandas as pd
import os
import numpy as np
csv_path = r'/Users/demac/我的文件/SYSU/4. 2022第二学期/图像处理实验/实验一/图像实验一/number_recognize/Data.csv'
df = pd.read_csv(csv_path)
arr = np.array(df)
arr = arr.tolist()
f = open(r'/Users/demac/我的文件/SYSU/4. 2022第二学期/图像处理实验/实验一/图像实验一/number_recognize/Data.txt', 'w')
for cur_label in range(0, 10):
for root, dirs, files in os.walk(fr"/Users/demac/我的文件/SYSU/4. 2022第二学期/图像处理实验/实验一/图像实验一/number_recognize/train_dataset/{cur_label}"):
for file in files:
# 获取文件路径
path = os.path.join(root, file)
# f.write(path + '\n')
new_row = [path, cur_label]
arr.append(new_row)
arr = np.array(arr)
df = pd.DataFrame(arr)
df.to_csv(csv_path)
print()
分类代码如下:
clear all;
opts = delimitedTextImportOptions("NumVariables", 3);
% 指定范围和分隔符
opts.DataLines = [2, Inf];
opts.Delimiter = ",";
% 指定列名称和类型
opts.VariableNames = ["VarName1", "VarName2", "VarName3"];
opts.VariableTypes = ["double", "string", "double"];
% 指定文件级属性
opts.ExtraColumnsRule = "ignore";
opts.EmptyLineRule = "read";
% 指定变量属性
opts = setvaropts(opts, "VarName2", "WhitespaceRule", "preserve");
opts = setvaropts(opts, "VarName2", "EmptyFieldRule", "auto");
opts = setvaropts(opts, ["VarName1", "VarName3"], "ThousandsSeparator", ",");
% 导入数据
Data = readtable("Data.csv", opts)
% 数据预处理
clear opts
array = table2array(Data);
% 此时array里面存的就是路径和标签
% 在所有数据中取出9张作为label比对图片
img_model = {};
idx = 1;
for i = 1:9
img_model{i} = imread(array(idx,2));
idx = idx+1000;
end
% img_model已经处理好了
开始test
correct = 0;
for i = 1:length(array)
true = array(i,3);
img = imread(array(i,2));
pred_idx = -1;
min_err = 100000;
for j = 1:length(img_model)
error = count_err(img_model{j},img);
if(error < min_err)
min_err = error;
pred_idx = j - 1;
end
end
if pred_idx == -1
disp("error");
end
if pred_idx == str2num(true)
% 代表预测正确
correct = correct + 1;
end
end
disp(correct);
disp("最终的准确率为: " + num2str(correct/length(array)));
function error = count_err(img1,img2)
error = norm(double(img1) - double(img2));
end
2dCNN
clear all;
DatasetPath = fullfile(['/Users/demac/我的文件/SYSU/' ...
'4. 2022第二学期/图像处理实验/实验一/图像实验一/number_recognize/train_dataset/']);
imds = imageDatastore(DatasetPath, ...
'IncludeSubfolders',true, ...
'LabelSource','foldernames');
每个类别有1000张图片,取700张进行train,300张进行test
train_size = 700;
[imdsTrain,imdsValidation] = splitEachLabel(imds,train_size,'randomized');
% 定义神经网络的forward
inplane = [28,28,1]; % 图像输入大小
numClasses = 10; %10分类任务
layers = [
imageInputLayer(inplane)
convolution2dLayer(5,20) % 卷积层
batchNormalizationLayer % 归一层
reluLayer % 激活函数
fullyConnectedLayer(numClasses) % 全链接层
softmaxLayer
classificationLayer];
% train
options = trainingOptions("sgdm", ...
"MaxEpochs",5, ...
"ValidationData",imdsValidation, ...
"ValidationFrequency",30, ...
"Verbose",false, ...
"Plots",'training-progress'); % 最后输出训练过程的趋势
net = trainNetwork(imdsTrain,layers,options); % 构建网络
% test
Pred = classify(net,imdsValidation);
YValidation = imdsValidation.Labels;
acc = mean(Pred == YValidation);
disp("acc: " + num2str(acc));
实验结果
差值法
差值法分类最后的acc为:0.1676
2dCNN法
2dCNN法最后得到的acc为:0.98267
训练收敛过程如下图所示:
总结
通过两个处理方法的实验,我们发现,差值法并不能很好的完成minist数据集的10分类任务。与此同时,卷积神经网络是一种很好的分类方法,对于98.27%的准确率,我们还可以通过调整网络前向传播,如增加注意力机制等模块等方式继续提高分类的准确度。
附件
main.mlx 差值法分类任务代码源文件
main2.mlx 卷积神经网络分类任务代码源文件
main.pdf和main2.pdf 实时脚本输出pdf文件
label.py 打标签Python源文件
final.jpg 神经网络模型结构图
train.png 神经网络训练过程图
基于minsit数据集的图像分类任务|CNN简单应用项目的更多相关文章
- 基于MNIST数据的卷积神经网络CNN
基于tensorflow使用CNN识别MNIST 参数数量:第一个卷积层5x5x1x32=800个参数,第二个卷积层5x5x32x64=51200个参数,第三个全连接层7x7x64x1024=3211 ...
- MNIST数据集上卷积神经网络的简单实现(使用PyTorch)
设计的CNN模型包括一个输入层,输入的是MNIST数据集中28*28*1的灰度图 两个卷积层, 第一层卷积层使用6个3*3的kernel进行filter,步长为1,填充1.这样得到的尺寸是(28+1* ...
- 【实践】如何利用tensorflow的object_detection api开源框架训练基于自己数据集的模型(Windows10系统)
如何利用tensorflow的object_detection api开源框架训练基于自己数据集的模型(Windows10系统) 一.环境配置 1. Python3.7.x(注:我用的是3.7.3.安 ...
- 机器学习算法(二): 基于鸢尾花数据集的朴素贝叶斯(Naive Bayes)预测分类
机器学习算法(二): 基于鸢尾花数据集的朴素贝叶斯(Naive Bayes)预测分类 项目链接参考:https://www.heywhale.com/home/column/64141d6b1c8c8 ...
- 器学习算法(六)基于天气数据集的XGBoost分类预测
1.机器学习算法(六)基于天气数据集的XGBoost分类预测 1.1 XGBoost的介绍与应用 XGBoost是2016年由华盛顿大学陈天奇老师带领开发的一个可扩展机器学习系统.严格意义上讲XGBo ...
- 基于MNIST数据集使用TensorFlow训练一个没有隐含层的浅层神经网络
基础 在参考①中我们详细介绍了没有隐含层的神经网络结构,该神经网络只有输入层和输出层,并且输入层和输出层是通过全连接方式进行连接的.具体结构如下: 我们用此网络结构基于MNIST数据集(参考②)进行训 ...
- 基于Xilinx Zynq Z7045 SoC的CNN的视觉识别应用
基于Xilinx Zynq Z7045 SoC的CNN的视觉识别应用 由 judyzhong 于 星期三, 08/16/2017 - 14:56 发表 作者:stark 近些年来随着科学技术的不断进步 ...
- SpringBoot整合Shiro实现基于角色的权限访问控制(RBAC)系统简单设计从零搭建
SpringBoot整合Shiro实现基于角色的权限访问控制(RBAC)系统简单设计从零搭建 技术栈 : SpringBoot + shiro + jpa + freemark ,因为篇幅原因,这里只 ...
- 一个基于 .NET Core 2.0 开发的简单易用的快速开发框架 - LinFx
LinFx 一个基于 .NET Core 2.0 开发的简单易用的快速开发框架,遵循领域驱动设计(DDD)规范约束,提供实现事件驱动.事件回溯.响应式等特性的基础设施.让开发者享受到正真意义的面向对象 ...
- 基于C++11的100行实现简单线程池
基于C++11的100行实现简单线程池 1 线程池原理 线程池是一种多线程处理形式,处理过程中将任务添加到队列,然后在创建线程后自动启动这些任务.线程池线程都是后台线程.每个线程都使用默认的堆栈大小, ...
随机推荐
- Python | BitMap算法及其实现
BitMap概述 本文介绍 BitMap 算法的应用背景,算法思想和相关实现细节. 概括而言,BitMap 主要用来解决海量数据中元素查询,去重.以及排序等问题.这里对海量数据场景的强调,似乎暗示了这 ...
- Codeforces Round #721 (Div. 2) AB思维,B2博弈,C题map
补题链接:Here 1527A. And Then There Were K 题目大意: 给一个正整数n,求最大的k,使得 \(n \& (n−1) \& (n−2) \& ( ...
- L2-012 关于堆的判断 (25分) (字符串處理)
将一系列给定数字顺序插入一个初始为空的小顶堆H[].随后判断一系列相关命题是否为真.命题分下列几种: x is the root:x是根结点: x and y are siblings:x和y是兄弟结 ...
- spring管理实务有几种方式
一:事务认识 大家所了解的事务Transaction,它是一些列严密操作动作,要么都操作完成,要么都回滚撤销.Spring事务管理基于底层数据库本身的事务处理机制.数据库事务的基础,是掌握Spring ...
- vue-asome-swiper
- d3条形图案例
- Mysql 开启慢日志查询及查看慢日志 sql
本文为博主原创,转载请注明出处: 目录: 1.Mysql 开启慢日志配置的查询 2. 通过sql 设置Mysql 的慢日志开启 3. 通过慢 sql 日志文件查看慢 sql 1.M ...
- Windows 下 Outlook 点击关闭最小化和开机自动运行
.markdown-body { line-height: 1.75; font-weight: 400; font-size: 16px; overflow-x: hidden; color: rg ...
- [转帖]使用 TiUP 扩容缩容 TiDB 集群
https://docs.pingcap.com/zh/tidb/stable/scale-tidb-using-tiup TiDB 集群可以在不中断线上服务的情况下进行扩容和缩容. 本文介绍如何使用 ...
- [转帖]armv6、armv7、armv7s、armv8、armv64及其i386、x86_64区别
ARM处理器指令集 一. 苹果模拟器指令集: 指令集 分析 i386 针对intel通用微处理器32架构的 x86_64 针对x86架构的64位处理器 i386|x86_64 是Mac处理器的指令集, ...