From 02481448cb26c83e27e8456340c07ef22e631e49 Mon Sep 17 00:00:00 2001 From: Anjali Sood Date: Mon, 6 Feb 2017 09:45:22 -0800 Subject: [PATCH 1/2] AWS S3 support --- .../sparktk/models/ScoringModelUtils.scala | 75 +++++++++++++++++-- .../LogisticRegressionModel.scala | 2 +- .../naive_bayes/NaiveBayesModel.scala | 2 +- .../RandomForestClassifierModel.scala | 2 +- .../models/classification/svm/SvmModel.scala | 2 +- .../clustering/gmm/GaussianMixtureModel.scala | 2 +- .../clustering/kmeans/KMeansModel.scala | 2 +- .../models/clustering/lda/LdaModel.scala | 2 +- .../models/dimreduction/pca/PcaModel.scala | 2 +- .../CollaborativeFilteringModel.scala | 2 +- .../LinearRegressionModel.scala | 2 +- .../RandomForestRegressorModel.scala | 2 +- .../cox_ph/CoxProportionalHazardsModel.scala | 2 +- .../models/timeseries/arima/ArimaModel.scala | 2 +- .../timeseries/arimax/ArimaxModel.scala | 2 +- .../models/timeseries/arx/ArxModel.scala | 2 +- .../models/timeseries/max/MaxModel.scala | 2 +- .../sparktk/saveload/SaveLoad.scala | 23 +++--- 18 files changed, 97 insertions(+), 33 deletions(-) diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/ScoringModelUtils.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/ScoringModelUtils.scala index b1a277e9..9706ee13 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/ScoringModelUtils.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/ScoringModelUtils.scala @@ -16,12 +16,13 @@ package org.trustedanalytics.sparktk.models import java.io.{ FileOutputStream, File } -import java.net.URI +import java.net.{ URL, URI } import java.nio.DoubleBuffer import java.nio.file.{ Files, Path } import org.apache.commons.lang.StringUtils import org.apache.hadoop.conf.Configuration import org.apache.commons.io.{ IOUtils, FileUtils } +import org.apache.spark.SparkContext import org.trustedanalytics.model.archive.format.ModelArchiveFormat import org.trustedanalytics.sparktk.saveload.SaveLoad @@ -91,7 +92,8 @@ object ScoringModelUtils { * @param sourcePath Path to source location. Defaults to use the path to the currently running jar. * @return full path to the location of the MAR file for Scoring Engine */ - def saveToMar(marSavePath: String, + def saveToMar(sc: SparkContext, + marSavePath: String, modelClass: String, modelSrcDir: java.nio.file.Path, modelReader: String = classOf[SparkTkModelAdapter].getName, @@ -114,24 +116,81 @@ object ScoringModelUtils { val x = new TkSearchPath(absolutePath.substring(0, absolutePath.lastIndexOf("/"))) var jarFileList = x.jarsInSearchPath.values.toList - if (marSavePath.startsWith("hdfs")) { + val protocol = getProtocol(marSavePath) + + if ("file".equalsIgnoreCase(protocol)) { + print("Local") + jarFileList = jarFileList ::: List(new File(modelSrcDir.toString)) + } + else { + print("not local") val modelFile = Files.createTempDirectory("localModel") val localModelPath = new org.apache.hadoop.fs.Path(modelFile.toString) val hdfsFileSystem: org.apache.hadoop.fs.FileSystem = org.apache.hadoop.fs.FileSystem.get(new URI(modelFile.toString), new Configuration()) hdfsFileSystem.copyToLocalFile(new org.apache.hadoop.fs.Path(modelSrcDir.toString), localModelPath) jarFileList = jarFileList ::: List(new File(localModelPath.toString)) } - else { - jarFileList = jarFileList ::: List(new File(modelSrcDir.toString)) - } ModelArchiveFormat.write(jarFileList, modelReader, modelClass, zipOutStream) - } - SaveLoad.saveMar(marSavePath, zipFile) + SaveLoad.saveMar(sc, marSavePath, zipFile) } finally { FileUtils.deleteQuietly(zipFile) IOUtils.closeQuietly(zipOutStream) } } + + /** + * Returns the protocol for a given URI or filename. + * + * @param source Determine the protocol for this URI or filename. + * + * @return The protocol for the given source. + */ + def getProtocol(source: String): String = { + require(source != null, "marfile source must not be null") + + var protocol: String = null + try { + val uri = new URI(source) + + if (uri.isAbsolute) { + protocol = uri.getScheme + } + else { + val url = new URL(source) + protocol = url.getProtocol + } + + } + catch { + case ex: Exception => + if (source.startsWith("//")) { + throw new IllegalArgumentException("Relative context: " + source) + } + else { + val file = new File(source) + protocol = getProtocol(file) + } + } + protocol + } + + /** + * Returns the protocol for a given file. + * + * @param file Determine the protocol for this file. + * + * @return The protocol for the given file. + */ + private def getProtocol(file: File): String = { + var result: String = null + try { + result = file.toURI.toURL.getProtocol + } + catch { + case ex: Exception => result = "unknown" + } + result + } } \ No newline at end of file diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/logistic_regression/LogisticRegressionModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/logistic_regression/LogisticRegressionModel.scala index 1190b5a4..c1dc8d7d 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/logistic_regression/LogisticRegressionModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/logistic_regression/LogisticRegressionModel.scala @@ -442,7 +442,7 @@ case class LogisticRegressionModel private[logistic_regression] (observationColu try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[LogisticRegressionModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[LogisticRegressionModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/naive_bayes/NaiveBayesModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/naive_bayes/NaiveBayesModel.scala index a9d1b8e6..6d51d043 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/naive_bayes/NaiveBayesModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/naive_bayes/NaiveBayesModel.scala @@ -217,7 +217,7 @@ case class NaiveBayesModel private[naive_bayes] (sparkModel: SparkNaiveBayesMode try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[NaiveBayesModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[NaiveBayesModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/random_forest_classifier/RandomForestClassifierModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/random_forest_classifier/RandomForestClassifierModel.scala index 42283436..24c8940c 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/random_forest_classifier/RandomForestClassifierModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/random_forest_classifier/RandomForestClassifierModel.scala @@ -348,7 +348,7 @@ case class RandomForestClassifierModel private[random_forest_classifier] (sparkM try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString, overwrite = true) - ScoringModelUtils.saveToMar(marSavePath, classOf[RandomForestClassifierModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[RandomForestClassifierModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/svm/SvmModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/svm/SvmModel.scala index a1a65ae3..c98cc761 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/svm/SvmModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/classification/svm/SvmModel.scala @@ -281,7 +281,7 @@ case class SvmModel private[svm] (sparkModel: SparkSvmModel, try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[SvmModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[SvmModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/gmm/GaussianMixtureModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/gmm/GaussianMixtureModel.scala index 4a6f2dc7..a4288733 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/gmm/GaussianMixtureModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/gmm/GaussianMixtureModel.scala @@ -244,7 +244,7 @@ case class GaussianMixtureModel private[gmm] (observationColumns: Seq[String], try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[GaussianMixtureModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[GaussianMixtureModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/kmeans/KMeansModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/kmeans/KMeansModel.scala index d677b433..855a65a2 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/kmeans/KMeansModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/kmeans/KMeansModel.scala @@ -312,7 +312,7 @@ case class KMeansModel private[kmeans] (columns: Seq[String], try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[KMeansModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[KMeansModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/lda/LdaModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/lda/LdaModel.scala index d8ca50fd..5a377b2d 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/lda/LdaModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/clustering/lda/LdaModel.scala @@ -296,7 +296,7 @@ case class LdaModel private[lda] (documentColumnName: String, try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[LdaModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[LdaModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/dimreduction/pca/PcaModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/dimreduction/pca/PcaModel.scala index 2c01520e..4281c931 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/dimreduction/pca/PcaModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/dimreduction/pca/PcaModel.scala @@ -211,7 +211,7 @@ case class PcaModel private[pca] (columns: Seq[String], try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[PcaModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[PcaModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/recommendation/collaborative_filtering/CollaborativeFilteringModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/recommendation/collaborative_filtering/CollaborativeFilteringModel.scala index c8c85844..3fa252ce 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/recommendation/collaborative_filtering/CollaborativeFilteringModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/recommendation/collaborative_filtering/CollaborativeFilteringModel.scala @@ -319,7 +319,7 @@ case class CollaborativeFilteringModel(sourceColumnName: String, try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[CollaborativeFilteringModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[CollaborativeFilteringModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/regression/linear_regression/LinearRegressionModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/regression/linear_regression/LinearRegressionModel.scala index 608d8461..344434af 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/regression/linear_regression/LinearRegressionModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/regression/linear_regression/LinearRegressionModel.scala @@ -291,7 +291,7 @@ case class LinearRegressionModel(observationColumns: Seq[String], // The spark linear regression model save will fail, if we don't specify the "overwrite", since the temp // directory has already been created. save(sc, tmpDir.toString, overwrite = true) - ScoringModelUtils.saveToMar(marSavePath, classOf[LinearRegressionModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[LinearRegressionModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/regression/random_forest_regressor/RandomForestRegressorModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/regression/random_forest_regressor/RandomForestRegressorModel.scala index cc1da109..cfb0c0a0 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/regression/random_forest_regressor/RandomForestRegressorModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/regression/random_forest_regressor/RandomForestRegressorModel.scala @@ -341,7 +341,7 @@ case class RandomForestRegressorModel private[random_forest_regressor] (sparkMod try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString, overwrite = true) - ScoringModelUtils.saveToMar(marSavePath, classOf[RandomForestRegressorModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[RandomForestRegressorModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/survivalanalysis/cox_ph/CoxProportionalHazardsModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/survivalanalysis/cox_ph/CoxProportionalHazardsModel.scala index d73a648c..cc145d55 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/survivalanalysis/cox_ph/CoxProportionalHazardsModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/survivalanalysis/cox_ph/CoxProportionalHazardsModel.scala @@ -239,7 +239,7 @@ case class CoxProportionalHazardsModel private[cox_ph] (sparkModel: CoxPhModel, try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString, overwrite = true) - ScoringModelUtils.saveToMar(marSavePath, classOf[CoxProportionalHazardsModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[CoxProportionalHazardsModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arima/ArimaModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arima/ArimaModel.scala index f44d29ce..c464712e 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arima/ArimaModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arima/ArimaModel.scala @@ -267,7 +267,7 @@ case class ArimaModel private[arima] (ts: DenseVector, try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[ArimaModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[ArimaModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arimax/ArimaxModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arimax/ArimaxModel.scala index 331cae32..e63eb390 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arimax/ArimaxModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arimax/ArimaxModel.scala @@ -255,7 +255,7 @@ case class ArimaxModel private[arimax] (timeseriesColumn: String, try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[ArimaxModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[ArimaxModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arx/ArxModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arx/ArxModel.scala index d81d13c6..e5596155 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arx/ArxModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/arx/ArxModel.scala @@ -234,7 +234,7 @@ case class ArxModel private[arx] (timeseriesColumn: String, try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[ArxModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[ArxModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/max/MaxModel.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/max/MaxModel.scala index 7338646e..0342d5ec 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/max/MaxModel.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/timeseries/max/MaxModel.scala @@ -250,7 +250,7 @@ case class MaxModel private[max] (timeseriesColumn: String, try { tmpDir = Files.createTempDirectory("sparktk-scoring-model") save(sc, tmpDir.toString) - ScoringModelUtils.saveToMar(marSavePath, classOf[MaxModel].getName, tmpDir) + ScoringModelUtils.saveToMar(sc, marSavePath, classOf[MaxModel].getName, tmpDir) } finally { sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/saveload/SaveLoad.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/saveload/SaveLoad.scala index 2ce2c03f..f225a926 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/saveload/SaveLoad.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/saveload/SaveLoad.scala @@ -16,6 +16,7 @@ package org.trustedanalytics.sparktk.saveload import java.io.File +import java.nio.file.Files import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.Path import java.net.URI @@ -28,6 +29,7 @@ import org.json4s.jackson.Serialization import org.json4s.{ NoTypeHints, Extraction, DefaultFormats } import org.json4s.jackson.JsonMethods._ import org.json4s.JsonDSL._ +import org.trustedanalytics.sparktk.models.ScoringModelUtils /** * Simple save/load library which uses json4s to read/write text files, including info for format validation @@ -56,21 +58,24 @@ object SaveLoad { * @param zipFile the MAR file to be stored * @return full path to the location of the MAR file */ - def saveMar(storagePath: String, zipFile: File): String = { - if (storagePath.startsWith("hdfs")) { + def saveMar(sc: SparkContext, storagePath: String, zipFile: File): String = { + + val protocol = ScoringModelUtils.getProtocol(storagePath) + + if ("file".equalsIgnoreCase(protocol)) { + print("Local") + val file = new File(storagePath) + FileUtils.copyFile(zipFile, file) + file.getCanonicalPath + } + else { val hdfsPath = new Path(storagePath) - val hdfsFileSystem: org.apache.hadoop.fs.FileSystem = org.apache.hadoop.fs.FileSystem.get(new URI(storagePath), new Configuration()) + val hdfsFileSystem: org.apache.hadoop.fs.FileSystem = org.apache.hadoop.fs.FileSystem.get(new URI(storagePath), sc.hadoopConfiguration) val localPath = new Path(zipFile.getAbsolutePath) hdfsFileSystem.copyFromLocalFile(false, true, localPath, hdfsPath) hdfsFileSystem.setPermission(hdfsPath, new FsPermission(FsAction.ALL, FsAction.ALL, FsAction.NONE)) storagePath } - else { - val file = new File(storagePath) - FileUtils.copyFile(zipFile, file) - file.getCanonicalPath - } - } /** From 87ec88155ea1038386122f5059ab12d6c983b45a Mon Sep 17 00:00:00 2001 From: Anjali Sood Date: Tue, 7 Feb 2017 11:01:43 -0800 Subject: [PATCH 2/2] updated with code review feedback --- .../sparktk/models/ScoringModelUtils.scala | 58 +------------------ .../sparktk/saveload/SaveLoad.scala | 40 ++++++++++++- 2 files changed, 38 insertions(+), 60 deletions(-) diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/ScoringModelUtils.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/ScoringModelUtils.scala index 9706ee13..52577d6b 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/ScoringModelUtils.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/models/ScoringModelUtils.scala @@ -116,14 +116,12 @@ object ScoringModelUtils { val x = new TkSearchPath(absolutePath.substring(0, absolutePath.lastIndexOf("/"))) var jarFileList = x.jarsInSearchPath.values.toList - val protocol = getProtocol(marSavePath) + val protocol = SaveLoad.getProtocol(marSavePath) if ("file".equalsIgnoreCase(protocol)) { - print("Local") jarFileList = jarFileList ::: List(new File(modelSrcDir.toString)) } else { - print("not local") val modelFile = Files.createTempDirectory("localModel") val localModelPath = new org.apache.hadoop.fs.Path(modelFile.toString) val hdfsFileSystem: org.apache.hadoop.fs.FileSystem = org.apache.hadoop.fs.FileSystem.get(new URI(modelFile.toString), new Configuration()) @@ -139,58 +137,4 @@ object ScoringModelUtils { IOUtils.closeQuietly(zipOutStream) } } - - /** - * Returns the protocol for a given URI or filename. - * - * @param source Determine the protocol for this URI or filename. - * - * @return The protocol for the given source. - */ - def getProtocol(source: String): String = { - require(source != null, "marfile source must not be null") - - var protocol: String = null - try { - val uri = new URI(source) - - if (uri.isAbsolute) { - protocol = uri.getScheme - } - else { - val url = new URL(source) - protocol = url.getProtocol - } - - } - catch { - case ex: Exception => - if (source.startsWith("//")) { - throw new IllegalArgumentException("Relative context: " + source) - } - else { - val file = new File(source) - protocol = getProtocol(file) - } - } - protocol - } - - /** - * Returns the protocol for a given file. - * - * @param file Determine the protocol for this file. - * - * @return The protocol for the given file. - */ - private def getProtocol(file: File): String = { - var result: String = null - try { - result = file.toURI.toURL.getProtocol - } - catch { - case ex: Exception => result = "unknown" - } - result - } } \ No newline at end of file diff --git a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/saveload/SaveLoad.scala b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/saveload/SaveLoad.scala index f225a926..ee6b2767 100644 --- a/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/saveload/SaveLoad.scala +++ b/sparktk-core/src/main/scala/org/trustedanalytics/sparktk/saveload/SaveLoad.scala @@ -19,7 +19,7 @@ import java.io.File import java.nio.file.Files import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.Path -import java.net.URI +import java.net.{ URL, URI } import org.apache.hadoop.fs.permission.{ FsPermission, FsAction } import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext @@ -60,10 +60,9 @@ object SaveLoad { */ def saveMar(sc: SparkContext, storagePath: String, zipFile: File): String = { - val protocol = ScoringModelUtils.getProtocol(storagePath) + val protocol = getProtocol(storagePath) if ("file".equalsIgnoreCase(protocol)) { - print("Local") val file = new File(storagePath) FileUtils.copyFile(zipFile, file) file.getCanonicalPath @@ -78,6 +77,41 @@ object SaveLoad { } } + /** + * Returns the protocol for a given URI or filename. + * + * @param source Determine the protocol for this URI or filename. + * + * @return The protocol for the given source. + */ + def getProtocol(source: String): String = { + require(source != null && !source.isEmpty, "marfile source must not be null") + + val protocol: String = try { + val uri = new URI(source) + + if (uri.isAbsolute) { + uri.getScheme + } + else { + val url = new URL(source) + url.getProtocol + } + + } + catch { + case ex: Exception => + if (source.startsWith("//")) { + throw new IllegalArgumentException("Does not support Relative context starting with // : " + source) + } + else { + val file = new File(source) + file.toURI.toURL.getProtocol + } + } + protocol + } + /** * Loads data from a file into a json4s JValue and provides format identifier and version * @param sc active spark context