版权所有,转帖注明出处



Scikit-learn是一个开源Python库,它使用统一的接口实现了一系列机器学习、预处理、交叉验证和可视化算法。

一个基本例子

from sklearn import neighbors, datasets, preprocessing
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
iris = datasets.load_iris()
X, y = iris.data[:, :2], iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=33)
scaler = preprocessing.StandardScaler().fit(X_train)
X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test)
knn = neighbors.KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)
accuracy_score(y_test, y_pred)

加载数据

数据类型可以是NumPy数组、SciPy稀疏矩阵,或者其他可转换为数组的类型,如panda DataFrame等。

import numpy as np
X = np.random.random((10,5))
y = np.array(['M','M','F','F','M','F','M','M','F','F','F'])
X[X < 0.7] = 0

预处理数据

标准化/Standardization

from sklearn.preprocessing import StandardScaler
scaler = StandardScaler().fit(X_train)
standardized_X = scaler.transform(X_train)
standardized_X_test = scaler.transform(X_test)

归一化/Normalization

from sklearn.preprocessing import Normalizer
scaler = Normalizer().fit(X_train)
normalized_X = scaler.transform(X_train)
normalized_X_test = scaler.transform(X_test)

二值化/Binarization

from sklearn.preprocessing import Binarizer
binarizer = Binarizer(threshold=0.0).fit(X)
binary_X = binarizer.transform(X)

类别特征编码

from sklearn.preprocessing import LabelEncoder
enc = LabelEncoder()
y = enc.fit_transform(y)

缺失值估算

>>>from sklearn.preprocessing import Imputer
>>>imp = Imputer(missing_values=0, strategy='mean', axis=0)
>>>imp.fit_transform(X_train)

生成多项式特征

from sklearn.preprocessing import PolynomialFeatures
poly = PolynomialFeatures(5)
oly.fit_transform(X)

训练与测试数据分组

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X,y,random_state=0)

创建模型

有监督学习模型

线性回归

from sklearn.linear_model import LinearRegression
lr = LinearRegression(normalize=True)

支持向量机(SVM)

from sklearn.svm import SVC
svc = SVC(kernel='linear')

朴素贝叶斯

from sklearn.naive_bayes import GaussianNB
gnb = GaussianNB()

KNN

from sklearn.naive_bayes import GaussianNB
gnb = GaussianNB()

无监督学习模型

主成分分析(PCA)

from sklearn.decomposition import PCA
pca = PCA(n_components=0.95)

k均值/K Means

from sklearn.cluster import KMeans
k_means = KMeans(n_clusters=3, random_state=0)

模型拟合

有监督学习

lr.fit(X, y)
knn.fit(X_train, y_train)
svc.fit(X_train, y_train)

无监督学习

k_means.fit(X_train)
pca_model = pca.fit_transform(X_train)

模型预测

有监督学习

y_pred = svc.predict(np.random.random((2,5)))
y_pred = lr.predict(X_test)
y_pred = knn.predict_proba(X_test))

无监督学习

y_pred = k_means.predict(X_test)

评估模型性能

分类指标

准确度

knn.score(X_test, y_test)
from sklearn.metrics import accuracy_score
accuracy_score(y_test, y_pred)

分类报告

from sklearn.metrics import classification_report
print(classification_report(y_test, y_pred)))

混淆矩阵

from sklearn.metrics import confusion_matrix
print(confusion_matrix(y_test, y_pred)))

回归指标

平均绝对误差

from sklearn.metrics import mean_absolute_error
y_true = [3, -0.5, 2])
mean_absolute_error(y_true, y_pred))

均方差

from sklearn.metrics import mean_squared_error
mean_squared_error(y_test, y_pred))

$R^2$分数

from sklearn.metrics import r2_score
r2_score(y_true, y_pred))

聚类指标

调整兰德系数

from sklearn.metrics import adjusted_rand_score
adjusted_rand_score(y_true, y_pred))

同质性/Homogeneity

from sklearn.metrics import homogeneity_score
homogeneity_score(y_true, y_pred))

调和平均指标/V-measure

from sklearn.metrics import v_measure_score
metrics.v_measure_score(y_true, y_pred))

交叉验证

print(cross_val_score(knn, X_train, y_train, cv=4))
print(cross_val_score(lr, X, y, cv=2))

模型调优

网格搜索

from sklearn.grid_search import GridSearchCV
params = {"n_neighbors": np.arange(1,3), "metric": ["euclidean", "cityblock"]}
grid = GridSearchCV(estimator=knn,param_grid=params)
grid.fit(X_train, y_train)
print(grid.best_score_)
print(grid.best_estimator_.n_neighbors)

随机参数优化

from sklearn.grid_search import RandomizedSearchCV
params = {"n_neighbors": range(1,5), "weights": ["uniform", "distance"]}
rsearch = RandomizedSearchCV(estimator=knn,
param_distributions=params,
cv=4,
n_iter=8,
random_state=5)
rsearch.fit(X_train, y_train)
print(rsearch.best_score_)

Sklearn 速查的更多相关文章

  1. 机器学习算法 Python&R 速查表

    sklearn实战-乳腺癌细胞数据挖掘( 博主亲自录制) https://study.163.com/course/introduction.htm?courseId=1005269003&u ...

  2. 常用的14种HTTP状态码速查手册

    分类 1xx \> Information(信息) // 接收的请求正在处理 2xx \> Success(成功) // 请求正常处理完毕 3xx \> Redirection(重定 ...

  3. jQuery 常用速查

    jQuery 速查 基础 $("css 选择器") 选择元素,创建jquery对象 $("html字符串") 创建jquery对象 $(callback) $( ...

  4. 简明 Git 命令速查表(中文版)

    原文引用地址:https://github.com/flyhigher139/Git-Cheat-Sheet/blob/master/Git%20Cheat%20Sheet-Zh.md在Github上 ...

  5. 《zw版·Halcon-delphi系列原创教程》 zw版-Halcon常用函数Top100中文速查手册

    <zw版·Halcon-delphi系列原创教程> zw版-Halcon常用函数Top100中文速查手册 Halcon函数库非常庞大,v11版有1900多个算子(函数). 这个Top版,对 ...

  6. .htaccess下Flags速查表

    Flags是可选参数,当有多个标志同时出现时,彼此间以逗号分隔. 速查表: RewirteRule 标记 含义 描述 R Redirect 发出一个HTTP重定向 F Forbidden 禁止对URL ...

  7. IL指令速查

    名称 说明 Add 将两个值相加并将结果推送到计算堆栈上. Add.Ovf 将两个整数相加,执行溢出检查,并且将结果推送到计算堆栈上. Add.Ovf.Un 将两个无符号整数值相加,执行溢出检查,并且 ...

  8. Linux命令速查手册,超详细Linux命令教程

    一.常用命令速查 ls cd pwd cat more less tail head cp scp mv mkdir rmdir touch rm ps kill top free clear tre ...

  9. 25个有用的和方便的 WordPress 速查手册

    如果你是 WordPress 开发人员,下载一些方便的 WordPress 备忘单可以在你需要的时候快速查找.下面这个列表,我们已经列出了25个有用的和方便的 WordPress 速查手册,赶紧收藏吧 ...

随机推荐

  1. 初识IntPtr------转载

    初识IntPtr 一:什么是IntPtr 先来看看MSDN上说的:用于表示指针或句柄的平台特定类型.这个其实说出了这样两个事实,IntPtr 可以用来表示指针或句柄.它是一个平台特定类型.对于它的解释 ...

  2. Linux 创建网卡子接口

    创建网卡子接口,添加IP别名 ifconfig eth0:0  2.2.2.2/24 或 ip addr add 2.2.2.2/24 dev eth0 label eth0:0 清除网卡子接口,删除 ...

  3. 关于Java构造类与对象的思考

    简单记录一下Java构造类与对象时的流程以及this和super对于特殊例子的分析. 首先,接着昨天的问题,我做出了几个变形: Pic1.原版: Pic2.去掉了T.foo方法中的this关键字: P ...

  4. CH15 面向对象程序设计

    面向对象程序设计是基于三个基本概念的:数据抽象.继承和多态. 第7章介绍了数据抽象的知识,简单来说,C++通过定义自己的数据类型来实现数据抽象. 数据抽象是一种依赖于接口和实现分离的编程技术:类的设计 ...

  5. luogu P3358 最长k可重区间集问题

    网络流建图好难,这题居然是网络流(雾,一般分析来说,有限制的情况最大流情况可以拆点通过capacity来限制,比如只使用一次,把一个点拆成入点出点,capacity为1即可,这题是限制最大k重复,可以 ...

  6. Day3-R-Aggressive cows POJ2456

    Farmer John has built a new long barn, with N (2 <= N <= 100,000) stalls. The stalls are locat ...

  7. Matplotlib 多个图形

    章节 Matplotlib 安装 Matplotlib 入门 Matplotlib 基本概念 Matplotlib 图形绘制 Matplotlib 多个图形 Matplotlib 其他类型图形 Mat ...

  8. Fiddler抓取HTTP请求。

    参考链接:http://blog.csdn.net/ohmygirl/article/details/17849983/ http://www.cnblogs.com/kingwolf_JavaScr ...

  9. Day6 - 牛客203E

    https://ac.nowcoder.com/acm/contest/203/E 埋坑不会做

  10. leetcode349 350 Intersection of Two Arrays & II

    """ Intersection of Two Arrays Given two arrays, write a function to compute their in ...