Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions src/main/scala/com/springml/spark/sftp/DefaultSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,27 +98,30 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr
parameters: Map[String, String],
data: DataFrame): BaseRelation = {


val username = parameters.get("username")
val password = parameters.get("password")
val pemFileLocation = parameters.get("pem")
val pemPassphrase = parameters.get("pemPassphrase")
val pemPassphrase = parameters.get("pempassphrase")
val host = parameters.getOrElse("host", sys.error("SFTP Host has to be provided using 'host' option"))
val port = parameters.get("port")
val path = parameters.getOrElse("path", sys.error("'path' must be specified"))
val fileType = parameters.getOrElse("fileType", sys.error("File type has to be provided using 'fileType' option"))
val fileType = parameters.getOrElse("filetype", sys.error("File type has to be provided using 'fileType' option"))
val header = parameters.getOrElse("header", "true")
val copyLatest = parameters.getOrElse("copyLatest", "false")
val tmpFolder = parameters.getOrElse("tempLocation", System.getProperty("java.io.tmpdir"))
val hdfsTemp = parameters.getOrElse("hdfsTempLocation", tmpFolder)
val cryptoKey = parameters.getOrElse("cryptoKey", null)
val cryptoAlgorithm = parameters.getOrElse("cryptoAlgorithm", "AES")
val copyLatest = parameters.getOrElse("copylatest", "false")
val tmpFolder = parameters.getOrElse("templocation", System.getProperty("java.io.tmpdir"))
val hdfsTemp = parameters.getOrElse("hdfstemplocation", tmpFolder)
val cryptoKey = parameters.getOrElse("cryptokey", null)
val cryptoAlgorithm = parameters.getOrElse("cryptoalgorithm", "AES")
val delimiter = parameters.getOrElse("delimiter", ",")
val quote = parameters.getOrElse("quote", "\"")
val escape = parameters.getOrElse("escape", "\\")
val multiLine = parameters.getOrElse("multiLine", "false")
val multiLine = parameters.getOrElse("multiline", "false")
val codec = parameters.getOrElse("codec", null)
val rowTag = parameters.getOrElse(constants.xmlRowTag, null)
val rootTag = parameters.getOrElse(constants.xmlRootTag, null)
val azureMountPoint = parameters.get("azuremountpoint")
val gen = parameters.get("gen")

val supportedFileTypes = List("csv", "json", "avro", "parquet", "txt", "xml","orc")
if (!supportedFileTypes.contains(fileType)) {
Expand All @@ -127,7 +130,8 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr

val sftpClient = getSFTPClient(username, password, pemFileLocation, pemPassphrase, host, port,
cryptoKey, cryptoAlgorithm)
val tempFile = writeToTemp(sqlContext, data, hdfsTemp, tmpFolder, fileType, header, delimiter, quote, escape, multiLine, codec, rowTag, rootTag)
val tempFile = writeToTemp(sqlContext, data, hdfsTemp, tmpFolder, fileType, header, delimiter, quote, escape,
multiLine, codec, rowTag, rootTag, azureMountPoint,gen.getOrElse("gen1"))

upload(tempFile, path, sftpClient)
return createReturnRelation(data)
Expand Down Expand Up @@ -234,11 +238,12 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr

private def writeToTemp(sqlContext: SQLContext, df: DataFrame,
hdfsTemp: String, tempFolder: String, fileType: String, header: String,
delimiter: String, quote: String, escape: String, multiLine: String, codec: String, rowTag: String, rootTag: String) : String = {
delimiter: String, quote: String, escape: String, multiLine: String,
codec: String, rowTag: String, rootTag: String,
azureMountPoint: Option[String], gen: String) : String = {
val randomSuffix = "spark_sftp_connection_temp_" + UUID.randomUUID
val hdfsTempLocation = hdfsTemp + File.separator + randomSuffix
val localTempLocation = tempFolder + File.separator + randomSuffix

val hdfsTempLocation = hdfsTemp + File.separator + randomSuffix + File.separator
val localTempLocation = tempFolder + File.separator + randomSuffix + File.separator
addShutdownHook(localTempLocation)

fileType match {
Expand All @@ -261,7 +266,7 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr
}

copyFromHdfs(sqlContext, hdfsTempLocation, localTempLocation)
copiedFile(localTempLocation)
copiedFile(localTempLocation, azureMountPoint, gen)
}

private def addShutdownHook(tempLocation: String) {
Expand All @@ -270,8 +275,15 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr
Runtime.getRuntime.addShutdownHook(hook)
}

private def copiedFile(tempFileLocation: String) : String = {
val baseTemp = new File(tempFileLocation)
private def copiedFile(tempFileLocation: String, azureMountPoint: Option[String], gen: String) : String = {
val tempFile = azureMountPoint.map{
mountPoint =>
gen match{
case "gen1" => mountPoint+File.separator+tempFileLocation.drop(tempFileLocation.indexOf("azuredatalakestore.net")+22)
case "gen2" => mountPoint+ tempFileLocation.substring(tempFileLocation.substring(0,tempFileLocation.lastIndexOf("/")).lastIndexOf("/"))
}
}.getOrElse(tempFileLocation)
val baseTemp = new File(tempFile)
val files = baseTemp.listFiles().filter { x =>
(!x.isDirectory()
&& !x.getName.contains("SUCCESS")
Expand All @@ -281,6 +293,7 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr
&& !x.getName.contains("_started_")
)
}
files(0).getAbsolutePath
println(files)
files.head.getAbsolutePath
}
}