diff --git a/clients/spark/src/main/scala/io/treeverse/clients/LakeFSContext.scala b/clients/spark/src/main/scala/io/treeverse/clients/LakeFSContext.scala index 43781b77d3a..e028edbe170 100644 --- a/clients/spark/src/main/scala/io/treeverse/clients/LakeFSContext.scala +++ b/clients/spark/src/main/scala/io/treeverse/clients/LakeFSContext.scala @@ -11,7 +11,6 @@ import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.util.SerializableConfiguration import org.slf4j.{Logger, LoggerFactory} -import java.io.File import java.util.concurrent.TimeUnit object LakeFSJobParams { @@ -127,6 +126,11 @@ object LakeFSContext { conf.set(LAKEFS_CONF_JOB_REPO_NAME_KEY, params.repoName) conf.setStrings(LAKEFS_CONF_JOB_COMMIT_IDS_KEY, params.commitIDs.toArray: _*) + val tmpDir = sc.getConf.get("spark.local.dir", null) + if (tmpDir != null) { + conf.set("spark.local.dir", tmpDir) + } + conf.set(LAKEFS_CONF_JOB_STORAGE_NAMESPACE_KEY, params.storageNamespace) if (StringUtils.isBlank(conf.get(LAKEFS_CONF_API_URL_KEY))) { throw new InvalidJobConfException(s"$LAKEFS_CONF_API_URL_KEY must not be empty") @@ -186,7 +190,7 @@ object LakeFSContext { ranges.flatMap((range: Range) => { val path = new Path(apiClient.getRangeURL(repoName, range.id)) val fs = path.getFileSystem(conf) - val localFile = File.createTempFile("lakefs.", ".range") + val localFile = StorageUtils.createTempFile(tmpDir, "lakefs.", ".range") fs.copyToLocalFile(false, path, new Path(localFile.getAbsolutePath), true) val companion = Entry.messageCompanion diff --git a/clients/spark/src/main/scala/io/treeverse/clients/LakeFSInputFormat.scala b/clients/spark/src/main/scala/io/treeverse/clients/LakeFSInputFormat.scala index a48e1d4e7cd..1b0aa320bf8 100644 --- a/clients/spark/src/main/scala/io/treeverse/clients/LakeFSInputFormat.scala +++ b/clients/spark/src/main/scala/io/treeverse/clients/LakeFSInputFormat.scala @@ -17,7 +17,6 @@ import scalapb.GeneratedMessageCompanion import java.io.DataInput import java.io.DataOutput -import java.io.File import java.net.URI import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer @@ -92,7 +91,8 @@ class EntryRecordReader[Proto <: GeneratedMessage with scalapb.Message[Proto]]( var item: Item[Proto] = _ var rangeID: String = "" override def initialize(split: InputSplit, context: TaskAttemptContext): Unit = { - localFile = File.createTempFile("lakefs.", ".range") + val tmpDir = context.getConfiguration.get("spark.local.dir") + localFile = StorageUtils.createTempFile(tmpDir, "lakefs.", ".range") // Cleanup the local file - using the same technic as other data sources: // https://github.com/apache/spark/blob/c0b1735c0bfeb1ff645d146e262d7ccd036a590e/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala#L123 Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => localFile.delete())) @@ -153,6 +153,7 @@ class EntryRecordReader[Proto <: GeneratedMessage with scalapb.Message[Proto]]( object LakeFSInputFormat { val DummyFileName = "dummy" val logger: Logger = LoggerFactory.getLogger(getClass.toString) + def read[Proto <: GeneratedMessage with scalapb.Message[Proto]]( reader: SSTableReader[Proto] ): Seq[Item[Proto]] = reader.newIterator().toSeq diff --git a/clients/spark/src/main/scala/io/treeverse/clients/SSTableReader.scala b/clients/spark/src/main/scala/io/treeverse/clients/SSTableReader.scala index bf80d2821d0..d34a54b5fa9 100644 --- a/clients/spark/src/main/scala/io/treeverse/clients/SSTableReader.scala +++ b/clients/spark/src/main/scala/io/treeverse/clients/SSTableReader.scala @@ -63,7 +63,8 @@ object SSTableReader { private def copyToLocal(configuration: Configuration, url: String) = { val p = new Path(url) val fs = p.getFileSystem(configuration) - val localFile = File.createTempFile("lakefs.", ".sstable") + val tmpDir = configuration.get("spark.local.dir") + val localFile = StorageUtils.createTempFile(tmpDir, "lakefs.", ".sstable") // Cleanup the local file - using the same technic as other data sources: // https://github.com/apache/spark/blob/c0b1735c0bfeb1ff645d146e262d7ccd036a590e/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala#L123 Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => localFile.delete())) diff --git a/clients/spark/src/main/scala/io/treeverse/clients/StorageUtils.scala b/clients/spark/src/main/scala/io/treeverse/clients/StorageUtils.scala index 6afc04e4713..c095bac200c 100644 --- a/clients/spark/src/main/scala/io/treeverse/clients/StorageUtils.scala +++ b/clients/spark/src/main/scala/io/treeverse/clients/StorageUtils.scala @@ -8,7 +8,9 @@ import com.amazonaws.services.s3.model.{Region, GetBucketLocationRequest} import com.amazonaws.services.s3.{AmazonS3, AmazonS3ClientBuilder} import com.amazonaws._ import org.slf4j.{Logger, LoggerFactory} +import org.apache.commons.lang3.StringUtils +import java.io.File import java.net.URI import java.util.concurrent.TimeUnit @@ -161,6 +163,20 @@ object StorageUtils { val GCSMaxBulkSize = 500 // 1000 is the max size, 500 is the recommended size to avoid timeouts or hitting HTTP size limits } + + /** Create a temporary file in the Spark local directory if configured. + * This ensures temporary files are stored in executor storage rather than system temp. + */ + def createTempFile(sparkLocalDir: String, prefix: String, suffix: String): File = { + if (StringUtils.isNotBlank(sparkLocalDir)) { + val dir = new File(sparkLocalDir) + if (dir.exists() || dir.mkdirs()) { + return File.createTempFile(prefix, suffix, dir) + } + } + // Fallback to system temp directory + File.createTempFile(prefix, suffix) + } } class S3RetryDeleteObjectsCondition extends SDKDefaultRetryCondition {