基于minsit数据集的图像分类任务|CNN简单应用项目
Github地址
README
摘要
本次实验报告用两种方式完成了基于minst数据集完成了图像的分类任务
第一种方式采用课件所讲述的差值法对训练集里的每一张图片进行了预测,并最后得出总体的测试acc,由于只是简单采用差值法对图片进行预测,没有作其他的操作,因此acc只达到了16.8%
第二种方式采用了深度学习,2d卷积神经网络的方式进行图像分类。acc达到了0.98267
实验内容及目的
实验内容为通过差值法和2dCNN的方法,对每一类1000张,共10类的minsit数据集作分类任务。实验目的是为:掌握Matlab图片导入、分析和操作的方式。
实验相关原理描述
差值法的原理:
本质上是矩阵之间的相似性,相似度最大的即为预测类别,其中公式如所示
该模型不需要进行训练,在测试的时候,每一张图片所对应的矩阵与哪一个带有标签的矩阵相似度最大即可,则该图片的预测结果即位对应的标签值。
2dCNN法的原理:
通过2d卷积神经网络,让学习器学习minst数据集图像特征,最后根据学习到的特征进行分类的预测。
CNN网络模型图如图所示:

实验过程
差值法
在数据处理方面,在使用差值法进行图像分类时,我利用了Python语言对所有图片进行打标签工作,以便于Matlab程序读取每一张图片及其相对应的标签,代码如下:
import pandas as pd
import os
import numpy as np
csv_path = r'/Users/demac/我的文件/SYSU/4. 2022第二学期/图像处理实验/实验一/图像实验一/number_recognize/Data.csv'
df = pd.read_csv(csv_path)
arr = np.array(df)
arr = arr.tolist()
f = open(r'/Users/demac/我的文件/SYSU/4. 2022第二学期/图像处理实验/实验一/图像实验一/number_recognize/Data.txt', 'w')
for cur_label in range(0, 10):
for root, dirs, files in os.walk(fr"/Users/demac/我的文件/SYSU/4. 2022第二学期/图像处理实验/实验一/图像实验一/number_recognize/train_dataset/{cur_label}"):
for file in files:
# 获取文件路径
path = os.path.join(root, file)
# f.write(path + '\n')
new_row = [path, cur_label]
arr.append(new_row)
arr = np.array(arr)
df = pd.DataFrame(arr)
df.to_csv(csv_path)
print()
分类代码如下:
clear all;
opts = delimitedTextImportOptions("NumVariables", 3);
% 指定范围和分隔符
opts.DataLines = [2, Inf];
opts.Delimiter = ",";
% 指定列名称和类型
opts.VariableNames = ["VarName1", "VarName2", "VarName3"];
opts.VariableTypes = ["double", "string", "double"];
% 指定文件级属性
opts.ExtraColumnsRule = "ignore";
opts.EmptyLineRule = "read";
% 指定变量属性
opts = setvaropts(opts, "VarName2", "WhitespaceRule", "preserve");
opts = setvaropts(opts, "VarName2", "EmptyFieldRule", "auto");
opts = setvaropts(opts, ["VarName1", "VarName3"], "ThousandsSeparator", ",");
% 导入数据
Data = readtable("Data.csv", opts)
% 数据预处理
clear opts
array = table2array(Data);
% 此时array里面存的就是路径和标签
% 在所有数据中取出9张作为label比对图片
img_model = {};
idx = 1;
for i = 1:9
img_model{i} = imread(array(idx,2));
idx = idx+1000;
end
% img_model已经处理好了
开始test
correct = 0;
for i = 1:length(array)
true = array(i,3);
img = imread(array(i,2));
pred_idx = -1;
min_err = 100000;
for j = 1:length(img_model)
error = count_err(img_model{j},img);
if(error < min_err)
min_err = error;
pred_idx = j - 1;
end
end
if pred_idx == -1
disp("error");
end
if pred_idx == str2num(true)
% 代表预测正确
correct = correct + 1;
end
end
disp(correct);
disp("最终的准确率为: " + num2str(correct/length(array)));
function error = count_err(img1,img2)
error = norm(double(img1) - double(img2));
end
2dCNN
clear all;
DatasetPath = fullfile(['/Users/demac/我的文件/SYSU/' ...
'4. 2022第二学期/图像处理实验/实验一/图像实验一/number_recognize/train_dataset/']);
imds = imageDatastore(DatasetPath, ...
'IncludeSubfolders',true, ...
'LabelSource','foldernames');
每个类别有1000张图片,取700张进行train,300张进行test
train_size = 700;
[imdsTrain,imdsValidation] = splitEachLabel(imds,train_size,'randomized');
% 定义神经网络的forward
inplane = [28,28,1]; % 图像输入大小
numClasses = 10; %10分类任务
layers = [
imageInputLayer(inplane)
convolution2dLayer(5,20) % 卷积层
batchNormalizationLayer % 归一层
reluLayer % 激活函数
fullyConnectedLayer(numClasses) % 全链接层
softmaxLayer
classificationLayer];
% train
options = trainingOptions("sgdm", ...
"MaxEpochs",5, ...
"ValidationData",imdsValidation, ...
"ValidationFrequency",30, ...
"Verbose",false, ...
"Plots",'training-progress'); % 最后输出训练过程的趋势
net = trainNetwork(imdsTrain,layers,options); % 构建网络
% test
Pred = classify(net,imdsValidation);
YValidation = imdsValidation.Labels;
acc = mean(Pred == YValidation);
disp("acc: " + num2str(acc));
实验结果
差值法
差值法分类最后的acc为:0.1676
2dCNN法
2dCNN法最后得到的acc为:0.98267
训练收敛过程如下图所示:
总结
通过两个处理方法的实验,我们发现,差值法并不能很好的完成minist数据集的10分类任务。与此同时,卷积神经网络是一种很好的分类方法,对于98.27%的准确率,我们还可以通过调整网络前向传播,如增加注意力机制等模块等方式继续提高分类的准确度。
附件
main.mlx 差值法分类任务代码源文件
main2.mlx 卷积神经网络分类任务代码源文件
main.pdf和main2.pdf 实时脚本输出pdf文件
label.py 打标签Python源文件
final.jpg 神经网络模型结构图
train.png 神经网络训练过程图
基于minsit数据集的图像分类任务|CNN简单应用项目的更多相关文章
- 基于MNIST数据的卷积神经网络CNN
基于tensorflow使用CNN识别MNIST 参数数量:第一个卷积层5x5x1x32=800个参数,第二个卷积层5x5x32x64=51200个参数,第三个全连接层7x7x64x1024=3211 ...
- MNIST数据集上卷积神经网络的简单实现(使用PyTorch)
设计的CNN模型包括一个输入层,输入的是MNIST数据集中28*28*1的灰度图 两个卷积层, 第一层卷积层使用6个3*3的kernel进行filter,步长为1,填充1.这样得到的尺寸是(28+1* ...
- 【实践】如何利用tensorflow的object_detection api开源框架训练基于自己数据集的模型(Windows10系统)
如何利用tensorflow的object_detection api开源框架训练基于自己数据集的模型(Windows10系统) 一.环境配置 1. Python3.7.x(注:我用的是3.7.3.安 ...
- 机器学习算法(二): 基于鸢尾花数据集的朴素贝叶斯(Naive Bayes)预测分类
机器学习算法(二): 基于鸢尾花数据集的朴素贝叶斯(Naive Bayes)预测分类 项目链接参考:https://www.heywhale.com/home/column/64141d6b1c8c8 ...
- 器学习算法(六)基于天气数据集的XGBoost分类预测
1.机器学习算法(六)基于天气数据集的XGBoost分类预测 1.1 XGBoost的介绍与应用 XGBoost是2016年由华盛顿大学陈天奇老师带领开发的一个可扩展机器学习系统.严格意义上讲XGBo ...
- 基于MNIST数据集使用TensorFlow训练一个没有隐含层的浅层神经网络
基础 在参考①中我们详细介绍了没有隐含层的神经网络结构,该神经网络只有输入层和输出层,并且输入层和输出层是通过全连接方式进行连接的.具体结构如下: 我们用此网络结构基于MNIST数据集(参考②)进行训 ...
- 基于Xilinx Zynq Z7045 SoC的CNN的视觉识别应用
基于Xilinx Zynq Z7045 SoC的CNN的视觉识别应用 由 judyzhong 于 星期三, 08/16/2017 - 14:56 发表 作者:stark 近些年来随着科学技术的不断进步 ...
- SpringBoot整合Shiro实现基于角色的权限访问控制(RBAC)系统简单设计从零搭建
SpringBoot整合Shiro实现基于角色的权限访问控制(RBAC)系统简单设计从零搭建 技术栈 : SpringBoot + shiro + jpa + freemark ,因为篇幅原因,这里只 ...
- 一个基于 .NET Core 2.0 开发的简单易用的快速开发框架 - LinFx
LinFx 一个基于 .NET Core 2.0 开发的简单易用的快速开发框架,遵循领域驱动设计(DDD)规范约束,提供实现事件驱动.事件回溯.响应式等特性的基础设施.让开发者享受到正真意义的面向对象 ...
- 基于C++11的100行实现简单线程池
基于C++11的100行实现简单线程池 1 线程池原理 线程池是一种多线程处理形式,处理过程中将任务添加到队列,然后在创建线程后自动启动这些任务.线程池线程都是后台线程.每个线程都使用默认的堆栈大小, ...
随机推荐
- AtCoder Beginner Contest 218 A~D
比赛链接:Here A - Weather Forecas 水题,判断 \(s[n - 1] = o\) 的话输出 YES B - qwerty 题意:给出 \((1,2,...,26)\) 的某个全 ...
- 机器学习-线性分类-支持向量机SVM-合页损失-SVM输出概率值-16
目录 1. SVM概率化输出 2. 合页损失 1. SVM概率化输出 标准的SVM进行预测 输出的结果是: 是无法输出0-1之间的 正样本 发生的概率值 sigmoid-fitting 方法: 将标准 ...
- MySQL的SQL优化常用30种方法[转]
MySQL的SQL优化常用30种方法 1.对查询进行优化,应尽量避免全表扫描,首先应考虑在 where 及 order by 涉及的列上建立索引. 2.应尽量避免在 where 子句中使用!=或< ...
- Laravel路由匹配
Route常规用法如下,特别是最后一个传参之后可以进行正则匹配,非常好用. //@后面内容为所要访问的方法 Route::get('foo', 'Photos\AdminController@meth ...
- [转帖]Oracle优化案例:vfs_cache_pressure和min_free_kbytes解决RMAN挂起问题
https://www.modb.pro/db/34028 环境: Oracle 11gr2 + dataguard 512GB内存 + 128核cpu + 高性能存储服务器 uname -an Li ...
- TiKV 服务部署的注意事项
TiKV 服务部署的注意事项 背景 最近发现tikv总是会掉线 不知道是哪里触发了啥样子的bug. 所以想着使用systemd 管理一下, 至少在tikv宕机的时候能够拉起来服务. 二进制文件 pd- ...
- [转帖]Nginx优化与防盗链
目录 一.配置Nginx隐藏版本号 1.第一种方法修改配置文件 2.第二种方法修改源码文件,重新编译安装 二.修改Nginx用户与组 三.配置Nginx网页缓存时间 四.实现Nginx的日志分割 五. ...
- [转帖]实战瓶颈定位-我的MySQL为什么压不上去
https://plantegg.github.io/2023/06/20/%E5%AE%9E%E6%88%98%E7%93%B6%E9%A2%88%E5%AE%9A%E4%BD%8D-%E6%88% ...
- [转帖]Linux—编写shell脚本操作数据库执行sql
Linux-编写shell脚本操作数据库执行sql Hughman关注IP属地: 北京 0.0762020.03.20 09:02:13字数 295阅读 1,036 修改数据库数据 在升级应用时, ...
- [转帖]InnoDB表聚集索引层高什么时候发生变化
导读 本文略长,主要解决以下几个疑问 1.聚集索引里都存储了什么宝贝 2.什么时候索引层高会发生变化 3.预留的1/16空闲空间做什么用的 4.记录被删除后的空间能回收重复利用吗 1.背景信息 1.1 ...