如何在scikit-learn模型中使用Keras

通过用 KerasClassifier 或 KerasRegressor 类包装Keras模型,可将其用于scikit-learn。

要使用这些包装,必须定义一个函数,以便按顺序模式创建并返回Keras,然后当构建 KerasClassifier 类时,把该函数传递给 build_fn 参数。

例如:

def create_model():
...
return model model = KerasClassifier(build_fn=create_model)

KerasClassifier类 的构建器为可以采取默认参数,并将其被传递给 model.fit() 的调用函数,比如 epochs数目和批尺寸(batch size)。

例如:

def create_model():
...
return model model = KerasClassifier(build_fn=create_model, nb_epoch=10)

KerasClassifier类的构造也可以使用新的参数,使之能够传递给自定义的create_model()函数。这些新的参数,也必须由使用默认参数的 create_model() 函数的签名定义。

例如:

def create_model(dropout_rate=0.0):
...
return model model = KerasClassifier(build_fn=create_model, dropout_rate=0.2)

pred = estimator.predict(X_test)#返回给定测试数据的类预测。
pred1=estimator.predict_proba(X_test)#返回给定测试数据的类概率估计。
# pred3=estimator.score(X_test,Y_test)#返回给定测试数据和标签的平均精度。
print(X_test)#
print(Y_test)#实际类别
print(pred)#预测类别

print(pred1)

[[0. 1. 0. ... 1. 0. 0.]
[0. 0. 1. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 1. 1. ... 0. 0. 0.]]
[[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 1.]
...
[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 1.]
[0. 0. 1. 0. 0. 0.]]
[5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5]
[[0.02377683 0.0266185 0.04945414 0.08426233 0.04495123 0.77093697]
[0.02115186 0.01721832 0.03360457 0.05283894 0.05303674 0.82214963]
[0.00838055 0.01647644 0.02293482 0.05378568 0.057558 0.8408645 ]
...
[0.01674003 0.01713392 0.03502046 0.03685626 0.03512193 0.85912746]
[0.0494712 0.0336375 0.05689533 0.03956604 0.04415505 0.77627486]
[0.04764625 0.04542363 0.08352048 0.15077472 0.10701337 0.5656215 ]]

estimator = KerasClassifier的更多相关文章

  1. 【Python与机器学习】:利用Keras进行多类分类

    多类分类问题本质上可以分解为多个二分类问题,而解决二分类问题的方法有很多.这里我们利用Keras机器学习框架中的ANN(artificial neural network)来解决多分类问题.这里我们采 ...

  2. Python机器学习笔记:利用Keras进行分类预测

    Keras是一个用于深度学习的Python库,它包含高效的数值库Theano和TensorFlow. 本文的目的是学习如何从csv中加载数据并使其可供Keras使用,如何用神经网络建立多类分类的数据进 ...

  3. Keras人工神经网络多分类(SGD)

    import numpy as np import pandas as pd from keras.models import Sequential from keras.layers import ...

  4. python多标签分类模版

    from sklearn.multioutput import MultiOutputClassifier from sklearn.ensemble import RandomForestClass ...

  5. np_utils.to_categorical

    https://blog.csdn.net/zlrai5895/article/details/79560353 多类分类问题本质上可以分解为多个二分类问题,而解决二分类问题的方法有很多.这里我们利用 ...

  6. 3.2. Grid Search: Searching for estimator parameters

    3.2. Grid Search: Searching for estimator parameters Parameters that are not directly learnt within ...

  7. 机器学习笔记5-Tensorflow高级API之tf.estimator

    前言 本文接着上一篇继续来聊Tensorflow的接口,上一篇中用较低层的接口实现了线性模型,本篇中将用更高级的API--tf.estimator来改写线性模型. 还记得之前的文章<机器学习笔记 ...

  8. [sklearn]官方例程-Imputing missing values before building an estimator 随机填充缺失值

    官方链接:http://scikit-learn.org/dev/auto_examples/plot_missing_values.html#sphx-glr-auto-examples-plot- ...

  9. tensorflow estimator API小栗子

    TensorFlow的高级机器学习API(tf.estimator)可以轻松配置,训练和评估各种机器学习模型. 在本教程中,您将使用tf.estimator构建一个神经网络分类器,并在Iris数据集上 ...

随机推荐

  1. springmvc入门(1)

    一..springmvc框架 1.什么是springmvc springmvc是spring框架的一个模块,springmvc和spring无需通过中间整合层进行整合.springmvc是一个基于mv ...

  2. background 和渐变 总结

    一,background-position:(图片定位) 三种写法: 1):按%比,左上角最小(0%,0%),右下角最大(100%,%100): 2):(x,y)左上角最小(0,0),右下角最大(ma ...

  3. 基于Confluent.Kafka实现的KafkaConsumer消费者类和KafkaProducer消息生产者类型

    一.引言 研究Kafka有一段时间了,略有心得,基于此自己就写了一个Kafka的消费者的类和Kafka消息生产者的类,进行了单元测试和生产环境的测试,还是挺可靠的. 二.源码 话不多说,直接上代码,代 ...

  4. Oracle_PL/SQL(2) 过程控制

    0.检索单行数据0.1使用标量变量接受数据例1: 7788declare v_ename emp.ename%type; v_sal emp.sal%type;begin select ename,s ...

  5. IntelliJ idea 的破解

    ·1.破解的jar包下载链接: https://pan.baidu.com/s/1JV6GwguGQNs5pNQtst29Hw  提取码: u2jd 2.安装和破解地址:https://www.cnb ...

  6. $ each() 小结

    each()方法能使DOM循环结构简洁,不容易出错.each()函数封装了十分强大的遍历功能,使用也很方便,它可以遍历一维数组.多维数组.DOM, JSON 等等在javaScript开发过程中使用$ ...

  7. ATM作业

    关于ATM作业,最近做了很久,才明白,其实看了很久的作业视频讲解,到不如将作业的整个下载下来进行运行,去了解程序本身的结构和运行方式.首先说需求,就感觉是各种懵逼,这才学了函数,和模块之间的简单调用就 ...

  8. PHP 批量移动文件改名

    public function changeCoverName(){ //$type = '考研'; //$coverPath = './Public/course_cover/kaoyan/'; $ ...

  9. Maximum Swap LT670

    Given a non-negative integer, you could swap two digits at most once to get the maximum valued numbe ...

  10. ASP.NET 在OnClientClick中js方法直接调用Eval绑定字段的数据

    最近有一项目中使用到了asp.net的GridView控件.需要在前端被点击某一行数据时,前端获取到改行后台绑定的数据序列号.遍用<%# Bind("ID) %>.<%# ...