版权所有,转帖注明出处



Scikit-learn是一个开源Python库,它使用统一的接口实现了一系列机器学习、预处理、交叉验证和可视化算法。

一个基本例子

from sklearn import neighbors, datasets, preprocessing
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
iris = datasets.load_iris()
X, y = iris.data[:, :2], iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=33)
scaler = preprocessing.StandardScaler().fit(X_train)
X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test)
knn = neighbors.KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)
accuracy_score(y_test, y_pred)

加载数据

数据类型可以是NumPy数组、SciPy稀疏矩阵,或者其他可转换为数组的类型,如panda DataFrame等。

import numpy as np
X = np.random.random((10,5))
y = np.array(['M','M','F','F','M','F','M','M','F','F','F'])
X[X < 0.7] = 0

预处理数据

标准化/Standardization

from sklearn.preprocessing import StandardScaler
scaler = StandardScaler().fit(X_train)
standardized_X = scaler.transform(X_train)
standardized_X_test = scaler.transform(X_test)

归一化/Normalization

from sklearn.preprocessing import Normalizer
scaler = Normalizer().fit(X_train)
normalized_X = scaler.transform(X_train)
normalized_X_test = scaler.transform(X_test)

二值化/Binarization

from sklearn.preprocessing import Binarizer
binarizer = Binarizer(threshold=0.0).fit(X)
binary_X = binarizer.transform(X)

类别特征编码

from sklearn.preprocessing import LabelEncoder
enc = LabelEncoder()
y = enc.fit_transform(y)

缺失值估算

>>>from sklearn.preprocessing import Imputer
>>>imp = Imputer(missing_values=0, strategy='mean', axis=0)
>>>imp.fit_transform(X_train)

生成多项式特征

from sklearn.preprocessing import PolynomialFeatures
poly = PolynomialFeatures(5)
oly.fit_transform(X)

训练与测试数据分组

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X,y,random_state=0)

创建模型

有监督学习模型

线性回归

from sklearn.linear_model import LinearRegression
lr = LinearRegression(normalize=True)

支持向量机(SVM)

from sklearn.svm import SVC
svc = SVC(kernel='linear')

朴素贝叶斯

from sklearn.naive_bayes import GaussianNB
gnb = GaussianNB()

KNN

from sklearn.naive_bayes import GaussianNB
gnb = GaussianNB()

无监督学习模型

主成分分析(PCA)

from sklearn.decomposition import PCA
pca = PCA(n_components=0.95)

k均值/K Means

from sklearn.cluster import KMeans
k_means = KMeans(n_clusters=3, random_state=0)

模型拟合

有监督学习

lr.fit(X, y)
knn.fit(X_train, y_train)
svc.fit(X_train, y_train)

无监督学习

k_means.fit(X_train)
pca_model = pca.fit_transform(X_train)

模型预测

有监督学习

y_pred = svc.predict(np.random.random((2,5)))
y_pred = lr.predict(X_test)
y_pred = knn.predict_proba(X_test))

无监督学习

y_pred = k_means.predict(X_test)

评估模型性能

分类指标

准确度

knn.score(X_test, y_test)
from sklearn.metrics import accuracy_score
accuracy_score(y_test, y_pred)

分类报告

from sklearn.metrics import classification_report
print(classification_report(y_test, y_pred)))

混淆矩阵

from sklearn.metrics import confusion_matrix
print(confusion_matrix(y_test, y_pred)))

回归指标

平均绝对误差

from sklearn.metrics import mean_absolute_error
y_true = [3, -0.5, 2])
mean_absolute_error(y_true, y_pred))

均方差

from sklearn.metrics import mean_squared_error
mean_squared_error(y_test, y_pred))

$R^2$分数

from sklearn.metrics import r2_score
r2_score(y_true, y_pred))

聚类指标

调整兰德系数

from sklearn.metrics import adjusted_rand_score
adjusted_rand_score(y_true, y_pred))

同质性/Homogeneity

from sklearn.metrics import homogeneity_score
homogeneity_score(y_true, y_pred))

调和平均指标/V-measure

from sklearn.metrics import v_measure_score
metrics.v_measure_score(y_true, y_pred))

交叉验证

print(cross_val_score(knn, X_train, y_train, cv=4))
print(cross_val_score(lr, X, y, cv=2))

模型调优

网格搜索

from sklearn.grid_search import GridSearchCV
params = {"n_neighbors": np.arange(1,3), "metric": ["euclidean", "cityblock"]}
grid = GridSearchCV(estimator=knn,param_grid=params)
grid.fit(X_train, y_train)
print(grid.best_score_)
print(grid.best_estimator_.n_neighbors)

随机参数优化

from sklearn.grid_search import RandomizedSearchCV
params = {"n_neighbors": range(1,5), "weights": ["uniform", "distance"]}
rsearch = RandomizedSearchCV(estimator=knn,
param_distributions=params,
cv=4,
n_iter=8,
random_state=5)
rsearch.fit(X_train, y_train)
print(rsearch.best_score_)

Sklearn 速查的更多相关文章

  1. 机器学习算法 Python&R 速查表

    sklearn实战-乳腺癌细胞数据挖掘( 博主亲自录制) https://study.163.com/course/introduction.htm?courseId=1005269003&u ...

  2. 常用的14种HTTP状态码速查手册

    分类 1xx \> Information(信息) // 接收的请求正在处理 2xx \> Success(成功) // 请求正常处理完毕 3xx \> Redirection(重定 ...

  3. jQuery 常用速查

    jQuery 速查 基础 $("css 选择器") 选择元素,创建jquery对象 $("html字符串") 创建jquery对象 $(callback) $( ...

  4. 简明 Git 命令速查表(中文版)

    原文引用地址:https://github.com/flyhigher139/Git-Cheat-Sheet/blob/master/Git%20Cheat%20Sheet-Zh.md在Github上 ...

  5. 《zw版·Halcon-delphi系列原创教程》 zw版-Halcon常用函数Top100中文速查手册

    <zw版·Halcon-delphi系列原创教程> zw版-Halcon常用函数Top100中文速查手册 Halcon函数库非常庞大,v11版有1900多个算子(函数). 这个Top版,对 ...

  6. .htaccess下Flags速查表

    Flags是可选参数,当有多个标志同时出现时,彼此间以逗号分隔. 速查表: RewirteRule 标记 含义 描述 R Redirect 发出一个HTTP重定向 F Forbidden 禁止对URL ...

  7. IL指令速查

    名称 说明 Add 将两个值相加并将结果推送到计算堆栈上. Add.Ovf 将两个整数相加,执行溢出检查,并且将结果推送到计算堆栈上. Add.Ovf.Un 将两个无符号整数值相加,执行溢出检查,并且 ...

  8. Linux命令速查手册,超详细Linux命令教程

    一.常用命令速查 ls cd pwd cat more less tail head cp scp mv mkdir rmdir touch rm ps kill top free clear tre ...

  9. 25个有用的和方便的 WordPress 速查手册

    如果你是 WordPress 开发人员,下载一些方便的 WordPress 备忘单可以在你需要的时候快速查找.下面这个列表,我们已经列出了25个有用的和方便的 WordPress 速查手册,赶紧收藏吧 ...

随机推荐

  1. Python学习第六课——基本数据类型一之tuple and dict

    元组 (tuple) tu=(11,22,(123,456),[22,55],) # 一般定义元组的时候最后面加一个, # 元组不能被修改或者删除 v = tu[0] # 也可以根据索引取值 prin ...

  2. 课堂测试用javaweb写一个注册界面,并将数据保存到后台数据库(部分完成)

    今天我到现在为止,也只完成了数据库的连接,还没有写前台的javascript的检查输入的代码,打算周四前完成. 代码如下: package Dao; import java.sql.Connectio ...

  3. Ajax接收Json数据,调用template模板循环渲染页面的方法

    一. 后台接口吐出JSON数据 后台php接口中,需要写三个部分: 1.1 开头header规定数据格式: header("content-type:application/json;cha ...

  4. 黑马客户管理系统(SSM)

    黑马客户管理系统 1系统概述 1.1系统功能介绍 本系统后台使用SSM框架编写,前台页面使用当前主流的Bootstrap和jQuery框架完成页面信息展示功能(关于Bootstrap的知识,有兴趣的读 ...

  5. delphi窗体按钮灰化禁用

    1.使最小化按钮变灰:setwindowlong(handle,gwl_style,getwindowlong(handle,gwl_style)   and   not   ws_minimizeb ...

  6. SciPy 插值

    章节 SciPy 介绍 SciPy 安装 SciPy 基础功能 SciPy 特殊函数 SciPy k均值聚类 SciPy 常量 SciPy fftpack(傅里叶变换) SciPy 积分 SciPy ...

  7. 前端学习(20)~css布局(十三)

    常见的布局属性 (1)display 确定元素的显示类型: block:块级元素. inline:行内元素. inline-block:对外的表现是行内元素(不会独占一行),对内的表现是块级元素(可以 ...

  8. Linux系统-----包管理器的演变

    每个电脑设备都使用某种形式的软件来执行其预定任务.在软件开发的早期,对产品进行了严格的bug和其他缺陷测试.在过去的十多年里,软件通过互联网发布,目的是通过应用新版本的软件来修复任何错误.在某些情况下 ...

  9. Centos 7 安装与卸载MYSQL5.7

    先介绍卸载防止重装 yum方式 查看yum是否安装过mysqlyum list installed mysql*如或显示了列表,说明系统中有MySQL yum卸载 根据列表上的名字 yum remov ...

  10. 在 .net 中释放嵌入的资源

        private static void ExtractResourceToFile(string resourceName, string filename) {     if (!Syste ...