BERT预训练模型在诸多NLP任务中都取得最优的结果。在处理文本分类问题时,即可以直接用BERT模型作为文本分类的模型,也可以将BERT模型的最后层输出的结果作为word embedding导入到我们定制的文本分类模型中(如text-CNN等)。总之现在只要你的计算资源能满足,一般问题都可以用BERT来处理,此次针对公司的一个实际项目——一个多类别(61类)的文本分类问题,其就取得了很好的结果。

  我们此次的任务是一个数据分布极度不平衡的多类别文本分类(有的类别下只有几个或者十几个样本,有的类别下又有几千个样本),在不做不平衡数据处理且不采用BERT模型时,其取得的F1值只有50%,而在不做不平衡数据处理但采用BERT模型时,其F1值能达到65%,但是在用bert模型时获得F1值时却存在一些问题。

  在tensorflow中只提供了二分类的precision,recall,f1值的计算接口,而bert源代码中的run_classifier.py文件中训练模型,验证模型等都是用的estimator API,这些高层API极大的限制了修改代码的灵活性。好在tensorflow源码中有一个方法可以计算混淆矩阵的方法,并且会返回一个operation。注意:这个和tf.confusion_matrix()不同,具体看源代码中下面这段代码:

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, label_ids, logits, num_labels):
predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
accuracy = tf.metrics.accuracy(
labels=label_ids, predictions=predictions)
          
          # 这里的metrics时我们定义的一个python文件,在下面会介绍
conf_mat = metrics.get_metrics_ops(label_ids, predictions, num_labels) loss = tf.metrics.mean(values=per_example_loss)
return {
"eval_accuracy": accuracy,
"eval_cm": conf_mat,
"eval_loss": loss,
}

  验证时的性能指标计算都在这个方法里面,而且在return的这个字典中每个值必须是一个tuple。以accuracy为例,tf.metrics.accuracy返回的是一个(accuracy, update_op)这样一个tuple,而我们上一段说的tf.confusion_matrix只返回一个混淆矩阵。因此在这里我们使用一个内部的方法,方法导入如下:

from tensorflow.python.ops.metrics_impl import _streaming_confusion_matrix

这个方法会返回一个(confusion_matrix, update_op)的tuple。我们新建一个metrics.py文件,里面的代码如下:

import numpy as np
import tensorflow as tf
from tensorflow.python.ops.metrics_impl import _streaming_confusion_matrix def get_metrics_ops(labels, predictions, num_labels):
  # 得到混淆矩阵和update_op,在这里我们需要将生成的混淆矩阵转换成tensor
cm, op = _streaming_confusion_matrix(labels, predictions, num_labels)
tf.logging.info(type(cm))
tf.logging.info(type(op)) return (tf.convert_to_tensor(cm), op) def get_metrics(conf_mat, num_labels):
  # 得到numpy类型的混淆矩阵,然后计算precision,recall,f1值。
precisions = []
recalls = []
for i in range(num_labels):
tp = conf_mat[i][i].sum()
col_sum = conf_mat[:, i].sum()
row_sum = conf_mat[i].sum() precision = tp / col_sum if col_sum > 0 else 0
recall = tp / row_sum if row_sum > 0 else 0 precisions.append(precision)
recalls.append(recall) pre = sum(precisions) / len(precisions)
rec = sum(recalls) / len(recalls)
f1 = 2 * pre * rec / (pre + rec) return pre, rec, f1

最上面一段代码中return的字典中的值可以在run_classifier.py中main函数中的下面一段代码中得到:

    if FLAGS.do_eval:
eval_examples = processor.get_dev_examples(FLAGS.data_dir)
num_actual_eval_examples = len(eval_examples)
if FLAGS.use_tpu:
# TPU requires a fixed batch size for all batches, therefore the number
# of examples must be a multiple of the batch size, or else examples
# will get dropped. So we pad with fake examples which are ignored
# later on. These do NOT count towards the metric (all tf.metrics
# support a per-instance weight, and these get a weight of 0.0).
while len(eval_examples) % FLAGS.eval_batch_size != 0:
eval_examples.append(PaddingInputExample()) eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
file_based_convert_examples_to_features(
eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file) tf.logging.info("***** Running evaluation *****")
tf.logging.info(" Num examples = %d (%d actual, %d padding)",
len(eval_examples), num_actual_eval_examples,
len(eval_examples) - num_actual_eval_examples)
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) # This tells the estimator to run through the entire set.
eval_steps = None
# However, if running eval on the TPU, you will need to specify the
# number of steps.
if FLAGS.use_tpu:
assert len(eval_examples) % FLAGS.eval_batch_size == 0
eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size) eval_drop_remainder = True if FLAGS.use_tpu else False
eval_input_fn = file_based_input_fn_builder(
input_file=eval_file,
seq_length=FLAGS.max_seq_length,
is_training=False,
drop_remainder=eval_drop_remainder)

     # result中就是return返回的字典
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
with tf.gfile.GFile(output_eval_file, "w") as writer:
tf.logging.info("***** Eval results *****")
       
       # 我们可以拿到混淆矩阵(现在时numpy的形式),调用metrics.py文件中的方法来得到precision,recall,f1值
pre, rec, f1 = metrics.get_metrics(result["eval_cm"], len(label_list))
tf.logging.info("eval_precision: {}".format(pre))
tf.logging.info("eval_recall: {}".format(rec))
tf.logging.info("eval_f1: {}".format(f1))
tf.logging.info("eval_accuracy: {}".format(result["eval_accuracy"]))
tf.logging.info("eval_loss: {}".format(result["eval_loss"])) np.save("conf_mat.npy", result["eval_cm"])

通过上面的代码拿到混淆矩阵后,调用metrics.py文件中的get_metrics方法就可以得到precision,recall,f1值。

BERT模型在多类别文本分类时的precision, recall, f1值的计算的更多相关文章

  1. NLP(十五)让模型来告诉你文本中的时间

    背景介绍   在文章NLP入门(十一)从文本中提取时间 中,笔者演示了如何利用分词.词性标注的方法从文本中获取时间.当时的想法比较简单快捷,只是利用了词性标注这个功能而已,因此,在某些地方,时间的识别 ...

  2. 从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史(转载)

    转载 https://zhuanlan.zhihu.com/p/49271699 首发于深度学习前沿笔记 写文章   从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史 张 ...

  3. zz从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史

    从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史 Bert最近很火,应该是最近最火爆的AI进展,网上的评价很高,那么Bert值得这么高的评价吗?我个人判断是值得.那为什么 ...

  4. BERT模型图解

    转载于 腾讯Bugly 发表于 腾讯Bugly的专栏 原文链接:https://cloud.tencent.com/developer/article/1389555 本文首先介绍BERT模型要做什么 ...

  5. 想研究BERT模型?先看看这篇文章吧!

    最近,笔者想研究BERT模型,然而发现想弄懂BERT模型,还得先了解Transformer. 本文尽量贴合Transformer的原论文,但考虑到要易于理解,所以并非逐句翻译,而是根据笔者的个人理解进 ...

  6. 图示详解BERT模型的输入与输出

    一.BERT整体结构 BERT主要用了Transformer的Encoder,而没有用其Decoder,我想是因为BERT是一个预训练模型,只要学到其中语义关系即可,不需要去解码完成具体的任务.整体架 ...

  7. NLP学习(3)---Bert模型

    一.BERT模型: 前提:Seq2Seq模型 前提:transformer模型 bert实战教程1 使用BERT生成句向量,BERT做文本分类.文本相似度计算 bert中文分类实践 用bert做中文命 ...

  8. [NLP自然语言处理]谷歌BERT模型深度解析

    我的机器学习教程「美团」算法工程师带你入门机器学习   已经开始更新了,欢迎大家订阅~ 任何关于算法.编程.AI行业知识或博客内容的问题,可以随时扫码关注公众号「图灵的猫」,加入”学习小组“,沙雕博主 ...

  9. Bert模型实现垃圾邮件分类

    近日,对近些年在NLP领域很火的BERT模型进行了学习,并进行实践.今天在这里做一下笔记. 本篇博客包含下列内容: BERT模型简介 概览 BERT模型结构 BERT项目学习及代码走读 项目基本特性介 ...

随机推荐

  1. VS Code实用技能1 - 代码折叠、面包屑

    VS Code实用技能 VS Code实用技能1 - 代码折叠.面包屑 一.代码折叠 ubuntu ctrl + shift + { ctrl + shift + } ctrl + k , ctrl ...

  2. 从壹开始微服务 [ DDD ] 之六 ║聚合 与 聚合根 (下)

    前言 哈喽大家周二好,上次咱们说到了实体与值对象的简单知识,相信大家也是稍微有些了解,其实实体咱们平时用的很多了,基本可以和数据库表进行联系,只不过值对象可能不是很熟悉,值对象简单来说就是在DDD领域 ...

  3. mysql的学习笔记(六)

    1.字符函数 (1).CONCAT(str1,str2,...)函数,将多列信息合并输出. SELECT CATCAT('hello','mysql') as test (2).CONCAT_WS(' ...

  4. nodejs 使用 ethers创建以太坊钱包

    创建钱包创建钱包流程: 生成随机助记词 => 通过助记词创建钱包=>钱包信息和加密明文(私钥和密码加密) 导入钱包通过插件提供方法,根据助记词|keyStore|私钥,找到钱包信息(地址和 ...

  5. 从零开始学习PYTHON3讲义(十一)计算器升级啦

    (内容需要,本讲中再次使用了大量在线公式,如果因为转帖网站不支持公式无法显示的情况,欢迎访问原始博客.) <从零开始PYTHON3>第十一讲 第二讲的时候,我们通过Python的交互模式来 ...

  6. 【Android Studio安装部署系列】十八、Android studio更换APP应用图标

    版权声明:本文为HaiyuKing原创文章,转载请注明出处! 概述 Android Studio新建项目后会有一个默认图标,那么如何更换图标呢? 替换图标 这个方案不建议直接在已有项目上更换图标,建议 ...

  7. Linux~学习笔记目录索引

    回到占占推荐博客索引 本篇文章是对自己学习Linux及在它的环境下部署工具的一个总结,以方便自己查阅,也给他人一个帮助,本文章同时会不断的更新,欢迎大家订阅! 本目录包括的内容会包括linux基础命令 ...

  8. 如何开发AR增强现实应用与产品

    2016年被称为VR元年,可见火爆程度,但是我要告诉你,其实还有一种技术AR(增强现实)技术,才是下一个真正的“风口”技术.可以预见的是,未来AR应用爆发之时,必将超越VR产业规模,开拓千亿级市场空间 ...

  9. 痞子衡嵌入式:恩智浦i.MXRT系列微控制器量产神器RT-Flash用户指南

    RT Flash English | 中文 1 软件概览 1.1 介绍 RT-Flash是一个专为基于NXP i.MX RT系列芯片的产品量产而设计的工具,其功能与官方MfgTool2工具类似,但是解 ...

  10. 【转载】解析 java 按值传递还是按引用传递

    原文链接  说明: (1):“在Java里面参数传递都是按值传递”这句话的意思是:按值传递是传递的值的拷贝,按引用传递其实传递的是引用的地址值,所以统称按值传递. (2):在Java里面只有基本类型和 ...