diff --git a/.gitignore b/.gitignore index 700ce67..768bf87 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,7 @@ project/plugins/project/ #Eclipse specific .classpath .project + +#IDEA specific +.idea +.idea_modules diff --git a/src/main/scala/ml/tree/DecisionTree.scala b/src/main/scala/ml/tree/DecisionTree.scala new file mode 100644 index 0000000..7a6f1ce --- /dev/null +++ b/src/main/scala/ml/tree/DecisionTree.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ml.tree +import javax.naming.OperationNotSupportedException +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.classification.ClassificationModel +import org.apache.spark.SparkContext +import org.apache.spark.util.StatCounter +import org.apache.spark.Logging +import ml.tree.impurity.{Variance, Entropy, Gini, Impurity} +import ml.tree.strategy.Strategy +import ml.tree.split.{SplitPredicate, Split} +import org.apache.spark.broadcast.Broadcast +import scala.Some +import ml.tree.strategy.Strategy +import ml.tree.split.Split +import ml.tree.node._ +import ml.tree.Metrics._ +import scala.Some +import ml.tree.strategy.Strategy +import ml.tree.split.Split + + +/* + * Class for building the Decision Tree model. Should be used for both classification and regression tree. + */ +class DecisionTree ( + val input: RDD[(Double, Array[Double])], //input RDD + val maxDepth: Int, // depth of the tree + val numSplitPredicates: Int, // number of bins per features + val fraction: Double, // fraction of the data to be used for performing quantile calculation + val strategy: Strategy, // classification or regression + val impurity: Impurity, // impurity calculation strategy (variance, gini, entropy, etc.) + val sparkContext : SparkContext) { + + //Calculating length of the features + val featureLength = input.first._2.length + println("feature length = " + featureLength) + + //Sampling a fraction of the input RDD + val sampledData = input.sample(false, fraction, 42).cache() + println("sampled data size for quantile calculation = " + sampledData.count) + + //Sorting the sampled data along each feature and storing it for quantile calculation + println("started sorting sampled data") + val sortedSampledFeatures = { + val sortedFeatureArray = new Array[Array[Double]](featureLength) + 0 until featureLength foreach { + i => sortedFeatureArray(i) = sampledData.map(x => x._2(i) -> None).sortByKey(true).map(_._1).collect() + } + sortedFeatureArray + } + println("finished sorting sampled data") + + val numSamples = sampledData.count + println("num samples = " + numSamples) + + // Calculating the index to jump to find the quantile points + val stride = scala.math.max(numSamples / numSplitPredicates, 1) + println("stride = " + stride) + + //Calculating all possible splits for the features + println("calculating all possible splits for features") + val allSplitsList = for { + featureIndex <- 0 until featureLength; + index <- stride until numSamples - 1 by stride + } yield createSplit(featureIndex, index) + println("finished calculating all possible splits for features") + + //Remove duplicate splits. Especially help for one-hot encoded categorical variables. + val allSplits = sparkContext.broadcast(allSplitsList.toSet) + + //for (split <- allSplits) yield println(split) + + /* + * Find the exact value using feature index and index into the sorted features + */ + def valueAtRDDIndex(featuresIndex: Long, index: Long): Double = { + sortedSampledFeatures(featuresIndex.toInt)(index.toInt) + } + + /* + * Create splits using feature index and index into the sorted features + */ + def createSplit(featureIndex: Int, index: Long): Split = { + new Split(featureIndex, valueAtRDDIndex(featureIndex, index)) + } + + def buildTree(): Node = { + + println("building decision tree") + + strategy match { + case Strategy("Classification") => new TopClassificationNode(input, allSplits, impurity, strategy, maxDepth) + case Strategy("Regression") => { + val count = input.count + //TODO: calculate mean and variance together + val variance = input.map(x => x._1).variance + val mean = input.map(x => x._1).mean + val nodeStats = new NodeStats(count = Some(count), variance = Some(variance), mean = Some(mean)) + new TopRegressionNode(input, nodeStats,allSplits, impurity, strategy, maxDepth) + } + } + } + +} + + +object DecisionTree { + def train( + input: RDD[(Double, Array[Double])], + numSplitPredicates: Int, + strategy: Strategy, + impurity: Impurity, + maxDepth : Int, + fraction : Double, + sparkContext : SparkContext): Option[NodeModel] = { + val tree = new DecisionTree( + input = input, + numSplitPredicates = numSplitPredicates, + strategy = strategy, + impurity = impurity, + maxDepth = maxDepth, + fraction = fraction, + sparkContext = sparkContext) + .buildTree + .extractModel + + println("calculating performance on training data") + val trainingError = { + strategy match { + case Strategy("Classification") => accuracyScore(tree, input) + case Strategy("Regression") => meanSquaredError(tree, input) + } + } + println("accuracy = " + trainingError) + + tree + } +} \ No newline at end of file diff --git a/src/main/scala/ml/tree/Metrics.scala b/src/main/scala/ml/tree/Metrics.scala new file mode 100644 index 0000000..225b636 --- /dev/null +++ b/src/main/scala/ml/tree/Metrics.scala @@ -0,0 +1,31 @@ +package ml.tree + +import org.apache.spark.SparkContext._ +import ml.tree.node.NodeModel +import org.apache.spark.rdd.RDD + +/* +Helper methods for measuring performance of ML algorithms + */ +object Metrics { + + //TODO: Make these generic MLTable metrics. + def accuracyScore(tree : Option[NodeModel], data : RDD[(Double, Array[Double])]) : Double = { + if (tree.isEmpty) return 1 //TODO: Throw exception + val correctCount = data.filter(y => tree.get.predict(y._2) == y._1).count() + val count = data.count() + println("correct prediction count = " + correctCount) + println("data count = " + count) + correctCount.toDouble / count + } + + //TODO: Make these generic MLTable metrics + def meanSquaredError(tree : Option[NodeModel], data : RDD[(Double, Array[Double])]) : Double = { + if (tree.isEmpty) return 1 //TODO: Throw exception + val meanSumOfSquares = data.map(y => (tree.get.predict(y._2) - y._1)*(tree.get.predict(y._2) - y._1)).mean() + println("meanSumOfSquares = " + meanSumOfSquares) + meanSumOfSquares + } + + +} diff --git a/src/main/scala/ml/tree/README.md b/src/main/scala/ml/tree/README.md new file mode 100644 index 0000000..56a6f9f --- /dev/null +++ b/src/main/scala/ml/tree/README.md @@ -0,0 +1,38 @@ +#Decision Tree +Decision tree classifiers are both popular supervised learning algorithms and also building blocks for other ensemble learning algorithms such as random forests, boosting, etc. This document discusses its implementation in the Spark project. + +#Usage +``` +ml.tree.TreeRunner +[slices] +--strategy +--trainDataDir path +--testDataDir path +[--maxDepth num] +[--impurity ] +[--samplingFractionForSplitCalculation num] +``` + +#Example +``` +sbt/sbt "run-main ml.tree.TreeRunner local[2] --strategy Classification +--trainDataDir ../train_data --testDataDir ../test_data +--maxDepth 1 --impurity Gini --samplingFractionForSplitCalculation 1 +``` + +This command will create a decision tree model using the training data in the *trainDataDir* and calculate test error using the data in the *testDataDir*. The mis-classification error is calculated for a Classification *strategy* and mean squared error is calculated for the Regression *strategy*. + +#Performance testing +To be done + +#Improvements +* Print to dot files +* Unit tests +* Change fractions to quantiles +* Add logging +* Move metrics to a different package + +#Extensions +* Extremely randomized trees +* Random forest +* Boosting diff --git a/src/main/scala/ml/tree/TreeRunner.scala b/src/main/scala/ml/tree/TreeRunner.scala new file mode 100644 index 0000000..3343180 --- /dev/null +++ b/src/main/scala/ml/tree/TreeRunner.scala @@ -0,0 +1,93 @@ +package ml.tree + +import org.apache.spark.SparkContext._ +import org.apache.spark.{Logging, SparkContext} +import ml.tree.impurity.{Variance, Entropy, Gini} +import ml.tree.strategy.Strategy + +import ml.tree.node.NodeModel +import org.apache.spark.rdd.RDD + +import ml.tree.Metrics.{accuracyScore,meanSquaredError} + +object TreeRunner extends Logging { + val usage = """ + Usage: TreeRunner [slices] --strategy --trainDataDir path --testDataDir path [--maxDepth num] [--impurity ] [--samplingFractionForSplitCalculation num] + """ + + def main(args: Array[String]) { + + if (args.length < 2) { + System.err.println(usage) + System.exit(1) + } + + /**START Experimental*/ + System.setProperty("spark.cores.max", "8") + /**END Experimental*/ + val sc = new SparkContext(args(0), "Decision Tree Runner", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + + val arglist = args.toList.drop(1) + type OptionMap = Map[Symbol, Any] + + def nextOption(map : OptionMap, list: List[String]) : OptionMap = { + def isSwitch(s : String) = (s(0) == '-') + list match { + case Nil => map + case "--strategy" :: string :: tail => nextOption(map ++ Map('strategy -> string), tail) + case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string), tail) + case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string), tail) + case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail) + case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail) + case "--samplingFractionForSplitCalculation" :: string :: tail => nextOption(map ++ Map('samplingFractionForSplitCalculation -> string), tail) + case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail) + case option :: tail => println("Unknown option "+option) + exit(1) + } + } + val options = nextOption(Map(),arglist) + println(options) + //TODO: Add check for acceptable string inputs + + val trainData = TreeUtils.loadLabeledData(sc, options.get('trainDataDir).get.toString) + val strategyStr = options.get('strategy).get.toString + val impurityStr = options.getOrElse('impurity,"Gini").toString + val impurity = { + impurityStr match { + case "Gini" => Gini + case "Entropy" => Entropy + case "Variance" => Variance + } + } + val maxDepth = options.getOrElse('maxDepth,"1").toString.toInt + val fraction = options.getOrElse('samplingFractionForSplitCalculation,"1.0").toString.toDouble + + val tree = DecisionTree.train( + input = trainData, + numSplitPredicates = 1000, + strategy = new Strategy(strategyStr), + impurity = impurity, + maxDepth = maxDepth, + fraction = fraction, + sparkContext = sc) + println(tree) + //println("prediction = " + tree.get.predict(Array(1.0, 2.0))) + + println("loading test data") + val testData = TreeUtils.loadLabeledData(sc, options.get('testDataDir).get.toString) + + println("calculating performance of test data") + val testError = { + strategyStr match { + case "Classification" => accuracyScore(tree, testData) + case "Regression" => meanSquaredError(tree, testData) + } + } + println("accuracy = " + testError) + + } + + +} diff --git a/src/main/scala/ml/tree/TreeUtils.scala b/src/main/scala/ml/tree/TreeUtils.scala new file mode 100644 index 0000000..29abc4e --- /dev/null +++ b/src/main/scala/ml/tree/TreeUtils.scala @@ -0,0 +1,37 @@ +package ml.tree + +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD + + +//TODO: Deprecate this when we find something equivalent in ml utils +/** + * Helper methods to load and save data + * Data format: + * , ... + * where , are feature values in Double and is the corresponding label as Double. + */ +object TreeUtils { + + /** + * @param sc SparkContext + * @param dir Directory to the input data files. + * @return An RDD of tuples. For each tuple, the first element is the label, and the second + * element represents the feature values (an array of Double). + */ + def loadLabeledData(sc: SparkContext, dir: String): RDD[(Double, Array[Double])] = { + sc.textFile(dir).map { line => + val parts = line.trim().split(",") + val label = parts(0).toDouble + val features = parts.slice(1,parts.length).map(_.toDouble) + //val features = parts.slice(1, 30).map(_.toDouble) + (label, features) + } + } + + def saveLabeledData(data: RDD[(Double, Array[Double])], dir: String) { + val dataStr = data.map(x => x._1 + "," + x._2.mkString(" ")) + dataStr.saveAsTextFile(dir) + } + +} diff --git a/src/main/scala/ml/tree/design.md b/src/main/scala/ml/tree/design.md new file mode 100644 index 0000000..7864978 --- /dev/null +++ b/src/main/scala/ml/tree/design.md @@ -0,0 +1,83 @@ +#Tree design doc +Decision tree classifiers are both popular supervised learning algorithms and also building blocks for other ensemble learning algorithms such as random forests, boosting, etc. This document discusses its implementation in the Spark project. + +**The current design will be optimized for the scenario where all the data can be fit into the in-cluster memory.** + +##Algorithm +Decision tree classifier is formed by creating recursive binary partitions using the optimal splitting criterion that maximizes the information gain at each step. It handles both ordered (numeric) and unordered (categorial) features. + +###Identifying Split Predicates +The split predicates will be calculated by performing a single pass over the data at the start of the tree model building. The binning of the data can be performed using two techniques: + +1. Sorting the ordered features and finding the exact quantile points. Complexity: O(N*logN) * #features +2. Using an [approximate quantile calculation algorithm](http://infolab.stanford.edu/~manku/papers/99sigmod-unknown.pdf) cited by the PLANET paper. + +###Optimal Splitting Criterion +The splitting criterion is calculated using one of two popular criterion: + +1. [Gini impurity](http://en.wikipedia.org/wiki/Gini_coefficient) +2. [Entropy](http://en.wikipedia.org/wiki/Information_gain_in_decision_trees) + +Each split is stored in a model for future predictions. + +###Stopping criterion +There are various criterion that can be used to stop adding more levels to the tree. The first implementation will be kept simple and will use the following criteria : no further information gain can be achieved or the maximum depth has been reached. Once a stopping criteria is met, the current node is a leaf of the tree and updates the model with the distribution of the remaining classes at the node. + +###Prediction +To make a prediction, a new sample is run through the decision tree model till it arrives at a leaf node. Upon reaching the leaf node, a prediction is made using the distribution of the underlying samples. (typically, the distribution itself is the output) + +##Implementation + +###Code +For a consistent API, the training code will be consistent with the existing logistic regressions algorithms for supervised learning. + +The train() method will take be of the following format + + def train(input: RDD[(Double, Array[Double])]): DecisionTreeModel = {...} + def predict(testData: spark.RDD[Array[Double]]) = {...} + +All splitting criterion can be evaluated in parallel using the *map* operation. The *reduce* operation will select the best splitting criterion. The split criterion will create a *filter* that should be applied to the RDD at each node to derive the RDDs at the next node. + +The pseudocode is given below: + + def train(input: RDD[(Double, Array[Double])]): DecisionTreeModel = { + filterList = new List() + root = new Node() + buildTree(root,input,filterList) + } + + def buildTree(node : Node, input : RDD[(Double, Array[Double])]), filterList : List) : Tree = { + splits = find_possible_splits(input) + bestSplit = splits.map( split => calculateInformationGain(input, split)).reduce(_ max _) + if (bestSplit > threshold){ + leftRDD = RDD.filter(sample => sample.satisfiesSplit(bestSplit)) + rightRDD = RDD.filter(sample => sample.satisfiesSplit(bestSplit.invert())) + node.split = bestSplit + node.left = new Node() + node.right = new Node() + lefttree = buildTree(node.left,leftRDD,filterList.add(bestSplit)) + righttree = buildTree(node.right,rightRDD,filterList.add(bestSplit.invert())) + } + node + } + +###Testing + +#####Unit testing +As a standard programming practice, unit tests will be written to test the important building blocks. + +####Comparison with other libraries +There are several machine learning libraries in other languages. The scikit-learn library will be used a benchmark for functional tests. + +###Constraints ++ Two class labels -- The first implementation will support only binary labels. ++ Class weights -- Class weighting option (useful for highly unblanaced data) will not be supported ++ Sanity checks -- The input data sanity checks will not be performed. Ideally, a separate pre-processing step (that that is common to all ML algorithms) should handle this. + +## Future Work ++ Weights to handle unbalanced classes ++ Ensemble methods -- random forests, boosting, etc. + +##References +1. Hastie, Tibshirani, Friedman. The Elements of Statistical Learning: Data Mining, Inference, and Prediction. Springer 2009. +2. Biswanath, Herbach, Basu and Roberto. PLANET: Massively Parallel Learning of Tree Ensembles with MapReduce, VLDB 2009. \ No newline at end of file diff --git a/src/main/scala/ml/tree/impurity/Entropy.scala b/src/main/scala/ml/tree/impurity/Entropy.scala new file mode 100644 index 0000000..4c0e7e0 --- /dev/null +++ b/src/main/scala/ml/tree/impurity/Entropy.scala @@ -0,0 +1,18 @@ +package ml.tree.impurity + +object Entropy extends Impurity { + + def log2(x: Double) = scala.math.log(x) / scala.math.log(2) + + def calculate(c0: Double, c1: Double): Double = { + if (c0 == 0 || c1 == 0) { + 0 + } else { + val total = c0 + c1 + val f0 = c0 / total + val f1 = c1 / total + -(f0 * log2(f0)) - (f1 * log2(f1)) + } + } + + } diff --git a/src/main/scala/ml/tree/impurity/Gini.scala b/src/main/scala/ml/tree/impurity/Gini.scala new file mode 100644 index 0000000..ec349d3 --- /dev/null +++ b/src/main/scala/ml/tree/impurity/Gini.scala @@ -0,0 +1,12 @@ +package ml.tree.impurity + +object Gini extends Impurity { + + def calculate(c0 : Double, c1 : Double): Double = { + val total = c0 + c1 + val f0 = c0 / total + val f1 = c1 / total + 1 - f0*f0 - f1*f1 + } + + } diff --git a/src/main/scala/ml/tree/impurity/Impurity.scala b/src/main/scala/ml/tree/impurity/Impurity.scala new file mode 100644 index 0000000..8b9095b --- /dev/null +++ b/src/main/scala/ml/tree/impurity/Impurity.scala @@ -0,0 +1,56 @@ +package ml.tree.impurity + +import ml.tree.node.NodeStats +import ml.tree.split.Split + +trait Impurity { + + def calculateClassificationGain(split: Split, calculations : Map[(Split, String, Double),Long]): Double = { + val leftRddZeroCount = calculations.getOrElse((split,"left",0.0),0L).toDouble; + val rightRddZeroCount = calculations.getOrElse((split,"right",0.0),0L).toDouble; + val leftRddOneCount = calculations.getOrElse((split,"left",1.0),0L).toDouble; + val rightRddOneCount = calculations.getOrElse((split,"right",1.0),0L).toDouble; + val leftRddCount = leftRddZeroCount + leftRddOneCount; + val rightRddCount = rightRddZeroCount + rightRddOneCount; + val totalZeroCount = leftRddZeroCount + rightRddZeroCount; + val totalOneCount = leftRddOneCount + rightRddOneCount; + val totalCount = totalZeroCount + totalOneCount; + val gain = { + if (leftRddCount == 0 || rightRddCount == 0) 0 + else { + val topGini = calculate(totalZeroCount,totalOneCount) + val leftWeight = leftRddCount / totalCount + val leftGini = calculate(leftRddZeroCount,leftRddOneCount) * leftWeight + val rightWeight = rightRddCount / totalCount + val rightGini = calculate(rightRddZeroCount,rightRddOneCount) * rightWeight + topGini - leftGini - rightGini + } + } + gain + } + + def calculateRegressionGain(split: Split, calculations : Map[(Split, String),(Double, Double, Long)], nodeStats : NodeStats): (Double, NodeStats, NodeStats) = { + val topCount = nodeStats.count.get + val leftCount = calculations.getOrElse((split,"left"),(0,0,0L))._3 + val rightCount = calculations.getOrElse((split,"right"),(0,0,0L))._3 + if (leftCount == 0 || rightCount == 0){ + // No gain return values + //println("leftCount = " + leftCount + "rightCount = " + rightCount + " topCount = " + topCount) + (0, new NodeStats, new NodeStats) + } else{ + val topVariance = nodeStats.variance.get + val leftMean = calculations((split,"left"))._1 + val leftVariance = calculations((split,"left"))._2 + val rightMean = calculations((split,"right"))._1 + val rightVariance = calculations((split,"right"))._2 + //TODO: Check and if needed improve these toDouble conversions + val gain = topVariance - ((leftCount.toDouble / topCount) * leftVariance) - ((rightCount.toDouble/topCount) * rightVariance) + (gain, + new NodeStats(mean = Some(leftMean), variance = Some(leftVariance), count = Some(leftCount)), + new NodeStats(mean = Some(rightMean), variance = Some(rightVariance), count = Some(rightCount))) + } + } + + def calculate(c0 : Double, c1 : Double): Double + +} diff --git a/src/main/scala/ml/tree/impurity/Variance.scala b/src/main/scala/ml/tree/impurity/Variance.scala new file mode 100644 index 0000000..8dda877 --- /dev/null +++ b/src/main/scala/ml/tree/impurity/Variance.scala @@ -0,0 +1,7 @@ +package ml.tree.impurity + +import javax.naming.OperationNotSupportedException + +object Variance extends Impurity { + def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate") + } diff --git a/src/main/scala/ml/tree/node/Node.scala b/src/main/scala/ml/tree/node/Node.scala new file mode 100644 index 0000000..c179822 --- /dev/null +++ b/src/main/scala/ml/tree/node/Node.scala @@ -0,0 +1,44 @@ +package ml.tree.node + +import org.apache.spark.rdd.RDD +import ml.tree.split.SplitPredicate +import ml.tree.Metrics._ +import scala.Some + +/* + * Node trait as a template for implementing various types of nodes in the decision tree. + */ +trait Node { + + //Method for checking whether the class has any left/right child nodes. + def isLeaf: Boolean + + //Left/Right child nodes + def left: Node + def right: Node + + //Depth of the node from the top node + def depth: Int + + //RDD data as an input to the node + def data: RDD[(Double, Array[Double])] + + //List of split predicates applied to the base RDD thus far + def splitPredicates: List[SplitPredicate] + + // Split to arrive at the node + def splitPredicate: Option[SplitPredicate] + + //Extract model + def extractModel: Option[NodeModel] = { + //Add probability logic + if (!splitPredicate.isEmpty) { + Some(new NodeModel(splitPredicate, left.extractModel, right.extractModel, depth, isLeaf, Some(prediction))) + } + else { + Some(new NodeModel(None, None, None, depth, isLeaf, Some(prediction))) + } + } + //Prediction at the node + def prediction: Prediction +} diff --git a/src/main/scala/ml/tree/node/NodeModel.scala b/src/main/scala/ml/tree/node/NodeModel.scala new file mode 100644 index 0000000..7045a6e --- /dev/null +++ b/src/main/scala/ml/tree/node/NodeModel.scala @@ -0,0 +1,57 @@ +package ml.tree.node + +import org.apache.spark.mllib.classification.ClassificationModel +import org.apache.spark.rdd.RDD +import ml.tree.split.SplitPredicate +import ml.tree.Metrics._ + +/** + * The decision tree model class that + */ +class NodeModel( + val splitPredicate: Option[SplitPredicate], + val trueNode: Option[NodeModel], + val falseNode: Option[NodeModel], + val depth: Int, + val isLeaf: Boolean, + val prediction: Option[Prediction]) extends ClassificationModel { + + override def toString() = if (!splitPredicate.isEmpty) { + "[" + trueNode.get + "\n" + "[" + "depth = " + depth + ", split predicate = " + this.splitPredicate.get + ", predict = " + this.prediction + "]" + "]\n" + falseNode.get + } else { + "Leaf : " + "depth = " + depth + ", predict = " + prediction + ", isLeaf = " + isLeaf + } + + /** + * Predict values for the given data set using the model trained. + * + * @param testData RDD representing data points to be predicted + * @return RDD[Int] where each entry contains the corresponding prediction + */ + def predict(testData: RDD[Array[Double]]): RDD[Double] = { + testData.map { x => predict(x) } + } + + /** + * Predict values for a single data point using the model trained. + * + * @param testData array representing a single data point + * @return Int prediction from the trained model + */ + def predict(testData: Array[Double]): Double = { + + //TODO: Modify this logic to handle regression + val pred = prediction.get + if (this.isLeaf) { + if (pred.prob > 0.5) 1 else 0 + } else { + val spPred = splitPredicate.get + if (testData(spPred.split.feature) <= spPred.split.threshold) { + trueNode.get.predict(testData) + } else { + falseNode.get.predict(testData) + } + } + } + +} diff --git a/src/main/scala/ml/tree/node/NodeStats.scala b/src/main/scala/ml/tree/node/NodeStats.scala new file mode 100644 index 0000000..7387e90 --- /dev/null +++ b/src/main/scala/ml/tree/node/NodeStats.scala @@ -0,0 +1,10 @@ +package ml.tree.node + +class NodeStats( + val gini: Option[Double] = None, + val entropy: Option[Double] = None, + val mean: Option[Double] = None, + val variance: Option[Double] = None, + val count: Option[Long] = None) extends Serializable{ + override def toString = "variance = " + variance + "count = " + count + "mean = " + mean +} diff --git a/src/main/scala/ml/tree/node/Prediction.scala b/src/main/scala/ml/tree/node/Prediction.scala new file mode 100644 index 0000000..7417e20 --- /dev/null +++ b/src/main/scala/ml/tree/node/Prediction.scala @@ -0,0 +1,8 @@ +package ml.tree.node + +/* + * Class used to store the prediction values at each node of the tree. + */ +class Prediction(val prob: Double, val distribution: Map[Double, Double]) extends Serializable { + override def toString = { "probability = " + prob + ", distribution = " + distribution } +} diff --git a/src/main/scala/ml/tree/node/decisionNodes.scala b/src/main/scala/ml/tree/node/decisionNodes.scala new file mode 100644 index 0000000..5b58444 --- /dev/null +++ b/src/main/scala/ml/tree/node/decisionNodes.scala @@ -0,0 +1,271 @@ +package ml.tree.node + +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.RDD +import ml.tree.split.{Split, SplitPredicate} +import org.apache.spark.broadcast.Broadcast +import ml.tree.impurity.Impurity +import ml.tree.strategy.Strategy +import org.apache.spark.util.StatCounter +import javax.naming.OperationNotSupportedException +import ml.tree.Metrics._ +import scala.Some +import ml.tree.strategy.Strategy +import ml.tree.split.Split +import scala.collection.mutable + +abstract class DecisionNode( + val data: RDD[(Double, Array[Double])], + val depth: Int, + val splitPredicates: List[SplitPredicate], + val nodeStats: NodeStats, + val allSplits: Broadcast[Set[Split]], + val impurity: Impurity, + val strategy: Strategy, + val maxDepth: Int) extends Node { + + //TODO: Change empty logic + val splits = splitPredicates.map(x => x.split) + + //TODO: Think about the merits of doing BFS and removing the parents RDDs from memory instead of doing DFS like below. + val (left, right, splitPredicate, isLeaf) = createLeftRightChild() + + override def toString() = "[" + left + "[" + this.splitPredicate + " prediction = " + this.prediction + "]" + right + "]" + + def createNode(data: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats: NodeStats): DecisionNode + + def createLeftRightChild(): (Node, Node, Option[SplitPredicate], Boolean) = { + if (depth > maxDepth) { + (new LeafNode(data), new LeafNode(data), None, true) + } else { + println("split count " + splits.length) + val split_gain = findBestSplit(nodeStats) + val (split, gain, leftNodeStats, rightNodeStats) = split_gain + println("Selected split = " + split + " with gain = " + gain, "left node stats = " + leftNodeStats + " right node stats = " + rightNodeStats) + if (split_gain._2 > 0) { + println("creating new nodes at depth = " + depth) + val leftPredicate = new SplitPredicate(split, true) + val rightPredicate = new SplitPredicate(split, false) + val leftData = data.filter(sample => sample._2(leftPredicate.split.feature) <= leftPredicate.split.threshold).cache + val rightData = data.filter(sample => sample._2(rightPredicate.split.feature) > rightPredicate.split.threshold).cache + val leftNode = if (leftData.count != 0) createNode(leftData, depth + 1, splitPredicates ::: List(leftPredicate), leftNodeStats) else new LeafNode(data) + val rightNode = if (rightData.count != 0) createNode(rightData, depth + 1, splitPredicates ::: List(rightPredicate), rightNodeStats) else new LeafNode(data) + (leftNode, rightNode, Some(leftPredicate), false) + } else { + println("not creating more child nodes since gain is not greater than 0") + (new LeafNode(data), new LeafNode(data), None, true) + } + } + } + + def comparePair(x: (Split, Double), y: (Split, Double)): (Split, Double) = { + if (x._2 > y._2) x else y + } + + def compareRegressionPair(x: (Split, Double, NodeStats, NodeStats), y: (Split, Double, NodeStats, NodeStats)): (Split, Double, NodeStats, NodeStats) = { + if (x._2 > y._2) x else y + } + + + def findBestSplit(nodeStats: NodeStats): (Split, Double, NodeStats, NodeStats) = { + + //TODO: Also remove splits that are subsets of previous splits + val availableSplits = allSplits.value filterNot (split => splits contains split) + println("availableSplit count " + availableSplits.size) + //availableSplits.map(split1 => (split1, impurity.calculateGain(split1, data))).reduce(comparePair(_, _)) + + strategy match { + case Strategy("Classification") => { + + //Write a function that takes an RDD and list of splits + //and returns a map of (split, , label) -> count + + val splits = availableSplits.toSeq + + //Modify numLabels to support multiple classes in the future + val numLabels = 2 + val numChildren = 2 + val lenSplits = splits.length + val outputVectorLength = numLabels * numChildren * lenSplits + val vecToVec : RDD[Array[Int]] = data.map( + sample => { + val storage : Array[Int] = new Array[Int](outputVectorLength) + val label = sample._1 + val features = sample._2 + splits.zipWithIndex.foreach{case (split, i) => + val featureIndex = split.feature + val threshold = split.threshold + if (features(featureIndex) <= threshold) { //left node + val index = i*(numLabels*numChildren) + label.toInt + storage(index) = 1 + } else{ //right node + val index = i*(numLabels*numChildren) + numLabels + label.toInt + storage(index) = 1 + } + } + storage + } + ) + + //val countVecToVec : Array[Long] = vecToVec.reduce((a1,a2) => NodeHelper.sumTwoArrays(a1,a2)) + val countVecToVec : Array[Long] = + vecToVec.aggregate(new Array[Long](outputVectorLength))(NodeHelper.sumLongIntArrays,NodeHelper.sumTwoLongArrays) + + + //TOOD: Unnecessary step. Use indices directly instead of creating a map. Not a big hit in performance. Optimize later. + var newGainCalculations = Map[(Split,String,Double),Long]() + splits.zipWithIndex.foreach{case(split,i) => + newGainCalculations += ((split,"left",0.0) -> countVecToVec(i*(numLabels*numChildren) + 0)) + newGainCalculations += ((split,"left",1.0) -> countVecToVec(i*(numLabels*numChildren) + 1)) + newGainCalculations += ((split,"right",0.0) -> countVecToVec(i*(numLabels*numChildren) + numLabels + 0)) + newGainCalculations += ((split,"right",1.0) -> countVecToVec(i*(numLabels*numChildren) + numLabels + 1)) + } + + val split_gain_list = for ( + split <- availableSplits; + //gain = impurity.calculateClassificationGain(split, gainCalculations) + gain = impurity.calculateClassificationGain(split, newGainCalculations) + ) yield (split, gain) + + val split_gain = split_gain_list.reduce(comparePair(_, _)) + (split_gain._1, split_gain._2, new NodeStats, new NodeStats) + + } + case Strategy("Regression") => { + + val splitWiseCalculations = data.flatMap(sample => { + val label = sample._1 + val features = sample._2 + val leftOrRight = for { + split <- availableSplits.toSeq + featureIndex = split.feature + threshold = split.threshold + } yield { + if (features(featureIndex) <= threshold) ((split, "left"), label) else ((split, "right"), label) + } + leftOrRight + }) + + // Calculate variance for each split + val splitVariancePairs = splitWiseCalculations + .groupByKey() + .map(x => x._1 -> {val stat = StatCounter(x._2); (stat.mean, stat.variance, stat.count)}) + .collect + //Tuple array to map conversion + val gainCalculations = scala.collection.immutable.Map(splitVariancePairs: _*) + + val split_gain_list = for ( + split <- availableSplits; + (gain, leftNodeStats, rightNodeStats) = impurity.calculateRegressionGain(split, gainCalculations, nodeStats) + ) yield (split, gain, leftNodeStats, rightNodeStats) + + val split_gain = split_gain_list.reduce(compareRegressionPair(_, _)) + (split_gain._1, split_gain._2, split_gain._3, split_gain._4) + } + } + } + + def calculateVarianceSize(seq: Seq[Double]): (Double, Double, Long) = { + val stat = StatCounter(seq) + (stat.mean, stat.variance, stat.count) + } + +} + + +/* + * Top node for building a classification tree + */ +class TopClassificationNode(input: RDD[(Double, Array[Double])], allSplits: Broadcast[Set[Split]], impurity: Impurity, strategy: Strategy, maxDepth: Int) + extends ClassificationNode(input.cache, 1, List[SplitPredicate](), new NodeStats, allSplits, impurity, strategy, maxDepth) { + override def toString() = "[" + left + "[" + "TopNode" + "]" + right + "]" +} + +/* + * Class for each node in the classification tree + */ +class ClassificationNode(data: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats: NodeStats, allSplits: Broadcast[Set[Split]], impurity: Impurity, strategy: Strategy, maxDepth: Int) + extends DecisionNode(data, depth, splitPredicates, nodeStats, allSplits, impurity, strategy, maxDepth) { + + // Prediction at each classification node + val prediction: Prediction = { + val countZero: Double = data.filter(x => (x._1 == 0.0)).count + val countOne: Double = data.filter(x => (x._1 == 1.0)).count + val countTotal: Double = countZero + countOne + new Prediction(countOne / countTotal, Map(0.0 -> countZero, 1.0 -> countOne)) + } + + //Static factory method. Put it in a better location. + def createNode(anyData: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats: NodeStats) + = new ClassificationNode(anyData, depth, splitPredicates, nodeStats, allSplits, impurity, strategy, maxDepth) + +} + +/* + * Top node for building a regression tree + */ +class TopRegressionNode(input: RDD[(Double, Array[Double])], nodeStats: NodeStats, allSplits: Broadcast[Set[Split]], impurity: Impurity, strategy: Strategy, maxDepth: Int) + extends RegressionNode(input.cache, 1, List[SplitPredicate](), nodeStats, allSplits, impurity, strategy, maxDepth) { + override def toString() = "[" + left + "[" + "TopNode" + "]" + right + "]" +} + +/* + * Class for each node in the regression tree + */ +class RegressionNode(data: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats: NodeStats, allSplits: Broadcast[Set[Split]], impurity: Impurity, strategy: Strategy, maxDepth: Int) + extends DecisionNode(data, depth, splitPredicates, nodeStats, allSplits, impurity, strategy, maxDepth) { + + // Prediction at each regression node + val prediction: Prediction = new Prediction(data.map(_._1).mean, Map()) + + //Static factory method. Put it in a better location. + def createNode(anyData: RDD[(Double, Array[Double])], depth: Int, splitPredicates: List[SplitPredicate], nodeStats: NodeStats) + = new RegressionNode(anyData, depth, splitPredicates, nodeStats, allSplits, impurity, strategy, maxDepth) +} + +/* + * Empty Node class used to terminate leaf nodes + */ +class LeafNode(val data: RDD[(Double, Array[Double])]) extends Node { + def isLeaf = true + + def left = throw new OperationNotSupportedException("EmptyNode.left") + + def right = throw new OperationNotSupportedException("EmptyNode.right") + + def depth = throw new OperationNotSupportedException("EmptyNode.depth") + + def splitPredicates = throw new OperationNotSupportedException("EmptyNode.splitPredicates") + + def splitPredicate = throw new OperationNotSupportedException("EmptyNode.splitPredicate") + + override def toString() = "Empty" + + val prediction: Prediction = { + val countZero: Double = data.filter(x => (x._1 == 0.0)).count + val countOne: Double = data.filter(x => (x._1 == 1.0)).count + val countTotal: Double = countZero + countOne + new Prediction(countOne / countTotal, Map(0.0 -> countZero, 1.0 -> countOne)) + } +} + +object NodeHelper extends Serializable { + + //There definitely has to be a library function to do this! + def sumTwoLongArrays(a1 : Array[Long], a2 : Array[Long]) : Array[Long] = { + val storage = new Array[Long](a1.length) + for (i <- 0 until a1.length){storage(i) = a1(i) + a2(i)} + storage + } + + //There definitely has to be a library function to do this! + def sumLongIntArrays(a1 : Array[Long], a2 : Array[Int]) : Array[Long] = { + val storage = new Array[Long](a1.length) + for (i <- 0 until a1.length){storage(i) = a1(i) + a2(i)} + storage + } + +} + + + diff --git a/src/main/scala/ml/tree/split/Split.scala b/src/main/scala/ml/tree/split/Split.scala new file mode 100644 index 0000000..69db788 --- /dev/null +++ b/src/main/scala/ml/tree/split/Split.scala @@ -0,0 +1,8 @@ +package ml.tree.split + +/* + * Class for storing splits -- feature index and threshold + */ +case class Split(val feature: Int, val threshold: Double) { + override def toString = "feature = " + feature + ", threshold = " + threshold +} diff --git a/src/main/scala/ml/tree/split/SplitPredicate.scala b/src/main/scala/ml/tree/split/SplitPredicate.scala new file mode 100644 index 0000000..63a1701 --- /dev/null +++ b/src/main/scala/ml/tree/split/SplitPredicate.scala @@ -0,0 +1,8 @@ +package ml.tree.split + +/* + * Class for storing the split predicate. + */ +class SplitPredicate(val split: Split, lessThanEqualTo: Boolean = true) extends Serializable { + override def toString = "split = " + split.toString + ", lessThan = " + lessThanEqualTo +} diff --git a/src/main/scala/ml/tree/strategy/Strategy.scala b/src/main/scala/ml/tree/strategy/Strategy.scala new file mode 100644 index 0000000..9a0e144 --- /dev/null +++ b/src/main/scala/ml/tree/strategy/Strategy.scala @@ -0,0 +1,3 @@ +package ml.tree.strategy + +case class Strategy(val name: String) diff --git a/src/test/scala/ml/tree/DecisionTreeTest.scala b/src/test/scala/ml/tree/DecisionTreeTest.scala new file mode 100644 index 0000000..d288c7a --- /dev/null +++ b/src/test/scala/ml/tree/DecisionTreeTest.scala @@ -0,0 +1,8 @@ +package ml.tree +import org.scalatest.FunSuite + +class DecisionTreeTest extends FunSuite { + test("Basic decision tree test") { + //Decision Tree test + } +}