Matlab 分类算法
一、分类算法核心概念
分类是监督学习任务,目标是将数据分配到预定义的类别中。关键步骤包括:
- 特征工程:提取/选择区分性强的特征
- 模型训练:学习特征与类别的映射关系
- 评估指标:准确率、精确率、召回率、F1分数、混淆矩阵
二、常用分类算法解析
1. K近邻(KNN)
原理:基于距离度量,将样本分配给其k个最近邻中最常见的类别
公式:
欧氏距离: \(d(\mathbf{x}_i, \mathbf{x}_j) = \sqrt{\sum_{k=1}^{n}(x_{ik} - x_{jk})^2}\)
优点:简单、无需训练、适用于多分类
缺点:计算开销大、对高维数据敏感
MATLAB代码:
% 生成示例数据
rng(1);
data = [randn(100,2)*0.5+1; randn(100,2)*0.5-1];
labels = [ones(100,1); 2*ones(100,1)];
% 划分训练测试集
cv = cvpartition(length(labels), 'HoldOut', 0.3);
trainData = data(training(cv),:);
trainLabels = labels(training(cv));
testData = data(test(cv),:);
testLabels = labels(test(cv));
% KNN模型训练与预测
Mdl = fitcknn(trainData, trainLabels, 'NumNeighbors', 5);
predicted = predict(Mdl, testData);
% 评估
accuracy = sum(predicted == testLabels)/numel(testLabels);
confMat = confusionmat(testLabels, predicted);
disp(['Accuracy: ', num2str(accuracy*100), '%']);
disp('Confusion Matrix:');
disp(confMat);
2. 支持向量机(SVM)
原理:寻找最优超平面最大化类别间隔
公式:
优化目标: \(\min_{\mathbf{w},b} \frac{1}{2}\|\mathbf{w}\|^2 + C\sum_{i=1}^{n}\xi_i\)
约束: \(y_i(\mathbf{w}\cdot\mathbf{x}_i + b) \geq 1 - \xi_i\)
优点:高维有效、泛化能力强
缺点:计算复杂、需参数调优
MATLAB代码:
% 使用相同数据
% 训练SVM模型(线性核)
SVMModel = fitcsvm(trainData, trainLabels, 'KernelFunction', 'linear', ...
'BoxConstraint', 1, 'Standardize', true);
% 预测与评估
predicted = predict(SVMModel, testData);
accuracy = sum(predicted == testLabels)/numel(testLabels);
disp(['SVM Accuracy: ', num2str(accuracy*100), '%']);
% 可视化决策边界
figure;
hgscatter = gscatter(trainData(:,1), trainData(:,2), trainLabels);
hold on;
h = gca;
lim = [h.XLim h.YLim];
[xx,yy] = meshgrid(linspace(lim(1),lim(2),100), linspace(lim(3),lim(4),100));
XGrid = [xx(:), yy(:)];
predGrid = predict(SVMModel, XGrid);
gscatter(xx(:), yy(:), predGrid, [0 0.5 0.1; 0.1 0.5 0]);
title('SVM Decision Boundary');
hold off;
3.决策树算法(Decision Tree)
核心原理
通过递归分割构建树状结构,每个节点根据特征阈值进行二元决策:
- 分裂准则:基尼不纯度(Gini Index)或信息增益(Information Gain)
% 基尼系数公式 (MATLAB实现)
function gini = gini_index(labels)
classes = unique(labels);
prob = histcounts(labels, [classes; max(classes)+1])/length(labels);
gini = 1 - sum(prob.^2);
end
- 停止条件:最大深度/最小样本数/纯度阈值
MATLAB 代码实现
% 训练决策树模型
treeModel = fitctree(irisInputs, irisTargets, ...
'MaxDepth', 5, ...
'MinLeafSize', 10, ...
'SplitCriterion', 'gdi'); % 基尼系数
% 可视化决策树
view(treeModel, 'Mode', 'graph');
% 预测与评估
predicted = predict(treeModel, testInputs);
accuracy = sum(predicted == testTargets)/numel(testTargets);
disp(['决策树准确率: ', num2str(accuracy*100), '%']);
% 特征重要性分析
imp = predictorImportance(treeModel);
bar(imp);
xlabel('特征');
ylabel('重要性得分');
title('特征重要性排序');
完整示例:
% 设置随机种子保证可重复性
rng(42);
% 1. 加载鸢尾花数据集
load fisheriris; % 数据集存储在变量meas(150x4)和species(150x1)中
% 将类别标签转换为类别数组(便于后续处理)
species = categorical(species);
% 2. 划分训练集和测试集(70%训练,30%测试)
cv = cvpartition(species, 'HoldOut', 0.3);
idxTrain = training(cv);
idxTest = test(cv);
XTrain = meas(idxTrain, :);
yTrain = species(idxTrain);
XTest = meas(idxTest, :);
yTest = species(idxTest);
% 3. 训练决策树模型(使用基尼指数作为分裂准则)
treeModel = fitctree(XTrain, yTrain, 'SplitCriterion', 'gdi');
% 4. 可视化决策树(生成图形化树结构)
view(treeModel, 'Mode', 'graph');
% 5. 预测与评估
yPred = predict(treeModel, XTest);
accuracy = sum(yPred == yTest) / numel(yTest);
fprintf('测试准确率: %.2f%%\n', accuracy * 100);
% 输出混淆矩阵
C = confusionmat(yTest, yPred);
% 注意:confusionchart需要深度学习工具箱(Deep Learning Toolbox)
if exist('confusionchart', 'file')
figure;
confusionchart(yTest, yPred);
title('决策树混淆矩阵');
else
disp('混淆矩阵:');
disp(C);
end
% 6. 特征重要性(基于节点分裂时特征被选择的次数加权计算)
imp = predictorImportance(treeModel);
featureNames = {'SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth'};
% 绘制特征重要性条形图
figure;
bar(imp);
title('决策树特征重要性');
set(gca, 'XTickLabel', featureNames, 'XTick', 1:numel(featureNames));
ylabel('重要性得分');
% 重要提示:决策树可能会过拟合,可通过剪枝优化
% 计算剪枝水平(交叉验证)
[~, ~, ~, bestLevel] = cvLoss(treeModel, 'SubTrees', 'all', 'TreeSize', 'min');
% 剪枝到最佳水平
prunedTree = prune(treeModel, 'Level', bestLevel);
% 评估剪枝后的树
yPredPruned = predict(prunedTree, XTest);
accuracyPruned = sum(yPredPruned == yTest) / numel(yTest);
fprintf('剪枝后测试准确率: %.2f%%\n', accuracyPruned * 100);
% 可视化剪枝后的树(可选)
% view(prunedTree, 'Mode', 'graph');
决策树特点
优势 | 劣势 |
---|---|
模型可解释性强 | 容易过拟合 |
处理混合类型特征 | 对数据波动敏感 |
无需特征缩放 | 边界只能是轴对齐 |
4.神经网络算法(Neural Network)
核心原理(以多层感知机MLP为例)
- 前向传播:\(z^{(l)} = W^{(l)}a^{(l-1)} + b^{(l)}\)
\(a^{(l)} = \sigma(z^{(l)})\) - 激活函数:ReLU \(\sigma(x) = \max(0,x)\) (隐藏层),Softmax(输出层)
- 损失函数:交叉熵 \(L = -\sum y_i \log(\hat{y}_i)\)
MATLAB深度学习工具箱实现
% 数据准备
[XTrain, YTrain, XTest, YTest] = prepareData(); % 自定义数据预处理
% 构建网络结构
layers = [
featureInputLayer(size(XTrain,2)) % 输入层
fullyConnectedLayer(128) % 全连接层
batchNormalizationLayer % 批标准化
reluLayer % ReLU激活
dropoutLayer(0.3) % Dropout正则化
fullyConnectedLayer(64)
reluLayer
fullyConnectedLayer(numClasses) % 输出层
softmaxLayer
classificationLayer];
% 训练配置
options = trainingOptions('adam', ...
'MaxEpochs', 100, ...
'MiniBatchSize', 64, ...
'ValidationData', {XTest, YTest}, ...
'Plots', 'training-progress', ...
'LearnRateSchedule', 'piecewise', ...
'LearnRateDropFactor', 0.5, ...
'LearnRateDropPeriod', 20);
% 训练网络
net = trainNetwork(XTrain, categorical(YTrain), layers, options);
% 测试评估
predicted = classify(net, XTest);
accuracy = sum(predicted == categorical(YTest))/numel(YTest);
confusionchart(YTest, double(predicted));
三、算法对比矩阵
特性 | 决策树 | 神经网络 | KNN | SVM |
---|---|---|---|---|
训练速度 | ️️️ | ️ | ️️ | ️️ |
预测速度 | ️️️ | ️️ | ️ | ️️️ |
可解释性 | ||||
处理高维 | ||||
抗噪声 | ||||
特征工程 | 无需 | 自动提取 | 需缩放 | 需缩放 |
四、关键问题解决方案
决策树过拟合处理
% 后剪枝策略
prunedTree = prune(treeModel, 'Level', 5); % 层级剪枝
cvTree = crossval(treeModel, 'KFold', 5); % 交叉验证剪枝
loss = kfoldLoss(cvTree);
神经网络梯度消失
- 使用ReLU代替Sigmoid
- 添加Batch Normalization层
- 残差连接(ResNet结构)
% 残差块示例
function lgraph = addResBlock(lgraph, blockName, numFilters)
layers = [
convolution2dLayer(3, numFilters, 'Padding','same', 'Name',[blockName '_conv1'])
batchNormalizationLayer('Name',[blockName '_bn1'])
reluLayer('Name',[blockName '_relu1'])
convolution2dLayer(3, numFilters, 'Padding','same', 'Name',[blockName '_conv2'])
batchNormalizationLayer('Name',[blockName '_bn2'])
additionLayer(2,'Name',[blockName '_add'])];
lgraph = addLayers(lgraph, layers);
lgraph = connectLayers(lgraph, 'input', [blockName '_conv1']);
lgraph = connectLayers(lgraph, [blockName '_relu1'], [blockName '_conv2']);
lgraph = connectLayers(lgraph, 'input', [blockName '_add/in2']);
end
五、算法选择指南
- 中小型结构化数据 → 决策树(可解释性优先)或SVM(精度优先)
- 图像/语音/文本数据 → 神经网络(CNN/RNN)
- 实时预测场景 → 决策树(毫秒级响应)
- 缺乏ML经验 → KNN(参数简单)或预训练神经网络
A[数据类型] --> B{结构化?}
B -->|是| C{需要解释模型?}
C -->|是| D[决策树]
C -->|否| E[SVM/神经网络]
B -->|否| F{时序/空间特征?}
F -->|是| G[CNN/RNN]
F -->|否| H[全连接网络]
实践经验:从决策树基准开始,逐渐尝试更复杂模型。对于表格数据,LightGBM/XGBoost(基于决策树的集成方法)通常优于单一模型,MATLAB可通过调用Python库实现:
pyrun('import lightgbm as lgb')
model = pyrun('lgb.LGBMClassifier', [], boosting_type='gbdt', num_leaves=31);
六、分类算法性能对比表
算法 | 训练速度 | 预测速度 | 内存需求 | 适用场景 |
---|---|---|---|---|
KNN | 快 | 慢 | 高 | 小规模数据 |
SVM | 慢 | 快 | 低 | 高维数据 |
决策树 | 快 | 快 | 低 | 可解释性要求高 |
神经网络 | 很慢 | 中等 | 高 | 复杂模式识别 |
七、关键注意事项
- 数据预处理:
% 标准化处理
[trainData, mu, sigma] = zscore(trainData);
testData = (testData - mu)./sigma;
- 类别不平衡处理:
% 使用代价敏感学习
SVMModel = fitcsvm(trainData, trainLabels, 'Cost', [0 2; 1 0]);
- 参数调优(以SVM为例):
% 交叉验证选择最佳参数
opts = struct('Optimizer','bayesopt', 'ShowPlots', true, ...
'CVPartition', cvpartition(trainLabels,'KFold',5));
params = hyperparameters('fitcsvm', trainData, trainLabels);
SVMModel = fitcsvm(trainData, trainLabels, 'OptimizeHyperparameters','auto', ...
'HyperparameterOptimizationOptions', opts);
八、进阶技巧
- 多分类问题:
- 使用
fitcecoc
进行错误校正输出编码
ECOMModel = fitcecoc(trainData, trainLabels);
- 使用
- 特征选择:
% 使用最小冗余最大相关算法
idx = fscmrmr(trainData, trainLabels);
selectedData = trainData(:, idx(1:10));
- 模型融合:
% 创建投票分类器
knnModel = fitcknn(trainData, trainLabels);
treeModel = fitctree(trainData, trainLabels);
ensemble = fitcensemble(trainData, trainLabels, 'Method', 'Subspace');
九、完整工作流示例
% 1. 数据准备
data = readtable('classification_data.csv');
predictors = data(:,1:end-1);
response = data(:,end);
% 2. 特征工程
predictors = fillmissing(predictors, 'constant', 0);
predictors = normalize(predictors);
% 3. 训练/验证集划分
cv = cvpartition(height(response), 'HoldOut', 0.2);
trainPredictors = predictors(training(cv),:);
trainResponse = response(training(cv),:);
valPredictors = predictors(test(cv),:);
valResponse = response(test(cv),:);
% 4. 模型训练与调优
ensemble = fitcensemble(trainPredictors, trainResponse, ...
'OptimizeHyperparameters','all', ...
'HyperparameterOptimizationOptions', ...
struct('AcquisitionFunctionName','expected-improvement-plus'));
% 5. 模型评估
predicted = predict(ensemble, valPredictors);
confusionchart(table2array(valResponse), predicted);
fprintf('F1 Score: %.2f\n', f1score(table2array(valResponse), predicted));
重要提示:实际应用中需根据数据特性选择算法。对于大型数据集推荐使用SVM或集成方法,对于需要解释性的场景可选择决策树,实时系统可考虑KNN或朴素贝叶斯。
Matlab 分类算法的更多相关文章
- 数据挖掘之分类算法---knn算法(有matlab例子)
knn算法(k-Nearest Neighbor algorithm).是一种经典的分类算法.注意,不是聚类算法.所以这种分类算法 必然包括了训练过程. 然而和一般性的分类算法不同,knn算法是一种懒 ...
- 数据挖掘之分类算法---knn算法(有matlab样例)
knn算法(k-Nearest Neighbor algorithm).是一种经典的分类算法. 注意,不是聚类算法.所以这样的分类算法必定包含了训练过程. 然而和一般性的分类算法不同,knn算法是一种 ...
- K近邻分类算法实现 in Python
K近邻(KNN):分类算法 * KNN是non-parametric分类器(不做分布形式的假设,直接从数据估计概率密度),是memory-based learning. * KNN不适用于高维数据(c ...
- 神经网络、logistic回归等分类算法简单实现
最近在github上看到一个很有趣的项目,通过文本训练可以让计算机写出特定风格的文章,有人就专门写了一个小项目生成汪峰风格的歌词.看完后有一些自己的小想法,也想做一个玩儿一玩儿.用到的原理是深度学习里 ...
- Logistic回归分类算法原理分析与代码实现
前言 本文将介绍机器学习分类算法中的Logistic回归分类算法并给出伪代码,Python代码实现. (说明:从本文开始,将接触到最优化算法相关的学习.旨在将这些最优化的算法用于训练出一个非线性的函数 ...
- [分类算法] :SVM支持向量机
Support vector machines 支持向量机,简称SVM 分类算法的目的是学会一个分类函数或者分类模型(分类器),能够把数据库中的数据项映射给定类别中的某一个,从而可以预测未知类别. S ...
- 算法杂货铺——分类算法之朴素贝叶斯分类(Naive Bayesian classification)
算法杂货铺——分类算法之朴素贝叶斯分类(Naive Bayesian classification) 0.写在前面的话 我个人一直很喜欢算法一类的东西,在我看来算法是人类智慧的精华,其中蕴含着无与伦比 ...
- 分类算法之贝叶斯(Bayes)分类器
摘要:旁听了清华大学王建勇老师的 数据挖掘:理论与算法 的课,讲的还是挺细的,好记性不如烂笔头,在此记录自己的学习内容,方便以后复习. 一:贝叶斯分类器简介 1)贝叶斯分类器是一种基于统计的分类器 ...
- Netflix工程总监眼中的分类算法:深度学习优先级最低
Netflix工程总监眼中的分类算法:深度学习优先级最低 摘要:不同分类算法的优势是什么?Netflix公司工程总监Xavier Amatriain根据奥卡姆剃刀原理依次推荐了逻辑回归.SVM.决策树 ...
- 第二篇:基于K-近邻分类算法的约会对象智能匹配系统
前言 假如你想到某个在线约会网站寻找约会对象,那么你很可能将该约会网站的所有用户归为三类: 1. 不喜欢的 2. 有点魅力的 3. 很有魅力的 你如何决定某个用户属于上述的哪一类呢?想必你会分析用户的 ...
随机推荐
- 把 PySide6 移植到安卓上去!
官方教程在此:https://www.qt.io/blog/taking-qt-for-python-to-android 寥寥几句,其实不少坑.凭回忆写的,可能不是很全(无招胜有招) 仅支持 Lin ...
- 论文解读:Locating and Editing Factual Associations in GPT(ROME)
本文发表在人工智能顶会NeurIPS上(原文链接),研究了GPT(Generative Pre-trained Transformer)中事实关联的存储和回忆,发现这些关联与局部化.可直接编辑的计 ...
- `.NC`文件的读取与使用
.NC文件的读取与使用 前言 NetCDF(network Common Data Form)网络通用数据格式是一种面向数组型并适于网络共享的数据的描述和编码标准.目前,NetCDF广泛应用于大气科学 ...
- PRIMPERM - Prime Permutations
将题目分解成两个部分: 判断素数 如果用暴力筛因子的方法,在 $t \le 10^4,n \le 10^7$ 下肯定是要超时的,所以用了时间和空间都比较廉价的埃氏筛法. 代码: bool f[1000 ...
- 关于PHP 函数性能优化的技巧
本文由 ChatMoney团队出品 本文将详细介绍 PHP 函数性能优化的技巧.通过分析 PHP 函数的执行过程和性能瓶颈,提供一系列实用的优化方法,并结合代码示例,帮助读者提升 PHP 代码的执行效 ...
- Redis五-哨兵
目录 哨兵 导读 基本概念 主从复制问题 Redis Sentinel的高可用性 安装和部署 部署数据节点 部署Sentinel节点 Seninel配置优化 sentinel API 实现原理 三个定 ...
- Kafka入门实战教程(3).NET Core操作Kafka
1 可用的Kafka .NET客户端 作为一个.NET Developer,自然想要在.NET项目中集成Kafka实现发布订阅功能.那么,目前可用的Kafka客户端有哪些呢? 目前.NET圈子主流使用 ...
- electron 热更新以及对 ts 的支持
前言 虽然 Electron 官方宣布支持 TypeScript,但它只是支持了类型定义文件,而不是真正的 TS 开箱即用. 比如你的入口文件是 ts,当你运行 electron .启动项目的时候,依 ...
- batocera添加游戏
进入batocera系统之后,会发现就只有几个不懂什么东西的模拟器,或者FC和其它,好象没有发现可以用NS或者其它的模拟器,在此说明 一下: 整个系统至少支持73个以上模拟器,主界面没有仅因为你没有放 ...
- 【7*】期望DP学习笔记
前言 由于马上就要把同学的<概率论与数理统计>还回去了,所以赶快看一点,并做一点笔记. 感觉网上很多文章讲期望 DP 都讲得不够透彻啊,就写一些自己的理解造福后人吧,自我感觉讲得很透彻. ...