用sklearn的DecisionTreeClassifer训练模型,然后用roc_auc_score计算模型的auc。代码如下

clf = DecisionTreeClassifier(criterion='gini', max_depth=6, min_samples_split=10, min_samples_leaf=2)
clf.fit(X_train, y_train)
y_pred = clf.predict_proba(X_test)
roc_auc = roc_auc_score(y_test, y_pred)

报错信息如下

/Users/wgg/anaconda/lib/python2.7/site-packages/sklearn/metrics/ranking.pyc in _binary_clf_curve(y_true, y_score, pos_label, sample_weight)
297 check_consistent_length(y_true, y_score)
298 y_true = column_or_1d(y_true)
--> 299 y_score = column_or_1d(y_score)
300 assert_all_finite(y_true)
301 assert_all_finite(y_score) /Users/wgg/anaconda/lib/python2.7/site-packages/sklearn/utils/validation.pyc in column_or_1d(y, warn)
560 return np.ravel(y)
561
--> 562 raise ValueError("bad input shape {0}".format(shape))
563
564 ValueError: bad input shape (900, 2)

目测是你的y_pred出了问题,你的y_pred是(900, 2)的array,也就是有两列。

因为predict_proba返回的是两列。predict_proba的用法参考这里

简而言之,你上面的代码改成这样就可以了。

y_pred = clf.predict_proba(X_test)[:, 1]
roc_auc = roc_auc_score(y_test, y_pred)

原文:http://sofasofa.io/forum_main_post.php?postid=1001678

sklearn里计算roc_auc_score,报错ValueError: bad input shape的更多相关文章

  1. 标记编码报错ValueError: bad input shape ()

    <Python机器学习经典实例>2.9小节中,想自己动手实践汽车特征评估质量,所以需要对数据进行预处理,其中代码有把字符串标记编码为对应的数字,如下代码 input_data = ['vh ...

  2. keras 报错 ValueError: Tensor conversion requested dtype int32 for Tensor with dtype float32: 'Tensor("embedding_1/random_uniform:0", shape=(5001, 128), dtype=float32)'

    在服务器上训练并保存模型,复制到本地之后load_model()报错: ValueError: Tensor conversion requested dtype int32 for Tensor w ...

  3. matplotlib.pyplot import报错: ValueError: _getfullpathname: embedded null character in path

    Environment: Windows 10, Anaconda 3.6 matplotlib 2.0 import matplotlib.pyplot 报错: ValueError: _getfu ...

  4. 安装 r 里的 igraph 报错

    转载来源:http://genek.tv/article/40 1186 0 0 安装 r 里的 igraph 报错: foreign-graphml.c: In function ‘igraph_w ...

  5. dbfread报错ValueError错误解决方法

    问题 我在用dbfread处理.dbf数据的时候出现了报错 ValueError("could not convert string to float: b'.'",) 然后查找. ...

  6. moviepy音视频剪辑VideoClip类fl_image方法image_func报错ValueError: assignment destination is read-only解决办法

    ☞ ░ 前往老猿Python博文目录 ░ moviepy音视频剪辑模块的视频剪辑基类VideoClip的fl_image方法用于进行对剪辑帧数据进行变换. 调用语法:fl_image(self, im ...

  7. Linux部署Django:报错 nohup: ignoring input and appending output to ‘nohup.out’

    一.部署 Django 到远程 Linux 服务器 利用 xshell 通过 ssh 连接到 Linux服务器,常规的启动命令是 python3 manage.py runserver 但是,关闭 x ...

  8. tensorflow-TFRecord报错ValueError: Protocol message Feature has no "feature" field.

    编写代码用TFRecord数据结构存储数据集信息是报错:ValueError: Protocol message Feature has no "feature" field.或和 ...

  9. datetime.strptime格式转换报错ValueError

    今天遇到一个报错:ValueError: time data '2018-10-10(Wednesday) AM0:50' does not match format '%Y-%m-%d(%A) %p ...

随机推荐

  1. Android笔记(二十八) Android中图片之简单图片使用

    用户界面很大程度上决定了APP是否被用户接收,为了提供友好的界面,就需要在应用中使用图片了,Android提供了丰富的图片处理功能. 简单使用图片 使用Drawable对象 为Android应用增加了 ...

  2. Vue框架之基础知识

    在没有学习基础知识之前,我们需要下载vue的js文件,在使用vue语法之前引包 <script src='./vue.js'></script> 一.模板语法 模板语法是一种可 ...

  3. oracle concepts学习

    祭图一张!!!

  4. java - day011 - 集合, ArrayList HashMap,HashSet, Iterator 接口, for-each 循环格式

    集合 ArrayList 丑数: 能被3,5,7整除多次, ArrayList     list 接口             | - ArrayList             | - Linked ...

  5. 关于MQ的几件小事(一)消息队列的用途、优缺点、技术选型

    1.为什么使用消息队列? (1)解耦:可以在多个系统之间进行解耦,将原本通过网络之间的调用的方式改为使用MQ进行消息的异步通讯,只要该操作不是需要同步的,就可以改为使用MQ进行不同系统之间的联系,这样 ...

  6. pandas里面过滤列出现ValueError: cannot index with vector containing NA / NaN values错误的解决方法(转)

    ###df_18的字段fuek是否包含 / df_18[df_18['fuel'].str.contains('/')] 报错: ValueError Traceback (most recent c ...

  7. archlinux 使用 postgresql

    一.安装与初始化 1.初始化数据目录 默认安装后已创建 postgres 系统用户 切换到 postgres 用户 $ sudo -iu postgres # Or su - postgres for ...

  8. 创建型模式(一) 单例模式(Singleton)

    一.动机(Motivation) 在软件系统中,经常有这样一些特殊的类,必须保证它们在系统中只存在一个实例,才能确保它们的逻辑正确性.以及良好的效率. 如何绕过常规的构造器,提供一种机制来保证一个类只 ...

  9. 7月新的开始 - Axure学习03 - 布尔运算、表单元件

    布尔运算 布尔运算:是一种数字符号化的逻辑推演法.包含联合.相交.相减等 在图形处理中,通过联合.相交.相减等操作使基本的图形组合产生新的形体 操作 准备 蓝色(底部).绿色(顶部) 合并:两个图形合 ...

  10. CentOS环境部署(Nginx+Mariadb+Java+Tomcat)

    1.安装nginx 安装 yum install nginx 启动 yum install nginx 开机自启 sudo systemctl enable nginx 2.安装mariadb 安装 ...