diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/ScoringModelUtils.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/ScoringModelUtils.scala index b1a277e9..52577d6b 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/ScoringModelUtils.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/ScoringModelUtils.scala @@ -16,12 +16,13 @@ package org.trustedanalytics.sparktk.models import java.io.{ FileOutputStream, File } -import java.net.URI +import java.net.{ URL, URI } import java.nio.DoubleBuffer import java.nio.file.{ Files, Path } import org.apache.commons.lang.StringUtils import org.apache.hadoop.conf.Configuration import org.apache.commons.io.{ IOUtils, FileUtils } +import org.apache.spark.SparkContext import org.trustedanalytics.model.archive.format.ModelArchiveFormat import org.trustedanalytics.sparktk.saveload.SaveLoad @@ -91,7 +92,8 @@ object ScoringModelUtils { * @param sourcePath Path to source location. Defaults to use the path to the currently running jar. * @return full path to the location of the MAR file for Scoring Engine */ - def saveToMar(marSavePath: String, + def saveToMar(sc: SparkContext, + marSavePath: String, modelClass: String, modelSrcDir: java.nio.file.Path, modelReader: String = classOf[SparkTkModelAdapter].getName, @@ -114,20 +116,21 @@ object ScoringModelUtils { val x = new TkSearchPath(absolutePath.substring(0, absolutePath.lastIndexOf("/"))) var jarFileList = x.jarsInSearchPath.values.toList - if (marSavePath.startsWith("hdfs")) { + val protocol = SaveLoad.getProtocol(marSavePath) + + if ("file".equalsIgnoreCase(protocol)) { + jarFileList = jarFileList ::: List(new File(modelSrcDir.toString)) + } + else { val modelFile = Files.createTempDirectory("localModel") val localModelPath = new org.apache.hadoop.fs.Path(modelFile.toString) val hdfsFileSystem: org.apache.hadoop.fs.FileSystem = org.apache.hadoop.fs.FileSystem.get(new URI(modelFile.toString), new Configuration()) hdfsFileSystem.copyToLocalFile(new org.apache.hadoop.fs.Path(modelSrcDir.toString), localModelPath) jarFileList = jarFileList ::: List(new File(localModelPath.toString)) } - else { - jarFileList = jarFileList ::: List(new File(modelSrcDir.toString)) - } ModelArchiveFormat.write(jarFileList, modelReader, modelClass, zipOutStream) - } - SaveLoad.saveMar(marSavePath, zipFile) + SaveLoad.saveMar(sc, marSavePath, zipFile) } finally { FileUtils.deleteQuietly(zipFile) diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/logistic_regression/LogisticRegressionModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/logistic_regression/LogisticRegressionModel.scala index 1190b5a4..c1dc8d7d 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/logistic_regression/LogisticRegressionModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/logistic_regression/LogisticRegressionModel.scala @@ -442,7 +442,7 @@ case class LogisticRegressionModel private[logistic_regression] (observationColu try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[LogisticRegressionModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[LogisticRegressionModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/naive_bayes/NaiveBayesModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/naive_bayes/NaiveBayesModel.scala index a9d1b8e6..6d51d043 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/naive_bayes/NaiveBayesModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/naive_bayes/NaiveBayesModel.scala @@ -217,7 +217,7 @@ case class NaiveBayesModel private[naive_bayes] (sparkModel: SparkNaiveBayesMode try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[NaiveBayesModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[NaiveBayesModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/random_forest_classifier/RandomForestClassifierModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/random_forest_classifier/RandomForestClassifierModel.scala index 42283436..24c8940c 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/random_forest_classifier/RandomForestClassifierModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/random_forest_classifier/RandomForestClassifierModel.scala @@ -348,7 +348,7 @@ case class RandomForestClassifierModel private[random_forest_classifier] (sparkM try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString, overwrite = true) - ScoringModelUtils.saveToMar(marSavePath, classOf[RandomForestClassifierModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[RandomForestClassifierModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/svm/SvmModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/svm/SvmModel.scala index a1a65ae3..c98cc761 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/svm/SvmModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/svm/SvmModel.scala @@ -281,7 +281,7 @@ case class SvmModel private[svm] (sparkModel: SparkSvmModel, try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[SvmModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[SvmModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/gmm/GaussianMixtureModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/gmm/GaussianMixtureModel.scala index 4a6f2dc7..a4288733 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/gmm/GaussianMixtureModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/gmm/GaussianMixtureModel.scala @@ -244,7 +244,7 @@ case class GaussianMixtureModel private[gmm] (observationColumns: Seq[String], try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[GaussianMixtureModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[GaussianMixtureModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/kmeans/KMeansModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/kmeans/KMeansModel.scala index d677b433..855a65a2 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/kmeans/KMeansModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/kmeans/KMeansModel.scala @@ -312,7 +312,7 @@ case class KMeansModel private[kmeans] (columns: Seq[String], try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[KMeansModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[KMeansModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/lda/LdaModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/lda/LdaModel.scala index d8ca50fd..5a377b2d 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/lda/LdaModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/lda/LdaModel.scala @@ -296,7 +296,7 @@ case class LdaModel private[lda] (documentColumnName: String, try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[LdaModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[LdaModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/dimreduction/pca/PcaModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/dimreduction/pca/PcaModel.scala index 2c01520e..4281c931 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/dimreduction/pca/PcaModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/dimreduction/pca/PcaModel.scala @@ -211,7 +211,7 @@ case class PcaModel private[pca] (columns: Seq[String], try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[PcaModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[PcaModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/recommendation/collaborative_filtering/CollaborativeFilteringModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/recommendation/collaborative_filtering/CollaborativeFilteringModel.scala index c8c85844..3fa252ce 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/recommendation/collaborative_filtering/CollaborativeFilteringModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/recommendation/collaborative_filtering/CollaborativeFilteringModel.scala @@ -319,7 +319,7 @@ case class CollaborativeFilteringModel(sourceColumnName: String, try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[CollaborativeFilteringModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[CollaborativeFilteringModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/regression/linear_regression/LinearRegressionModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/regression/linear_regression/LinearRegressionModel.scala index 608d8461..344434af 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/regression/linear_regression/LinearRegressionModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/regression/linear_regression/LinearRegressionModel.scala @@ -291,7 +291,7 @@ case class LinearRegressionModel(observationColumns: Seq[String], // The spark linear regression model save will fail, if we don't specify the "overwrite", since the temp // directory has already been created. save(sc, tmpDir.toString, overwrite = true) - ScoringModelUtils.saveToMar(marSavePath, classOf[LinearRegressionModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[LinearRegressionModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/regression/random_forest_regressor/RandomForestRegressorModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/regression/random_forest_regressor/RandomForestRegressorModel.scala index cc1da109..cfb0c0a0 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/regression/random_forest_regressor/RandomForestRegressorModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/regression/random_forest_regressor/RandomForestRegressorModel.scala @@ -341,7 +341,7 @@ case class RandomForestRegressorModel private[random_forest_regressor] (sparkMod try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString, overwrite = true) - ScoringModelUtils.saveToMar(marSavePath, classOf[RandomForestRegressorModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[RandomForestRegressorModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/survivalanalysis/cox_ph/CoxProportionalHazardsModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/survivalanalysis/cox_ph/CoxProportionalHazardsModel.scala index d73a648c..cc145d55 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/survivalanalysis/cox_ph/CoxProportionalHazardsModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/survivalanalysis/cox_ph/CoxProportionalHazardsModel.scala @@ -239,7 +239,7 @@ case class CoxProportionalHazardsModel private[cox_ph] (sparkModel: CoxPhModel, try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString, overwrite = true) - ScoringModelUtils.saveToMar(marSavePath, classOf[CoxProportionalHazardsModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[CoxProportionalHazardsModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arima/ArimaModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arima/ArimaModel.scala index f44d29ce..c464712e 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arima/ArimaModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arima/ArimaModel.scala @@ -267,7 +267,7 @@ case class ArimaModel private[arima] (ts: DenseVector, try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[ArimaModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[ArimaModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arimax/ArimaxModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arimax/ArimaxModel.scala index 331cae32..e63eb390 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arimax/ArimaxModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arimax/ArimaxModel.scala @@ -255,7 +255,7 @@ case class ArimaxModel private[arimax] (timeseriesColumn: String, try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[ArimaxModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[ArimaxModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arx/ArxModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arx/ArxModel.scala index d81d13c6..e5596155 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arx/ArxModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arx/ArxModel.scala @@ -234,7 +234,7 @@ case class ArxModel private[arx] (timeseriesColumn: String, try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[ArxModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[ArxModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/max/MaxModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/max/MaxModel.scala index 7338646e..0342d5ec 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/max/MaxModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/max/MaxModel.scala @@ -250,7 +250,7 @@ case class MaxModel private[max] (timeseriesColumn: String, try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[MaxModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[MaxModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/saveload/SaveLoad.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/saveload/SaveLoad.scala index 2ce2c03f..ee6b2767 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/saveload/SaveLoad.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/saveload/SaveLoad.scala @@ -16,9 +16,10 @@ package org.trustedanalytics.sparktk.saveload import java.io.File +import java.nio.file.Files import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.Path -import java.net.URI +import java.net.{ URL, URI } import org.apache.hadoop.fs.permission.{ FsPermission, FsAction } import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext @@ -28,6 +29,7 @@ import org.json4s.jackson.Serialization import org.json4s.{ NoTypeHints, Extraction, DefaultFormats } import org.json4s.jackson.JsonMethods._ import org.json4s.JsonDSL._ +import org.trustedanalytics.sparktk.models.ScoringModelUtils /** * Simple save/load library which uses json4s to read/write text files, including info for format validation @@ -56,21 +58,58 @@ object SaveLoad { * @param zipFile the MAR file to be stored * @return full path to the location of the MAR file */ - def saveMar(storagePath: String, zipFile: File): String = { - if (storagePath.startsWith("hdfs")) { + def saveMar(sc: SparkContext, storagePath: String, zipFile: File): String = { + + val protocol = getProtocol(storagePath) + + if ("file".equalsIgnoreCase(protocol)) { + val file = new File(storagePath) + FileUtils.copyFile(zipFile, file) + file.getCanonicalPath + } + else { val hdfsPath = new Path(storagePath) - val hdfsFileSystem: org.apache.hadoop.fs.FileSystem = org.apache.hadoop.fs.FileSystem.get(new URI(storagePath), new Configuration()) + val hdfsFileSystem: org.apache.hadoop.fs.FileSystem = org.apache.hadoop.fs.FileSystem.get(new URI(storagePath), sc.hadoopConfiguration) val localPath = new Path(zipFile.getAbsolutePath) hdfsFileSystem.copyFromLocalFile(false, true, localPath, hdfsPath) hdfsFileSystem.setPermission(hdfsPath, new FsPermission(FsAction.ALL, FsAction.ALL, FsAction.NONE)) storagePath } - else { - val file = new File(storagePath) - FileUtils.copyFile(zipFile, file) - file.getCanonicalPath - } + } + + /** + * Returns the protocol for a given URI or filename. + * + * @param source Determine the protocol for this URI or filename. + * + * @return The protocol for the given source. + */ + def getProtocol(source: String): String = { + require(source != null && !source.isEmpty, "marfile source must not be null") + + val protocol: String = try { + val uri = new URI(source) + + if (uri.isAbsolute) { + uri.getScheme + } + else { + val url = new URL(source) + url.getProtocol + } + } + catch { + case ex: Exception => + if (source.startsWith("//")) { + throw new IllegalArgumentException("Does not support Relative context starting with // : " + source) + } + else { + val file = new File(source) + file.toURI.toURL.getProtocol + } + } + protocol } /**