如何在scikit-learn模型中使用Keras

通过用 KerasClassifier 或 KerasRegressor 类包装Keras模型,可将其用于scikit-learn。

要使用这些包装,必须定义一个函数,以便按顺序模式创建并返回Keras,然后当构建 KerasClassifier 类时,把该函数传递给 build_fn 参数。

例如:

def create_model():
...
return model model = KerasClassifier(build_fn=create_model)

KerasClassifier类 的构建器为可以采取默认参数,并将其被传递给 model.fit() 的调用函数,比如 epochs数目和批尺寸(batch size)。

例如:

def create_model():
...
return model model = KerasClassifier(build_fn=create_model, nb_epoch=10)

KerasClassifier类的构造也可以使用新的参数,使之能够传递给自定义的create_model()函数。这些新的参数,也必须由使用默认参数的 create_model() 函数的签名定义。

例如:

def create_model(dropout_rate=0.0):
...
return model model = KerasClassifier(build_fn=create_model, dropout_rate=0.2)

pred = estimator.predict(X_test)#返回给定测试数据的类预测。
pred1=estimator.predict_proba(X_test)#返回给定测试数据的类概率估计。
# pred3=estimator.score(X_test,Y_test)#返回给定测试数据和标签的平均精度。
print(X_test)#
print(Y_test)#实际类别
print(pred)#预测类别

print(pred1)

[[0. 1. 0. ... 1. 0. 0.]
[0. 0. 1. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 1. 1. ... 0. 0. 0.]]
[[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 1.]
...
[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 1.]
[0. 0. 1. 0. 0. 0.]]
[5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5]
[[0.02377683 0.0266185 0.04945414 0.08426233 0.04495123 0.77093697]
[0.02115186 0.01721832 0.03360457 0.05283894 0.05303674 0.82214963]
[0.00838055 0.01647644 0.02293482 0.05378568 0.057558 0.8408645 ]
...
[0.01674003 0.01713392 0.03502046 0.03685626 0.03512193 0.85912746]
[0.0494712 0.0336375 0.05689533 0.03956604 0.04415505 0.77627486]
[0.04764625 0.04542363 0.08352048 0.15077472 0.10701337 0.5656215 ]]

estimator = KerasClassifier的更多相关文章

  1. 【Python与机器学习】:利用Keras进行多类分类

    多类分类问题本质上可以分解为多个二分类问题,而解决二分类问题的方法有很多.这里我们利用Keras机器学习框架中的ANN(artificial neural network)来解决多分类问题.这里我们采 ...

  2. Python机器学习笔记:利用Keras进行分类预测

    Keras是一个用于深度学习的Python库,它包含高效的数值库Theano和TensorFlow. 本文的目的是学习如何从csv中加载数据并使其可供Keras使用,如何用神经网络建立多类分类的数据进 ...

  3. Keras人工神经网络多分类(SGD)

    import numpy as np import pandas as pd from keras.models import Sequential from keras.layers import ...

  4. python多标签分类模版

    from sklearn.multioutput import MultiOutputClassifier from sklearn.ensemble import RandomForestClass ...

  5. np_utils.to_categorical

    https://blog.csdn.net/zlrai5895/article/details/79560353 多类分类问题本质上可以分解为多个二分类问题,而解决二分类问题的方法有很多.这里我们利用 ...

  6. 3.2. Grid Search: Searching for estimator parameters

    3.2. Grid Search: Searching for estimator parameters Parameters that are not directly learnt within ...

  7. 机器学习笔记5-Tensorflow高级API之tf.estimator

    前言 本文接着上一篇继续来聊Tensorflow的接口,上一篇中用较低层的接口实现了线性模型,本篇中将用更高级的API--tf.estimator来改写线性模型. 还记得之前的文章<机器学习笔记 ...

  8. [sklearn]官方例程-Imputing missing values before building an estimator 随机填充缺失值

    官方链接:http://scikit-learn.org/dev/auto_examples/plot_missing_values.html#sphx-glr-auto-examples-plot- ...

  9. tensorflow estimator API小栗子

    TensorFlow的高级机器学习API(tf.estimator)可以轻松配置,训练和评估各种机器学习模型. 在本教程中,您将使用tf.estimator构建一个神经网络分类器,并在Iris数据集上 ...

随机推荐

  1. 26-算法训练 Torry的困惑(基本型) 素数打表

      算法训练 Torry的困惑(基本型)   时间限制:1.0s   内存限制:512.0MB      问题描述 Torry从小喜爱数学.一天,老师告诉他,像2.3.5.7……这样的数叫做质数.To ...

  2. public void method(),void前面的泛型T是什么

    public <T>这个T是个修饰符的功能,表示是个泛型方法,就像有static修饰的方法是个静态方法一样. 注意<T> 不是返回值,此处的返回值是void ,此处的<T ...

  3. 26.mysql日志

    26.mysql日志mysql日志包括:错误日志.二进制日志.查询日志.慢查询日志. 26.1 错误日志错误日志记录了mysqld启动到停止之间发生的任何严重错误的相关信息.mysql故障时应首先查看 ...

  4. 8P - 钱币兑换问题

    在一个国家仅有1分,2分,3分硬币,将钱N兑换成硬币有很多种兑法.请你编程序计算出共有多少种兑法. Input 每行只有一个正整数N,N小于32768. Output 对应每个输入,输出兑换方法数. ...

  5. redis 数据类型为set命令整理以及示例

    数据类型为set.可以保证set内数据唯一.场景:生成订单号,因为要求订单号是绝对不能重复的,所以数据库中要设置为unique索引.但是其实可以通过redis,set来做每天的订单集合.比如A客户的订 ...

  6. hdu 5693 && LightOj 1422 区间DP

    hdu 5693 题目链接http://acm.hdu.edu.cn/showproblem.php?pid=5693 等差数列当划分细了后只用比较2个或者3个数就可以了,因为大于3的数都可以由2和3 ...

  7. Taxi

    /* After the lessons n groups of schoolchildren went outside and decided to visit Polycarpus to cele ...

  8. XiaoKL学Python(D)argparse

    该文以Python 2为基础. 1. argparse简介 argparse使得编写用户友好的命令行接口更简单. argparse知道如何解析sys.argv. argparse 模块自动生成 “帮助 ...

  9. 数据结构:链表 >> 链表按结点中第j个数据属性排序(冒泡排序法)

    创建结点类,链表类,测试类 import java.lang.Object; //结点node=数据date+指针pointer public class Node { Object iprop; p ...

  10. sqli-labs:17,增删改

    增 insert into users values(','lcamry','lcamry'); 删 delete from users where id=16 删数据库:drop database ...