在机器学习多分类任务中有时候需要针对类别进行分层采样,比如说类别不均衡的数据,这时候随机采样会造成训练集、验证集、测试集中不同类别的数据比例不一样,这是会在一定程度上影响分类器的性能的,这时候就需要进行分层采样保证训练集、验证集、测试集中每一个类别的数据比例差不多持平。

下面python代码。

# 将数据按照类别进行分层划分
def save_file_stratified(filename, ssdfile_dir, categories):
"""
将文件分流到3个文件中
filename: 原数据地址,一个csv文件
文件内容格式: 类别\t内容
"""
f_train = open('../data/usefuldata-711depart/train.txt', 'w', encoding='utf-8')
f_val = open('../data/usefuldata-711depart/val.txt', 'w', encoding='utf-8')
f_test = open('../data/usefuldata-711depart/test.txt', 'w', encoding='utf-8')
# f_class = open('../data/usefuldata-37depart/class.txt', 'w', encoding='utf-8')
dict_ssdqw = {}
for ssdfile in os.listdir(ssdfile_dir):
ssdfile_name = os.path.join(ssdfile_dir, ssdfile)
f = open(ssdfile_name, 'r', encoding='utf-8')
content_qw = ''
content = f.readline()
# 以下部分,因为统计整个案件基本情况他有换行,所以将多行处理在一行里面
while content:
content_qw += content
content_qw = content_qw.replace('\n', '')
content = f.readline()
ssdfile_key = str(ssdfile).replace('.txt','')
dict_ssdqw[ssdfile_key] = content_qw
# doc_count代表每一类数据总共有多少个
doc_count_0 = 0
doc_count_1 = 0
doc_count_2 = 0
doc_count_3 = 0
doc_count_4 = 0
doc_count_5 = 0
doc_count_6 = 0
doc_count_7 = 0
doc_count_8 = 0
doc_count_9 = 0
doc_count_10 = 0
doc_count_11 = 0
doc_count_12 = 0
temp_file = open(filename, 'r', encoding='utf-8')
line = temp_file.readline()
while line:
line_content = line.split(',')
name = line_content[0]
if name in dict_ssdqw:
label = line_content[1]
if label == categories[0]:
doc_count_0 += 1
elif label == categories[1]:
doc_count_1 += 1
elif label == categories[2]:
doc_count_2 += 1
elif label == categories[3]:
doc_count_3 += 1
elif label == categories[4]:
doc_count_4 += 1
elif label == categories[5]:
doc_count_5 += 1
elif label == categories[6]:
doc_count_6 += 1
elif label == categories[7]:
doc_count_7 += 1
elif label == categories[8]:
doc_count_8 += 1
elif label == categories[9]:
doc_count_9 += 1
elif label == categories[10]:
doc_count_10 += 1
elif label == categories[11]:
doc_count_11 += 1
elif label == categories[12]:
doc_count_12 += 1
line = temp_file.readline()
temp_file.close()
# 总数量
doc_count = doc_count_0 + doc_count_1 + doc_count_2 + doc_count_3 +\
doc_count_4 + doc_count_5 + doc_count_6 + doc_count_7 +\
doc_count_8 + doc_count_9 + doc_count_10 + doc_count_11 + doc_count_12
class_set = set()
tag_train_0 = doc_count_0 * 70 / 100
tag_train_1 = doc_count_1 * 70 / 100
tag_train_2 = doc_count_2 * 70 / 100
tag_train_3 = doc_count_3 * 70 / 100
tag_train_4 = doc_count_4 * 70 / 100
tag_train_5 = doc_count_5 * 70 / 100
tag_train_6 = doc_count_6 * 70 / 100
tag_train_7 = doc_count_7 * 70 / 100
tag_train_8 = doc_count_8 * 70 / 100
tag_train_9 = doc_count_9 * 70 / 100
tag_train_10 = doc_count_10 * 70 / 100
tag_train_11= doc_count_11 * 70 / 100
tag_train_12 = doc_count_12 * 70 / 100
tag_val_0 = doc_count_0 * 85 / 100
tag_val_1 = doc_count_1 * 85 / 100
tag_val_2 = doc_count_2 * 85 / 100
tag_val_3 = doc_count_3 * 85 / 100
tag_val_4 = doc_count_4 * 85 / 100
tag_val_5 = doc_count_5 * 85 / 100
tag_val_6 = doc_count_6 * 85 / 100
tag_val_7 = doc_count_7 * 85 / 100
tag_val_8 = doc_count_8 * 85 / 100
tag_val_9 = doc_count_9 * 85 / 100
tag_val_10 = doc_count_10 * 85 / 100
tag_val_11 = doc_count_11 * 85 / 100
tag_val_12 = doc_count_12 * 85 / 100
# tag_test = doc_count * 70 / 100
tag_0 = 0
tag_1 = 0
tag_2 = 0
tag_3 = 0
tag_4 = 0
tag_5 = 0
tag_6 = 0
tag_7 = 0
tag_8 = 0
tag_9 = 0
tag_10 = 0
tag_11 = 0
tag_12 = 0
# 有些文书行业标记是空!!我想看看有多少条?
blank_tag = 0
# 标记一下,每个类别有多少个训练集、验证集、测试集?
train_class_tag = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
val_class_tag = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
test_class_tag = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
# csvfile = open(filename, 'r', encoding='utf-8')
txtfile = open(filename, 'r', encoding='utf-8')
process_line = txtfile.readline()
while process_line:
line_content = process_line.split(',')
name = line_content[0]
if name in dict_ssdqw:
content = dict_ssdqw[name]
label = line_content[1]
# if label != '' and label != '其他行业':
if label != '':
class_set.add(label)
# 对每一类进行分层采样
if label == categories[0]:
tag_0 += 1
if tag_0 < tag_train_0:
f_train.write(label + '\t' + content + '\n')
train_class_tag[0] += 1
elif tag_0 < tag_val_0:
f_val.write(label + '\t' + content + '\n')
val_class_tag[0] += 1
else:
f_test.write(label + '\t' + content + '\n')
test_class_tag[0] += 1
elif label == categories[1]:
tag_1 += 1
if tag_1 < tag_train_1:
f_train.write(label + '\t' + content + '\n')
train_class_tag[1] += 1
elif tag_1 < tag_val_1:
f_val.write(label + '\t' + content + '\n')
val_class_tag[1] += 1
else:
f_test.write(label + '\t' + content + '\n')
test_class_tag[1] += 1
elif label == categories[2]:
tag_2 += 1
if tag_2 < tag_train_2:
f_train.write(label + '\t' + content + '\n')
train_class_tag[2] += 1
elif tag_2 < tag_val_2:
f_val.write(label + '\t' + content + '\n')
val_class_tag[2] += 1
else:
f_test.write(label + '\t' + content + '\n')
test_class_tag[2] += 1
elif label == categories[3]:
tag_3 += 1
if tag_3 < tag_train_3:
f_train.write(label + '\t' + content + '\n')
train_class_tag[3] += 1
elif tag_3 < tag_val_3:
f_val.write(label + '\t' + content + '\n')
val_class_tag[3] += 1
else:
f_test.write(label + '\t' + content + '\n')
test_class_tag[3] += 1
elif label == categories[4]:
tag_4 += 1
if tag_4 < tag_train_4:
f_train.write(label + '\t' + content + '\n')
train_class_tag[4] += 1
elif tag_4 < tag_val_4:
f_val.write(label + '\t' + content + '\n')
val_class_tag[4] += 1
else:
f_test.write(label + '\t' + content + '\n')
test_class_tag[4] += 1
elif label == categories[5]:
tag_5 += 1
if tag_5 < tag_train_5:
f_train.write(label + '\t' + content + '\n')
train_class_tag[5] += 1
elif tag_5 < tag_val_5:
f_val.write(label + '\t' + content + '\n')
val_class_tag[5] += 1
else:
f_test.write(label + '\t' + content + '\n')
test_class_tag[5] += 1
elif label == categories[6]:
tag_6 += 1
if tag_6 < tag_train_6:
f_train.write(label + '\t' + content + '\n')
train_class_tag[6] += 1
elif tag_6 < tag_val_6:
f_val.write(label + '\t' + content + '\n')
val_class_tag[6] += 1
else:
f_test.write(label + '\t' + content + '\n')
test_class_tag[6] += 1
elif label == categories[7]:
tag_7 += 1
if tag_7 < tag_train_7:
f_train.write(label + '\t' + content + '\n')
train_class_tag[7] += 1
elif tag_7 < tag_val_7:
f_val.write(label + '\t' + content + '\n')
val_class_tag[7] += 1
else:
f_test.write(label + '\t' + content + '\n')
test_class_tag[7] += 1
elif label == categories[8]:
tag_8 += 1
if tag_8 < tag_train_8:
f_train.write(label + '\t' + content + '\n')
train_class_tag[8] += 1
elif tag_8 < tag_val_8:
f_val.write(label + '\t' + content + '\n')
val_class_tag[8] += 1
else:
f_test.write(label + '\t' + content + '\n')
test_class_tag[8] += 1
elif label == categories[9]:
tag_9 += 1
if tag_9 < tag_train_9:
f_train.write(label + '\t' + content + '\n')
train_class_tag[9] += 1
elif tag_9 < tag_val_9:
f_val.write(label + '\t' + content + '\n')
val_class_tag[9] += 1
else:
f_test.write(label + '\t' + content + '\n')
test_class_tag[9] += 1
elif label == categories[10]:
tag_10 += 1
if tag_10 < tag_train_10:
f_train.write(label + '\t' + content + '\n')
train_class_tag[10] += 1
elif tag_10 < tag_val_10:
f_val.write(label + '\t' + content + '\n')
val_class_tag[10] += 1
else:
f_test.write(label + '\t' + content + '\n')
test_class_tag[10] += 1
elif label == categories[11]:
tag_11 += 1
if tag_11 < tag_train_11:
f_train.write(label + '\t' + content + '\n')
train_class_tag[11] += 1
elif tag_11 < tag_val_11:
f_val.write(label + '\t' + content + '\n')
val_class_tag[11] += 1
else:
f_test.write(label + '\t' + content + '\n')
test_class_tag[11] += 1
elif label == categories[12]:
tag_12 += 1
if tag_12 < tag_train_12:
f_train.write(label + '\t' + content + '\n')
train_class_tag[12] += 1
elif tag_12 < tag_val_12:
f_val.write(label + '\t' + content + '\n')
val_class_tag[12] += 1
else:
f_test.write(label + '\t' + content + '\n')
test_class_tag[12] += 1
else:
blank_tag += 1
process_line = txtfile.readline()
txtfile.close()
print("有" + str(blank_tag) + "个文书的行业标记为空!")
print("train:")
print(train_class_tag)
train_tag_total =0
for i_total in train_class_tag:
train_tag_total += i_total
train_class_tag_distribute = []
for i in train_class_tag:
train_class_tag_distribute.append((i / train_tag_total) * 100)
print("分布:")
print(train_class_tag_distribute)
print("val:")
print(val_class_tag)
val_tag_total = 0
for i_total in val_class_tag:
val_tag_total += i_total
val_class_tag_distribute = []
for i in val_class_tag:
val_class_tag_distribute.append((i / val_tag_total) * 100)
print("分布:")
print(val_class_tag_distribute)
print("test:")
print(test_class_tag)
test_tag_total = 0
for i_total in test_class_tag:
test_tag_total += i_total
test_class_tag_distribute = []
for i in test_class_tag:
test_class_tag_distribute.append((i / test_tag_total) * 100)
print("分布:")
print(test_class_tag_distribute)
f_train.close()
f_test.close()
f_val.close()
if __name__ == '__main__':
categories = [
"class1",
"class2",
"class3",
"class4",
"class5",
"class6",
"class7",
"class8",
"class9",
"class10",
"class11",
"class12",
"class13"
]
save_file_stratified('../data/qwdata/shuffle-try3/classified_table_ms.txt', '../data/qwdata/ms-ygscplusssdqw',categories)

后面可以看到类别划分


这里要注意的一点是:这是我早期写的文章,需要注意的一点是,我们通常在训练集和验证集上做分层采样即可,测试集最好保持原样不要动。

python 多分类任务中按照类别分层采样的更多相关文章

  1. map集合中取出分类优先级最高的类别名称

    import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.Map ...

  2. 如何使用Python在Kaggle竞赛中成为Top15

    如何使用Python在Kaggle竞赛中成为Top15 Kaggle比赛是一个学习数据科学和投资时间的非常的方式,我自己通过Kaggle学习到了很多数据科学的概念和思想,在我学习编程之后的几个月就开始 ...

  3. Objective-C /iphone开发基础:分类(category,又称类别)

    在c++中我们可以多继承来实现代码复用和封装使程序更加简练.在objective-c中只能单继承,不能多继承,那么除了协议protocol之外,我们可以实现类似多继承的一个方法就是,分类(catego ...

  4. [置顶] Objective-C,/,ios,/iphone开发基础:分类(category,又称类别)

    在c++中我们可以多继承来实现代码复用和封装使程序更加简练.在objective-c中只能单继承,不能多继承,那么除了协议protocol之外,我们可以实现类似多继承的一个方法就是,分类(catego ...

  5. python入门-分类和回归各种初级算法

    引自:http://www.cnblogs.com/taichu/p/5251332.html ########################### #说明: # 撰写本文的原因是,笔者在研究博文“ ...

  6. 13、Selenium+python+API分类总结

    Selenium+python+API分类总结 http://selenium-python.readthedocs.org/index.html 分类 方法 方法描述 客户端操作 __init__( ...

  7. [ios]objective-c中Category类别(扩展类)专题总结

    本文转载至 http://yul100887.blog.163.com/blog/static/20033613520126333344127/   objective-c类别的作用?通过类别的方式, ...

  8. [技术博客]Pyqt中View类别容器和Widget类别容器的区别

    Pyqt中View类别容器和Widget类别容器的区别 简介 在beta迭代中,我们选择用pyqt5来重写alpha迭代中使用tkinter库编写的界面. ​ 按钮之类的与tkiner使用无异,在显示 ...

  9. 【分类问题中模型的性能度量(二)】超强整理,超详细解析,一文彻底搞懂ROC、AUC

    文章目录 1.背景 2.ROC曲线 2.1 ROC名称溯源(选看) 2.2 ROC曲线的绘制 3.AUC(Area Under ROC Curve) 3.1 AUC来历 3.2 AUC几何意义 3.3 ...

随机推荐

  1. 【Leetcode_easy】892. Surface Area of 3D Shapes

    problem 892. Surface Area of 3D Shapes 题意:感觉不清楚立方体是如何堆积的,所以也不清楚立方体之间是如何combine的.. Essentially, compu ...

  2. Cas(04)——更改认证方式

    在Cas Server的WEB-INF目录下有一个deployerConfigContext.xml文件,该文件是基于Spring的配置文件,里面存放的内容常常是部署人员需要修改的内容.其中认证方式也 ...

  3. Python标准库: functools (cmp_to_key, lru_cache, total_ordering, partial, partialmethod, reduce, singledispatch, update_wrapper, wraps)

    functools模块处理的对象都是其他的函数,任何可调用对象都可以被视为用于此模块的函数. 1. functools.cmp_to_key(func) 因为Python3不支持比较函数,cmp_to ...

  4. ZooKeeper 相关问题

    [为什么部署个数是奇数个?] zookeeper有这样一个特性:集群中只要有过半的机器是正常工作的,那么整个集群对外就是可用的.即 2n 个机器的集群,最多可以容忍 n-1 个机器不可用,这个容忍度与 ...

  5. hdu 2476 题解

    题目 题意 给出两个字符串 $ s1,s2 $,每次操作可以使一段连续的子串全变成一个字母,问最少多少次操作可以使 $ s1 $ 变为 $ s2 $. 例如 $ zzzzzfzzzzz $,长度为 $ ...

  6. 转换器1:ThinkPhp模板转Flask模板

    Template Converter 网上的PHP资源很多,项目要用Python,所以想起做一个模板转换器,从ThinkPhp转成Flask的Jinja模板. 直接指定两个目录,将目录下的模板文件转换 ...

  7. 【自学系列一】HTML5大前端学习路线+视频教程(完整版)

    今年,本公司全新发布了囊括Java.HTML5前端.大数据.Python爬虫.全链UI设计.软件测试.Unity 3D.Go语言等多个技术方向的全套视频. 面对这么多的知识点,有的盆友就麻爪了…… 我 ...

  8. 第二次用map23333

    度熊所居住的 D 国,是一个完全尊重人权的国度.以至于这个国家的所有人命名自己的名字都非常奇怪.一个人的名字由若干个字符组成,同样的,这些字符的全排列的结果中的每一个字符串,也都是这个人的名字.例如, ...

  9. 在论坛中出现的比较难的sql问题:39(动态行转列 动态日期列问题)

    原文:在论坛中出现的比较难的sql问题:39(动态行转列 动态日期列问题) 最近,在论坛中,遇到了不少比较难的sql问题,虽然自己都能解决,但发现过几天后,就记不起来了,也忘记解决的方法了. 所以,觉 ...

  10. 在论坛中出现的比较难的sql问题:21(递归问题 检索某个节点下所有叶子节点)

    原文:在论坛中出现的比较难的sql问题:21(递归问题 检索某个节点下所有叶子节点) 最近,在论坛中,遇到了不少比较难的sql问题,虽然自己都能解决,但发现过几天后,就记不起来了,也忘记解决的方法了. ...