diff --git a/README.md b/README.md index a9c2f43..a9f5863 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ This library requires following options: * `delimiter`: (Optional) Set the field delimiter. Applicable only for csv fileType. Default is comma. * `quote`: (Optional) Set the quote character. Applicable only for csv fileType. Default is ". * `escape`: (Optional) Set the escape character. Applicable only for csv fileType. Default is \. +* `encoding`: (Optional) Set the file encoding. Applicable only for the csv fileType. Default is UTF-8. * `multiLine`: (Optional) Set the multiline. Applicable only for csv fileType. Default is false. * `codec`: (Optional) Applicable only for csv fileType. Compression codec to use when saving to file. Should be the fully qualified name of a class implementing org.apache.hadoop.io.compress.CompressionCodec or one of case-insensitive shorten names (bzip2, gzip, lz4, and snappy). Defaults to no compression when a codec is not specified. diff --git a/src/main/scala/com/springml/spark/sftp/DatasetRelation.scala b/src/main/scala/com/springml/spark/sftp/DatasetRelation.scala index 60b341c..65c75b9 100644 --- a/src/main/scala/com/springml/spark/sftp/DatasetRelation.scala +++ b/src/main/scala/com/springml/spark/sftp/DatasetRelation.scala @@ -18,6 +18,7 @@ case class DatasetRelation( delimiter: String, quote: String, escape: String, + encoding: String, multiLine: String, rowTag: String, customSchema: StructType, @@ -46,6 +47,7 @@ case class DatasetRelation( option("delimiter", delimiter). option("quote", quote). option("escape", escape). + option("encoding", encoding). option("multiLine", multiLine). option("inferSchema", inferSchema). csv(fileLocation) diff --git a/src/main/scala/com/springml/spark/sftp/DefaultSource.scala b/src/main/scala/com/springml/spark/sftp/DefaultSource.scala index a62e57a..b2133fd 100644 --- a/src/main/scala/com/springml/spark/sftp/DefaultSource.scala +++ b/src/main/scala/com/springml/spark/sftp/DefaultSource.scala @@ -57,6 +57,7 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr val header = parameters.getOrElse("header", "true") val delimiter = parameters.getOrElse("delimiter", ",") val quote = parameters.getOrElse("quote", "\"") + val encoding = parameters.getOrElse("encoding", "UTF-8") val escape = parameters.getOrElse("escape", "\\") val multiLine = parameters.getOrElse("multiLine", "false") val createDF = parameters.getOrElse("createDF", "true") @@ -87,7 +88,7 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr logger.info("Returning an empty dataframe after copying files...") createReturnRelation(sqlContext, schema) } else { - DatasetRelation(fileLocation, fileType, inferSchemaFlag, header, delimiter, quote, escape, multiLine, rowTag, schema, + DatasetRelation(fileLocation, fileType, inferSchemaFlag, header, delimiter, quote, escape, encoding, multiLine, rowTag, schema, sqlContext) } } @@ -115,6 +116,7 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr val delimiter = parameters.getOrElse("delimiter", ",") val quote = parameters.getOrElse("quote", "\"") val escape = parameters.getOrElse("escape", "\\") + val encoding = parameters.getOrElse("encoding", "UTF-8") val multiLine = parameters.getOrElse("multiLine", "false") val codec = parameters.getOrElse("codec", null) val rowTag = parameters.getOrElse(constants.xmlRowTag, null) @@ -127,7 +129,7 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr val sftpClient = getSFTPClient(username, password, pemFileLocation, pemPassphrase, host, port, cryptoKey, cryptoAlgorithm) - val tempFile = writeToTemp(sqlContext, data, hdfsTemp, tmpFolder, fileType, header, delimiter, quote, escape, multiLine, codec, rowTag, rootTag) + val tempFile = writeToTemp(sqlContext, data, hdfsTemp, tmpFolder, fileType, header, delimiter, quote, escape, encoding, multiLine, codec, rowTag, rootTag) upload(tempFile, path, sftpClient) return createReturnRelation(data) @@ -234,7 +236,7 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr private def writeToTemp(sqlContext: SQLContext, df: DataFrame, hdfsTemp: String, tempFolder: String, fileType: String, header: String, - delimiter: String, quote: String, escape: String, multiLine: String, codec: String, rowTag: String, rootTag: String) : String = { + delimiter: String, quote: String, escape: String, encoding: String, multiLine: String, codec: String, rowTag: String, rootTag: String) : String = { val randomSuffix = "spark_sftp_connection_temp_" + UUID.randomUUID val hdfsTempLocation = hdfsTemp + File.separator + randomSuffix val localTempLocation = tempFolder + File.separator + randomSuffix @@ -252,6 +254,7 @@ class DefaultSource extends RelationProvider with SchemaRelationProvider with Cr option("delimiter", delimiter). option("quote", quote). option("escape", escape). + option("encoding", encoding). option("multiLine", multiLine). optionNoNull("codec", Option(codec)). csv(hdfsTempLocation) diff --git a/src/test/resources/sample_utf-32be.csv b/src/test/resources/sample_utf-32be.csv new file mode 100644 index 0000000..9bc2a06 Binary files /dev/null and b/src/test/resources/sample_utf-32be.csv differ diff --git a/src/test/scala/com/springml/spark/sftp/CustomSchemaTest.scala b/src/test/scala/com/springml/spark/sftp/CustomSchemaTest.scala index 0a190e5..f0d01a5 100644 --- a/src/test/scala/com/springml/spark/sftp/CustomSchemaTest.scala +++ b/src/test/scala/com/springml/spark/sftp/CustomSchemaTest.scala @@ -43,7 +43,7 @@ class CustomSchemaTest extends FunSuite with BeforeAndAfterEach { val expectedSchema = StructType(columnStruct) val fileLocation = getClass.getResource("/sample.csv").getPath - val dsr = DatasetRelation(fileLocation, "csv", "false", "true", ",", "\"", "\\", "false", null, expectedSchema, ss.sqlContext) + val dsr = DatasetRelation(fileLocation, "csv", "false", "true", ",", "\"", "\\", "UTF-8", "false", null, expectedSchema, ss.sqlContext) val rdd = dsr.buildScan() assert(dsr.schema.fields.length == columnStruct.length) @@ -55,7 +55,7 @@ class CustomSchemaTest extends FunSuite with BeforeAndAfterEach { val expectedSchema = StructType(columnStruct) val fileLocation = getClass.getResource("/people.json").getPath - val dsr = DatasetRelation(fileLocation, "json", "false", "true", ",", "\"", "\\", "false", null, expectedSchema, ss.sqlContext) + val dsr = DatasetRelation(fileLocation, "json", "false", "true", ",", "\"", "\\", "UTF-8", "false", null, expectedSchema, ss.sqlContext) val rdd = dsr.buildScan() assert(dsr.schema.fields.length == columnStruct.length) diff --git a/src/test/scala/com/springml/spark/sftp/TestDatasetRelation.scala b/src/test/scala/com/springml/spark/sftp/TestDatasetRelation.scala index fb7527a..d50d56e 100644 --- a/src/test/scala/com/springml/spark/sftp/TestDatasetRelation.scala +++ b/src/test/scala/com/springml/spark/sftp/TestDatasetRelation.scala @@ -1,6 +1,7 @@ package com.springml.spark.sftp import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types._ import org.scalatest.{BeforeAndAfterEach, FunSuite} /** @@ -15,63 +16,81 @@ class TestDatasetRelation extends FunSuite with BeforeAndAfterEach { test ("Read CSV") { val fileLocation = getClass.getResource("/sample.csv").getPath - val dsr = DatasetRelation(fileLocation, "csv", "false", "true", ",", "\"", "\\", "false", null, null, ss.sqlContext) + val dsr = DatasetRelation(fileLocation, "csv", "false", "true", ",", "\"", "\\", "UTF-8","false", null, null, ss.sqlContext) val rdd = dsr.buildScan() assert(3 == rdd.count()) } test ("Read CSV using custom delimiter") { val fileLocation = getClass.getResource("/sample.csv").getPath - val dsr = DatasetRelation(fileLocation, "csv", "false", "true", ";", "\"", "\\", "false", null, null, ss.sqlContext) + val dsr = DatasetRelation(fileLocation, "csv", "false", "true", ";", "\"", "\\", "UTF-8", "false", null, null, ss.sqlContext) val rdd = dsr.buildScan() assert(3 == rdd.count()) } test ("Read multiline CSV using custom quote and escape") { val fileLocation = getClass.getResource("/sample_quoted_multiline.csv").getPath - val dsr = DatasetRelation(fileLocation, "csv", "false", "true", ",", "\"", "\\", "true", null, null, ss.sqlContext) + val dsr = DatasetRelation(fileLocation, "csv", "false", "true", ",", "\"", "\\", "UTF-8", "true", null, null, ss.sqlContext) val rdd = dsr.buildScan() assert(3 == rdd.count()) } + test ("Read CSV encoded as UTF-32be") { + val schema = StructType( Array( + StructField("ProposalId", StringType,true), + StructField("OpportunityId", StringType,true), + StructField("Clicks", StringType,true), + StructField("Impressions", StringType,true), + StructField("Currency", StringType,true) + )) + + val fileLocation = getClass.getResource("/sample_utf-32be.csv").getPath + val dsr = DatasetRelation(fileLocation, "csv", "false", "true", ",", "\"", "\\", "UTF-32BE", "true", null, null, ss.sqlContext) + val rdd = dsr.buildScan() + val df = ss.createDataFrame(rdd, schema) + assert(5 == df.columns.size) + assert("£" == df.head.getString(4)) + assert(3 == rdd.count()) + } + test ("Read JSON") { val fileLocation = getClass.getResource("/people.json").getPath - val dsr = DatasetRelation(fileLocation, "json", "false", "true", ",", "\"", "\\", "false", null, null, ss.sqlContext) + val dsr = DatasetRelation(fileLocation, "json", "false", "true", ",", "\"", "\\", "UTF-8", "false", null, null, ss.sqlContext) val rdd = dsr.buildScan() assert(3 == rdd.count()) } test ("Read AVRO") { val fileLocation = getClass.getResource("/users.avro").getPath - val dsr = DatasetRelation(fileLocation, "avro", "false", "true", ",", "\"", "\\", "false", null, null, ss.sqlContext) + val dsr = DatasetRelation(fileLocation, "avro", "false", "true", ",", "\"", "\\", "UTF-8", "false", null, null, ss.sqlContext) val rdd = dsr.buildScan() assert(2 == rdd.count()) } test ("Read parquet") { val fileLocation = getClass.getResource("/users.parquet").getPath - val dsr = DatasetRelation(fileLocation, "parquet", "false", "true", ",", "\"", "\\", "false", null, null, ss.sqlContext) + val dsr = DatasetRelation(fileLocation, "parquet", "false", "true", ",", "\"", "\\", "UTF-8", "false", null, null, ss.sqlContext) val rdd = dsr.buildScan() assert(2 == rdd.count()) } test ("Read text file") { val fileLocation = getClass.getResource("/plaintext.txt").getPath - val dsr = DatasetRelation(fileLocation, "txt", "false", "true", ",", "\"", "\\", "false", null, null, ss.sqlContext) + val dsr = DatasetRelation(fileLocation, "txt", "false", "true", ",", "\"", "\\", "UTF-8", "false", null, null, ss.sqlContext) val rdd = dsr.buildScan() assert(3 == rdd.count()) } test ("Read xml file") { val fileLocation = getClass.getResource("/books.xml").getPath - val dsr = DatasetRelation(fileLocation, "xml", "false", "true", ",", "\"", "\\", "false", "book", null, ss.sqlContext) + val dsr = DatasetRelation(fileLocation, "xml", "false", "true", ",", "\"", "\\", "UTF-8", "false", "book", null, ss.sqlContext) val rdd = dsr.buildScan() assert(12 == rdd.count()) } test ("Read orc file") { val fileLocation = getClass.getResource("/books.orc").getPath - val dsr = DatasetRelation(fileLocation, "orc", "false", "true", ",", "\"", "\\", "false", "book", null, ss.sqlContext) + val dsr = DatasetRelation(fileLocation, "orc", "false", "true", ",", "\"", "\\", "UTF-8", "false", "book", null, ss.sqlContext) val rdd = dsr.buildScan() assert(12 == rdd.count()) }