diff --git a/moa/src/main/java/moa/classifiers/core/attributeclassobservers/FastNominalAttributeClassObserver.java b/moa/src/main/java/moa/classifiers/core/attributeclassobservers/FastNominalAttributeClassObserver.java new file mode 100644 index 000000000..18d466152 --- /dev/null +++ b/moa/src/main/java/moa/classifiers/core/attributeclassobservers/FastNominalAttributeClassObserver.java @@ -0,0 +1,173 @@ +/* + * FastNominalAttributeClassObserver.java + * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand + * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ +package moa.classifiers.core.attributeclassobservers; + +import moa.classifiers.core.AttributeSplitSuggestion; +import moa.classifiers.core.conditionaltests.NominalAttributeBinaryTest; +import moa.classifiers.core.conditionaltests.NominalAttributeMultiwayTest; +import moa.classifiers.core.splitcriteria.SplitCriterion; +import moa.core.DoubleVector; +import moa.core.ObjectRepository; +import moa.core.Utils; +import moa.options.AbstractOptionHandler; +import moa.tasks.TaskMonitor; + +import java.util.HashMap; +import java.util.Map; + + +/** + * Class for observing the class data distribution for a nominal attribute. + * This observer monitors the class distribution of a given attribute. + * Used in naive Bayes and decision trees to monitor data statistics on leaves. + * + * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) + * @author Eugene Kamenev (eugene.kamenev@gmail.com) + * @version $Revision: 7 $ + */ +public class FastNominalAttributeClassObserver extends AbstractOptionHandler implements DiscreteAttributeClassObserver { + + private static final long serialVersionUID = 1L; + + protected double totalWeightObserved = 0.0; + + protected double missingWeightObserved = 0.0; + + protected Map> attValDistPerClassCount = new HashMap<>(); + + protected Map classTotalCount = new HashMap<>(); + + protected Map maxAttrValue = new HashMap<>(); + + @Override + public void observeAttributeClass(double attVal, int classVal, double weight) { + if (Utils.isMissingValue(attVal)) { + this.missingWeightObserved += weight; + } else { + int attValInt = (int) attVal; + Map valDistCount = this.attValDistPerClassCount.computeIfAbsent(classVal, k -> new HashMap<>()); + // update distribution count + valDistCount.put(attValInt, valDistCount.getOrDefault(attValInt, 0.0) + weight); + Integer maxValue = this.maxAttrValue.get(classVal); + if (maxValue == null || attVal > maxValue) { + // update max attribute value + this.maxAttrValue.put(classVal, attValInt); + } + // update the total count for the class + this.classTotalCount.put(classVal, this.classTotalCount.getOrDefault(classVal, 0.0) + weight); + } + this.totalWeightObserved += weight; + } + + @Override + public double probabilityOfAttributeValueGivenClass(double attVal, int classVal) { + Map obs = this.attValDistPerClassCount.get(classVal); + Double sumCounts = this.classTotalCount.getOrDefault(classVal, 0.0); + Integer max = this.maxAttrValue.getOrDefault(classVal, (int) attVal) + 1; + return obs != null ? (obs.getOrDefault((int) attVal, 0.0) + 1.0) / (sumCounts + max) : 0.0; + } + + public double[][] getClassDistsResultingFromMultiwaySplit(int maxAttValsObserved) { + DoubleVector[] resultingDists = new DoubleVector[maxAttValsObserved]; + for (int i = 0; i < resultingDists.length; i++) { + resultingDists[i] = new DoubleVector(); + } + for (Map.Entry> entry : this.attValDistPerClassCount.entrySet()) { + int classVal = entry.getKey(); + Map attValDistCount = entry.getValue(); + for (int j = 0; j < maxAttValsObserved; j++) { + resultingDists[j].addToValue(classVal, attValDistCount.getOrDefault(j, 0.0)); + } + } + double[][] distributions = new double[maxAttValsObserved][]; + for (int i = 0; i < distributions.length; i++) { + distributions[i] = resultingDists[i].getArrayRef(); + } + return distributions; + } + + public double[][] getClassDistsResultingFromBinarySplit(int valIndex) { + DoubleVector equalsDist = new DoubleVector(); + DoubleVector notEqualDist = new DoubleVector(); + for (Map.Entry> entry : this.attValDistPerClassCount.entrySet()) { + int classVal = entry.getKey(); + Map attValDistCount = entry.getValue(); + double count = attValDistCount.getOrDefault(valIndex, 0.0); + equalsDist.addToValue(classVal, count); + notEqualDist.addToValue(classVal, this.classTotalCount.get(classVal) - count); + } + return new double[][]{equalsDist.getArrayRef(), + notEqualDist.getArrayRef()}; + } + + @Override + public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(SplitCriterion criterion, double[] preSplitDist, + int attIndex, boolean binaryOnly) { + AttributeSplitSuggestion bestSuggestion = null; + int maxAttValsObserved = 0; + for (Integer max : this.maxAttrValue.values()) { + if (max > maxAttValsObserved) { + maxAttValsObserved = max + 1; + } + } + if (!binaryOnly) { + double[][] postSplitDists = getClassDistsResultingFromMultiwaySplit(maxAttValsObserved); + double merit = criterion.getMeritOfSplit(preSplitDist, + postSplitDists); + bestSuggestion = new AttributeSplitSuggestion( + new NominalAttributeMultiwayTest(attIndex), postSplitDists, + merit); + } + for (int valIndex = 0; valIndex < maxAttValsObserved; valIndex++) { + double[][] postSplitDists = getClassDistsResultingFromBinarySplit(valIndex); + double merit = criterion.getMeritOfSplit(preSplitDist, + postSplitDists); + if ((bestSuggestion == null) || (merit > bestSuggestion.merit)) { + bestSuggestion = new AttributeSplitSuggestion( + new NominalAttributeBinaryTest(attIndex, valIndex), + postSplitDists, merit); + } + } + return bestSuggestion; + } + + @Override + public void observeAttributeTarget(double v, double v1) { + throw new UnsupportedOperationException("Not supported yet."); + } + + @Override + protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { + // TODO Auto-generated method stub + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + // TODO Auto-generated method stub + } + + public double totalWeightOfClassObservations() { + return this.totalWeightObserved; + } + + public double weightOfObservedMissingValues() { + return this.missingWeightObserved; + } +}