接着统计学习中knn算法实验(1)的内容

Problem:

  1. Explore the data before classification using summary statistics or visualization
  2. Pre-process the data (such as denoising, normalization, feature selection, …)
  3. Try other distance metrics or distance-based voting
  4. Try other dimensionality reduction methods
  5. How to set the k value, if not using cross validation? Verify your idea
问题:
  1. 在对数据分类之前使用对数据进行可视化处理
  2. 预处理数据(去噪,归一化,数据选择)
  3. 在knn算法中使用不同的距离计算方法
  4. 使用其他的降维算法
  5. 如何在不使用交叉验证的情况下设置k值

使用Parallel coordinates plot做数据可视化,首先对数据进行归一化处理,数据的动态范围控制在[0,1]。注意归一化的处理针对的是每一个fearture。






通过对图的仔细观察,我们挑选出重叠度比较低的feature来进行fearture selection,feature selection实际上是对数据挑选出更易区分的类型作为下一步分类算法的数据。我们挑选出feature序号为(1)、(2)、(5)、(6)、(7)、(10)的feature。个人认为,feature selection是一种简单而粗暴的降维和去噪的操作,但是可能效果会很好。

根据上一步的操作,从Parallel coordinates上可以看出,序号为(1)、(2)、(5)、(6)、(7)、(10)这几个feature比较适合作为classify的feature。我们选取以上几个feature作knn,得到的结果如下:

当K=1 的时候,Accuracy达到了85.38%,并且相比于简单的使用knn或者PCA+knn的方式,Normalization、Featrure Selection的方法使得准确率大大提升。我们也可以使用不同的feature搭配,通过实验得到更好的结果。

MaxAccuracy= 0.8834 when k=17 (Normalization+FeartureSelection+KNN)

试验中,我们使用了两种不同的Feature Selection 策略,选用较少fearture的策略对分类的准确率还是有影响的,对于那些从平行坐标看出的不那么好的fearture,对分类还是有一定的帮助的。
在较小的k值下,Feature Selection的结果要比直接采用全部Feature的结果要好。这也体现了在相对纯净的数据下,较小的k值能够获得较好的结果,这和直观感觉出来的一致。
我们再尝试对数据进行进一步的预处理操作,比如denoising。
数据去噪的方法利用对Trainning数据进行一个去处最大最小边缘值的操作,我们认为,对于一个合适的feature,它的数据应该处于一个合理的范围中,过大或者过小的数据都将是异常的。

Denoising的代码如下:

  1. function[DNData]=DataDenoising(InputData,KillRange)
  2. DNData=InputData;
  3. %MedianData=median(DNData);
  4. for i=2:size(InputData,2)
  5. [temp,DNIndex]=sort(DNData(:,i));
  6. DNData=DNData(DNIndex(1+KillRange:end-KillRange),:);
  7. end

采用LLE作为降维的手段,通过和以上的几种方案作对比,如下:



MaxAccuracy= 0.9376 when K=23 (LLE dimensionality reduction to 2)

关于LLE算法,参见这篇论文

  • Nonlinear dimensionality reduction by locally linear embedding.Sam Roweis & Lawrence Saul.Science, v.290 no.5500 , Dec.22, 2000. pp.2323--2326.
以及项目主页:

源代码:

StatLearnProj.m

  1. clear;
  2. data=load('wine.data.txt');
  3. %calc 5-folder knn
  4. Accuracy=[];
  5. for i=1:5
  6. Test=data(i:5:end,:);
  7. TestData=Test(:,2:end);
  8. TestLabel=Test(:,1);
  9. Trainning=setdiff(data,Test,'rows');
  10. TrainningData=Trainning(:,2:end);
  11. TrainningLabel=Trainning(:,1);
  12. Accuracy=cat(1,Accuracy,CalcAccuracy(TestData,TestLabel,TrainningData,TrainningLabel));
  13. end
  14. AccuracyKNN=mean(Accuracy,1);
  15.  
  16. %calc PCA
  17. Accuracy=[];
  18. %PCA
  19. [Coeff,Score,Latent]=princomp(data(:,2:end));
  20. dataPCA=[data(:,1),Score(:,1:6)];
  21. Latent
  22. for i=1:5
  23. Test=dataPCA(i:5:end,:);
  24. TestData=Test(:,2:end);
  25. TestLabel=Test(:,1);
  26. Trainning=setdiff(dataPCA,Test,'rows');
  27. TrainningData=Trainning(:,2:end);
  28. TrainningLabel=Trainning(:,1);
  29. Accuracy=cat(1,Accuracy,CalcAccuracy(TestData,TestLabel,TrainningData,TrainningLabel));
  30. end
  31. AccuracyPCA=mean(Accuracy,1);
  32. BarData=[AccuracyKNN;AccuracyPCA];
  33. bar(1:2:51,BarData');
  34.  
  35. [D,I]=sort(AccuracyKNN,'descend');
  36. D(1)
  37. I(1)
  38. [D,I]=sort(AccuracyPCA,'descend');
  39. D(1)
  40. I(1)
  41.  
  42. %pre-processing data
  43. %Normalization
  44. labs1={'1)Alcohol','(2)Malic acid','3)Ash','4)Alcalinity of ash'};
  45. labs2={'5)Magnesium','6)Total phenols','7)Flavanoids','8)Nonflavanoid phenols'};
  46. labs3={'9)Proanthocyanins','10)Color intensity','11)Hue','12)OD280/OD315','13)Proline'};
  47. uniData=[];
  48. for i=2:size(data,2)
  49. uniData=cat(2,uniData,(data(:,i)-min(data(:,i)))/(max(data(:,i))-min(data(:,i))));
  50. end
  51. figure();
  52. parallelcoords(uniData(:,1:4),'group',data(:,1),'labels',labs1);
  53. figure();
  54. parallelcoords(uniData(:,5:8),'group',data(:,1),'labels',labs2);
  55. figure();
  56. parallelcoords(uniData(:,9:13),'group',data(:,1),'labels',labs3);
  57.  
  58. %denoising
  59.  
  60. %Normalization && Feature Selection
  61. uniData=[data(:,1),uniData];
  62. %Normalization all feature
  63.  
  64. for i=1:5
  65. Test=uniData(i:5:end,:);
  66. TestData=Test(:,2:end);
  67. TestLabel=Test(:,1);
  68. Trainning=setdiff(uniData,Test,'rows');
  69. TrainningData=Trainning(:,2:end);
  70. TrainningLabel=Trainning(:,1);
  71. Accuracy=cat(1,Accuracy,CalcAccuracy(TestData,TestLabel,TrainningData,TrainningLabel));
  72. end
  73. AccuracyNorm=mean(Accuracy,1);
  74.  
  75. %KNN PCA Normalization
  76. BarData=[AccuracyKNN;AccuracyPCA;AccuracyNorm];
  77. bar(1:2:51,BarData');
  78.  
  79. %Normalization& FS 1 2 5 6 7 10 we select 1 2 5 6 7 10 feature
  80. FSData=uniData(:,[1 2 3 6 7 8 11]);
  81. size(FSData)
  82. for i=1:5
  83. Test=FSData(i:5:end,:);
  84. Trainning=setdiff(FSData,Test,'rows');
  85. TestData=Test(:,2:end);
  86. TestLabel=Test(:,1);
  87. TrainningData=Trainning(:,2:end);
  88. TrainningLabel=Trainning(:,1);
  89. Accuracy=cat(1,Accuracy,CalcAccuracy(TestData,TestLabel,TrainningData,TrainningLabel));
  90. end
  91. AccuracyNormFS1=mean(Accuracy,1);
  92.  
  93. %Normalization& FS 1 6 7
  94. FSData=uniData(:,[1 2 7 8]);
  95. for i=1:5
  96. Test=FSData(i:5:end,:);
  97. Trainning=setdiff(FSData,Test,'rows');
  98. TestData=Test(:,2:end);
  99. TestLabel=Test(:,1);
  100. TrainningData=Trainning(:,2:end);
  101. TrainningLabel=Trainning(:,1);
  102. Accuracy=cat(1,Accuracy,CalcAccuracy(TestData,TestLabel,TrainningData,TrainningLabel));
  103. end
  104. AccuracyNormFS2=mean(Accuracy,1);
  105. figure();
  106. BarData=[AccuracyNorm;AccuracyNormFS1;AccuracyNormFS2];
  107. bar(1:2:51,BarData');
  108.  
  109. [D,I]=sort(AccuracyNorm,'descend');
  110. D(1)
  111. I(1)
  112. [D,I]=sort(AccuracyNormFS1,'descend');
  113. D(1)
  114. I(1)
  115. [D,I]=sort(AccuracyNormFS2,'descend');
  116. D(1)
  117. I(1)
  118. %denoiding
  119. %Normalization& FS 1 6 7
  120. FSData=uniData(:,[1 2 7 8]);
  121. for i=1:5
  122. Test=FSData(i:5:end,:);
  123. Trainning=setdiff(FSData,Test,'rows');
  124. Trainning=DataDenoising(Trainning,2);
  125. TestData=Test(:,2:end);
  126. TestLabel=Test(:,1);
  127. TrainningData=Trainning(:,2:end);
  128. TrainningLabel=Trainning(:,1);
  129. Accuracy=cat(1,Accuracy,CalcAccuracy(TestData,TestLabel,TrainningData,TrainningLabel));
  130. end
  131. AccuracyNormFSDN=mean(Accuracy,1);
  132. figure();
  133. hold on
  134. plot(1:2:51,AccuracyNormFSDN);
  135. plot(1:2:51,AccuracyNormFS2,'r');
  136.  
  137. %other distance metrics
  138.  
  139. Dist='cityblock';
  140. for i=1:5
  141. Test=uniData(i:5:end,:);
  142. TestData=Test(:,2:end);
  143. TestLabel=Test(:,1);
  144. Trainning=setdiff(uniData,Test,'rows');
  145. TrainningData=Trainning(:,2:end);
  146. TrainningLabel=Trainning(:,1);
  147. Accuracy=cat(1,Accuracy,CalcAccuracyPlus(TestData,TestLabel,TrainningData,TrainningLabel,Dist));
  148. end
  149. AccuracyNormCity=mean(Accuracy,1);
  150.  
  151. BarData=[AccuracyNorm;AccuracyNormCity];
  152. figure();
  153. bar(1:2:51,BarData');
  154.  
  155. [D,I]=sort(AccuracyNormCity,'descend');
  156. D(1)
  157. I(1)
  158.  
  159. %denoising
  160. FSData=uniData(:,[1 2 7 8]);
  161. Dist='cityblock';
  162. for i=1:5
  163. Test=FSData(i:5:end,:);
  164. TestData=Test(:,2:end);
  165. TestLabel=Test(:,1);
  166. Trainning=setdiff(FSData,Test,'rows');
  167. Trainning=DataDenoising(Trainning,3);
  168. TrainningData=Trainning(:,2:end);
  169. TrainningLabel=Trainning(:,1);
  170. Accuracy=cat(1,Accuracy,CalcAccuracyPlus(TestData,TestLabel,TrainningData,TrainningLabel,Dist));
  171. end
  172. AccuracyNormCityDN=mean(Accuracy,1);
  173. figure();
  174. hold on
  175. plot(1:2:51,AccuracyNormCityDN);
  176. plot(1:2:51,AccuracyNormCity,'r');
  177.  
  178. %call lle
  179.  
  180. data=load('wine.data.txt');
  181. uniData=[];
  182. for i=2:size(data,2)
  183. uniData=cat(2,uniData,(data(:,i)-min(data(:,i)))/(max(data(:,i))-min(data(:,i))));
  184. end
  185. uniData=[data(:,1),uniData];
  186. LLEData=lle(uniData(:,2:end)',5,2);
  187. %size(LLEData)
  188. LLEData=LLEData';
  189. LLEData=[data(:,1),LLEData];
  190.  
  191. Accuracy=[];
  192. for i=1:5
  193. Test=LLEData(i:5:end,:);
  194. TestData=Test(:,2:end);
  195. TestLabel=Test(:,1);
  196. Trainning=setdiff(LLEData,Test,'rows');
  197. Trainning=DataDenoising(Trainning,2);
  198. TrainningData=Trainning(:,2:end);
  199. TrainningLabel=Trainning(:,1);
  200. Accuracy=cat(1,Accuracy,CalcAccuracyPlus(TestData,TestLabel,TrainningData,TrainningLabel,'cityblock'));
  201. end
  202. AccuracyLLE=mean(Accuracy,1);
  203. [D,I]=sort(AccuracyLLE,'descend');
  204. D(1)
  205. I(1)
  206.  
  207. BarData=[AccuracyNorm;AccuracyNormFS2;AccuracyNormFSDN;AccuracyLLE];
  208. figure();
  209. bar(1:2:51,BarData');
  210.  
  211. save('ProcessingData.mat');

CalcAccuracy.m

  1. function Accuracy=CalcAccuracy(TestData,TestLabel,TrainningData,TrainningLabel)
  2. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
  3. %calculate the accuracy of classify
  4. %TestData:M*D matrix D stand for dimension,M is sample
  5. %TrainningData:T*D matrix
  6. %TestLabel:Label of TestData
  7. %TrainningLabel:Label of Trainning Data
  8. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
  9. CompareResult=[];
  10. for k=1:2:51
  11. ClassResult=knnclassify(TestData,TrainningData,TrainningLabel,k);
  12. CompareResult=cat(2,CompareResult,(ClassResult==TestLabel));
  13. end
  14. SumCompareResult=sum(CompareResult,1);
  15. Accuracy=SumCompareResult/length(CompareResult(:,1));

CalcAccuracyPlus.m

  1. function Accuracy=CalcAccuracyPlus(TestData,TestLabel,TrainningData,TrainningLabel,Dist)
  2. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
  3. %just as CalcAccuracy,but add distance metrics
  4. %calculate the accuracy of classify
  5. %TestData:M*D matrix D stand for dimension,M is sample
  6. %TrainningData:T*D matrix
  7. %TestLabel:Label of TestData
  8. %TrainningLabel:Label of Trainning Data
  9. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
  10. CompareResult=[];
  11. for k=1:2:51
  12. ClassResult=knnclassify(TestData,TrainningData,TrainningLabel,k,Dist);
  13. CompareResult=cat(2,CompareResult,(ClassResult==TestLabel));
  14. end
  15. SumCompareResult=sum(CompareResult,1);
  16. Accuracy=SumCompareResult/length(CompareResult(:,1));

【StatLearn】统计学习中knn算法实验(2)的更多相关文章

  1. 【StatLearn】统计学习中knn算法的实验(1)

    Problem: Develop a k-NN classifier with Euclidean distance and simple voting Perform 5-fold cross va ...

  2. 学习OpenCV——KNN算法

    转自:http://blog.csdn.net/lyflower/article/details/1728642 文本分类中KNN算法,该方法的思路非常简单直观:如果一个样本在特征空间中的k个最相似( ...

  3. Machine Learning In Action 第二章学习笔记: kNN算法

    本文主要记录<Machine Learning In Action>中第二章的内容.书中以两个具体实例来介绍kNN(k nearest neighbors),分别是: 约会对象预测 手写数 ...

  4. 统计学习中感知机的C++代码

    感知机是古老的统计学习方法,主要应用于二类线性可分数据,策略是在给定的超平面上对误差点进行纠正,从而保证所有的点都是正确可分的. 用到的方法是随机梯度下降法,由于是线性可分的,可保证最终在有限步内收敛 ...

  5. [译]针对科学数据处理的统计学习教程(scikit-learn教程2)

    翻译:Tacey Wong 统计学习: 随着科学实验数据的迅速增长,机器学习成了一种越来越重要的技术.问题从构建一个预测函数将不同的观察数据联系起来,到将观测数据分类,或者从未标记数据中学习到一些结构 ...

  6. KNN算法的R语言实现

    近邻分类 简言之,就是将未标记的案例归类为与它们最近相似的.带有标记的案例所在的类. 应用领域: 1.计算机视觉:包含字符和面部识别等 2.推荐系统:推荐受众喜欢电影.美食和娱乐等 3.基因工程:识别 ...

  7. 机器学习(一)之KNN算法

    knn算法原理 ①.计算机将计算所有的点和该点的距离 ②.选出最近的k个点 ③.比较在选择的几个点中那个类的个数多就将该点分到那个类中 KNN算法的特点: knn算法的优点:精度高,对异常值不敏感,无 ...

  8. 吴裕雄--天生自然python机器学习实战:K-NN算法约会网站好友喜好预测以及手写数字预测分类实验

    实验设备与软件环境 硬件环境:内存ddr3 4G及以上的x86架构主机一部 系统环境:windows 软件环境:Anaconda2(64位),python3.5,jupyter 内核版本:window ...

  9. SVM(支持向量机)与统计机器学习 & 也说一下KNN算法

    因为SVM和统计机器学习内容很多,所以从 http://www.cnblogs.com/charlesblc/p/6188562.html 这篇文章里面分出来,单独写. 为什么说SVM和统计学关系很大 ...

随机推荐

  1. 阿里云Linux服务器,配置JDK,MySQL

    云服务器配置:低配 Linux CentOS 7.4 64位 选择空白镜像: 安装图形界面 yum groups install "MATE Desktop" yum groups ...

  2. Java死锁排查和Java CPU 100% 排查的步骤整理

    ================================================= 人工智能教程.零基础!通俗易懂!风趣幽默!大家可以看看是否对自己有帮助! 点击查看高清无码教程 == ...

  3. java中的stream的Map收集器操作

    package test9; import java.util.Collections; import java.util.HashSet; import java.util.Map; import ...

  4. flask 中的request

    request.args                    从URL地址中的参数request.form                   POST请求时 从FormData中获取参数reque ...

  5. 【BZOJ-3681】Arietta 网络流 + 线段树合并

    3681: Arietta Time Limit: 20 Sec  Memory Limit: 64 MBSubmit: 182  Solved: 70[Submit][Status][Discuss ...

  6. Gym 100646 You’ll be Working on the Railroad dfs

    You'll be Working on the Railroad 题目连接: http://codeforces.com/gym/100646/attachments Description Con ...

  7. Technical Information ARM-related JTAG / SWD / SWV / ETM Target Interfaces

    https://www.computex.co.jp/eg/products/pdf/technical_pdf/arm_if01_gijutsu_eng.pdf

  8. 【Go命令教程】7. go run

    Go 源码文件包括:命令源码文件.库源码文件 和 测试源码文件.其中,命令源码文件 总应该属于 main 代码包,且在其中有无参数声明.无结果声明的 main 函数.单个命令源码文件可以被单独编译,也 ...

  9. Android WebView加载Html右边空白问题的解决方案

    用WebView显示Html时,右边会出现一条空白区,如下图所示: 最开始的时候,认为是网页本身的空白. 后来发现网页本身无问题,且这个空白区是跟Scroll Bar 的位置和粗细比较相符,于是去控制 ...

  10. [置顶] Linux下发布QT程序

    Linux下发布QT程序 概述 无论在windows下还是在linux下,可执行程序的运行都依赖于相关的运行库,我们需要将依赖的库找到放到特定的位置,让可执行文件能够找到.在不知道可执行文件依赖哪些库 ...