来源:https://databricks.com/blog/2016/02/08/auto-scaling-scikit-learn-with-apache-spark.html

Data scientists often spend hours or days tuning models to get the highest accuracy. This tuning typically involves running a large number of independent Machine Learning (ML) tasks coded in Python or R. Following some work presented at Spark Summit Europe 2015, we are excited to release scikit-learn integration package for Apache Spark that dramatically simplifies the life of data scientists using Python. This package automatically distributes the most repetitive tasks of model tuning on a Spark cluster, without impacting the workflow of data scientists:

  • When used on a single machine, Spark can be used as a substitute to the default multithreading framework used by scikit-learn (Joblib).
  • If a need comes to spread the work across multiple machines, no change is required in the code between the single-machine case and the cluster case.

Scale data science effortlessly

Python is one of the most popular programming languages for data exploration and data science, and this is in no small part due to high quality libraries such as Pandas for data exploration or scikit-learn for machine learning. Scikit-learn provides fast and robust implementations of standard ML algorithms such as clustering, classification, and regression.

Scikit-learn’s strength has typically been in the realm of computing on a single node, though. For some common scenarios, such as parameter tuning, a large number of small tasks can be run in parallel. These scenarios are perfect use cases for Spark.

We explored how to integrate Spark with scikit-learn, and the result is the Scikit-learn integration package for Spark. It combines the strengths of Spark and scikit-learn with no changes to users’ code. It re-implements some components of scikit-learn that benefit the most from distributed computing. Users will find a Spark-based cross-validator class that is fully compatible with scikit-learn’s cross-validation tools. By swapping out a single class import, users can distribute cross-validation for their existing scikit-learn workflows.

Distribute tuning of Random Forests

Consider a classical example of identifying digits in images. Here are a few examples of images taken from the popular digits dataset, with their labels:

We are going to train a random forest classifier to recognize the digits. This classifier has a number of parameters to adjust, and there is no easy way to know which parameters work best, other than trying out many different combinations. Scikit-learn provides GridSearchCV, a search algorithm that explores many parameter settings automatically. GridSearchCV uses selection by cross-validation, illustrated below. Each parameter setting produces one model, and the best-performing model is selected.

The original code, using only scikit-learn, is as follows:

from sklearn import grid_search, datasets
from sklearn.ensemble import RandomForestClassifier
from sklearn.grid_search import GridSearchCV
digits = datasets.load_digits()
X, y = digits.data, digits.target
param_grid = {"max_depth": [3, None],
"max_features": [1, 3, 10],
"min_samples_split": [1, 3, 10],
"min_samples_leaf": [1, 3, 10],
"bootstrap": [True, False],
"criterion": ["gini", "entropy"],
"n_estimators": [10, 20, 40, 80]}
gs = grid_search.GridSearchCV(RandomForestClassifier(), param_grid=param_grid)
gs.fit(X, y)

The dataset is small (in the hundreds of kilobytes), but exploring all the combinations takes about 5 minutes on a single core. The scikit-learn package for Spark provides an alternative implementation of the cross-validation algorithm that distributes the workload on a Spark cluster. Each node runs the training algorithm using a local copy of the scikit-learn library, and reports the best model back to the master:

The code is the same as before, except for a one-line change:

from sklearn import grid_search, datasets
from sklearn.ensemble import RandomForestClassifier
# Use spark_sklearn’s grid search instead:
from spark_sklearn import GridSearchCV
digits = datasets.load_digits()
X, y = digits.data, digits.target
param_grid = {"max_depth": [3, None],
"max_features": [1, 3, 10],
"min_samples_split": [1, 3, 10],
"min_samples_leaf": [1, 3, 10],
"bootstrap": [True, False],
"criterion": ["gini", "entropy"],
"n_estimators": [10, 20, 40, 80]}
gs = grid_search.GridSearchCV(RandomForestClassifier(), param_grid=param_grid)
gs.fit(X, y)

This example runs under 30 seconds on a 4-node cluster (which has 16 CPUs). For larger Datasets

" style="box-sizing: border-box; color: rgb(0, 0, 0) !important; text-decoration-line: none !important; border-bottom: 1px dotted rgb(0, 0, 0) !important;">datasets and more parameter settings, the difference is even more dramatic.

Get started

If you would like to try out this package yourself, it is available as a Spark package and as a PyPI library. To get started, check out this example notebook on Databricks.

In addition to distributing ML tasks in Python across a cluster, Scikit-learn integration package for Spark provides additional tools to export data from Spark to python and vice-versa. You can find methods to convert Spark DataFrames

" style="box-sizing: border-box; color: rgb(0, 0, 0) !important; text-decoration-line: none !important; border-bottom: 1px dotted rgb(0, 0, 0) !important;">DataFrames to Pandas dataframes and numpy arrays. More details can be found in this Spark Summit Europe presentation and in the API documentation.

Auto-scaling scikit-learn with Apache Spark的更多相关文章

  1. (原创)(三)机器学习笔记之Scikit Learn的线性回归模型初探

    一.Scikit Learn中使用estimator三部曲 1. 构造estimator 2. 训练模型:fit 3. 利用模型进行预测:predict 二.模型评价 模型训练好后,度量模型拟合效果的 ...

  2. Offset Management For Apache Kafka With Apache Spark Streaming

    An ingest pattern that we commonly see being adopted at Cloudera customers is Apache Spark Streaming ...

  3. Why Apache Spark is a Crossover Hit for Data Scientists [FWD]

    Spark is a compelling multi-purpose platform for use cases that span investigative, as well as opera ...

  4. Apache Spark 章节1

    作者:jiangzz 电话:15652034180 微信:jiangzz_wx 微信公众账号:jiangzz_wy 背景介绍 Spark是一个快如闪电的统一分析引擎(计算框架)用于大规模数据集的处理. ...

  5. APACHE SPARK 2.0 API IMPROVEMENTS: RDD, DATAFRAME, DATASET AND SQL

    What’s New, What’s Changed and How to get Started. Are you ready for Apache Spark 2.0? If you are ju ...

  6. What’s new for Spark SQL in Apache Spark 1.3(中英双语)

    文章标题 What’s new for Spark SQL in Apache Spark 1.3 作者介绍 Michael Armbrust 文章正文 The Apache Spark 1.3 re ...

  7. A Tale of Three Apache Spark APIs: RDDs, DataFrames, and Datasets(中英双语)

    文章标题 A Tale of Three Apache Spark APIs: RDDs, DataFrames, and Datasets 且谈Apache Spark的API三剑客:RDD.Dat ...

  8. Real Time Credit Card Fraud Detection with Apache Spark and Event Streaming

    https://mapr.com/blog/real-time-credit-card-fraud-detection-apache-spark-and-event-streaming/ Editor ...

  9. How-to: Tune Your Apache Spark Jobs (Part 1)

    Learn techniques for tuning your Apache Spark jobs for optimal efficiency. When you write Apache Spa ...

  10. Using Apache Spark and MySQL for Data Analysis

    What is Spark Apache Spark is a cluster computing framework, similar to Apache Hadoop. Wikipedia has ...

随机推荐

  1. 实验13:VLAN/TRUNK/VTP/

    实验10-1: 划分VLAN Ø    实验目的通过本实验,读者可以掌握如下技能:(1) 熟悉VLAN 的创建(2) 把交换机接口划分到特定VLAN Ø    实验拓扑 实验步骤要配置VLAN,首先要 ...

  2. Rancher2.x 一键式部署 Prometheus + Grafana 监控 Kubernetes 集群

    目录 1.Prometheus & Grafana 介绍 2.环境.软件准备 3.Rancher 2.x 应用商店 4.一键式部署 Prometheus 5.验证 Prometheus + G ...

  3. AI: 如何用钢笔工具画曲线

    AI 可以用来绘制矢量图片. 点击钢笔工具,点击画图会画出直线,点击拖拉画图会画出曲线. 锚点的摆放位置在侧面而非顶端. 控制柄越长,图形越尖锐. 画圆时控制柄长度控制在两点之间1/3 长度. 使用的 ...

  4. css- :before :after

    :before和:after的作用就是在指定的元素内容(而不是元素本身)之前或者之后插入一个包含content属性指定内容的行内元素,最基本的用法如下: #example:before { conte ...

  5. 【Detection】物体识别-制作PASCAL VOC数据集

    PASCAL VOC数据集 PASCAL VOC为图像识别和分类提供了一整套标准化的优秀的数据集,从2005年到2012年每年都会举行一场图像识别challenge 默认为20类物体 1 数据集结构 ...

  6. Scala 学习(9)之「函数式编程」

    引用透明 对相同的输入,总是能得到相同的输出. 如果 f(x) 的参数 x 和函数体都是引用透明的,那么函数 f 是纯函数. 违反引用透明的例子 我们可以很清楚的看到,对于相同的输入,第二次调用app ...

  7. [Python]List 过滤

    获取数据库列表屏蔽系统自带数据库 # 原代码 db_list_result = [('master', ), ('tempdb', ), ('model', ), ('msdb', ), ('stud ...

  8. 二进制编译安装nginx并加入systemctl管理服务

    一.安装nginx所需环境 # yum install gcc-c++ pcre pcre-devel zlib zlib-devel openssl openssl-devel -y 二.安装ngi ...

  9. MySql基础补漏笔记

    在MySQL教程|菜鸟教程系统复习的时候有一些知识点还没掌握透的或者思维方式还没完全跟上的地方,写了一个笔记,讲道理此笔记对除我之外的任何读者不具有任何实用价值,只针对我在复习MySQL基础过程中的查 ...

  10. Leetcode:110. 平衡二叉树

    Leetcode:110. 平衡二叉树 Leetcode:110. 平衡二叉树 点链接就能看到原题啦~ 关于AVL的判断函数写法,请跳转:平衡二叉树的判断 废话不说直接上代码吧~主要的解析的都在上面的 ...