画出决策边界线--plot_2d_separator.py源代码【来自python机器学习基础教程】
import numpy as np
import matplotlib.pyplot as plt
from .plot_helpers import cm2, cm3, discrete_scatter def _call_classifier_chunked(classifier_pred_or_decide, X):
# The chunk_size is used to chunk the large arrays to work with x86
# memory models that are restricted to < 2 GB in memory allocation. The
# chunk_size value used here is based on a measurement with the
# MLPClassifier using the following parameters:
# MLPClassifier(solver='lbfgs', random_state=0,
# hidden_layer_sizes=[1000,1000,1000])
# by reducing the value it is possible to trade in time for memory.
# It is possible to chunk the array as the calculations are independent of
# each other.
# Note: an intermittent version made a distinction between
# 32- and 64 bit architectures avoiding the chunking. Testing revealed
# that even on 64 bit architectures the chunking increases the
# performance by a factor of 3-5, largely due to the avoidance of memory
# swapping.
chunk_size = 10000 # We use a list to collect all result chunks
Y_result_chunks = [] # Call the classifier in chunks.
for x_chunk in np.array_split(X, np.arange(chunk_size, X.shape[0],
chunk_size, dtype=np.int32),
axis=0):
Y_result_chunks.append(classifier_pred_or_decide(x_chunk)) return np.concatenate(Y_result_chunks) def plot_2d_classification(classifier, X, fill=False, ax=None, eps=None,
alpha=1, cm=cm3):
# multiclass
if eps is None:
eps = X.std() / 2. if ax is None:
ax = plt.gca() x_min, x_max = X[:, 0].min() - eps, X[:, 0].max() + eps
y_min, y_max = X[:, 1].min() - eps, X[:, 1].max() + eps
xx = np.linspace(x_min, x_max, 1000)
yy = np.linspace(y_min, y_max, 1000) X1, X2 = np.meshgrid(xx, yy)
X_grid = np.c_[X1.ravel(), X2.ravel()]
decision_values = classifier.predict(X_grid)
ax.imshow(decision_values.reshape(X1.shape), extent=(x_min, x_max,
y_min, y_max),
aspect='auto', origin='lower', alpha=alpha, cmap=cm)
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_xticks(())
ax.set_yticks(()) def plot_2d_scores(classifier, X, ax=None, eps=None, alpha=1, cm="viridis",
function=None):
# binary with fill
if eps is None:
eps = X.std() / 2. if ax is None:
ax = plt.gca() x_min, x_max = X[:, 0].min() - eps, X[:, 0].max() + eps
y_min, y_max = X[:, 1].min() - eps, X[:, 1].max() + eps
xx = np.linspace(x_min, x_max, 100)
yy = np.linspace(y_min, y_max, 100) X1, X2 = np.meshgrid(xx, yy)
X_grid = np.c_[X1.ravel(), X2.ravel()]
if function is None:
function = getattr(classifier, "decision_function",
getattr(classifier, "predict_proba"))
else:
function = getattr(classifier, function)
decision_values = function(X_grid)
if decision_values.ndim > 1 and decision_values.shape[1] > 1:
# predict_proba
decision_values = decision_values[:, 1]
grr = ax.imshow(decision_values.reshape(X1.shape),
extent=(x_min, x_max, y_min, y_max), aspect='auto',
origin='lower', alpha=alpha, cmap=cm) ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_xticks(())
ax.set_yticks(())
return grr def plot_2d_separator(classifier, X, fill=False, ax=None, eps=None, alpha=1,
cm=cm2, linewidth=None, threshold=None,
linestyle="solid"):
# binary?
if eps is None:
eps = X.std() / 2. if ax is None:
ax = plt.gca() x_min, x_max = X[:, 0].min() - eps, X[:, 0].max() + eps
y_min, y_max = X[:, 1].min() - eps, X[:, 1].max() + eps
xx = np.linspace(x_min, x_max, 1000)
yy = np.linspace(y_min, y_max, 1000) X1, X2 = np.meshgrid(xx, yy)
X_grid = np.c_[X1.ravel(), X2.ravel()]
if hasattr(classifier, "decision_function"):
decision_values = _call_classifier_chunked(classifier.decision_function,
X_grid)
levels = [0] if threshold is None else [threshold]
fill_levels = [decision_values.min()] + levels + [
decision_values.max()]
else:
# no decision_function
decision_values = _call_classifier_chunked(classifier.predict_proba,
X_grid)[:, 1]
levels = [.5] if threshold is None else [threshold]
fill_levels = [0] + levels + [1]
if fill:
ax.contourf(X1, X2, decision_values.reshape(X1.shape),
levels=fill_levels, alpha=alpha, cmap=cm)
else:
ax.contour(X1, X2, decision_values.reshape(X1.shape), levels=levels,
colors="black", alpha=alpha, linewidths=linewidth,
linestyles=linestyle, zorder=5) ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_xticks(())
ax.set_yticks(()) if __name__ == '__main__':
from sklearn.datasets import make_blobs
from sklearn.linear_model import LogisticRegression
X, y = make_blobs(centers=2, random_state=42)
clf = LogisticRegression(solver='lbfgs').fit(X, y)
plot_2d_separator(clf, X, fill=True)
discrete_scatter(X[:, 0], X[:, 1], y)
plt.show()
画出决策边界线--plot_2d_separator.py源代码【来自python机器学习基础教程】的更多相关文章
- WPF 如何画出1像素的线
如何有人告诉你,请你画出1像素的线,是不是觉得很简单,实际上在 WPF 上还是比较难的. 本文告诉大家,如何让画出的线不模糊 画出线的第一个方法,创建一个 Canvas ,添加一个线 界面代码 < ...
- python运用turtle 画出汉诺塔搬运过程
python运用turtle 画出汉诺塔搬运过程 1.打开 IDLE 点击File-New File 新建立一个py文件 2.向py文件中输入如下代码 import turtle class Stac ...
- caffe 中 plot accuracy和loss, 并画出网络结构图
plot accuracy + loss 详情可见:http://www.2cto.com/kf/201612/575739.html 1. caffe保存训练输出到log 并绘制accuracy l ...
- 如何用DOM 元素就能画出国宝熊猫
效果预览 在线演示 按下右侧的"点击预览"按钮可以在当前页面预览,点击链接可以全屏预览. https://codepen.io/comehope/pen/odKrpy 可交互视频教 ...
- scikit-learn机器学习(四)使用决策树做分类,并画出决策树,随机森林对比
数据来自 UCI 数据集 匹马印第安人糖尿病数据集 载入数据 # -*- coding: utf-8 -*- import pandas as pd import matplotlib matplot ...
- 前端每日实战:35# 视频演示如何把 CSS 径向渐变用得出神入化,只用一个 DOM 元素就能画出国宝熊猫
效果预览 按下右侧的"点击预览"按钮可以在当前页面预览,点击链接可以全屏预览. https://codepen.io/comehope/pen/odKrpy 可交互视频教程 此视频 ...
- H5坦克大战之【画出坦克】
今天是个特殊的日子,圣诞节,也是周末,在这里先祝大家圣诞快乐!喜庆的日子,我们可以稍微放松一下,扯一扯昨天雷霆对战凯尔特人的比赛,这场比赛大威少又双叒叕拿下三双,而且是一个45+11+11的超级三双, ...
- 像画笔一样慢慢画出Path的三种方法(补充第四种)
今天大家在群里大家非常热闹的讨论像画笔一样慢慢画出Path的这种效果该如何实现. 北京-LGL 博客号@ligl007发起了这个话题.然后各路高手踊跃发表意见.最后雷叔 上海-雷蒙 博客号@雷蒙之星 ...
- 用css画出三角形
看到有面试题里会有问到如何用css画出三角形 众所周知好多图形都可以拆分成三角形,所以说会了画三角形就可以画出很多有意思的形状 画出三角形的原理是调整border(边框)的四个方向的宽度,线条样式以及 ...
随机推荐
- 使用docker搭建自己的博客(一)
购买服务器 首先服务器选择腾讯云学生服务器,25岁以下实名认证后月租10块,还是很适合我这种简约派的 又财大气粗买了个一年的域名,后面涨价再说吧 安装docker 使用xshell连上服务器 安装必要 ...
- 题目分享N
题意:有辆车,有r行,s*2列,在第s列和第s+1列之间有个过道,出口在第r+1行的过道处,现在给出每个人的位置(行号和列号),每人每次只能动一格,问最少耗费多长时间全员才能逃出去 分析:假如车上只有 ...
- Qt for Android (三) 打开Android相册并选一个图片进行显示
Qt for Android (三) 这两天弄了一下android相册的相关功能.还是花了挺长时间的,这里总结一下,避免以后再踩坑.同时也在这篇文章里面补齐一些android开发的基础支持 打开And ...
- Java常见的集合的数据结构
数据结构 数据结构__栈:先进后出 栈:stack,又称堆栈,它是运算受限的线性表,其限制是仅允许在标的一端进行插入和删除操作,不允许在其他任何位置进行添加.查找.删除等操作. 简单的说:采用该结构的 ...
- 【Spark】必须要用CDH版本的Spark?那你是不是需要重新编译?
目录 为什么要重新编译? 步骤 一.下载Spark的源码 二.准备linux环境,安装必须软件 三.解压spark源码,修改配置,准备编译 四.开始编译 为什么要重新编译? 由于我们所有的环境统一使用 ...
- 基于C语言的Q格式使用详解
用过DSP的应该都知道Q格式吧: 目录 1 前言 2 Q数据的表示 2.1 范围和精度 2.2 推导 3 Q数据的运算 3.1 0x7FFF 3.2 0x8000 3.3 加法 3.4 减法 3.5 ...
- c++11 符号修饰与函数签名、函数指针、匿名函数、仿函数、std::function与std::bind
一.符号修饰与函数签名 1.符号修饰 编译器将c++源代码编译成目标文件时,用函数签名的信息对函数名进行改编,形成修饰名.GCC的C++符号修饰方法如下: 1)所有符号都以_z开头 2)名字空间的名字 ...
- JUC之ReentrantLock源码分析
ReentrantLock:实现了Lock接口,是一个可重入锁,并且支持线程公平竞争和非公平竞争两种模式,默认情况下是非公平模式.ReentrantLock算是synchronized的补充和替代方案 ...
- android progressbar 自定义图片匀速旋转
项目中需要使用圆形进度条进行数据加载的显示,所以需要两个步骤 1:自定义progressbar滚动图片 2:匀速旋转图片 步骤一:自定义progressbar图片 <ProgressBar an ...
- tomcat 添加 ssl 证书
1. 将证书提供方给的证书(server.crt)及密钥文件(server.key)上传到服务器 tomcat 的 conf 目录 2. 在tomcat conf 目录下执行如下命令 (1) 生成P1 ...