diff --git a/README.md b/README.md index a9c2f43..b02c78b 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,19 @@ df.write. option("delimiter", ";"). option("codec", "bzip2"). save("/ftp/files/sample.csv") + +// Configuration to write input DataFrame into n-part files +spark.conf.set("spark.sftp.coalesce.partitions", 4) + +df.write. + format("com.springml.spark.sftp"). + option("host", "SFTP_HOST"). + option("username", "SFTP_USER"). + option("password", "****"). + option("fileType", "csv"). + option("delimiter", ";"). + save("/ftp/files/") + // Construct spark dataframe using text file in FTP server diff --git a/src/main/scala/com/springml/spark/sftp/DefaultSource.scala b/src/main/scala/com/springml/spark/sftp/DefaultSource.scala index a62e57a..47a284f 100644 --- a/src/main/scala/com/springml/spark/sftp/DefaultSource.scala +++ b/src/main/scala/com/springml/spark/sftp/DefaultSource.scala @@ -20,7 +20,6 @@ import java.util.UUID import com.springml.sftp.client.SFTPClient import com.springml.spark.sftp.util.Utils.ImplicitDataFrameWriter - import org.apache.commons.io.FilenameUtils import org.apache.hadoop.fs.Path import org.apache.log4j.Logger @@ -238,15 +237,17 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr val randomSuffix = "spark_sftp_connection_temp_" + UUID.randomUUID val hdfsTempLocation = hdfsTemp + File.separator + randomSuffix val localTempLocation = tempFolder + File.separator + randomSuffix + val numPartitions = sqlContext.getConf(constants.coalescePartitionsConfKey, "1").toInt + logger.info(s"Applying coalesce with numPartitions=$numPartitions on the input DataFrame!!") addShutdownHook(localTempLocation) fileType match { - case "xml" => df.coalesce(1).write.format(constants.xmlClass) + case "xml" => df.coalesce(numPartitions).write.format(constants.xmlClass) .option(constants.xmlRowTag, rowTag) .option(constants.xmlRootTag, rootTag).save(hdfsTempLocation) - case "csv" => df.coalesce(1). + case "csv" => df.coalesce(numPartitions). write. option("header", header). option("delimiter", delimiter). @@ -255,9 +256,9 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr option("multiLine", multiLine). optionNoNull("codec", Option(codec)). csv(hdfsTempLocation) - case "txt" => df.coalesce(1).write.text(hdfsTempLocation) - case "avro" => df.coalesce(1).write.format("com.databricks.spark.avro").save(hdfsTempLocation) - case _ => df.coalesce(1).write.format(fileType).save(hdfsTempLocation) + case "txt" => df.coalesce(numPartitions).write.text(hdfsTempLocation) + case "avro" => df.coalesce(numPartitions).write.format("com.databricks.spark.avro").save(hdfsTempLocation) + case _ => df.coalesce(numPartitions).write.format(fileType).save(hdfsTempLocation) } copyFromHdfs(sqlContext, hdfsTempLocation, localTempLocation) @@ -281,6 +282,16 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr && !x.getName.contains("_started_") ) } - files(0).getAbsolutePath + if (files.length > 1) { + val partsDir = new File(baseTemp, "parts") + partsDir.mkdir() + val ext: String = files.head.getName.split("\\.", 2)(1) + for (i <- files.indices) { + files(i).renameTo(new File(partsDir, s"part-$i.$ext")) + } + partsDir.getAbsolutePath + } else { + files(0).getAbsolutePath + } } } diff --git a/src/main/scala/com/springml/spark/sftp/constants.scala b/src/main/scala/com/springml/spark/sftp/constants.scala index 262a64d..cc9c0ff 100644 --- a/src/main/scala/com/springml/spark/sftp/constants.scala +++ b/src/main/scala/com/springml/spark/sftp/constants.scala @@ -8,5 +8,6 @@ object constants { val xmlClass: String = "com.databricks.spark.xml" val xmlRowTag: String = "rowTag" val xmlRootTag: String = "rootTag" + val coalescePartitionsConfKey: String = "spark.sftp.coalesce.partitions" }