Generating cross-validation folds (Java approach)

文献:

http://weka.wikispaces.com/Generating+cross-validation+folds+%28Java+approach%29

This article describes how to generate train/test splits for cross-validation using
the Weka API directly.

The following variables are given:

Instances data =  ...;   // contains the full dataset we wann
create train/test sets from

int seed = ...;          // the seed for
randomizing the data

int folds = ...;         // the number of
folds to generate, >=2

 Randomize the data

First, randomize
your data:

Random rand = new Random(seed);   // create seeded number generator

randData = new
Instances(data);   // create copy of
original data

randData.randomize(rand);         // randomize data
with number generator

In case your data
has a nominal class and you wanna perform stratified cross-validation:

randData.stratify(folds);

 Generate the folds

 Single run

Next thing that we
have to do is creating the train and the test set:

for
(int n = 0; n < folds; n++) {

Instances train = randData.trainCV(folds, n);

Instances test = randData.testCV(folds, n);

// further
processing, classification, etc.

...

}

Note:

  • the above code is used by the weka.filters.supervised.instance.StratifiedRemoveFolds filter
  • the weka.classifiers.Evaluation class and the Explorer/Experimenter
    would use this method for obtaining the train set:

Instances train = randData.trainCV(folds, n, rand);

 Multiple runs

The example above
only performs one run of a cross-validation. In case you want to run 10 runs of
10-fold cross-validation, use the following loop:

Instances data = ...;  // our dataset again, obtained from
somewhere

int runs = 10;

for
(int i = 0; i < runs; i++) {

seed = i+1;  // every run gets a
new, but defined seed value

// see:
randomize the data

...

// see: generate
the folds

...

}

一个简单的小实验:

继续对上一节中的红酒和白酒进行分类。分类器没有变化,只是增加了重复试验过程

package assignment2;

import weka.core.Instances;

import weka.core.converters.ConverterUtils.DataSource;

import weka.core.Utils;

import weka.classifiers.Classifier;

import weka.classifiers.Evaluation;

import weka.classifiers.trees.J48;

import weka.filters.Filter;

import weka.filters.unsupervised.attribute.Remove;

import java.io.FileReader;

import java.util.Random;

public class cv_rw {

    public static Instances getFileInstances(String filename) throws Exception{

       FileReader frData =new FileReader(filename);

       Instances data = new Instances(frData);

       int length= data.numAttributes();

       String[] options = new String[2];

       options[0]="-R";

       options[1]=Integer.toString(length);

       Remove remove =new Remove();

       remove.setOptions(options);

       remove.setInputFormat(data);

       Instances newData= Filter.useFilter(data, remove);

       return newData;

    }

    public static void main(String[] args) throws Exception {

        // loads data and set class index

       Instances data = getFileInstances("D://Weka_tutorial//WineQuality//RedWhiteWine.arff");

//     System.out.println(instances);

       data.setClassIndex(data.numAttributes()-1);

        // classifier

//      String[] tmpOptions;

//      String classname;

//      tmpOptions     = Utils.splitOptions(Utils.getOption("W", args));

//      classname      = tmpOptions[0];

//      tmpOptions[0]  = "";

//      Classifier cls = (Classifier) Utils.forName(Classifier.class, classname, tmpOptions);

//

//      // other options

//      int runs  = Integer.parseInt(Utils.getOption("r", args));//重复试验

//      int folds = Integer.parseInt(Utils.getOption("x", args));

       int runs=1;

       int folds=10;

       J48 j48= new J48();

//     j48.buildClassifier(instances);

        // perform cross-validation

        for (int i = 0; i < runs; i++) {

          // randomize data

          int seed = i + 1;

          Random rand = new Random(seed);

          Instances randData = new Instances(data);

          randData.randomize(rand);

//        if (randData.classAttribute().isNominal())    //没看懂这里什么意思,往高手回复,万分感谢

//          randData.stratify(folds);

          Evaluation eval = new Evaluation(randData);

          for (int n = 0; n < folds; n++) {

            Instances train = randData.trainCV(folds, n);

            Instances test = randData.testCV(folds, n);

            // the above code is used by the StratifiedRemoveFolds filter, the

            // code below by the Explorer/Experimenter:

            // Instances train = randData.trainCV(folds, n, rand);

            // build and evaluate classifier

            Classifier j48Copy = Classifier.makeCopy(j48);

            j48Copy.buildClassifier(train);

            eval.evaluateModel(j48Copy, test);

          }

          // output evaluation

          System.out.println();

          System.out.println("=== Setup run " + (i+1) + " ===");

          System.out.println("Classifier: " + j48.getClass().getName());

          System.out.println("Dataset: " + data.relationName());

          System.out.println("Folds: " + folds);

          System.out.println("Seed: " + seed);

          System.out.println();

          System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation run " + (i+1) + "===", false));

        }

    }

}

运行程序得到实验结果:

=== Setup run 1 ===

Classifier:
weka.classifiers.trees.J48

Dataset:
RedWhiteWine-weka.filters.unsupervised.instance.Randomize-S42-weka.filters.unsupervised.instance.Randomize-S42-weka.filters.unsupervised.attribute.Remove-R13

Folds: 10

Seed: 1

=== 10-fold Cross-validation run
1===

Correctly Classified Instances        6415               98.7379 %

Incorrectly Classified
Instances        82                1.2621 %

Kappa statistic                          0.9658

Mean absolute error                      0.0159

Root mean squared error                  0.1109

Relative absolute error                  4.2898 %

Root relative squared error             25.7448 %

Total Number of Instances             6497

使用weka进行Cross-validation实验的更多相关文章

  1. 交叉验证(cross validation)

    转自:http://www.vanjor.org/blog/2010/10/cross-validation/ 交叉验证(Cross-Validation): 有时亦称循环估计, 是一种统计学上将数据 ...

  2. Cross Validation(交叉验证)

    交叉验证(Cross Validation)方法思想 Cross Validation一下简称CV.CV是用来验证分类器性能的一种统计方法. 思想:将原始数据(dataset)进行分组,一部分作为训练 ...

  3. S折交叉验证(S-fold cross validation)

    S折交叉验证(S-fold cross validation) 觉得有用的话,欢迎一起讨论相互学习~Follow Me 仅为个人观点,欢迎讨论 参考文献 https://blog.csdn.net/a ...

  4. 交叉验证(Cross Validation)简介

    参考    交叉验证      交叉验证 (Cross Validation)刘建平 一.训练集 vs. 测试集 在模式识别(pattern recognition)与机器学习(machine lea ...

  5. cross validation笔记

    preface:做实验少不了交叉验证,平时常用from sklearn.cross_validation import train_test_split,用train_test_split()函数将数 ...

  6. cross validation

    k-folder cross-validation:k个子集,每个子集均做一次测试集,其余的作为训练集.交叉验证重复k次,每次选择一个子集作为测试集,并将k次的平均交叉验证识别正确率作为结果.优点:所 ...

  7. 交叉验证(Cross Validation)方法思想简介

      以下简称交叉验证(Cross Validation)为CV.CV是用来验证分类器的性能一种统计分析方法,基本思想是把在某种意义下将原始数据(dataset)进行分组,一部分做为训练集(train ...

  8. 交叉验证(Cross Validation)原理小结

    交叉验证是在机器学习建立模型和验证模型参数时常用的办法.交叉验证,顾名思义,就是重复的使用数据,把得到的样本数据进行切分,组合为不同的训练集和测试集,用训练集来训练模型,用测试集来评估模型预测的好坏. ...

  9. 交叉验证 Cross validation

    来源:CSDN: boat_lee 简单交叉验证 hold-out cross validation 从全部训练数据S中随机选择s个样例作为训练集training set,剩余的作为测试集testin ...

  10. Cross Validation done wrong

    Cross Validation done wrong Cross validation is an essential tool in statistical learning 1 to estim ...

随机推荐

  1. PAT-乙级-1042. 字符统计(20)

    1042. 字符统计(20) 时间限制 400 ms 内存限制 65536 kB 代码长度限制 8000 B 判题程序 Standard 作者 CHEN, Yue 请编写程序,找出一段给定文字中出现最 ...

  2. uva 1377

    比较不错的一个题,关键是理解状态转移 #include<algorithm> #include<cstdio> #include<cstring> #include ...

  3. uva 10892

    试了一下纯暴力  结果过了 无话可说  应该有更好的方法...... /**************************************************************** ...

  4. spoj 2148

    看似很水  却wa了好多遍   spoj上果然没有一下可以水过去的题....... #include<cstdio> #include<cstring> #include< ...

  5. FireFly 服务端 Unity3D黑暗世界 客户端 问题

    启动服务端成功截图: 连接成功截图: 测试服务端是否启动成功: 在Web输入:http://localhost:11009/  按回车 (查看cmd启动的服务端 是否多出如下显示) 服务端启动成功.P ...

  6. Java 垃圾回收机制

    1.delete是C++里面用于释放内存的运算符,而不是Java. 2.当发现某个对象的引用计数为0时,就将对象列入待回收列表中,并不是马上予以销毁. 3.System.gc()仅仅是一个回收请求,J ...

  7. 设置window窗口的背景色为护眼色

    win7设置:桌面右键 -> 个性化 -> 窗口颜色 -> 高级外观设置 ->  '项目'下拉菜单 ->  '窗口'

  8. Android:控件ListView列表项与适配器结合使用

    Listview是用来展示一些重复性的数据用的,比如一些列表集合数据展示到手机,需要适配器作为载体获取数据,最后将数据填充到布局. ListView里面的每个子项Item可以使一个字符串,也可以是一个 ...

  9. GCC警告选项例解

    程序员是追求完美的一族,即使是一般的程序员大多也都不想看到自己的程序中有甚至那么一点点的瑕疵.遇到任意一条编译器警告都坚决不放过.有人会说:我们可以使用比编译器更加严格的静态代码检查工具,如splin ...

  10. ColorBox常见问题

    发现colorbox官方网站的troubleshoot写的比较好,转载一下. 1,flash覆盖colorbox: This is not a ColorBox specific problem, b ...