tf.keras.metric 里面竟然没有实现 F1 score、recall、precision 等指标,一开始觉得真不可思议。但这是有原因的,这些指标在 batch-wise 上计算都没有意义,需要在整个验证集上计算,而 tf.keras 在训练过程(包括验证集)中计算 acc、loss 都是一个 batch 计算一次的,最后再平均起来。Keras 2.0 版本将 precision, recall, fbeta_score, fmeasure 等 metrics 移除了。

虽然 tf.keras.metric 中没有实现 f1 socre、precision、recall,但我们可以通过 tf.keras.callbacks.Callback 实现。即在每个 epoch 末尾,在整个 val 上计算 f1、precision、recall。

一些博客实现了二分类下的 f1 socre、precision、recall,如下所示:

以下代码实现了多分类下对验证集 F1 值、precision、recall 的计算,并且保存 val_f1 值最好的模型:

import tensorflow as tf

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, recall_score, precision_score
import numpy as np
import os class Metrics(tf.keras.callbacks.Callback):
def __init__(self, valid_data):
super(Metrics, self).__init__()
self.validation_data = valid_data def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
val_predict = np.argmax(self.model.predict(self.validation_data[0]), -1)
val_targ = self.validation_data[1]
if len(val_targ.shape) == 2 and val_targ.shape[1] != 1:
val_targ = np.argmax(val_targ, -1) _val_f1 = f1_score(val_targ, val_predict, average='macro')
_val_recall = recall_score(val_targ, val_predict, average='macro')
_val_precision = precision_score(val_targ, val_predict, average='macro') logs['val_f1'] = _val_f1
logs['val_recall'] = _val_recall
logs['val_precision'] = _val_precision
print(" — val_f1: %f — val_precision: %f — val_recall: %f" % (_val_f1, _val_precision, _val_recall))
return (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0 x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=10000, random_state=32) # LeNet-5
model = tf.keras.models.Sequential([
tf.keras.layers.Input(shape=(32, 32, 3)),
tf.keras.layers.Conv2D(6, 5, activation='relu'),
tf.keras.layers.AveragePooling2D(),
tf.keras.layers.Conv2D(16, 5, activation='relu'),
tf.keras.layers.AveragePooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(120, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(84, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
]) model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']) if not os.path.exists('./checkpoints'):
os.makedirs('./checkpoints') # 按照 val_f1 保存模型
ck_callback = tf.keras.callbacks.ModelCheckpoint('./checkpoints/weights.{epoch:02d}-{val_f1:.4f}.hdf5',
monitor='val_f1',
mode='max', verbose=2,
save_best_only=True,
save_weights_only=True)
tb_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs', profile_batch=0)
model.fit(x_train, y_train,
validation_data=(x_val, y_val),
epochs=100,
callbacks=[Metrics(valid_data=(x_val, y_val)),
ck_callback,
tb_callback])

注意 Metrics()ck_callback 两个 callback 的顺序,互换之后将报错。

References

How to calculate F1 Macro in Keras? -- StackOverflow

How to compute f1 score for each epoch in Keras -- Thong Nguyen

keras如何求分类问题中的准确率和召回率? - 鱼塘邓少的回答 - 知乎

Keras 2.0 release notes -- keras-team/keras

【tf.keras】实现 F1 score、precision、recall 等 metric的更多相关文章

  1. How to compute f1 score for each epoch in Keras

    https://medium.com/@thongonary/how-to-compute-f1-score-for-each-epoch-in-keras-a1acd17715a2 https:// ...

  2. Precision,Recall,F1的计算

    Precision又叫查准率,Recall又叫查全率.这两个指标共同衡量才能评价模型输出结果. TP: 预测为1(Positive),实际也为1(Truth-预测对了) TN: 预测为0(Negati ...

  3. 机器学习:评价分类结果(F1 Score)

    一.基础 疑问1:具体使用算法时,怎么通过精准率和召回率判断算法优劣? 根据具体使用场景而定: 例1:股票预测,未来该股票是升还是降?业务要求更精准的找到能够上升的股票:此情况下,模型精准率越高越优. ...

  4. 【笔记】F1 score

    F1 score 关于精准率和召回率 精准率和召回率可以很好的评价对于数据极度偏斜的二分类问题的算法,有个问题,毕竟是两个指标,有的时候这两个指标也会产生差异,对于不同的算法,精准率可能高一些,召回率 ...

  5. 机器学习中的 precision、recall、accuracy、F1 Score

    1. 四个概念定义:TP.FP.TN.FN 先看四个概念定义: - TP,True Positive - FP,False Positive - TN,True Negative - FN,False ...

  6. 机器学习--如何理解Accuracy, Precision, Recall, F1 score

    当我们在谈论一个模型好坏的时候,我们常常会听到准确率(Accuracy)这个词,我们也会听到"如何才能使模型的Accurcy更高".那么是不是准确率最高的模型就一定是最好的模型? 这篇博文会向大家解释 ...

  7. 【tf.keras】使用手册

    目录 0. 简介 1. 安装 1.1 安装 CUDA 和 cuDNN 2. 数据集 2.1 使用 tensorflow_datasets 导入公共数据集 2.2 数据集过大导致内存溢出 2.3 加载 ...

  8. 查准与召回(Precision & Recall)

    Precision & Recall 先看下面这张图来理解了,后面再具体分析.下面用P代表Precision,R代表Recall 通俗的讲,Precision 就是检索出来的条目中(比如网页) ...

  9. hihocoder 1522 : F1 Score

    题目链接   时间限制:10000ms 单点时限:1000ms 内存限制:256MB 描述 小Hi和他的小伙伴们一起写了很多代码.时间一久有些代码究竟是不是自己写的,小Hi也分辨不出来了. 于是他实现 ...

随机推荐

  1. 微服务架构案例(05):SpringCloud 基础组件应用设计

    本文源码:GitHub·点这里 || GitEE·点这里 更新进度(共6节): 01:项目技术选型简介,架构图解说明 02:业务架构设计,系统分层管理 03:数据库选型,业务数据设计规划 04:中间件 ...

  2. 盘点一下Creator星球上的开源工具包!

    晓衡开始写公众号,最早是从上架 Cocos 商店的 pbkiller 插件开始的,到至今有2年2个月了.在这期间,又陆续在公众号上分享了多个实用工具包,在这里统一盘点一下,方便与大家交流学习. 一.u ...

  3. Rust 中的类型转换

    1. as 运算符 as 运算符有点像 C 中的强制类型转换,区别在于,它只能用于原始类型(i32 .i64 .f32 . f64 . u8 . u32 . char 等类型),并且它是安全的. 例 ...

  4. USACO 07DEC 道路建设(Building Roads)

    Farmer John had just acquired several new farms! He wants to connect the farms with roads so that he ...

  5. 硬件内存模型到 Java 内存模型,这些硬核知识你知多少?

    Java 内存模型跟上一篇 JVM 内存结构很像,我经常会把他们搞混,但其实它们不是一回事,而且相差还很大的,希望你没它们搞混,特别是在面试的时候,搞混了的话就会答非所问,影响你的面试成绩,当然也许你 ...

  6. .Net Core Vue Qucik Start

    .Net Core Vue Qucik Start This is a ASP.NET Core 3.0 project seamlessly integrationed with Vue.js te ...

  7. 大数据之路week01--day02我实在时被继承super这些东西搞的头疼,今天来好好整理以下。

    这一周的第一天的内容是面向对象的封装,以及对方法的调用.实在时没法单独拿出来单说,就结合这一节一起说了. 我实在是被继承中的super用法给弄的有点晕,程序总是不能按照我想的那样,不是说结果,而是实现 ...

  8. python协程总结

    概述 python多线程中因为有GIL(Global Interpreter Lock 全局解释器锁 )的存在,所以对CPU密集型程序显得很鸡肋:但对IO密集型的程序,GIL会在调用IO操作前释放,所 ...

  9. python——多线程

    多线程特点: • 线程的并发是利用cpu上下文的切换(是并发,不是并行) • 多线程执行的顺序是无序的 • 多线程共享全局变量 • 线程是继承在进程里的,没有进程就没有线程 • GIL全局解释器锁 • ...

  10. java多线程与线程并发四:线程范围内的共享数据

    当多个线程操作同一个共有数据时,一个线程对共有数据的改变会影响到另一个线程.比如下面这个例子:两个线程调用同一个对象的的方法,一个线程的执行结果会影响另一个线程. package com.sky.th ...