Keras自定义评估函数
1. 比较一般的自定义函数:
需要注意的是,不能像sklearn那样直接定义,因为这里的y_true和y_pred是张量,不是numpy数组。示例如下:
from keras import backend def rmse(y_true, y_pred):
return backend.sqrt(backend.mean(backend.square(y_pred - y_true), axis=-1))
用的时候直接:
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[rmse])
2. 比较复杂的如AUC函数:
AUC的计算需要整体数据,如果直接在batch里算,误差就比较大,不能合理反映整体情况。这里采用回调函数写法,每个epoch计算一次:
from sklearn.metrics import roc_auc_score class roc_callback(keras.callbacks.Callback):
def __init__(self,training_data, validation_data): self.x = training_data[0]
self.y = training_data[1]
self.x_val = validation_data[0]
self.y_val = validation_data[1] def on_train_begin(self, logs={}):
return def on_train_end(self, logs={}):
return def on_epoch_begin(self, epoch, logs={}):
return def on_epoch_end(self, epoch, logs={}):
y_pred = self.model.predict(self.x)
roc = roc_auc_score(self.y, y_pred) y_pred_val = self.model.predict(self.x_val)
roc_val = roc_auc_score(self.y_val, y_pred_val) print('\rroc-auc: %s - roc-auc_val: %s' % (str(round(roc,4)),str(round(roc_val,4))),end=100*' '+'\n')
return def on_batch_begin(self, batch, logs={}):
return def on_batch_end(self, batch, logs={}):
return
调用回调函数示例:
model.fit(X_train, y_train, epochs=10, batch_size=4,
callbacks = [roc_callback(training_data=[X_train, y_train], validation_data=[X_test, y_test])] )
整体示例:
from tensorflow import keras
from sklearn import datasets
from sklearn import model_selection
from sklearn.metrics import roc_auc_score def rmse(y_true, y_pred):
return keras.backend.sqrt(keras.backend.mean(keras.backend.square(y_pred - y_true), axis=-1)) class roc_callback(keras.callbacks.Callback):
def __init__(self,training_data, validation_data): self.x = training_data[0]
self.y = training_data[1]
self.x_val = validation_data[0]
self.y_val = validation_data[1] def on_train_begin(self, logs={}):
return def on_train_end(self, logs={}):
return def on_epoch_begin(self, epoch, logs={}):
return def on_epoch_end(self, epoch, logs={}):
y_pred = self.model.predict(self.x)
roc = roc_auc_score(self.y, y_pred) y_pred_val = self.model.predict(self.x_val)
roc_val = roc_auc_score(self.y_val, y_pred_val) print('\rroc-auc: %s - roc-auc_val: %s' % (str(round(roc,4)),str(round(roc_val,4))),end=100*' '+'\n')
return def on_batch_begin(self, batch, logs={}):
return def on_batch_end(self, batch, logs={}):
return X, y = datasets.make_classification(n_samples=100, n_features=4, n_classes=2, random_state=2018)
X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2, random_state=2018)
print("TrainSet", X_train.shape, "TestSet", X_test.shape) model = keras.models.Sequential()
model.add(keras.layers.Dense(20, input_shape=(4,), activation='relu'))
model.add(keras.layers.Dense(1, activation='sigmoid'))
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[rmse]) model.fit(X_train, y_train, epochs=10, batch_size=4,
callbacks = [roc_callback(training_data=[X_train, y_train], validation_data=[X_test, y_test])] )
运行结果:
TrainSet (80, 4) TestSet (20, 4)
Epoch 1/10
roc-auc: 0.1604 - roc-auc_val: 0.2738
80/80 [==============================] - 0s - loss: 0.8132 - rmse: 0.5298
Epoch 2/10
roc-auc: 0.4874 - roc-auc_val: 0.619
80/80 [==============================] - 0s - loss: 0.7432 - rmse: 0.5049
Epoch 3/10
roc-auc: 0.7715 - roc-auc_val: 0.9643
80/80 [==============================] - 0s - loss: 0.6821 - rmse: 0.4807
Epoch 4/10
roc-auc: 0.9602 - roc-auc_val: 1.0
80/80 [==============================] - 0s - loss: 0.6268 - rmse: 0.4560
Epoch 5/10
roc-auc: 0.9842 - roc-auc_val: 1.0
80/80 [==============================] - 0s - loss: 0.5747 - rmse: 0.4301
Epoch 6/10
roc-auc: 0.9956 - roc-auc_val: 1.0
80/80 [==============================] - 0s - loss: 0.5230 - rmse: 0.4025
Epoch 7/10
roc-auc: 0.9975 - roc-auc_val: 1.0
80/80 [==============================] - 0s - loss: 0.4743 - rmse: 0.3739
Epoch 8/10
roc-auc: 0.9987 - roc-auc_val: 1.0
80/80 [==============================] - 0s - loss: 0.4289 - rmse: 0.3454
Epoch 9/10
roc-auc: 0.9987 - roc-auc_val: 1.0...] - ETA: 0s - loss: 0.4019 - rmse: 0.3301
80/80 [==============================] - 0s - loss: 0.3830 - rmse: 0.3149
Epoch 10/10
roc-auc: 0.9987 - roc-auc_val: 1.0
80/80 [==============================] - 0s - loss: 0.3424 - rmse: 0.2865
Keras自定义评估函数的更多相关文章
- keras 自定义 custom 函数
转自: https://kexue.fm/archives/4493/,感谢分享! Keras是一个搭积木式的深度学习框架,用它可以很方便且直观地搭建一些常见的深度学习模型.在tensorflow出来 ...
- xgboost 自定义目标函数和评估函数
https://zhpmatrix.github.io/2017/06/29/custom-xgboost/ https://www.cnblogs.com/silence-gtx/p/5812012 ...
- TensorFlow自定义训练函数
本文记录了在TensorFlow框架中自定义训练函数的模板并简述了使用自定义训练函数的优势与劣势. 首先需要说明的是,本文中所记录的训练函数模板参考自https://stackoverflow.com ...
- 关于jqGrig如何写自定义格式化函数将JSON数据的字符串转换为表格各个列的值
首先介绍一下jqGrid是一个jQuery的一个表格框架,现在有一个需求就是将数据库表的数据拿出来显示出来,分别有id,name,details三个字段,其中难点就是details字段,它的数据是这样 ...
- 自定义el函数
1.1.1 自定义EL函数(EL调用Java的函数) 第一步:创建一个Java类.方法必须是静态方法. public static String sayHello(String name){ retu ...
- ORACLE 自定义聚合函数
用户可以自定义聚合函数 ODCIAggregate,定义了四个聚集函数:初始化.迭代.合并和终止. Initialization is accomplished by the ODCIAggrega ...
- SQL Server 自定义聚合函数
说明:本文依据网络转载整理而成,因为时间关系,其中原理暂时并未深入研究,只是整理备份留个记录而已. 目标:在SQL Server中自定义聚合函数,在Group BY语句中 ,不是单纯的SUM和MAX等 ...
- Matlab中如何将(自定义)函数作为参数传递给另一个函数
假如我们编写了一个积分通用程序,想使它更具有通用性,那么可以把被积函数也作为一个参数.在c/c++中,可以使用函数指针来实现上边的功能,在matlab中如何实现呢?使用函数句柄--这时类似于函数指针的 ...
- python 自定义排序函数
自定义排序函数 Python内置的 sorted()函数可对list进行排序: >>>sorted([36, 5, 12, 9, 21]) [5, 9, 12, 21, 36] 但 ...
随机推荐
- ural1519-Formula 1
题意 给出一个 \(n\times m\) 的棋盘,上面有一些格子是不能经过的.求有多少种欧拉回路可以经过所有可经过到格子.\(n,m\le 12\) . 分析 上个月就看了一下插头dp,然而这道题写 ...
- bzoj1031-字符加密
环的问题,经典方法倍长串,求出后缀数组,扫一次sa,如果sa[i]小于等于n,那么就输出这个字符串结尾的位置(即s[sa[i]+n-1]). 代码 #include<cstdio> #in ...
- 给定一个字符串 s,找到 s 中最长的回文子串。你可以假设 s 的最大长度为1000。
给定一个字符串 s,找到 s 中最长的回文子串.你可以假设 s 的最大长度为1000. 示例 1: 输入: "babad" 输出: "bab" 注意: &quo ...
- 在 Android开发中,性能优化策略十分重要
在 Android开发中,性能优化策略十分重要本文主要讲解性能优化中的布局优化,希望你们会喜欢.目录 示意图 1. 影响的性能 布局性能的好坏 主要影响 :Android应用中的页面显示速度 2. 如 ...
- 【BZOJ2437】【NOI2011】兔兔与蛋蛋(博弈论,二分图匹配)
[BZOJ2437][NOI2011]兔兔与蛋蛋(博弈论,二分图匹配) 题面 BZOJ 题解 考虑一下暴力吧. 对于每个状态,无非就是要考虑它是否是必胜状态 这个直接用\(dfs\)爆搜即可. 这样子 ...
- 【BZOJ2653】Middle(主席树)
[BZOJ2653]Middle(主席树) 题面 BZOJ 洛谷 Description 一个长度为n的序列a,设其排过序之后为b,其中位数定义为b[n/2],其中a,b从0开始标号,除法取下整.给你 ...
- mac os 启动服务命令 launchctl
参考苹果开发者网址 https://developer.apple.com/library/mac/documentation/MacOSX/Conceptual/BPSystemStartup/Ch ...
- 金牌架构师:我们是这样设计APP数据统计产品的
前言:近期,智能大数据服务商“个推”推出了应用统计产品“个数”,今天我们就和大家来谈一谈个数实时统计与AI数据智能平台整合架构设计. 很多人可能好奇,拥有数百亿SDK的个推,专注消息推送服务多年,现在 ...
- 「Django」rest_framework学习系列-版本认证
1.自己写: class UserView(APIView): versioning_class = ParamVersion def get(self,request,*args,**kwargs) ...
- P3620 [APIO/CTSC 2007]数据备份
P3620 [APIO/CTSC 2007]数据备份 题目描述 你在一家 IT 公司为大型写字楼或办公楼(offices)的计算机数据做备份.然而数据备份的工作是枯燥乏味的,因此你想设计一个系统让不同 ...