这一篇我们将开始使用scikit-learn的API来实现模型并进行训练,这个包大大方便了我们的学习过程,其中包含了对常用算法的实现,并进行高度优化,以及含有数据预处理、调参和模型评估的很多方法。

我们来看一个之前看过的实例,不过这次我们使用sklearn来训练一个感知器模型,数据集还是Iris,使用其中两维度的特征,样本数据使用三个类别的全部150个样本

%matplotlib inline
import numpy as np
from sklearn import datasets
iris = datasets.load_iris()
X = iris.data[:, [2, 3]]
y = iris.target
np.unique(y)
array([0, 1, 2])

为了评估训练好的模型对新数据的预测能力,我们这里将Iris数据集分为训练集和测试集,这里我们通过调用trian_test_split方法来将数据集分为两部分,其中测试集占30%,训练集占70%

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)

我们再将特征进行缩放操作,这里调用StandardScaler来对特征进行标准化:

from sklearn.preprocessing import StandardScaler
sc = StandardScaler()
sc.fit(X_train)
X_train_std = sc.transform(X_train)
X_test_std = sc.transform(X_test)

新对象sc使用fit方法对数据集中每一维的特征计算出样本平均值和标准差,然后调用transform方法对数据集进行标准化,我们这里使用相同的标准化参数对待训练集和测试集。接下来我们训练一个感知器模型

from sklearn.linear_model import Perceptron
ppn = Perceptron(max_iter=40, eta0=0.1, random_state=0)
ppn.fit(X_train_std, y_train)
Perceptron(alpha=0.0001, class_weight=None, early_stopping=False, eta0=0.1,
           fit_intercept=True, max_iter=40, n_iter_no_change=5, n_jobs=None,
           penalty=None, random_state=0, shuffle=True, tol=0.001,
           validation_fraction=0.1, verbose=0, warm_start=False)
y_pred = ppn.predict(X_test_std)
print('Misclassified samples: %d' % (y_test != y_pred).sum())
Misclassified samples: 5

可以看出测试集中有5个样本被分错类了,因此错误分类率是0.11,则分类准确率为1-0.11=0.89,我们也可以直接计算分类准确率:

from sklearn.metrics import accuracy_score
print('Accuracy: %.2f' % accuracy_score(y_test, y_pred))
Accuracy: 0.89

最后我们画出分界区域,这里我们将plot_decision_regions函数进行一些修改,使我们可以区分训练集和测试集的样本

from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
def plot_decision_regions(X, y, classifier, test_idx=None, resolution=0.02):
    # setup marker generator and color map
    markers = ('s', 'x', 'o', '^', 'v')
    colors = ('red', 'blue', 'lightgreen', 'gray', 'cyan')
    cmap = ListedColormap(colors[:len(np.unique(y))])

    # plot the decision surface
    x1_min, x1_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    x2_min, x2_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx1, xx2 = np.meshgrid(np.arange(x1_min, x1_max, resolution),
                          np.arange(x2_min, x2_max, resolution))
    Z = classifier.predict(np.array([xx1.ravel(), xx2.ravel()]).T)
    Z = Z.reshape(xx1.shape)
    plt.contourf(xx1, xx2, Z, slpha=0.4, cmap=cmap)
    plt.xlim(xx1.min(), xx1.max())
    plt.ylim(xx2.min(), xx2.max())

    # plot all samples
    X_test, y_test = X[test_idx, :], y[test_idx]
    for idx, cl in enumerate(np.unique(y)):
        plt.scatter(x=X[y == cl, 0], y=X[y == cl, 1], alpha=0.8,
                   c=cmap(idx), marker=markers[idx], label=cl)

    # highlight test samples
    if test_idx:
        X_test, y_test = X[test_idx, :], y[test_idx]
        plt.scatter(X_test[:, 0], X_test[:, 1], c='',alpha=1.0,
                   linewidth=1, marker='o', s=55, label='test set')
X_combined_std = np.vstack((X_train_std, X_test_std))
y_combined = np.hstack((y_train, y_test))
plot_decision_regions(X=X_combined_std, y=y_combined, classifier=ppn,
                     test_idx=range(105, 150))
plt.xlabel('petal length [standardized]')
plt.ylabel('petal width [standardized]')
plt.legend(loc='upper left')
plt.show()

可以看出三个类别并没有被完美分类,这是由于这三类花并不是线性可分的数据。

python机器学习——使用scikit-learn训练感知机模型的更多相关文章

  1. 吴裕雄 python 机器学习——人工神经网络与原始感知机模型

    import numpy as np from matplotlib import pyplot as plt from mpl_toolkits.mplot3d import Axes3D from ...

  2. (原创)(三)机器学习笔记之Scikit Learn的线性回归模型初探

    一.Scikit Learn中使用estimator三部曲 1. 构造estimator 2. 训练模型:fit 3. 利用模型进行预测:predict 二.模型评价 模型训练好后,度量模型拟合效果的 ...

  3. 机器学习框架Scikit Learn的学习

    一   安装 安装pip 代码如下:# wget "https://pypi.python.org/packages/source/p/pip/pip-1.5.4.tar.gz#md5=83 ...

  4. 使用SKlearn(Sci-Kit Learn)进行SVR模型学习

    今天了解到sklearn这个库,简直太酷炫,一行代码完成机器学习. 贴一个自动生成数据,SVR进行数据拟合的代码,附带网格搜索(GridSearch, 帮助你选择合适的参数)以及模型保存.读取以及结果 ...

  5. 吴裕雄 python 机器学习——集成学习AdaBoost算法回归模型

    import numpy as np import matplotlib.pyplot as plt from sklearn import datasets,ensemble from sklear ...

  6. 吴裕雄 python 机器学习——集成学习AdaBoost算法分类模型

    import numpy as np import matplotlib.pyplot as plt from sklearn import datasets,ensemble from sklear ...

  7. 吴裕雄 python 机器学习——支持向量机SVM非线性分类SVC模型

    import numpy as np import matplotlib.pyplot as plt from sklearn import datasets, linear_model,svm fr ...

  8. 吴裕雄 python 机器学习——等度量映射Isomap降维模型

    # -*- coding: utf-8 -*- import numpy as np import matplotlib.pyplot as plt from sklearn import datas ...

  9. 吴裕雄 python 机器学习——多维缩放降维MDS模型

    # -*- coding: utf-8 -*- import numpy as np import matplotlib.pyplot as plt from sklearn import datas ...

随机推荐

  1. badboy录制脚本

    第一步:介绍badboy工具 1.1: 页面功能分析: 1. 界面视图,模拟浏览器,能够进行操作 2. 需要录制脚本的URL 3. 点击运行URL 4. Summary:运行的各指标,响应时间,成功事 ...

  2. .NET Core 3.0 ,WTM 2.3.9发布

    .Net Core 3.0已经来了,WTM怎么可以落后呢.最新发布的WTM2.3.9版本已经支持.Net Core 3.0啦,现在在线生成项目的时候可以选择2.2和3.0两个版本.小伙伴们快来体验吧. ...

  3. Detours 劫持

    在使用 Detours 劫持之前必须得拥有这两个东西:detours.h 和 detours.lib. 为了这两个东西我真的是弄了大半天,本着自己动手丰衣足食的思想: 我去 GitHub 克隆了一份来 ...

  4. powershell 基础

    目录 本教程概述 用到的工具 标签 简介 0x01使用简介 0x02脚本编写 0x03实例讲解 本教程概述 本课我们学习powershell使用. 用到的工具 cmd.exe   powershell ...

  5. shark恒破解笔记6-BC++假自效验

    这小节介绍了查壳(peid) 查软件编写语言(die)以及用esp定律脱aspack壳,最后是破解bc++的自校验部分 目标: 首先查看软件 peid查壳 有壳 ,但是不知道是什么语言写的,这里使用D ...

  6. Kubernetes的Ingress简单入门

    目录 一.什么是Ingress 二.部署Nginx Ingress Controller 三.部署一个Service将Nginx服务暴露出去 四.部署一个我们自己的服务Cafe 五.部署ingress ...

  7. python方法是什么?

    python方法是什么? 方法用来描述对象所具有的行为. 在类中定义的方法可以粗略分为四大类:公有方法.私有方法.静态方法.类方法. 公有方法.私有方法一般所指属于对象的实例方法, 私有方法的名字以两 ...

  8. Spring Boot提供RESTful接口时的错误处理实践

    使用Spring Boot开发微服务的过程中,我们会使用别人提供的接口,也会设计接口给别人使用,这时候微服务应用之间的协作就需要有一定的规范. 基于rpc协议,我们一般有两种思路:(1)提供服务的应用 ...

  9. Mybaits 源码解析 (二)----- 根据配置文件创建SqlSessionFactory(Configuration的创建过程)

    我们使用mybatis操作数据库都是通过SqlSession的API调用,而创建SqlSession是通过SqlSessionFactory.下面我们就看看SqlSessionFactory的创建过程 ...

  10. 百万年薪python之路 -- 软件的开发规范

    一. 软件的开发规范 什么是开发规范?为什么要有开发规范呢? 你现在包括之前写的一些程序,所谓的'项目',都是在一个py文件下完成的,代码量撑死也就几百行,你认为没问题,挺好.但是真正的后端开发的项目 ...