diff --git a/src/main/scala/com/springml/spark/sftp/DefaultSource.scala b/src/main/scala/com/springml/spark/sftp/DefaultSource.scala index a62e57a..8567b6e 100644 --- a/src/main/scala/com/springml/spark/sftp/DefaultSource.scala +++ b/src/main/scala/com/springml/spark/sftp/DefaultSource.scala @@ -98,27 +98,30 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr parameters: Map[String, String], data: DataFrame): BaseRelation = { + val username = parameters.get("username") val password = parameters.get("password") val pemFileLocation = parameters.get("pem") - val pemPassphrase = parameters.get("pemPassphrase") + val pemPassphrase = parameters.get("pempassphrase") val host = parameters.getOrElse("host", sys.error("SFTP Host has to be provided using 'host' option")) val port = parameters.get("port") val path = parameters.getOrElse("path", sys.error("'path' must be specified")) - val fileType = parameters.getOrElse("fileType", sys.error("File type has to be provided using 'fileType' option")) + val fileType = parameters.getOrElse("filetype", sys.error("File type has to be provided using 'fileType' option")) val header = parameters.getOrElse("header", "true") - val copyLatest = parameters.getOrElse("copyLatest", "false") - val tmpFolder = parameters.getOrElse("tempLocation", System.getProperty("java.io.tmpdir")) - val hdfsTemp = parameters.getOrElse("hdfsTempLocation", tmpFolder) - val cryptoKey = parameters.getOrElse("cryptoKey", null) - val cryptoAlgorithm = parameters.getOrElse("cryptoAlgorithm", "AES") + val copyLatest = parameters.getOrElse("copylatest", "false") + val tmpFolder = parameters.getOrElse("templocation", System.getProperty("java.io.tmpdir")) + val hdfsTemp = parameters.getOrElse("hdfstemplocation", tmpFolder) + val cryptoKey = parameters.getOrElse("cryptokey", null) + val cryptoAlgorithm = parameters.getOrElse("cryptoalgorithm", "AES") val delimiter = parameters.getOrElse("delimiter", ",") val quote = parameters.getOrElse("quote", "\"") val escape = parameters.getOrElse("escape", "\\") - val multiLine = parameters.getOrElse("multiLine", "false") + val multiLine = parameters.getOrElse("multiline", "false") val codec = parameters.getOrElse("codec", null) val rowTag = parameters.getOrElse(constants.xmlRowTag, null) val rootTag = parameters.getOrElse(constants.xmlRootTag, null) + val azureMountPoint = parameters.get("azuremountpoint") + val gen = parameters.get("gen") val supportedFileTypes = List("csv", "json", "avro", "parquet", "txt", "xml","orc") if (!supportedFileTypes.contains(fileType)) { @@ -127,7 +130,8 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr val sftpClient = getSFTPClient(username, password, pemFileLocation, pemPassphrase, host, port, cryptoKey, cryptoAlgorithm) - val tempFile = writeToTemp(sqlContext, data, hdfsTemp, tmpFolder, fileType, header, delimiter, quote, escape, multiLine, codec, rowTag, rootTag) + val tempFile = writeToTemp(sqlContext, data, hdfsTemp, tmpFolder, fileType, header, delimiter, quote, escape, + multiLine, codec, rowTag, rootTag, azureMountPoint,gen.getOrElse("gen1")) upload(tempFile, path, sftpClient) return createReturnRelation(data) @@ -234,11 +238,12 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr private def writeToTemp(sqlContext: SQLContext, df: DataFrame, hdfsTemp: String, tempFolder: String, fileType: String, header: String, - delimiter: String, quote: String, escape: String, multiLine: String, codec: String, rowTag: String, rootTag: String) : String = { + delimiter: String, quote: String, escape: String, multiLine: String, + codec: String, rowTag: String, rootTag: String, + azureMountPoint: Option[String], gen: String) : String = { val randomSuffix = "spark_sftp_connection_temp_" + UUID.randomUUID - val hdfsTempLocation = hdfsTemp + File.separator + randomSuffix - val localTempLocation = tempFolder + File.separator + randomSuffix - + val hdfsTempLocation = hdfsTemp + File.separator + randomSuffix + File.separator + val localTempLocation = tempFolder + File.separator + randomSuffix + File.separator addShutdownHook(localTempLocation) fileType match { @@ -261,7 +266,7 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr } copyFromHdfs(sqlContext, hdfsTempLocation, localTempLocation) - copiedFile(localTempLocation) + copiedFile(localTempLocation, azureMountPoint, gen) } private def addShutdownHook(tempLocation: String) { @@ -270,8 +275,15 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr Runtime.getRuntime.addShutdownHook(hook) } - private def copiedFile(tempFileLocation: String) : String = { - val baseTemp = new File(tempFileLocation) + private def copiedFile(tempFileLocation: String, azureMountPoint: Option[String], gen: String) : String = { + val tempFile = azureMountPoint.map{ + mountPoint => + gen match{ + case "gen1" => mountPoint+File.separator+tempFileLocation.drop(tempFileLocation.indexOf("azuredatalakestore.net")+22) + case "gen2" => mountPoint+ tempFileLocation.substring(tempFileLocation.substring(0,tempFileLocation.lastIndexOf("/")).lastIndexOf("/")) + } + }.getOrElse(tempFileLocation) + val baseTemp = new File(tempFile) val files = baseTemp.listFiles().filter { x => (!x.isDirectory() && !x.getName.contains("SUCCESS") @@ -281,6 +293,7 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr && !x.getName.contains("_started_") ) } - files(0).getAbsolutePath + println(files) + files.head.getAbsolutePath } }