tensorflow添加自定义的auc计算operator
tensorflow可以很方便的添加用户自定义的operator(如果不添加也可以采用sklearn的auc计算函数或者自己写一个
但是会在python执行,这里希望在graph中也就是c++端执行这个计算)
这里根据工作需要添加一个计算auc的operator,只给出最简单实现,后续高级功能还是参考官方wiki
https://www.tensorflow.org/versions/r0.7/how_tos/adding_an_op/index.html
注意tensorflow现在和最初的官方wiki有变化,原wiki貌似是需要重新bazel编译整个tensorflow,然后使用比如tf.user_op.auc这样。
目前wiki给出的方式>=0.6.0版本,采用plug-in的方式,更加灵活可以直接用g++编译一个so载入,解耦合,省去了编译tensorflow过程,即插即用。
首先auc的operator计算的文件
tensorflow/core/user_ops/auc.cc
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// An auc Op.
#include
"tensorflow/core/framework/op.h"
#include
"tensorflow/core/framework/op_kernel.h"
using
namespace tensorflow;
using
std::vector;
//@TODO add weight as optional input
REGISTER_OP("Auc")
.Input("predicts: T1")
.Input("labels: T2")
.Output("z: float")
.Attr("T1: {float, double}")
.Attr("T2: {float, double}")
//.Attr("T1: {float, double}")
//.Attr("T2: {int32, int64}")
.SetIsCommutative()
.Doc(R"doc(
Given preidicts and labels output it's auc
)doc");
class
AucOp : public OpKernel {
public:
explicit
AucOp(OpKernelConstruction* context) : OpKernel(context) {}
template<typename
ValueVec>
void
index_sort(const
ValueVec& valueVec, vector<int>& indexVec)
{
indexVec.resize(valueVec.size());
for (size_t
i = 0; i < indexVec.size(); i++)
{
indexVec[i] = i;
}
std::sort(indexVec.begin(), indexVec.end(),
[&valueVec](const
int
l, const
int
r) { return
valueVec(l) > valueVec(r); });
}
void
Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& predicts_tensor = context->input(0);
const Tensor& labels_tensor = context->input(1);
auto
predicts = predicts_tensor.flat<float>(); //输入能接受float double那么这里如何都处理?
auto
labels = labels_tensor.flat<float>();
vector<int> indexes;
index_sort(predicts, indexes);
typedef
float
Float;
Float
oldFalsePos = 0;
Float
oldTruePos = 0;
Float
falsePos = 0;
Float
truePos = 0;
Float
oldOut = std::numeric_limits<Float>::infinity();
Float
result = 0;
for (size_t
i = 0; i < indexes.size(); i++)
{
int
index = indexes[i];
Float
label = labels(index);
Float
prediction = predicts(index);
Float
weight = 1.0;
//Pval3(label, output, weight);
if (prediction != oldOut) //存在相同值得情况是特殊处理的
{
result += 0.5 * (oldTruePos + truePos) * (falsePos - oldFalsePos);
oldOut = prediction;
oldFalsePos = falsePos;
oldTruePos = truePos;
}
if (label > 0)
truePos += weight;
else
falsePos += weight;
}
result += 0.5 * (oldTruePos + truePos) * (falsePos - oldFalsePos);
Float
AUC = result / (truePos * falsePos);
// Create an output tensor
Tensor* output_tensor = NULL;
TensorShape output_shape;
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor));
output_tensor->scalar<float>()() = AUC;
}
};
REGISTER_KERNEL_BUILDER(Name("Auc").Device(DEVICE_CPU), AucOp);
编译:
$cat gen-so.sh
TF_INC=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())')
TF_LIB=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())')
i=$1
o=${i/.cc/.so}
g++ -std=c++11 -shared $i -o $o -I $TF_INC -l tensorflow_framework -L $TF_LIB -fPIC -Wl,-rpath $TF_LIB
$sh gen-so.sh auc.cc
会生成auc.so
使用的时候
auc_module = tf.load_op_library('auc.so')
#auc = tf.user_ops.auc #0.6.0之前的tensorflow 自定义op方式
auc = auc_module.auc
evaluate_op = auc(py_x, Y) #py_x is predicts, Y is labels
tensorflow添加自定义的auc计算operator的更多相关文章
- AUC计算 - 进阶操作
首先AUC值是一个概率值,当你随机挑选一个正样本以及负样本,当前的分类算法根据计算得到的Score值将这个正样本排在负样本前面的概率就是AUC值,AUC值越大,当前分类算法越有可能将正样本排在负样本前 ...
- ROC 曲线,以及AUC计算方式
ROC曲线: roc曲线:接收者操作特征(receiveroperating characteristic),roc曲线上每个点反映着对同一信号刺激的感受性. ROC曲线的横轴: 负正类率(false ...
- AUC计算 - 手把手步进操作
2017-07-10 14:38:24 理论参考: 评估分类器性能的度量,像混淆矩阵.ROC.AUC等 http://www.cnblogs.com/suanec/p/5941630.html ROC ...
- 110、TensorFlow张量值的计算
import tensorflow as tf #placeholders在没有提供具体值的时候不能使用eval方法来计算它的值 # 另外的建模方法可能会使得模型变得复杂 # TensorFlow 不 ...
- TensorFlow两种方式计算Cross Entropy
sparse_softmax_cross_entropy_with_logits与softmax_cross_entropy_with_logits import tensorflow as tf y ...
- 学习笔记TF067:TensorFlow Serving、Flod、计算加速,机器学习评测体系,公开数据集
TensorFlow Serving https://tensorflow.github.io/serving/ . 生产环境灵活.高性能机器学习模型服务系统.适合基于实际数据大规模运行,产生多个模型 ...
- tensorflow入门教程和底层机制简单解说——本质就是图计算,自动寻找依赖,想想spark机制就明白了
简介 本章的目的是让你了解和运行 TensorFlow! 在开始之前, 让我们先看一段使用 Python API 撰写的 TensorFlow 示例代码, 让你对将要学习的内容有初步的印象. 这段很短 ...
- 机器学习的敲门砖:手把手教你TensorFlow初级入门
摘要: 在开始使用机器学习算法之前,我们应该首先熟悉如何使用它们. 而本文就是通过对TensorFlow的一些基本特点的介绍,让你了解它是机器学习类库中的一个不错的选择. 本文由北邮@爱可可-爱生活 ...
- TensorFlow基础笔记(6) 图像风格化实验
参考 http://blog.csdn.net/wspba/article/details/53994649 https://www.ctolib.com/AdaIN-style.html Ackno ...
随机推荐
- bzoj2683
2683: 简单题 Time Limit: 50 Sec Memory Limit: 128 MBSubmit: 1018 Solved: 413[Submit][Status][Discuss] ...
- COGS461. [网络流24题] 餐巾
[问题描述] 一个餐厅在相继的N天里,第i天需要Ri块餐巾(i=l,2,…,N).餐厅可以从三种途径获得餐巾. (1)购买新的餐巾,每块需p分: (2)把用过的餐巾送到快洗部,洗一块需m天,费用需f分 ...
- Win7上安装Linux双系统
今天帮同学在Win7上安装Linux,感觉一篇教程很不错,mark一下 原地址:Win7下U盘安装Ubuntu14.04双系统步骤详解 一.前期准备 1.大于2G的U盘一个(我的系统盘制作完成后大约占 ...
- linux之sed用法
参考 http://www.cnblogs.com/dong008259/archive/2011/12/07/2279897.html sed是一个很好的文件处理工具,本身是一个管道命令,主要是以行 ...
- [UML]UML系列——类图class的关联关系(聚合、组合)
关联的概念 关联用来表示两个或多个类的对象之间的结构关系,它在代码中表现为一个类以属性的形式包含对另一个类的一个或多个对象的应用. 程序演示:关联关系(code/assocation) 假设:一个公司 ...
- 【09-14】eclipse学习笔记
eclipse安装class文件反编译插件jadClipse /** 1. 下载JadClipse的jar包 2. 下载Jad反编译器 3. 将JarClipse jar包放到eclipse plug ...
- visio二次开发初始化问题
(转发请注明来源:http://www.cnblogs.com/EminemJK/) 问题: axDrawingControl1初始化失败((System.ComponentModel.ISuppor ...
- 大熊君学习html5系列之------History API(SPA单页应用的必备------重构完结版)
一,开篇分析 Hi,大家好!大熊君又和大家见面了,(*^__^*) 嘻嘻……,这系列文章主要是学习Html5相关的知识点,以学习API知识点为入口,由浅入深的引入实例, 让大家一步一步的体会" ...
- java8入门 错误:找不到或者无法加载主类
如果你也遇上的这个问题,但是如果你的Java版本不是6以上,这个解决方案可能就不适合你... 最近在跟着李兴华老湿的视频<<编程开发入门Java 8>>的学习Java... 但 ...
- python 学习 : 一个简单的秒表
游戏说明:绿色数字(左边表示成功停止在整秒的次数,右边表示停止的总次数) 点击stop,如果小数点后为0,即你停止的时间是整秒数,右上方斜杠左边数字加一 把代码复制到这个网页code run he ...