Skip to content

Commit

Permalink
[SPARK-6097][MLLIB] Support tree model save/load in PySpark/MLlib
Browse files Browse the repository at this point in the history
Similar to `MatrixFactorizaionModel`, we only need wrappers to support save/load for tree models in Python.

jkbradley

Author: Xiangrui Meng <[email protected]>

Closes apache#4854 from mengxr/SPARK-6097 and squashes the following commits:

4586a4d [Xiangrui Meng] fix more typos
8ebcac2 [Xiangrui Meng] fix python style
91172d8 [Xiangrui Meng] fix typos
201b3b9 [Xiangrui Meng] update user guide
b5158e2 [Xiangrui Meng] support tree model save/load in PySpark/MLlib
  • Loading branch information
mengxr committed Mar 3, 2015
1 parent 54d1968 commit 7e53a79
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 33 deletions.
16 changes: 10 additions & 6 deletions docs/mllib-decision-tree.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,9 @@ DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath");

<div data-lang="python">

Note that the Python API does not yet support model save/load but will in the future.

{% highlight python %}
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.tree import DecisionTree
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from pyspark.mllib.util import MLUtils

# Load and parse the data file into an RDD of LabeledPoint.
Expand All @@ -317,6 +315,10 @@ testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(tes
print('Test Error = ' + str(testErr))
print('Learned classification tree model:')
print(model.toDebugString())

# Save and load model
model.save(sc, "myModelPath")
sameModel = DecisionTreeModel.load(sc, "myModelPath")
{% endhighlight %}
</div>

Expand Down Expand Up @@ -440,11 +442,9 @@ DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath");

<div data-lang="python">

Note that the Python API does not yet support model save/load but will in the future.

{% highlight python %}
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.tree import DecisionTree
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from pyspark.mllib.util import MLUtils

# Load and parse the data file into an RDD of LabeledPoint.
Expand All @@ -464,6 +464,10 @@ testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / flo
print('Test Mean Squared Error = ' + str(testMSE))
print('Learned regression tree model:')
print(model.toDebugString())

# Save and load model
model.save(sc, "myModelPath")
sameModel = DecisionTreeModel.load(sc, "myModelPath")
{% endhighlight %}
</div>

Expand Down
32 changes: 20 additions & 12 deletions docs/mllib-ensembles.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,8 @@ RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath");

<div data-lang="python">

Note that the Python API does not yet support model save/load but will in the future.

{% highlight python %}
from pyspark.mllib.tree import RandomForest
from pyspark.mllib.tree import RandomForest, RandomForestModel
from pyspark.mllib.util import MLUtils

# Load and parse the data file into an RDD of LabeledPoint.
Expand All @@ -228,6 +226,10 @@ testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(tes
print('Test Error = ' + str(testErr))
print('Learned classification forest model:')
print(model.toDebugString())

# Save and load model
model.save(sc, "myModelPath")
sameModel = RandomForestModel.load(sc, "myModelPath")
{% endhighlight %}
</div>

Expand Down Expand Up @@ -354,10 +356,8 @@ RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath");

<div data-lang="python">

Note that the Python API does not yet support model save/load but will in the future.

{% highlight python %}
from pyspark.mllib.tree import RandomForest
from pyspark.mllib.tree import RandomForest, RandomForestModel
from pyspark.mllib.util import MLUtils

# Load and parse the data file into an RDD of LabeledPoint.
Expand All @@ -380,6 +380,10 @@ testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / flo
print('Test Mean Squared Error = ' + str(testMSE))
print('Learned regression forest model:')
print(model.toDebugString())

# Save and load model
model.save(sc, "myModelPath")
sameModel = RandomForestModel.load(sc, "myModelPath")
{% endhighlight %}
</div>

Expand Down Expand Up @@ -581,10 +585,8 @@ GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "m

<div data-lang="python">

Note that the Python API does not yet support model save/load but will in the future.

{% highlight python %}
from pyspark.mllib.tree import GradientBoostedTrees
from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel
from pyspark.mllib.util import MLUtils

# Load and parse the data file.
Expand All @@ -605,6 +607,10 @@ testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(tes
print('Test Error = ' + str(testErr))
print('Learned classification GBT model:')
print(model.toDebugString())

# Save and load model
model.save(sc, "myModelPath")
sameModel = GradientBoostedTreesModel.load(sc, "myModelPath")
{% endhighlight %}
</div>

Expand Down Expand Up @@ -732,10 +738,8 @@ GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "m

<div data-lang="python">

Note that the Python API does not yet support model save/load but will in the future.

{% highlight python %}
from pyspark.mllib.tree import GradientBoostedTrees
from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel
from pyspark.mllib.util import MLUtils

# Load and parse the data file.
Expand All @@ -756,6 +760,10 @@ testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / flo
print('Test Mean Squared Error = ' + str(testMSE))
print('Learned regression GBT model:')
print(model.toDebugString())

# Save and load model
model.save(sc, "myModelPath")
sameModel = GradientBoostedTreesModel.load(sc, "myModelPath")
{% endhighlight %}
</div>

Expand Down
9 changes: 3 additions & 6 deletions python/pyspark/mllib/recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pyspark import SparkContext
from pyspark.rdd import RDD
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
from pyspark.mllib.util import Saveable, JavaLoader
from pyspark.mllib.util import JavaLoader, JavaSaveable

__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']

Expand All @@ -41,7 +41,7 @@ def __reduce__(self):


@inherit_doc
class MatrixFactorizationModel(JavaModelWrapper, Saveable, JavaLoader):
class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):

"""A matrix factorisation model trained by regularized alternating
least-squares.
Expand Down Expand Up @@ -92,7 +92,7 @@ class MatrixFactorizationModel(JavaModelWrapper, Saveable, JavaLoader):
0.43...
>>> try:
... os.removedirs(path)
... except:
... except OSError:
... pass
"""
def predict(self, user, product):
Expand All @@ -111,9 +111,6 @@ def userFeatures(self):
def productFeatures(self):
return self.call("getProductFeatures")

def save(self, sc, path):
self.call("save", sc._jsc.sc(), path)


class ALS(object):

Expand Down
27 changes: 26 additions & 1 deletion python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
Fuller unit tests for Python MLlib.
"""

import os
import sys
import tempfile
import array as pyarray

from numpy import array, array_equal
Expand Down Expand Up @@ -195,7 +197,8 @@ def test_gmm_deterministic(self):

def test_classification(self):
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\
RandomForestModel, GradientBoostedTrees, GradientBoostedTreesModel
data = [
LabeledPoint(0.0, [1, 0, 0]),
LabeledPoint(1.0, [0, 1, 1]),
Expand All @@ -205,6 +208,8 @@ def test_classification(self):
rdd = self.sc.parallelize(data)
features = [p.features.tolist() for p in data]

temp_dir = tempfile.mkdtemp()

lr_model = LogisticRegressionWithSGD.train(rdd)
self.assertTrue(lr_model.predict(features[0]) <= 0)
self.assertTrue(lr_model.predict(features[1]) > 0)
Expand All @@ -231,20 +236,40 @@ def test_classification(self):
self.assertTrue(dt_model.predict(features[2]) <= 0)
self.assertTrue(dt_model.predict(features[3]) > 0)

dt_model_dir = os.path.join(temp_dir, "dt")
dt_model.save(self.sc, dt_model_dir)
same_dt_model = DecisionTreeModel.load(self.sc, dt_model_dir)
self.assertEqual(same_dt_model.toDebugString(), dt_model.toDebugString())

rf_model = RandomForest.trainClassifier(
rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100)
self.assertTrue(rf_model.predict(features[0]) <= 0)
self.assertTrue(rf_model.predict(features[1]) > 0)
self.assertTrue(rf_model.predict(features[2]) <= 0)
self.assertTrue(rf_model.predict(features[3]) > 0)

rf_model_dir = os.path.join(temp_dir, "rf")
rf_model.save(self.sc, rf_model_dir)
same_rf_model = RandomForestModel.load(self.sc, rf_model_dir)
self.assertEqual(same_rf_model.toDebugString(), rf_model.toDebugString())

gbt_model = GradientBoostedTrees.trainClassifier(
rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
self.assertTrue(gbt_model.predict(features[0]) <= 0)
self.assertTrue(gbt_model.predict(features[1]) > 0)
self.assertTrue(gbt_model.predict(features[2]) <= 0)
self.assertTrue(gbt_model.predict(features[3]) > 0)

gbt_model_dir = os.path.join(temp_dir, "gbt")
gbt_model.save(self.sc, gbt_model_dir)
same_gbt_model = GradientBoostedTreesModel.load(self.sc, gbt_model_dir)
self.assertEqual(same_gbt_model.toDebugString(), gbt_model.toDebugString())

try:
os.removedirs(temp_dir)
except OSError:
pass

def test_regression(self):
from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \
RidgeRegressionWithSGD
Expand Down
21 changes: 17 additions & 4 deletions python/pyspark/mllib/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
from pyspark.mllib.common import callMLlibFunc, inherit_doc, JavaModelWrapper
from pyspark.mllib.linalg import _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.util import JavaLoader, JavaSaveable

__all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel',
'RandomForest', 'GradientBoostedTreesModel', 'GradientBoostedTrees']


class TreeEnsembleModel(JavaModelWrapper):
class TreeEnsembleModel(JavaModelWrapper, JavaSaveable):
def predict(self, x):
"""
Predict values for a single data point or an RDD of points using
Expand Down Expand Up @@ -66,7 +67,7 @@ def toDebugString(self):
return self._java_model.toDebugString()


class DecisionTreeModel(JavaModelWrapper):
class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader):
"""
.. note:: Experimental
Expand Down Expand Up @@ -103,6 +104,10 @@ def toDebugString(self):
""" full model. """
return self._java_model.toDebugString()

@classmethod
def _java_loader_class(cls):
return "org.apache.spark.mllib.tree.model.DecisionTreeModel"


class DecisionTree(object):
"""
Expand Down Expand Up @@ -227,13 +232,17 @@ def trainRegressor(cls, data, categoricalFeaturesInfo,


@inherit_doc
class RandomForestModel(TreeEnsembleModel):
class RandomForestModel(TreeEnsembleModel, JavaLoader):
"""
.. note:: Experimental
Represents a random forest model.
"""

@classmethod
def _java_loader_class(cls):
return "org.apache.spark.mllib.tree.model.RandomForestModel"


class RandomForest(object):
"""
Expand Down Expand Up @@ -406,13 +415,17 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt


@inherit_doc
class GradientBoostedTreesModel(TreeEnsembleModel):
class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader):
"""
.. note:: Experimental
Represents a gradient-boosted tree model.
"""

@classmethod
def _java_loader_class(cls):
return "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel"


class GradientBoostedTrees(object):
"""
Expand Down
37 changes: 33 additions & 4 deletions python/pyspark/mllib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import warnings

from pyspark.mllib.common import callMLlibFunc
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper, inherit_doc
from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint

Expand Down Expand Up @@ -191,6 +191,17 @@ def save(self, sc, path):
raise NotImplementedError


@inherit_doc
class JavaSaveable(Saveable):
"""
Mixin for models that provide save() through their Scala
implementation.
"""

def save(self, sc, path):
self._java_model.save(sc._jsc.sc(), path)


class Loader(object):
"""
Mixin for classes which can load saved models from files.
Expand All @@ -210,20 +221,38 @@ def load(cls, sc, path):
raise NotImplemented


@inherit_doc
class JavaLoader(Loader):
"""
Mixin for classes which can load saved models using its Scala
implementation.
"""

@classmethod
def load(cls, sc, path):
def _java_loader_class(cls):
"""
Returns the full class name of the Java loader. The default
implementation replaces "pyspark" by "org.apache.spark" in
the Python full class name.
"""
java_package = cls.__module__.replace("pyspark", "org.apache.spark")
java_class = ".".join([java_package, cls.__name__])
return ".".join([java_package, cls.__name__])

@classmethod
def _load_java(cls, sc, path):
"""
Load a Java model from the given path.
"""
java_class = cls._java_loader_class()
java_obj = sc._jvm
for name in java_class.split("."):
java_obj = getattr(java_obj, name)
return cls(java_obj.load(sc._jsc.sc(), path))
return java_obj.load(sc._jsc.sc(), path)

@classmethod
def load(cls, sc, path):
java_model = cls._load_java(sc, path)
return cls(java_model)


def _test():
Expand Down

0 comments on commit 7e53a79

Please sign in to comment.