Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
package org.trustedanalytics.sparktk.models

import java.io.{ FileOutputStream, File }
import java.net.URI
import java.net.{ URL, URI }
import java.nio.DoubleBuffer
import java.nio.file.{ Files, Path }
import org.apache.commons.lang.StringUtils
import org.apache.hadoop.conf.Configuration
import org.apache.commons.io.{ IOUtils, FileUtils }
import org.apache.spark.SparkContext
import org.trustedanalytics.model.archive.format.ModelArchiveFormat
import org.trustedanalytics.sparktk.saveload.SaveLoad

Expand Down Expand Up @@ -91,7 +92,8 @@ object ScoringModelUtils {
* @param sourcePath Path to source location. Defaults to use the path to the currently running jar.
* @return full path to the location of the MAR file for Scoring Engine
*/
def saveToMar(marSavePath: String,
def saveToMar(sc: SparkContext,
marSavePath: String,
modelClass: String,
modelSrcDir: java.nio.file.Path,
modelReader: String = classOf[SparkTkModelAdapter].getName,
Expand All @@ -114,20 +116,21 @@ object ScoringModelUtils {
val x = new TkSearchPath(absolutePath.substring(0, absolutePath.lastIndexOf("/")))
var jarFileList = x.jarsInSearchPath.values.toList

if (marSavePath.startsWith("hdfs")) {
val protocol = SaveLoad.getProtocol(marSavePath)

if ("file".equalsIgnoreCase(protocol)) {
jarFileList = jarFileList ::: List(new File(modelSrcDir.toString))
}
else {
val modelFile = Files.createTempDirectory("localModel")
val localModelPath = new org.apache.hadoop.fs.Path(modelFile.toString)
val hdfsFileSystem: org.apache.hadoop.fs.FileSystem = org.apache.hadoop.fs.FileSystem.get(new URI(modelFile.toString), new Configuration())
hdfsFileSystem.copyToLocalFile(new org.apache.hadoop.fs.Path(modelSrcDir.toString), localModelPath)
jarFileList = jarFileList ::: List(new File(localModelPath.toString))
}
else {
jarFileList = jarFileList ::: List(new File(modelSrcDir.toString))
}
ModelArchiveFormat.write(jarFileList, modelReader, modelClass, zipOutStream)

}
SaveLoad.saveMar(marSavePath, zipFile)
SaveLoad.saveMar(sc, marSavePath, zipFile)
}
finally {
FileUtils.deleteQuietly(zipFile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ case class LogisticRegressionModel private[logistic_regression] (observationColu
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[LogisticRegressionModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[LogisticRegressionModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ case class NaiveBayesModel private[naive_bayes] (sparkModel: SparkNaiveBayesMode
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[NaiveBayesModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[NaiveBayesModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ case class RandomForestClassifierModel private[random_forest_classifier] (sparkM
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString, overwrite = true)
ScoringModelUtils.saveToMar(marSavePath, classOf[RandomForestClassifierModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[RandomForestClassifierModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ case class SvmModel private[svm] (sparkModel: SparkSvmModel,
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[SvmModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[SvmModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ case class GaussianMixtureModel private[gmm] (observationColumns: Seq[String],
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[GaussianMixtureModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[GaussianMixtureModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ case class KMeansModel private[kmeans] (columns: Seq[String],
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[KMeansModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[KMeansModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ case class LdaModel private[lda] (documentColumnName: String,
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[LdaModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[LdaModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ case class PcaModel private[pca] (columns: Seq[String],
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[PcaModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[PcaModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ case class CollaborativeFilteringModel(sourceColumnName: String,
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[CollaborativeFilteringModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[CollaborativeFilteringModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ case class LinearRegressionModel(observationColumns: Seq[String],
// The spark linear regression model save will fail, if we don't specify the "overwrite", since the temp
// directory has already been created.
save(sc, tmpDir.toString, overwrite = true)
ScoringModelUtils.saveToMar(marSavePath, classOf[LinearRegressionModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[LinearRegressionModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ case class RandomForestRegressorModel private[random_forest_regressor] (sparkMod
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString, overwrite = true)
ScoringModelUtils.saveToMar(marSavePath, classOf[RandomForestRegressorModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[RandomForestRegressorModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ case class CoxProportionalHazardsModel private[cox_ph] (sparkModel: CoxPhModel,
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString, overwrite = true)
ScoringModelUtils.saveToMar(marSavePath, classOf[CoxProportionalHazardsModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[CoxProportionalHazardsModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ case class ArimaModel private[arima] (ts: DenseVector,
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[ArimaModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[ArimaModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ case class ArimaxModel private[arimax] (timeseriesColumn: String,
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[ArimaxModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[ArimaxModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ case class ArxModel private[arx] (timeseriesColumn: String,
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[ArxModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[ArxModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ case class MaxModel private[max] (timeseriesColumn: String,
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[MaxModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[MaxModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
package org.trustedanalytics.sparktk.saveload

import java.io.File
import java.nio.file.Files
import org.apache.commons.io.FileUtils
import org.apache.hadoop.fs.Path
import java.net.URI
import java.net.{ URL, URI }
import org.apache.hadoop.fs.permission.{ FsPermission, FsAction }
import org.apache.hadoop.conf.Configuration
import org.apache.spark.SparkContext
Expand All @@ -28,6 +29,7 @@ import org.json4s.jackson.Serialization
import org.json4s.{ NoTypeHints, Extraction, DefaultFormats }
import org.json4s.jackson.JsonMethods._
import org.json4s.JsonDSL._
import org.trustedanalytics.sparktk.models.ScoringModelUtils

/**
* Simple save/load library which uses json4s to read/write text files, including info for format validation
Expand Down Expand Up @@ -56,21 +58,58 @@ object SaveLoad {
* @param zipFile the MAR file to be stored
* @return full path to the location of the MAR file
*/
def saveMar(storagePath: String, zipFile: File): String = {
if (storagePath.startsWith("hdfs")) {
def saveMar(sc: SparkContext, storagePath: String, zipFile: File): String = {

val protocol = getProtocol(storagePath)

if ("file".equalsIgnoreCase(protocol)) {
val file = new File(storagePath)
FileUtils.copyFile(zipFile, file)
file.getCanonicalPath
}
else {
val hdfsPath = new Path(storagePath)
val hdfsFileSystem: org.apache.hadoop.fs.FileSystem = org.apache.hadoop.fs.FileSystem.get(new URI(storagePath), new Configuration())
val hdfsFileSystem: org.apache.hadoop.fs.FileSystem = org.apache.hadoop.fs.FileSystem.get(new URI(storagePath), sc.hadoopConfiguration)
val localPath = new Path(zipFile.getAbsolutePath)
hdfsFileSystem.copyFromLocalFile(false, true, localPath, hdfsPath)
hdfsFileSystem.setPermission(hdfsPath, new FsPermission(FsAction.ALL, FsAction.ALL, FsAction.NONE))
storagePath
}
else {
val file = new File(storagePath)
FileUtils.copyFile(zipFile, file)
file.getCanonicalPath
}
}

/**
* Returns the protocol for a given URI or filename.
*
* @param source Determine the protocol for this URI or filename.
*
* @return The protocol for the given source.
*/
def getProtocol(source: String): String = {
require(source != null && !source.isEmpty, "marfile source must not be null")

val protocol: String = try {
val uri = new URI(source)

if (uri.isAbsolute) {
uri.getScheme
}
else {
val url = new URL(source)
url.getProtocol
}

}
catch {
case ex: Exception =>
if (source.startsWith("//")) {
throw new IllegalArgumentException("Does not support Relative context starting with // : " + source)
}
else {
val file = new File(source)
file.toURI.toURL.getProtocol
}
}
protocol
}

/**
Expand Down