Kaggle实战之二分类问题

0. 前言

  • “尽管新技术新算法层出不穷,但是掌握好基础算法就能解决手头 90% 的机器学习问题。”

1. MNIST 数据集

MNIST是最常用的用来实验分类模型的数据集,有7w多张手写0-9的白底黑字数字图像,每张图像大小 28*28,共784个像素,像素取值范围为[0-255],0表示白色背景,255表示纯黑,如下图:

2. 二分类器

数字识别是个多分类问题,首先我们从两分类问题开始入手,即判断一张图片是5或者非5。应用sklearn线性模型的SGDClassifier直接训练,损失函数默认是 hinge loss,即SVM分类器,如果想用logistic regression来分类,可以选用 log loss。SGDClassifier 分类器还支持不同的损失函数,如 perceptron等,这里就不一一列举了。由于SGD分类器对训练样本的顺序是敏感的,所以在模型训练之前需要shuffle训练集。

用SVM训练结束后,用模型预测得到测试集的预测结果,评估准确率(accuracy)大概是96%,看起来是一个不错的结果,但是如果我们把所有的测试样本都判定为非5,准确率也能有90%(十分之九都是对的)。看来光凭准确率,在这种情况下不能说明我们模型学习得很好,看来如何评价模型的学习能力不是件那么简单的事情。

“观察到的一个有意思的细节:一些喜好机器学习或者数据科学的初学工程师和有机器学习或者数据科学背景的科学家,在工作上的主要区别在于如何对待负面的实验(包括线下和线上)结果。初学者往往就开始琢磨如何改模型,加Feature,调参数;思考如何从简单模型转换到复杂模型。有经验的人往往更加去了解实验的设置有没有问题;实验的Metrics的Comparison是到底怎么计算的;到真需要去思考模型的问题的时候,有经验的人往往会先反思训练数据的收集情况,测试数据和测试评测的真实度问题。初学者有点类似程咬金的三板斧,有那么几个技能,用完了,要是还没有效果,也就完了。而有经验的数据科学家,往往是从问题出发,去看是不是对问题本质的把握(比如优化的目标是不是对;有没有Counterfactual的情况)出现了偏差,最后再讨论模型。”

—— by @洪亮劼

3. 效果评测

从源头出发,如下图,x、o分别表示label为负和正的样本,划分为上下两列,假设模型预测值是一个连续值(如为正的概率),把正负样本按照预测值从低到高分别排列好。一个好的模型,应该是左上角分布较密集,表示很多负样本预测值较小,右下角分布也很密集,表示为模型预测正样本的概率值普遍偏高。当然,一般模型也无法做到百分之百的分类准确,所以存在少量的负样本预测概率较高,正样本预测概率偏低,如图右上角和左下角。

Confusion Matrix

我们设定一个阈值,用图中蓝色的竖线表示,高于阈值的模型预测为正样本,反之则为负样本。这个阈值是我们可以自行设定的,蓝色的竖线可以左右移动。红色的横线和蓝色的竖线将整个测试集数据分成四个部分,TN(True Positive)、FP(False Positive)、FN(False Negative)、TP(True Positive)。TPR(TP rate)即recall= TP/(TP+FN),precision=TP/(TP+FP)。上面我们计算accuracy实际上是 (TN+TP)/ALL,对于一个测试集来说,底下分母是不变的,如果TN对比TP很大,TP的变动很难通过accuracy反映出来。一个好的分类器,应该TP包含大部分圆圈,FP和FN几乎为空,所以很多比赛的评测指标是precision和recall的harmonic平均值,即:

harmonic平均比直接除以2更看重较小的那个值,只有两个值都比较大,整体才会大。

PR曲线和ROC曲线

为了得到较好的F1,需要调节适当的阈值。蓝色的线从最左往右滑动时,recall= TP/(TP+FN),分母不变,分子逐渐变小,从1单调递减到0。precision=TP/(TP+FP),分子和分母同时变小,总体上,TP变小的速度慢很多,大体上是递增的,但是并不绝对单调,尤其在靠近右侧。经常可以看到TP-1,FP不变,则precision反而变小:

关于Recall和Precision的tradeoff,还可以画一条PR 曲线:

ROC曲线是另外一种衡量二分类模型的方法,y轴是recall=TP/(TP+FN),x轴是FPR=FP/(FP+TN):

PR曲线与ROC曲线的区别在于,PR曲线不关心TN(x、y计算公式都没有包含TN),所以在负样本比例很高的时候,PR曲线波动比ROC曲线明显,更能体现优化空间。另外,ROC曲线关心TN恰好也是它的优势,比如在推荐、搜索等learn to rank 任务中,我们关心的是整个数据集的排序情况,TN也是需要考虑在内的,所以经常离线计算AUC(ROC曲线下方面积)来衡量rank model的优劣。

中场休息时间。。。喝口茶~ 欢迎关注公众号:kaggle实战,或博客:http://www.cnblogs.com/daniel-D/

4. 多分类器与误差分析

多分类器是指能区别两个以上类别的分类器,比如手写数字识别这个数据集要区分0-9,像大型图像数据集可能有几万个类别。有些算法可以直接区分多类,如softmax、RF或者贝叶斯,有些算法无法直接区分,比如上面用到的线性分类器等二分类器。二分类器也可以组合形成多分类器,常见的策略有 One vs All和 One vs One。

在数字识别这个任务中,One vs All(OVA) 一共要训练10个分类器,分别是0 vs 非0,1 vs 非1……预测的时候,10个分类器依次输出为0,为1等的概率,可直接取最大概率作为预测值。One vs One则需要10*(10-1)/ 2个分类器,依次是 0 vs 1,0 vs 2……8 vs 9。OVO和OVA在实际使用不多,这里就不赘述了。

用RF模型训练后,在预测集上预测图像属于哪个类别,由于模型不是百分百准确的,会有0判定成1或者1判定成2的情况,用rowIndex表示实际的label,colIndex表示预测的label,统计预测的label落到实际label的个数,可以得到以下矩阵:

可视化之后得到下图:

可以看到对角线方块很亮,说明所有类别基本判定准确。但是“5”方块较暗,可能是由于5的图片数量较少,或者5的准确率偏低导致,要具体分析数据才能找到原因。除了对角线方块,我们还想分析其余方块的情况,可以把Confusion Matrix每个元素处理该行的总和,对角线置0,得到下图:

可以看到3和5、7和9都容易混淆,想通过RF模型要提升效果的突破口可能就在这里。

5. Kaggle 实战

实际上,3和5、7和9容易混淆的原因在于,他们形态较为相似,直接用像素作为特征,相同的数字,在图像中旋转微小角度或者平移,都会导致像素空间的巨大变化,kaggle上高分kernel普遍都用神经网络里的CNN来提取特征,准确率可以轻松超过98%。预处理流程为:

  • 把 label 从0-9的dense编码转化为 one hot encode编码
  • 分割出4w个训练集和2k个验证集

然后定义一个最常用CNN网络结构和主要的超参数如下:

  • 卷积参数:一般设置stride=1,卷积后保持原尺寸,用0填充,非线性变换采用relu;pooling大小2*2,stride=2,取maxPooling

  • 网络结构:

    • input:(40000, 28, 28, 1)
    • conv1:kernel [5, 5, 1, 32] => (40000, 28, 28, 32)
    • maxPool1:kernel [1, 2, 2, 1] => (40000, 14, 14, 32)
    • conv2:kernel [5, 5, 32, 64] => (40000, 14, 14, 64)
    • maxPool2: kernel [1, 2, 2, 1] => (40000, 7, 7, 64)
    • flat:(40000,7*7*64) => (40000, 3136)
    • FC1: (40000, 1024),非线性变换仍然可以采用relu
    • dropout: 0.5
    • FC2:(40000, 10)
    • Loss:cross-entropy
  • 由于MNIST数据集各类分布都比较均匀,用准确率就能较好评估模型了,比其他指标更加直白

详细代码可以参考这个kernel:https://www.kaggle.com/kakauandme/tensorflow-deep-nn

参考资料

附:公众号

顺便测试下赞赏码

 
00:00
   

 
 
听说,资料要多重备份,本文博客:http://www.cnblogs.com/daniel-D/ ,本人公众号:kaggle实战,欢迎关注~~
 

Kaggle实战分类问题2的更多相关文章

  1. Kaggle实战之二分类问题

    0. 前言 1. MNIST 数据集 2. 二分类器 3. 效果评测 4. 多分类器与误差分析 5. Kaggle 实战 0. 前言 "尽管新技术新算法层出不穷,但是掌握好基础算法就能解决手 ...

  2. Kaggle实战之一回归问题

    0. 前言 1.任务描述 2.数据概览 3. 数据准备 4. 模型训练 5. kaggle实战 0. 前言 "尽管新技术新算法层出不穷,但是掌握好基础算法就能解决手头 90% 的机器学习问题 ...

  3. 机器学习(一):记一次k一近邻算法的学习与Kaggle实战

    本篇博客是基于以Kaggle中手写数字识别实战为目标,以KNN算法学习为驱动导向来进行讲解. 写这篇博客的原因 什么是KNN kaggle实战 优缺点及其优化方法 总结 参考文献 写这篇博客的原因 写 ...

  4. 【项目实战】kaggle产品分类挑战

    多分类特征的学习 这里还是b站刘二大人的视频课代码,视频链接:https://www.bilibili.com/video/BV1Y7411d7Ys?p=9 相关注释已经标明了(就当是笔记),因此在这 ...

  5. Python机器学习实践与Kaggle实战(转)

    https://mlnote.wordpress.com/2015/12/16/python%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E5%AE%9E%E8%B7%B5 ...

  6. Kaggle实战——点击率预估

    https://blog.csdn.net/chengcheng1394/article/details/78940565 原创文章,转载请注明出处: http://blog.csdn.net/che ...

  7. kaggle实战记录 =>Digit Recognizer

    date:2016-09-13 今天开始注册了kaggle,从digit recognizer开始学习, 由于是第一个案例对于整个流程目前我还不够了解,首先了解大神是怎么运行怎么构思,然后模仿.这样的 ...

  8. kaggle 实战 (1): PCA + KNN 手写数字识别

    文章目录 加载package read data PCA 降维探索 选择50维度, 拆分数据为训练集,测试机 KNN PCA降维和K值筛选 分析k & 维度 vs 精度 预测 生成提交文件 本 ...

  9. 基于Colab Pro & Google Drive的Kaggle实战

    原文:https://hippocampus-garden.com/kaggle_colab/ 原文标题:How to Kaggle with Colab Pro & Google Drive ...

随机推荐

  1. Httpd 文件服务器的搭建

    服务器信息 系统: CentOS 安装操作 安装 httpd 直接通过 yum 安装: yum install httpd 安装完成之后,可以检查版本: http 查看版本 httpd -versio ...

  2. android应用开发-从设计到实现 3-9 Origami动态原型设计

    动态原型设计 动态的可交互原型产品,是产品经理和界面设计师向开发人员阐释自己设计的最高效工具. 开发人员不须要推測设计师要什么样的效果,照着原型产品做就好了. 非常多创业团队也发现了产品人的这个刚需, ...

  3. Scala学习之爬豆瓣电影

    简单使用Scala和Jsoup对豆瓣电影进行爬虫,技术比較简单易学. 写文章不易,欢迎大家採我的文章,以及给出实用的评论,当然大家也能够关注一下我的github:多谢. 1.爬虫前期准备 找好须要抓取 ...

  4. Android与webserver数据交互编程---3网络爬虫项目实现虚拟浏览器的jsp后台执行

    背景:原先的b/s设计中在一个jsp界面中实现多个复杂的工作流... 为实现移动接口的调用保证工作流的正常webproject特别给提供了该虚拟浏览器的方案 原理:通过该方案实现虚拟浏览器后台运行js ...

  5. PopupMenu-使用实例跟监听事件

    今天需要给一个控件添加弹出菜单功能.就顺便学习了下popupMenu的使用,记录下来. 它的使用其实也非常的简单,看如下代码 popupMenu = new PopupMenu(MainActivit ...

  6. 使用PHP来压缩CSS文件

    这里将介绍使用PHP以一种简便的方式来压缩你的CSS文件.这种方法不需要命名你的.css文件和.php文件. 当前有许多方法需要将.css文件重命名成.php文件,然后在所有PHP文件中放置压缩代码. ...

  7. 智课雅思词汇---三、aud和auto和bene是什么意思

    智课雅思词汇---三.aud和auto和bene是什么意思 一.总结 一句话总结:aud:听 auto:自己,self bene:good,well 1.anthropo是什么意思? anthropo ...

  8. Android Studio 函数使用方法提示 快捷键

    看到好多说用F2的,转来转去,中国社区的氛围大概如此,你抄我的,我超你的. 下面的千篇一律: "悬浮窗不出来了,各种不习惯啊.那在Android Studio究竟怎样查看函数的说明呢.选中你 ...

  9. 28.lambda表达式与多线程

    #include <iostream> #include <thread> #include <Windows.h> #include <chrono> ...

  10. 28.semaphore跨进程通信

    根据id创建Semaphore,并初始化有一个信号量可用 name类型是char *...; HANDLE hsem = CreateSemaphoreA(NULL, 1, , name); 关闭句柄 ...