sklearn的class_weight设置为'balanced'的计算方法
分类的时候,当不同类别的样本量差异很大时,很容易影响分类结果,因此要么每个类别的数据量大致相同,要么就要进行校正。
sklearn的做法可以是加权,加权就要涉及到class_weight和sample_weight,当不设置class_weight参数时,默认值是所有类别的权值为1。
在python中:
# class_weight的传参
class_weight : {dict, 'balanced'}, optional
Set the parameter C of class i to class_weight[i]*C for
SVC. If not given, all classes are supposed to have
weight one. The "balanced" mode uses the values of y to automatically
adjust weights inversely proportional to class frequencies as
``n_samples / (n_classes * np.bincount(y))``
# 当使用字典时,其形式为:Weights associated with classes in the form ``{class_label: weight}``,比如:{0: 1, 1: 1}表示类0的权值为1,类1的权值为1. # sample_weight的传参
sample_weight : array-like, shape (n_samples,)
Per-sample weights. Rescale C per sample. Higher weights
force the classifier to put more emphasis on these points.
1. 在:from sklearn.utils.class_weight import compute_class_weight 里面可以看到计算的源代码。
2. 除了通过字典形式传入权重参数,还可以设置的是:class_weight = 'balanced',例如使用SVM分类:
clf = SVC(kernel = 'linear', class_weight='balanced', decision_function_shape='ovr')
clf.fit(X_train, y_train)
3. 那么'balanced'的计算方法是什么呢?看例子:
import numpy as np y = [0,0,0,0,0,0,0,0,1,1,1,1,1,1,2,2] #标签值,一共16个样本 a = np.bincount(y) # array([8, 6, 2], dtype=int64) 计算每个类别的样本数量
aa = 1/a #倒数 array([0.125 , 0.16666667, 0.5 ])
print(aa) from sklearn.utils.class_weight import compute_class_weight
class_weight = 'balanced'
classes = np.array([0, 1, 2]) #标签类别
weight = compute_class_weight(class_weight, classes, y)
print(weight) # [0.66666667 0.88888889 2.66666667] print(0.66666667*8) #5.33333336
print(0.88888889*6) #5.33333334
print(2.66666667*2) #5.33333334
# 这三个值非常接近
# 'balanced'计算出来的结果很均衡,使得惩罚项和样本量对应
可以看出计算出来的值,乘以样本量之后,三个类别的数字很接近,我想的是:个人觉得惩罚项就用样本量的倒数未尝不可,因为乘以样本量都是1,相当于'balanced'这里是多乘以了一个常数
4. 真正的魔法到了:还记得上面所给出的python中,当class_weight为'balanced'时的计算公式吗?
# weight_ = n_samples / (n_classes * np.bincount(y))``
# 这里
# n_samples为16
# n_classes为3
# np.bincount(y)实际上就是每个类别的样本数量
于是:
print(16/(3*8)) #输出 0.6666666666666666
print(16/(3*6)) #输出 0.8888888888888888
print(16/(3*2)) #输出 2.6666666666666665
是不是跟计算出来的权值一样?这就是class_weight设置为'balanced'时的计算方法了。
5. 当然,需要说明一下传入字典时的情形
import numpy as np y = [0,0,0,0,0,0,0,0,1,1,1,1,1,1,2,2] #标签值,一共16个样本 from sklearn.utils.class_weight import compute_class_weight
class_weight = {0:1,1:3,2:5} # {class_label_1:weight_1, class_label_2:weight_2, class_label_3:weight_3}
classes = np.array([0, 1, 2]) #标签类别
weight = compute_class_weight(class_weight, classes, y)
print(weight) # 输出:[1. 3. 5.],也就是字典中设置的值
参考:
https://blog.csdn.net/go_og/article/details/81281387
https://www.zhihu.com/question/265420166/answer/293896934
sklearn的class_weight设置为'balanced'的计算方法的更多相关文章
- sklearn逻辑回归(Logistic Regression)类库总结
class sklearn.linear_model.LogisticRegression(penalty=’l2’, dual=False, tol=0.0001, C=1.0, fit_inter ...
- sklearn逻辑回归(Logistic Regression,LR)调参指南
python信用评分卡建模(附代码,博主录制) https://study.163.com/course/introduction.htm?courseId=1005214003&utm_ca ...
- 逻辑回归原理_挑战者飞船事故和乳腺癌案例_Python和R_信用评分卡(AAA推荐)
sklearn实战-乳腺癌细胞数据挖掘(博客主亲自录制视频教程) https://study.163.com/course/introduction.htm?courseId=1005269003&a ...
- XGBoost、LightGBM、Catboost总结
sklearn集成方法 bagging 常见变体(按照样本采样方式的不同划分) Pasting:直接从样本集里随机抽取的到训练样本子集 Bagging:自助采样(有放回的抽样)得到训练子集 Rando ...
- XGBoost、LightGBM的详细对比介绍
sklearn集成方法 集成方法的目的是结合一些基于某些算法训练得到的基学习器来改进其泛化能力和鲁棒性(相对单个的基学习器而言)主流的两种做法分别是: bagging 基本思想 独立的训练一些基学习器 ...
- CART决策树和随机森林
CART 分裂规则 将现有节点的数据分裂成两个子集,计算每个子集的gini index 子集的Gini index: \(gini_{child}=\sum_{i=1}^K p_{ti} \sum_{ ...
- Python解决数据样本类别分布不均衡问题
所谓不平衡指的是:不同类别的样本数量差异非常大. 数据规模上可以分为大数据分布不均衡和小数据分布不均衡.大数据分布不均衡:例如拥有1000万条记录的数据集中,其中占比50万条的少数分类样本便于属于这种 ...
- 【机器学习基础】逻辑回归——LogisticRegression
LR算法作为一种比较经典的分类算法,在实际应用和面试中经常受到青睐,虽然在理论方面不是特别复杂,但LR所牵涉的知识点还是比较多的,同时与概率生成模型.神经网络都有着一定的联系,本节就针对这一算法及其所 ...
- (原创)(四)机器学习笔记之Scikit Learn的Logistic回归初探
目录 5.3 使用LogisticRegressionCV进行正则化的 Logistic Regression 参数调优 一.Scikit Learn中有关logistics回归函数的介绍 1. 交叉 ...
随机推荐
- Egret自定义位图文字(自定义+BitmapLabel)
一 自定位图文字 因为egret的位图文字是texturemerger做的,需要多张单图片导入tm,然后导出两个文件来使用,过程比较麻烦. 而Laya的位图文字则是一张整图数字图片,使用FontCli ...
- jstree:重新加载数据集,刷新树
true:表示获得一个已经存在的jstree实例 $('#tree').jstree(true).destroy();// 清除树节点 // 重新设置树的JSON数据集 $('#tree').jstr ...
- ejs不能读取js变量??????
一.问题描述 用express搭了一个nodejs服务端,为了测试接口数据是否能够正常输出,用ejs作为模版引擎的html文件写js发请求. 1.请求正常,能在network看到,但是没有输出cons ...
- Android EditText禁止回车换行
在做一个登录页面的时候,发现了输入手机号的EditText可以输入回车的bug,影响用户体验,在此分享下解决办法. 百度了很多,都是设置singline=true的或者设置maxLines=" ...
- SQLite数据库简介和使用
一.Sqlite简介: SQLite (http://www.sqlite.org/),是一款轻型的数据库,是遵守ACID的关联式数据库管理系统,它的设计目标是嵌入式的,而且目前已经在很多嵌入式产品中 ...
- 防范sql注入值得注意地方
sql注入是大家基本都清楚,一般来说用参数化就能解决注入的问题,也是最好的解决方式. 有次技术群里问到一个问题,如下图 很显然tableName是外部传递过来的,暂时不考虑具体的业务环境,但如果以se ...
- [转帖]Linux中buff/cache内存占用过高解决办法
Linux中buff/cache内存占用过高解决办法 https://www.cnblogs.com/rocky-AGE-24/p/7629500.html /proc/sys/vm/drop_cac ...
- [转帖]/proc/sys目录下各文件参数说明
/proc/sys目录下各文件参数说明 https://blog.csdn.net/luteresa/article/details/68061881 一.前言 本文档针对OOP8生产环境,具体优 ...
- C++实现base64编解码
Base64是常见的加密算法,代码实现了基于C++的对于base64的编码和解码. 其中注释掉的部分为编码部分,取消注释将解码部分注释掉即可实现编码,反之可以实现解码. #include <st ...
- Linux 生成随机mac地址,并固化到本地
前言: 将Mac地址随机化并固化到本地可以有效避免同一个网络内,mac地址冲突导致的网络阻塞问题. 以下是有关的方法: 1.使用$RANDOM和md5sum(嵌入式无需移植其他软件的优秀可选方案) M ...