1. 梯度计算式导出

我们在博客《统计学习:逻辑回归与交叉熵损失(Pytorch实现)》中提到,设\(w\)为权值(最后一维为偏置),样本总数为\(N\),\(\{(x_i, y_i)\}_{i=1}^N\)为训练样本集。样本维度为\(D\),\(x_i\in \mathbb{R}^{D+1}\)(最后一维扩充),\(y_i\in\{0, 1\}\)。则逻辑回归的损失函数为:

\[\mathcal{l}(w) = \sum_{i=1}^{N}\left[y_{i} \log \pi_{w}\left(x_{i}\right)+\left(1-y_{i}\right) \log \left(1-\pi_w\left(x_{i}\right)\right)\right]
\]

这里

\[\begin{aligned}
\pi_w(x) = p(y=1 \mid x; w) =\frac{1}{1+\exp \left(-w^{T} x\right)}
\end{aligned}
\]

写成这个形式就已经可以用诸如Pytorch这类工具来进行自动求导然后采用梯度下降法求解了。不过若需要用表达式直接计算出梯度,我们还需要将损失函数继续化简为:

\[\mathcal{l}(w) = -\sum_{i=1}^N(y_i w^T x_i - \log(1 + \exp(w^T x_i)))
\]

可将梯度表示如下:

\[\nabla_w{\mathcal{l}(w)} = -\sum_{i=1}^N(y_i - \frac{1}{\exp(-w^Tx)+1})x_i
\]

2. 基于Spark的并行化实现

逻辑回归的目标函数常采用梯度下降法求解,该算法的并行化可以采用如下的Map-Reduce架构:

先将第\(t\)轮迭代的权重广播到各worker,各worker计算一个局部梯度(map过程),然后再将每个节点的梯度聚合(reduce过程),最终对参数进行更新。

在Spark中每个task对应一个分区,决定了计算的并行度(分区的概念详间我们上一篇博客Spark: 单词计数(Word Count)的MapReduce实现(Java/Python))。在Spark的实现过程如下:

  • map阶段: 各task运行map()函数对每个样本\((x_i, y_i)\)计算梯度\(g_i\), 然后对每个样本对应的梯度运行进行本地聚合,以减少后面的数据传输量。如第1个task执行reduce()操作得到\(\widetilde{g}_1 = \sum_{i=1}^3 g_i\) 如下图所示:

  • reduce阶段:使用reduce()将所有task的计算结果收集到Driver端进行聚合,然后进行参数更新。

在上图中,训练数据用points:PrallelCollectionRDD来表示,参数向量用\(w\)来表示,注意参数向量不是RDD,只是一个单独的参与运算的变量。

此外需要注意一点,虽然每个task在本地进行了局部聚合,但如果task过多且每个task本地聚合后的结果(单个gradient)过大那么统一传递到Driver端仍然会造成单点的网络平均等问题。为了解决这个问题,Spark设计了性能更好的treeAggregate()操作,使用树形聚合方法来减少网络和计算延迟。

3. PySpark实现代码

PySpark的完整实现代码如下:

from sklearn.datasets import load_breast_cancer
import numpy as np
from pyspark.sql import SparkSession
from operator import add
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score n_slices = 3 # Number of Slices
n_iterations = 300 # Number of iterations
alpha = 0.01 # iteration step_size def logistic_f(x, w):
return 1 / (np.exp(-x.dot(w)) + 1) def gradient(point: np.ndarray, w: np.ndarray) -> np.ndarray:
""" Compute linear regression gradient for a matrix of data points
"""
y = point[-1] # point label
x = point[:-1] # point coordinate
# For each point (x, y), compute gradient function, then sum these up
return - (y - logistic_f(x, w)) * x if __name__ == "__main__": X, y = load_breast_cancer(return_X_y=True) D = X.shape[1]
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=0)
n_train, n_test = X_train.shape[0], X_test.shape[0] spark = SparkSession\
.builder\
.appName("Logistic Regression")\
.getOrCreate() matrix = np.concatenate(
[X_train, np.ones((n_train, 1)), y_train.reshape(-1, 1)], axis=1) points = spark.sparkContext.parallelize(matrix, n_slices).cache() # Initialize w to a random value
w = 2 * np.random.ranf(size=D + 1) - 1
print("Initial w: " + str(w)) for t in range(n_iterations):
print("On iteration %d" % (t + 1))
g = points.map(lambda point: gradient(point, w)).reduce(add)
w -= alpha * g y_pred = logistic_f(np.concatenate(
[X_test, np.ones((n_test, 1))], axis=1), w)
pred_label = np.where(y_pred < 0.5, 0, 1)
acc = accuracy_score(y_test, pred_label)
print("iterations: %d, accuracy: %f" % (t, acc)) print("Final w: %s " % w)
print("Final acc: %f" % acc) spark.stop()

注意spark.sparkContext.parallelize(matrix, n_slices)中的n_slices就是Spark中的分区数。我们在代码中采用breast cancer数据集进行训练和测试,该数据集是个二分类数据集。模型初始权重采用随机初始化。

最后,我们来看一下算法的输出结果。

初始权重如下:

Initial w: [-0.0575882   0.79680833  0.96928013  0.98983501 -0.59487909 -0.23279241
-0.34157571 0.93084048 -0.10126002 0.19124314 0.7163746 -0.49597826
-0.50197367 0.81784642 0.96319482 0.06248513 -0.46138666 0.76500396
0.30422518 -0.21588114 -0.90260279 -0.07102884 -0.98577817 -0.09454256
0.07157487 0.9879555 0.36608845 -0.9740067 0.69620032 -0.97704433
-0.30932467]

最终的模型权重与在测试集上的准确率结果如下:

Final w: [ 8.22414803e+02  1.48384087e+03  4.97062125e+03  4.47845441e+03
7.71390166e+00 1.21510016e+00 -7.67338147e+00 -2.54147183e+00
1.55496346e+01 6.52930570e+00 2.02480712e+00 1.09860082e+02
-8.82480263e+00 -2.32991671e+03 1.61742379e+00 8.57741145e-01
1.30270454e-01 1.16399854e+00 2.09101988e+00 5.30845885e-02
8.28547658e+02 1.90597805e+03 4.93391021e+03 -4.69112527e+03
1.10030574e+01 1.49957834e+00 -1.02290791e+01 -3.11020744e+00
2.37012097e+01 5.97116694e+00 1.03680530e+02]
Final acc: 0.923977

可见我们的算法收敛良好。

参考

分布式机器学习:逻辑回归的并行化实现(PySpark)的更多相关文章

  1. 机器学习---逻辑回归(二)(Machine Learning Logistic Regression II)

    在<机器学习---逻辑回归(一)(Machine Learning Logistic Regression I)>一文中,我们讨论了如何用逻辑回归解决二分类问题以及逻辑回归算法的本质.现在 ...

  2. 机器学习/逻辑回归(logistic regression)/--附python代码

    个人分类: 机器学习 本文为吴恩达<机器学习>课程的读书笔记,并用python实现. 前一篇讲了线性回归,这一篇讲逻辑回归,有了上一篇的基础,这一篇的内容会显得比较简单. 逻辑回归(log ...

  3. 机器学习---逻辑回归(一)(Machine Learning Logistic Regression I)

    逻辑回归(Logistic Regression)是一种经典的线性分类算法.逻辑回归虽然叫回归,但是其模型是用来分类的. 让我们先从最简单的二分类问题开始.给定特征向量x=([x1,x2,...,xn ...

  4. 机器学习——逻辑回归(Logistic Regression)

    1 前言 虽然该机器学习算法名字里面有"回归",但是它其实是个分类算法.取名逻辑回归主要是因为是从线性回归转变而来的. logistic回归,又叫对数几率回归. 2 回归模型 2. ...

  5. 吴裕雄 python 机器学习——逻辑回归

    import numpy as np import matplotlib.pyplot as plt from matplotlib import cm from mpl_toolkits.mplot ...

  6. python机器学习-逻辑回归

    1.逻辑函数 假设数据集有n个独立的特征,x1到xn为样本的n个特征.常规的回归算法的目标是拟合出一个多项式函数,使得预测值与真实值的误差最小: 而我们希望这样的f(x)能够具有很好的逻辑判断性质,最 ...

  7. Spark 机器学习------逻辑回归

    package Spark_MLlib import javassist.bytecode.SignatureAttribute.ArrayType import org.apache.spark.s ...

  8. python机器学习——逻辑回归

    我们知道感知器算法对于不能完全线性分割的数据是无能为力的,在这一篇将会介绍另一种非常有效的二分类模型--逻辑回归.在分类任务中,它被广泛使用 逻辑回归是一个分类模型,在实现之前我们先介绍几个概念: 几 ...

  9. 机器学习-逻辑回归与SVM的联系与区别

    (搬运工) 逻辑回归(LR)与SVM的联系与区别 LR 和 SVM 都可以处理分类问题,且一般都用于处理线性二分类问题(在改进的情况下可以处理多分类问题,如LR的Softmax回归用在深度学习的多分类 ...

随机推荐

  1. VueJs单页应用实现微信网页授权及微信分享功能

    在实际开发中,无论是做PC端.WebApp端还是微信公众号等类型的项目的时候,或多或少都会涉及到微信相关的开发,最近公司项目要求实现微信网页授权,并获取微信用户基本信息的功能及微信分享的功能,现在总算 ...

  2. 夏日葵电商:连锁零售店小程序o2o系统解决方案

    公众平台"附近小程序"功能上线后,一个主体账号可以同时绑定N+个门店,这对连锁零售店铺来说是重磅福利呀,无论你是通过搜索还是线下扫码进入小程序,线上与线下都完全贯通了,线上多种入口 ...

  3. 基于canvas和web audio实现低配版MikuTap

    导言 最近发掘了一个特别happy的网页小游戏--MikuTap.打开之后沉迷了一下午,导致开发工作没做完差点就要删库跑路了,还好boss瞥了我一眼就没下文了.于是第二天我就继续沉迷,随着一阵抽搐,这 ...

  4. 动态规划 洛谷P4017 最大食物链计数——图上动态规划 拓扑排序

    洛谷P4017 最大食物链计数 这是洛谷一题普及/提高-的题目,也是我第一次做的一题 图上动态规划/拓扑排序 ,我认为这题是很好的学习拓扑排序的题目. 在这题中,我学到了几个名词,入度,出度,及没有环 ...

  5. python爬虫---污言污语网站数据采集

    代码: import requests from lxml import etree headers = { "user-agent": "Mozilla/5.0 (Wi ...

  6. rabitmq 登录报错:User can only log in via localhost

    安装教程参考:https://blog.csdn.net/qq_43672652/article/details/107349063 修改了配置文件仍然报错,无法登录.解决办法:新建一个用户登录: 查 ...

  7. 在tomcat布置项目

    1.将项目打成war包复制到tomcat-webapps 2.修改tomcat端口号 3.指定jdk 一.找到tomcat目录/bin 文件夹下的 catalina.bat文件 二.在文件中找到 ec ...

  8. JavaScript实现指定格式字符串表单校验

    运行效果: 源代码: 1 <!DOCTYPE html> 2 <html lang="zh"> 3 <head> 4 <meta char ...

  9. zabbix 服务器500错误,解决故障。

    ZABBIX 500错误,查看apache错误日志,index.php 32 line.写着语法错误!!! 到路径下打开/var/www/html/zabbix/index.php文件,定位32行,可 ...

  10. Python入门-初识变量类型

    上一篇我们学习了第一行代码,我们print()了很多代码,我们可以print哪些东西呢,这一篇来讲. print()括号里面可以放哪些东西呢?..可以放很多东西,只要是Python的全部数据类型都可以 ...