1. 目的:根据人口普查数据来预测收入(预测每个个体年收入是否超过$50,000)

2. 数据来源:1994年美国人口普查数据,数据中共含31978个观测值,每个观测值代表一个个体

3. 变量介绍:

(1)age: 年龄(以年表示)

(2)workclass: 工作类别/性质 (e.g., 国家机关工作人员、当地政府工作人员、无收入人员等)

(3)education: 受教育水平 (e.g., 小学、初中、高中、本科、硕士、博士等)

(4)maritalstatus: 婚姻状态(e.g., 未婚、离异等)

(5)occupation: 工作类型 (e.g., 行政/文员、农业养殖人员、销售人员等)

(6)relationship: 家庭身份 (e.g., 丈夫、妻子、孩子等)

(7)race: 种族

(8)sex: 性别

(9)capitalgain: 1994年的资本收入 (买卖股票、债券等)

(10)capitalloss: 1994年的资本支出 (买卖股票、债券等)

(11)hoursperweek: 每周工作时长

(12)nativecountry: 国籍

(13)over50k: 1994年全年工资是否超过$50,000

4. 应用及分析

census <- read.csv("census.csv") #读取文件

  

library(caTools) # 加载caTools包
# 将数据分为测试集和训练集
set.seed(2000)
spl <- sample.split(census$over50k, SplitRatio = 0.6)
census.train <- subset(census, spl == T) # 测试集
census.test <- subset(census, spl == F) # 训练集

  

# 构建逻辑回归模型
census.logistic <- glm(over50k ~ ., data = census.train, family = 'binomial')
summary(census.logistic) # 查看模型拟合结果

# 在临界值为0.5的情况下,逻辑回归模型应用到测试集的准确性
## method1
census.logistic.pred <- predict(census.logistic, newdata = census.test, type = 'response')
library(caret)
confusionMatrix(as.factor(ifelse(census.logistic.pred >= 0.5, " >50K", " <=50K")), as.factor(census.test$over50k)) ## method2
table(census.test$over50k, census.logistic.pred>= 0.5)
sum(diag(table(census.test$over50k, census.logistic.pred>= 0.5)))/nrow(census.test) #0.8552 # 测试集的基础准确性
table(census.test$over50k)/nrow(census.test) #0.759

  

# ROC 以及 AUC
library(ROCR)
census.pred <- prediction(census.logistic.pred, census.test$over50k)
census.perf <- performance(census.pred, 'tpr', 'fpr')
plot(census.perf, colorize = T) #ROC curve
as.numeric(performance(census.pred, 'auc')@y.values) #AUC value is 0.9061598

虽然逻辑回归模型准确率高达0.8572,且变量的显著性有助于我们判断个体的收入情况;但是在自变量中的分类变量类别太多的情况下,我们无法判断哪些变量更重要。

因此,接下来构建CART模型。

# 默认的CART模型
library(rpart)
library(rpart.plot)
census.cart <- rpart(over50k ~ ., data = census.train, method = 'class')
prp(census.cart) # 作图

# 模型准确性
census.cart.pred <- predict(census.cart, newdata = census.test, type = 'class')
## method1
table(census.test$over50k, census.cart.pred)
sum(diag(table(census.test$over50k, census.cart.pred)))/nrow(census.test)
## method2
confusionMatrix(census.cart.pred, as.factor(census.test$over50k)) # 模型准确性为0.8474
# ROC 以及 AUC
census.cart.pred2 <- predict(census.cart, newdata = census.test)
census.cart.pred2
census.cart.pred3 <- prediction(census.cart.pred2[,2], census.test$over50k)
census.cart.perf <- performance(census.cart.pred3, 'tpr', 'fpr')
plot(census.cart.perf, colorize = T) # ROC as.numeric(performance(census.cart.pred3, 'auc')@y.values) #AUC value is 0.8470256
# 随机森林模型
set.seed(1)
census.train.small <- census.train[sample(nrow(census.train), 2000),]
## 构建随机森林模型之前先减小训练集样本数量。
## 因为随机森林过程中包含大量运算过程,小样本更益于模型的建立 library(randomForest)
census.train.small.rf <- randomForest(over50k ~ ., data = census.train.small) # 模型预测
census.train.small.rf.pred <- predict(census.train.small.rf, newdata = census.test) # 模型准确性
confusionMatrix(census.train.small.rf.pred, as.factor(census.test$over50k)) # 0.8533

  

因为随机森林模型是一系列分类决策树的集合,因此与分类决策树相比,随机森林模型的解释性稍差,但仍可用一些方法来衡量变量的重要性

# 方法一:统计随机过程中每个变量出现的次数
vu <- varUsed(census.train.small.rf, count=TRUE)
vusrted <- sort(vu, decreasing = FALSE, index.return = TRUE)
# draw a Cleveland dot plot
dotchart(vusorted$x, names(census.train.small.rf$forest$xlevels[vusorted$ix]))

其中,age出现次数最多,sex出现次数最少。

# 方法二:比较平均Gini指数的下降程度
varImpPlot(census.train.small.rf)

其中,occupation、education、age的平均Gini指数减少的最多,sex的平均Gini指数减少的最少

# 改进的CART模型(考虑cp值)
library(caret)
library(lattice)
library(ggplot2)
library(e1071) # 找出使得准确率最高的cp值
set.seed(2)
numFolds <- trainControl(method = 'cv', number = 10)
cpGrid <- expand.grid(.cp = seq(0.002,0.1,0.002))
train(over50k ~ ., data = census.train,
method = 'rpart', trControl = numFolds, tuneGrid = cpGrid) # cp = 0.002时模型准确度最高 # 构建新的CART模型(cp=0.002)
census.bestTree <- rpart(over50k ~ ., data = census.train, method = 'class', cp = 0.002)
prp(census.bestTree) # 作图 # 模型预测
predCV <- predict(census.bestTree, newdata = census.test, type = 'class') # 计算新模型的准确率
## method1
table(census.test$over50k, predCV)
sum(diag(table(census.test$over50k, predCV)))/nrow(census.test)
## method2
confusionMatrix(predCV, as.factor(census.test$over50k)) # 0.8612

考虑cp值以后的CART模型的准确性比默认模型高了1%左右,但是模型明显复杂了更多,因此需要在模型简洁性及准确性之间做出权衡。

本案例中,默认模型足够简洁且准确度也很高,所以倾向使用默认模型。

【R语言学习笔记】 Day1 CART 逻辑回归、分类树以及随机森林的应用及对比的更多相关文章

  1. R语言学习笔记:基础知识

    1.数据分析金字塔 2.[文件]-[改变工作目录] 3.[程序包]-[设定CRAN镜像] [程序包]-[安装程序包] 4.向量 c() 例:x=c(2,5,8,3,5,9) 例:x=c(1:100) ...

  2. R语言学习笔记—决策树分类

    一.简介 决策树分类算法(decision tree)通过树状结构对具有某特征属性的样本进行分类.其典型算法包括ID3算法.C4.5算法.C5.0算法.CART算法等.每一个决策树包括根节点(root ...

  3. R语言学习笔记之: 论如何正确把EXCEL文件喂给R处理

    博客总目录:http://www.cnblogs.com/weibaar/p/4507801.html ---- 前言: 应用背景兼吐槽 继续延续之前每个月至少一次更新博客,归纳总结学习心得好习惯. ...

  4. R语言学习笔记(二)

    今天主要学习了两个统计学的基本概念:峰度和偏度,并且用R语言来描述. > vars<-c("mpg","hp","wt") &g ...

  5. R语言学习笔记(一)

    1.不同的行业对数据集(即表格)的行和列称谓不同,统计学家称其为观测(observation)和变量(variable): 2.R语言存储数据的结构: ①向量:类似于C语言里的一位数组,执行组合功能的 ...

  6. R语言学习笔记:字符串处理

    想在R语言中生成一个图形文件的文件名,前缀是fitbit,后面跟上月份,再加上".jpg",先不百度,试了试其它语言的类似语法,没一个可行的: C#中:"fitbit&q ...

  7. R语言学习笔记:小试R环境

    买了三本R语言的书,同时使用来学习R语言,粗略翻下来感觉第一本最好: <R语言编程艺术>The Art of R Programming <R语言初学者使用>A Beginne ...

  8. R语言学习笔记 (入门知识)

    R免费使用:统计工具:# 注释,行注释块注释:anything="这是注释的内容"常用R语言编辑器:Rsutdio,Tinn-R,Eclipse+StatET:中文会有乱码帮助:? ...

  9. R语言学习笔记—K近邻算法

    K近邻算法(KNN)是指一个样本如果在特征空间中的K个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性.即每个样本都可以用它最接近的k个邻居来代表.KNN算法适 ...

随机推荐

  1. 算法学习之剑指offer(六)

    题目1 题目描述 输入n个整数,找出其中最小的K个数.例如输入4,5,1,6,2,7,3,8这8个数字,则最小的4个数字是1,2,3,4,. import java.util.*; public cl ...

  2. 一张图一个题帮你迅速理解RLU算法

    下面是某年的软考题: 某进程页面访问序列为4,2,3,1,2,4,5,3,1,2,3,5,且开始执行时内存中没有页面,分配给该进程的物理块数是3,则采用RLU页面置换算法时的缺页率是多少? 对于这个问 ...

  3. 基于AHB总线的master读写设计(Verilog)

    一.AHB总线学习 1. AHB总线结构 如图所示,AHB总线系统利用中央多路选择机制实现主机与从机的互联问题.从图中可以看出,AHB总线结构主要可分为三部分:主机.从机.控制部分.控制部分由仲裁器. ...

  4. Web安全 --Wfuzz 使用大全

    前言:  做web渗透大多数时候bp来fuzz   偶尔会有觉得要求达不到的时候 wfuzz就很有用了这时候 用了很久了这点来整理一次 wfuzz 是一款Python开发的Web安全模糊测试工具. 下 ...

  5. 设计糟糕的 RESTful API 就是在浪费时间!

    现在微服务真是火的一塌糊涂.大街小巷,逢人必谈微服务,各路大神纷纷忙着把自家的单体服务拆解成多个Web微小服务.而作为微服务之间通信的桥梁,Web API的设计就显得非常重要. HTTP是目前互联网使 ...

  6. 解决连接oracle报错 尝试加载Oracle客户端库时引发BadImageFomatException。如果在安装64位Oracle客户端组件的情况下以32位模式运行,将出现此问题的报错。

    最近遇到一个.NET连接Oracle的一个错误,其主要原因是换了一台电脑,在新电脑上运行以前的项目出现了的一个错误,工作环境为vs2017+Oracle 64位,win10系统 这个错误头疼了一天,找 ...

  7. [Abp vNext 源码分析] - 11. 用户的自定义参数与配置

    一.简要说明 文章信息: 基于的 ABP vNext 版本:1.0.0 创作日期:2019 年 10 月 23 日晚 更新日期:暂无 ABP vNext 针对用户可编辑的配置,提供了单独的 Volo. ...

  8. Mybatis源码阅读 之 玩转Executor

    承接上篇博客, 本文探究MyBatis中的Executor, 如下图: 是Executor体系图 本片博客的目的就是探究如上图中从顶级接口Executor中拓展出来的各个子执行器的功能,以及进一步了解 ...

  9. 四、pymysql模块、索引和慢查询

    目录 一.pymysql模块 (一)如何使用 (二)sql注入问题 二.索引 (一)主键索引 (二)唯一索引 (三)普通索引 (四)联合索引 (五)不会命中索引的情况 (六)explain (七)索引 ...

  10. SpringCloud之Zuul配置问题

    当通过网关去调用服务的时候,尤其是服务里面配置了熔断,会发现拿不到熔断返回的信息 hystrix: command: default: execution: isolation: thread: ti ...