Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,22 @@
package io.delta.kernel.spark.read;

import io.delta.kernel.expressions.Predicate;
import io.delta.kernel.spark.utils.PartitionUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.read.Batch;
import org.apache.spark.sql.connector.read.InputPartition;
import org.apache.spark.sql.connector.read.PartitionReaderFactory;
import org.apache.spark.sql.execution.datasources.FileFormat$;
import org.apache.spark.sql.execution.datasources.FilePartition;
import org.apache.spark.sql.execution.datasources.FilePartition$;
import org.apache.spark.sql.execution.datasources.PartitionedFile;
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat;
import org.apache.spark.sql.execution.datasources.parquet.ParquetUtils;
import org.apache.spark.sql.internal.SQLConf;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.types.StructType;
import scala.Function1;
import scala.Option;
import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.JavaConverters;

public class SparkBatch implements Batch {
Expand Down Expand Up @@ -88,7 +81,9 @@ public SparkBatch(
@Override
public InputPartition[] planInputPartitions() {
SparkSession sparkSession = SparkSession.active();
long maxSplitBytes = calculateMaxSplitBytes(sparkSession);
long maxSplitBytes =
PartitionUtils.calculateMaxSplitBytes(
sparkSession, totalBytes, partitionedFiles.size(), sqlConf);

scala.collection.Seq<FilePartition> filePartitions =
FilePartition$.MODULE$.getFilePartitions(
Expand All @@ -98,25 +93,14 @@ public InputPartition[] planInputPartitions() {

@Override
public PartitionReaderFactory createReaderFactory() {
boolean enableVectorizedReader =
ParquetUtils.isBatchReadSupportedForSchema(sqlConf, readDataSchema);
scala.collection.immutable.Map<String, String> optionsWithBatch =
scalaOptions.$plus(
new Tuple2<>(
FileFormat$.MODULE$.OPTION_RETURNING_BATCH(),
String.valueOf(enableVectorizedReader)));
Function1<PartitionedFile, Iterator<InternalRow>> readFunc =
new ParquetFileFormat()
.buildReaderWithPartitionValues(
SparkSession.active(),
dataSchema,
partitionSchema,
readDataSchema,
JavaConverters.asScalaBuffer(Arrays.asList(dataFilters)).toSeq(),
optionsWithBatch,
hadoopConf);

return new SparkReaderFactory(readFunc, enableVectorizedReader);
return PartitionUtils.createParquetReaderFactory(
dataSchema,
partitionSchema,
readDataSchema,
dataFilters,
scalaOptions,
hadoopConf,
sqlConf);
}

@Override
Expand Down Expand Up @@ -145,24 +129,4 @@ public int hashCode() {
result = 31 * result + Integer.hashCode(partitionedFiles.size());
return result;
}

private long calculateMaxSplitBytes(SparkSession sparkSession) {
long defaultMaxSplitBytes = sqlConf.filesMaxPartitionBytes();
long openCostInBytes = sqlConf.filesOpenCostInBytes();
Option<Object> minPartitionNumOption = sqlConf.filesMinPartitionNum();

int minPartitionNum =
minPartitionNumOption.isDefined()
? ((Number) minPartitionNumOption.get()).intValue()
: sqlConf
.getConf(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM())
.getOrElse(() -> sparkSession.sparkContext().defaultParallelism());
if (minPartitionNum <= 0) {
minPartitionNum = 1;
}
long calculatedTotalBytes = totalBytes + (long) partitionedFiles.size() * openCostInBytes;
long bytesPerCore = calculatedTotalBytes / minPartitionNum;

return Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
import io.delta.kernel.internal.actions.RemoveFile;
import io.delta.kernel.internal.util.Utils;
import io.delta.kernel.spark.snapshot.DeltaSnapshotManager;
import io.delta.kernel.spark.utils.PartitionUtils;
import io.delta.kernel.spark.utils.ScalaUtils;
import io.delta.kernel.spark.utils.StreamingHelper;
import io.delta.kernel.utils.CloseableIterator;
import java.io.IOException;
import java.time.ZoneId;
import java.util.*;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.sql.SparkSession;
Expand All @@ -45,8 +47,16 @@
import org.apache.spark.sql.delta.sources.DeltaSQLConf;
import org.apache.spark.sql.delta.sources.DeltaSource;
import org.apache.spark.sql.delta.sources.DeltaSourceOffset;
import org.apache.spark.sql.execution.datasources.FilePartition;
import org.apache.spark.sql.execution.datasources.FilePartition$;
import org.apache.spark.sql.execution.datasources.PartitionedFile;
import org.apache.spark.sql.internal.SQLConf;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.collection.JavaConverters;
import scala.collection.Seq;

public class SparkMicroBatchStream implements MicroBatchStream, SupportsAdmissionControl {

Expand All @@ -63,6 +73,14 @@ public class SparkMicroBatchStream implements MicroBatchStream, SupportsAdmissio
private final String tableId;
private final boolean shouldValidateOffsets;
private final SparkSession spark;
private final String tablePath;
private final StructType readDataSchema;
private final StructType dataSchema;
private final StructType partitionSchema;
private final Filter[] dataFilters;
private final Configuration hadoopConf;
private final SQLConf sqlConf;
private final scala.collection.immutable.Map<String, String> scalaOptions;

/**
* Tracks whether this is the first batch for this stream (no checkpointed offset).
Expand All @@ -74,31 +92,37 @@ public class SparkMicroBatchStream implements MicroBatchStream, SupportsAdmissio
*/
private boolean isFirstBatch = false;

public SparkMicroBatchStream(
DeltaSnapshotManager snapshotManager,
Snapshot snapshotAtSourceInit,
Configuration hadoopConf) {
this(
snapshotManager,
snapshotAtSourceInit,
hadoopConf,
SparkSession.active(),
new DeltaOptions(
scala.collection.immutable.Map$.MODULE$.empty(),
SparkSession.active().sessionState().conf()));
}

public SparkMicroBatchStream(
DeltaSnapshotManager snapshotManager,
Snapshot snapshotAtSourceInit,
Configuration hadoopConf,
SparkSession spark,
DeltaOptions options) {
this.spark = spark;
this.snapshotManager = snapshotManager;
this.snapshotAtSourceInit = snapshotAtSourceInit;
DeltaOptions options,
String tablePath,
StructType dataSchema,
StructType partitionSchema,
StructType readDataSchema,
Filter[] dataFilters,
scala.collection.immutable.Map<String, String> scalaOptions) {
this.snapshotManager = Objects.requireNonNull(snapshotManager, "snapshotManager is null");
this.hadoopConf = Objects.requireNonNull(hadoopConf, "hadoopConf is null");
this.spark = Objects.requireNonNull(spark, "spark is null");
this.engine = DefaultEngine.create(hadoopConf);
this.options = options;
this.options = Objects.requireNonNull(options, "options is null");
// Normalize tablePath to ensure it ends with "/" for consistent path construction
String normalizedTablePath = Objects.requireNonNull(tablePath, "tablePath is null");
this.tablePath =
normalizedTablePath.endsWith("/") ? normalizedTablePath : normalizedTablePath + "/";
this.dataSchema = Objects.requireNonNull(dataSchema, "dataSchema is null");
this.partitionSchema = Objects.requireNonNull(partitionSchema, "partitionSchema is null");
this.readDataSchema = Objects.requireNonNull(readDataSchema, "readDataSchema is null");
this.dataFilters =
Arrays.copyOf(
Objects.requireNonNull(dataFilters, "dataFilters is null"), dataFilters.length);
this.sqlConf = SQLConf.get();
this.scalaOptions = Objects.requireNonNull(scalaOptions, "scalaOptions is null");

this.snapshotAtSourceInit = snapshotAtSourceInit;
this.tableId = ((SnapshotImpl) snapshotAtSourceInit).getMetadata().getId();
this.shouldValidateOffsets =
(Boolean) spark.sessionState().conf().getConf(DeltaSQLConf.STREAMING_OFFSET_VALIDATION());
Expand Down Expand Up @@ -228,12 +252,58 @@ private Optional<DeltaSourceOffset> getNextOffsetFromPreviousOffset(

@Override
public InputPartition[] planInputPartitions(Offset start, Offset end) {
throw new UnsupportedOperationException("planInputPartitions is not supported");
DeltaSourceOffset startOffset = (DeltaSourceOffset) start;
DeltaSourceOffset endOffset = (DeltaSourceOffset) end;

long fromVersion = startOffset.reservoirVersion();
long fromIndex = startOffset.index();
boolean isInitialSnapshot = startOffset.isInitialSnapshot();

List<PartitionedFile> partitionedFiles = new ArrayList<>();
long totalBytesToRead = 0;
try (CloseableIterator<IndexedFile> fileChanges =
getFileChanges(fromVersion, fromIndex, isInitialSnapshot, Optional.of(endOffset))) {
while (fileChanges.hasNext()) {
IndexedFile indexedFile = fileChanges.next();
if (!indexedFile.hasFileAction() || indexedFile.getAddFile() == null) {
continue;
}
AddFile addFile = indexedFile.getAddFile();
PartitionedFile partitionedFile =
PartitionUtils.buildPartitionedFile(
addFile, partitionSchema, tablePath, ZoneId.of(sqlConf.sessionLocalTimeZone()));

totalBytesToRead += addFile.getSize();
partitionedFiles.add(partitionedFile);
}
} catch (IOException e) {
throw new RuntimeException(
String.format(
"Failed to get file changes for table %s from version %d index %d to offset %s",
tablePath, fromVersion, fromIndex, endOffset),
e);
}

long maxSplitBytes =
PartitionUtils.calculateMaxSplitBytes(
spark, totalBytesToRead, partitionedFiles.size(), sqlConf);
// Partitions files into Spark FilePartitions.
Seq<FilePartition> filePartitions =
FilePartition$.MODULE$.getFilePartitions(
spark, JavaConverters.asScalaBuffer(partitionedFiles).toSeq(), maxSplitBytes);
return JavaConverters.seqAsJavaList(filePartitions).toArray(new InputPartition[0]);
}

@Override
public PartitionReaderFactory createReaderFactory() {
throw new UnsupportedOperationException("createReaderFactory is not supported");
return PartitionUtils.createParquetReaderFactory(
dataSchema,
partitionSchema,
readDataSchema,
dataFilters,
scalaOptions,
hadoopConf,
sqlConf);
}

///////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,21 @@

import io.delta.kernel.Snapshot;
import io.delta.kernel.data.FilteredColumnarBatch;
import io.delta.kernel.data.MapValue;
import io.delta.kernel.data.Row;
import io.delta.kernel.defaults.engine.DefaultEngine;
import io.delta.kernel.engine.Engine;
import io.delta.kernel.expressions.Predicate;
import io.delta.kernel.internal.actions.AddFile;
import io.delta.kernel.internal.data.ScanStateRow;
import io.delta.kernel.spark.snapshot.DeltaSnapshotManager;
import io.delta.kernel.spark.utils.PartitionUtils;
import io.delta.kernel.spark.utils.ScalaUtils;
import io.delta.kernel.utils.CloseableIterator;
import java.io.IOException;
import java.time.ZoneId;
import java.util.*;
import java.util.stream.Collectors;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.paths.SparkPath;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.Expression;
Expand All @@ -50,7 +49,6 @@
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
import scala.collection.JavaConverters;

/** Spark DSV2 Scan implementation backed by Delta Kernel. */
public class SparkScan implements Scan, SupportsReportStatistics, SupportsRuntimeV2Filtering {
Expand Down Expand Up @@ -171,7 +169,17 @@ public MicroBatchStream toMicroBatchStream(String checkpointLocation) {
// Validate streaming options immediately after constructing DeltaOptions
validateStreamingOptions(deltaOptions);
return new SparkMicroBatchStream(
snapshotManager, initialSnapshot, hadoopConf, SparkSession.active(), deltaOptions);
snapshotManager,
initialSnapshot,
hadoopConf,
SparkSession.active(),
deltaOptions,
getTablePath(),
dataSchema,
partitionSchema,
readDataSchema,
dataFilters != null ? dataFilters : new Filter[0],
scalaOptions != null ? scalaOptions : scala.collection.immutable.Map$.MODULE$.empty());
}

@Override
Expand Down Expand Up @@ -214,45 +222,6 @@ private String getTablePath() {
return tableRoot.endsWith("/") ? tableRoot : tableRoot + "/";
}

/**
* Build the partition {@link InternalRow} from kernel partition values by casting them to the
* desired Spark types using the session time zone for temporal types.
*/
private InternalRow getPartitionRow(MapValue partitionValues) {
final int numPartCols = partitionSchema.fields().length;
assert partitionValues.getSize() == numPartCols
: String.format(
Locale.ROOT,
"Partition values size from add file %d != partition columns size %d",
partitionValues.getSize(),
numPartCols);

final Object[] values = new Object[numPartCols];

// Build field name -> index map once
final Map<String, Integer> fieldIndex = new HashMap<>(numPartCols);
for (int i = 0; i < numPartCols; i++) {
fieldIndex.put(partitionSchema.fields()[i].name(), i);
values[i] = null;
}

// Fill values in a single pass over partitionValues
for (int idx = 0; idx < partitionValues.getSize(); idx++) {
final String key = partitionValues.getKeys().getString(idx);
final String strVal = partitionValues.getValues().getString(idx);
final Integer pos = fieldIndex.get(key);
if (pos != null) {
final StructField field = partitionSchema.fields()[pos];
values[pos] =
(strVal == null)
? null
: PartitioningUtils.castPartValueToDesiredType(field.dataType(), strVal, zoneId);
}
}
return InternalRow.fromSeq(
JavaConverters.asScalaIterator(Arrays.asList(values).iterator()).toSeq());
}

/**
* Plan the files to scan by materializing {@link PartitionedFile}s and aggregating size stats.
* Ensures all iterators are closed to avoid resource leaks.
Expand All @@ -275,15 +244,7 @@ private void planScanFiles() {
final AddFile addFile = new AddFile(row.getStruct(0));

final PartitionedFile partitionedFile =
new PartitionedFile(
getPartitionRow(addFile.getPartitionValues()),
SparkPath.fromUrlString(tablePath + addFile.getPath()),
0L,
addFile.getSize(),
locations,
addFile.getModificationTime(),
addFile.getSize(),
otherConstantMetadataColumnValues);
PartitionUtils.buildPartitionedFile(addFile, partitionSchema, tablePath, zoneId);

totalBytes += addFile.getSize();
partitionedFiles.add(partitionedFile);
Expand Down
Loading
Loading