Skip to content

Latest commit

 

History

History
63 lines (46 loc) · 1.39 KB

example.md

File metadata and controls

63 lines (46 loc) · 1.39 KB
id title
example
Example

Building Base Learner

import org.apache.spark.ml.classification.DecisionTreeClassifier

val baseClassifier = new DecisionTreeClassifier()
.setMaxDepth(20)

Building Meta Estimator

import org.apache.spark.ml.classification.BaggingClassifier

val baggingClassifier = new BaggingClassifier()
.setBaseLearner(baseClassifier)
.setNumBaseLearners(10)
.setParallelism(4)

Building Param Grid

import org.apache.spark.ml.tuning.ParamGridBuilder

val paramGrid = new ParamGridBuilder()
        .addGrid(baggingClassifier.numBaseLearners, Array(10,20))
        .addGrid(baseClassifier.maxDepth, Array(10,20))
        .build()

Grid Search with Cross Validation

import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.tuning.CrossValidator
import org.apache.spark.ml.classification.BaggingClassificationModel

val cv = new CrossValidator()
        .setEstimator(baggingClassifier)
        .setEvaluator(new MulticlassClassificationEvaluator())
        .setEstimatorParamMaps(paramGrid)
        .setNumFolds(5)
        .setParallelism(4)

val cvModel = cv.fit(data)

val bestModel = cvModel.bestModel.asInstanceOf[BaggingClassificationModel]

bestModel

Save and Load

bestModel.write.overwrite().save("/tmp/model")
val loaded = BaggingClassificationModel.load("/tmp/model")