-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCrossValidation.java
73 lines (55 loc) · 2.02 KB
/
CrossValidation.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import com.sun.org.apache.bcel.internal.generic.ClassGen;
public class CrossValidation {
/*
* Returns the k-fold cross validation score of classifier clf on training data.
*/
public static double kFoldScore(Classifier clf, List<Instance> trainData, int k, int v) {
List<List<Instance>> subsets = new ArrayList<List<Instance>>(); //k folds
List<Instance> trainSubSet = new ArrayList<Instance>();
//List<Instance> testSubSet = new ArrayList<Instance>(); //subset used to test
int sizeSubSet = trainData.size() / k; //num instances in each fold
//int curr = 0;
for(int i = 0; i < k;i++){
//initialize k Instance lists for subset
subsets.add(new ArrayList<Instance>());
}
//split the input data into k folds
int i = 0;
while(i <trainData.size()){
subsets.get(i/sizeSubSet).add(trainData.get(i));
i++;
}
List<Double> acc = new ArrayList<Double>(); //holds accuracy for each fold
for( i = 0; i < k; i++){ //for each fold
trainSubSet = new ArrayList<Instance>();
clf = new NaiveBayesClassifier();
List<Instance> test = subsets.get(i);
//clf.train(input, v);
for(int j = 0; j < subsets.size();j++){
if(j != i){
for(int l = 0; l < subsets.get(j).size(); l++){
trainSubSet.add(subsets.get(j).get(l));
}
}
}
clf.train(trainSubSet, v);
int correct = 0;
for(Instance currInst: test){
ClassifyResult result = clf.classify(currInst.words);
if(currInst.label == result.label){
correct++;
}
}
acc.add((double)correct/(double)test.size());
}
double score = 0;
for(int num = 0; num < acc.size();num++){
//sSystem.out.println(acc.get(num));
score+=acc.get(num);
}
return (score/(double)k);
}
}