estimator = KerasClassifier
如何在scikit-learn模型中使用Keras
通过用 KerasClassifier
或 KerasRegressor
类包装Keras模型,可将其用于scikit-learn。
要使用这些包装,必须定义一个函数,以便按顺序模式创建并返回Keras,然后当构建 KerasClassifier
类时,把该函数传递给 build_fn
参数。
例如:
def create_model():
...
return model model = KerasClassifier(build_fn=create_model)
KerasClassifier类
的构建器为可以采取默认参数,并将其被传递给 model.fit()
的调用函数,比如 epochs数目和批尺寸(batch size)。
例如:
def create_model():
...
return model model = KerasClassifier(build_fn=create_model, nb_epoch=10)
KerasClassifier类的构造也可以使用新的参数,使之能够传递给自定义的create_model()函数。这些新的参数,也必须由使用默认参数的 create_model() 函数的签名定义。
例如:
def create_model(dropout_rate=0.0):
...
return model model = KerasClassifier(build_fn=create_model, dropout_rate=0.2)
pred = estimator.predict(X_test)#返回给定测试数据的类预测。
pred1=estimator.predict_proba(X_test)#返回给定测试数据的类概率估计。
# pred3=estimator.score(X_test,Y_test)#返回给定测试数据和标签的平均精度。
print(X_test)#
print(Y_test)#实际类别
print(pred)#预测类别
print(pred1)
[[0. 1. 0. ... 1. 0. 0.]
[0. 0. 1. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 1. 1. ... 0. 0. 0.]]
[[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 1.]
...
[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 1.]
[0. 0. 1. 0. 0. 0.]]
[5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5]
[[0.02377683 0.0266185 0.04945414 0.08426233 0.04495123 0.77093697]
[0.02115186 0.01721832 0.03360457 0.05283894 0.05303674 0.82214963]
[0.00838055 0.01647644 0.02293482 0.05378568 0.057558 0.8408645 ]
...
[0.01674003 0.01713392 0.03502046 0.03685626 0.03512193 0.85912746]
[0.0494712 0.0336375 0.05689533 0.03956604 0.04415505 0.77627486]
[0.04764625 0.04542363 0.08352048 0.15077472 0.10701337 0.5656215 ]]
estimator = KerasClassifier的更多相关文章
- 【Python与机器学习】:利用Keras进行多类分类
多类分类问题本质上可以分解为多个二分类问题,而解决二分类问题的方法有很多.这里我们利用Keras机器学习框架中的ANN(artificial neural network)来解决多分类问题.这里我们采 ...
- Python机器学习笔记:利用Keras进行分类预测
Keras是一个用于深度学习的Python库,它包含高效的数值库Theano和TensorFlow. 本文的目的是学习如何从csv中加载数据并使其可供Keras使用,如何用神经网络建立多类分类的数据进 ...
- Keras人工神经网络多分类(SGD)
import numpy as np import pandas as pd from keras.models import Sequential from keras.layers import ...
- python多标签分类模版
from sklearn.multioutput import MultiOutputClassifier from sklearn.ensemble import RandomForestClass ...
- np_utils.to_categorical
https://blog.csdn.net/zlrai5895/article/details/79560353 多类分类问题本质上可以分解为多个二分类问题,而解决二分类问题的方法有很多.这里我们利用 ...
- 3.2. Grid Search: Searching for estimator parameters
3.2. Grid Search: Searching for estimator parameters Parameters that are not directly learnt within ...
- 机器学习笔记5-Tensorflow高级API之tf.estimator
前言 本文接着上一篇继续来聊Tensorflow的接口,上一篇中用较低层的接口实现了线性模型,本篇中将用更高级的API--tf.estimator来改写线性模型. 还记得之前的文章<机器学习笔记 ...
- [sklearn]官方例程-Imputing missing values before building an estimator 随机填充缺失值
官方链接:http://scikit-learn.org/dev/auto_examples/plot_missing_values.html#sphx-glr-auto-examples-plot- ...
- tensorflow estimator API小栗子
TensorFlow的高级机器学习API(tf.estimator)可以轻松配置,训练和评估各种机器学习模型. 在本教程中,您将使用tf.estimator构建一个神经网络分类器,并在Iris数据集上 ...
随机推荐
- python脚本netifaces模块的调用
# vim get_ip.py # -*- coding: utf- -*- #complete local network card IP #need install netifaces modem ...
- 9.12 h5日记
9.12 知识点补充: 属性继承例子,color.font(font-size/style/family/weight) 1.浏览器的默认字体大小是16px,谷歌浏览器的最小字体是10px,其他浏览器 ...
- c# 软件绑定网卡mac的实用
一:网上搜c# 绑定网卡Mac 有好多信息,其中有篇分为几种方法获取mac 的方法,结果获得到的是一个list 队列的信息,信息获取到所有的物理网卡,无线网卡,蓝牙,隧道的网卡物理地址.对与软件绑定物 ...
- poj 2785 让和为0 暴力&二分
题目链接:http://poj.org/problem?id=2785 大意是输入一个n行四列的矩阵,每一列取一个数,就是四个数,求有多少种着四个数相加和为0的情况 首先脑海里想到的第一思维必然是一个 ...
- hdu 1757 (矩阵快速幂) 一个简单的问题 一个简单的开始
题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=1757 题意不难理解,当x小于10的时候,数列f(x)=x,当x大于等于10的时候f(x) = a0 * ...
- 常用MFC宏
最近我在用MFC开发一个智能家居监控平台的软件(用到了MSCOMM串口通信控件),当我通过在一个对话框类A中定义另一个对话框类B的对象访问B的public成员时,提示不可访问.后来经过多天的向朋友求救 ...
- PC 上的 LVM 灾难修复
LVM 介绍 LVM 简介 LVM 是逻辑盘卷管理(Logical Volume Manager)的简称,最早是 IBM 为 AIX 研发的存储管理机制.LVM 通过在硬盘和分区之间建立一个逻辑层,可 ...
- PS故障风海报制作技术分享
1.首先找一张看起来很酷的图(也可以选择自己喜欢的图片): 2. 复制图层,点击添加图层样式,选择混合选项,在高级混合里面的通道选项,有R.G.B三个通道选项,默认是全部勾选的状态,选择其中一个勾掉( ...
- pdo不抛出异常
$pdo->setAttribute(\PDO::ATTR_ERRMODE, \PDO::ERRMODE_EXCEPTION);
- 常用的 composer 命令
一.列表内容 composer composer list二.查看当前镜像源 composer config -l -g [repositories.packagist.org.type] compo ...