如何在scikit-learn模型中使用Keras

通过用 KerasClassifier 或 KerasRegressor 类包装Keras模型,可将其用于scikit-learn。

要使用这些包装,必须定义一个函数,以便按顺序模式创建并返回Keras,然后当构建 KerasClassifier 类时,把该函数传递给 build_fn 参数。

例如:

def create_model():
...
return model model = KerasClassifier(build_fn=create_model)

KerasClassifier类 的构建器为可以采取默认参数,并将其被传递给 model.fit() 的调用函数,比如 epochs数目和批尺寸(batch size)。

例如:

def create_model():
...
return model model = KerasClassifier(build_fn=create_model, nb_epoch=10)

KerasClassifier类的构造也可以使用新的参数,使之能够传递给自定义的create_model()函数。这些新的参数,也必须由使用默认参数的 create_model() 函数的签名定义。

例如:

def create_model(dropout_rate=0.0):
...
return model model = KerasClassifier(build_fn=create_model, dropout_rate=0.2)

pred = estimator.predict(X_test)#返回给定测试数据的类预测。
pred1=estimator.predict_proba(X_test)#返回给定测试数据的类概率估计。
# pred3=estimator.score(X_test,Y_test)#返回给定测试数据和标签的平均精度。
print(X_test)#
print(Y_test)#实际类别
print(pred)#预测类别

print(pred1)

[[0. 1. 0. ... 1. 0. 0.]
[0. 0. 1. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 1. 1. ... 0. 0. 0.]]
[[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 1.]
...
[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 1.]
[0. 0. 1. 0. 0. 0.]]
[5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5]
[[0.02377683 0.0266185 0.04945414 0.08426233 0.04495123 0.77093697]
[0.02115186 0.01721832 0.03360457 0.05283894 0.05303674 0.82214963]
[0.00838055 0.01647644 0.02293482 0.05378568 0.057558 0.8408645 ]
...
[0.01674003 0.01713392 0.03502046 0.03685626 0.03512193 0.85912746]
[0.0494712 0.0336375 0.05689533 0.03956604 0.04415505 0.77627486]
[0.04764625 0.04542363 0.08352048 0.15077472 0.10701337 0.5656215 ]]

estimator = KerasClassifier的更多相关文章

  1. 【Python与机器学习】:利用Keras进行多类分类

    多类分类问题本质上可以分解为多个二分类问题,而解决二分类问题的方法有很多.这里我们利用Keras机器学习框架中的ANN(artificial neural network)来解决多分类问题.这里我们采 ...

  2. Python机器学习笔记:利用Keras进行分类预测

    Keras是一个用于深度学习的Python库,它包含高效的数值库Theano和TensorFlow. 本文的目的是学习如何从csv中加载数据并使其可供Keras使用,如何用神经网络建立多类分类的数据进 ...

  3. Keras人工神经网络多分类(SGD)

    import numpy as np import pandas as pd from keras.models import Sequential from keras.layers import ...

  4. python多标签分类模版

    from sklearn.multioutput import MultiOutputClassifier from sklearn.ensemble import RandomForestClass ...

  5. np_utils.to_categorical

    https://blog.csdn.net/zlrai5895/article/details/79560353 多类分类问题本质上可以分解为多个二分类问题,而解决二分类问题的方法有很多.这里我们利用 ...

  6. 3.2. Grid Search: Searching for estimator parameters

    3.2. Grid Search: Searching for estimator parameters Parameters that are not directly learnt within ...

  7. 机器学习笔记5-Tensorflow高级API之tf.estimator

    前言 本文接着上一篇继续来聊Tensorflow的接口,上一篇中用较低层的接口实现了线性模型,本篇中将用更高级的API--tf.estimator来改写线性模型. 还记得之前的文章<机器学习笔记 ...

  8. [sklearn]官方例程-Imputing missing values before building an estimator 随机填充缺失值

    官方链接:http://scikit-learn.org/dev/auto_examples/plot_missing_values.html#sphx-glr-auto-examples-plot- ...

  9. tensorflow estimator API小栗子

    TensorFlow的高级机器学习API(tf.estimator)可以轻松配置,训练和评估各种机器学习模型. 在本教程中,您将使用tf.estimator构建一个神经网络分类器,并在Iris数据集上 ...

随机推荐

  1. MVC学习(三)Code-First Demo

    前面两篇文章介绍了DataBase-First例子,这里就介绍Code-First. 个人简单理解:就是在程序中编写代码,然后在数据库中生成相应的表.字段.约束等等.听上去蛮神奇的.Now,begin ...

  2. SAP请求号的传输

    SAP传输目的: SAP传输目的是把开发机中的程序或对象传输到对应的测试机或生成机中,保持各系统的同步性,方便测试和最后的部署! SAP求情号传输的步骤: 1.创建一个请求号 2.用SE10进入如下界 ...

  3. mysql 事务中如果有sql语句出错,会导致自动回滚吗?

    事务,我们都知道具有原子性,操作要么全部成功,要么全部失败.但是有可能会造成误解. 我们先准备一张表,来进行测试 CREATE TABLE `name` ( `id` int(11) unsigned ...

  4. 进程 day36

    python之路——进程   阅读目录 理论知识 操作系统背景知识 什么是进程 进程调度 进程的并发与并行 同步\异步\阻塞\非阻塞 进程的创建与结束 在python程序中的进程操作 multipro ...

  5. Linux_(2)基本命令(下)

    六.文件搜索命令11 :which功能描述:显示系统命令所在目录命令所在路径:/usr/bin/which执行权限:所有用户语法:which [命令名称]范例:$ which ls 12 :find功 ...

  6. PAT 1081 检查密码(15) (代码+思路)

    1081 检查密码(15 分) 本题要求你帮助某网站的用户注册模块写一个密码合法性检查的小功能.该网站要求用户设置的密码必须由不少于6个字符组成,并且只能有英文字母.数字和小数点 .,还必须既有字母也 ...

  7. BZOJ 1969 航线规划 - LCT 维护边双联通分量

    Solution 实际上就是查询 $u$ 到 $v$ 路径上 边双的个数 $ -1$. 并且题目仅有删边, 那么就离线倒序添边. 维护 边双 略有不同: 首先需要一个并查集, 记录 边双内的点. 在 ...

  8. BZOJ 3932 [CQOI2015]任务查询系统 - 差分 + 主席树

    Solution 差分就好了, 在$s_i$ 的点+1, $e_i + 1$ 的点 - 1. 查询的时候注意$l == r$ 要返回 $k * b[l]$ ,而不是$sum[node] $因为当前位置 ...

  9. Oracle VM VirtualBox如何设置网络地址转换NAT

    使用VirtualBox 安装好服务器后,需要设置网络,如果有IP, 则可以直接连接物理网络了, 如果没有,则可以直接使用NAT网络.设置方便快速. 先将虚拟机中的网络设置为自动获取,然后点击Virt ...

  10. 2019,UI设计师必备神器

      2019年将会是你全新起航的一年,相信你已经制定了很多规划,正在开启第一步的推动. 作为对UI设计师更大程度的支持,今天特意为你分享一款释放你双手的设计神器.让你可以把时间和精力投入到设计本身,这 ...