diff --git a/.github/workflows/pr_benchmark_check.yml b/.github/workflows/pr_benchmark_check.yml
index b07cc03c34..a879493a7f 100644
--- a/.github/workflows/pr_benchmark_check.yml
+++ b/.github/workflows/pr_benchmark_check.yml
@@ -84,9 +84,7 @@ jobs:
${{ runner.os }}-benchmark-maven-
- name: Check Scala compilation and linting
- # Pin to spark-4.0 (Scala 2.13.16) because the default profile is now
- # spark-4.1 / Scala 2.13.17, and semanticdb-scalac_2.13.17 is not yet
- # published, which breaks `-Psemanticdb`. See pr_build_linux.yml for
- # the same exclusion in the main lint matrix.
+ # Pinned to spark-4.0 because semanticdb-scalac_2.13.17 (spark-4.1 default)
+ # is not yet published, which breaks the -Psemanticdb scalafix lint.
run: |
- ./mvnw -B compile test-compile scalafix:scalafix -Dscalafix.mode=CHECK -Psemanticdb -Pspark-4.0 -DskipTests
+ ./mvnw -B compile test-compile scalafix:scalafix -Dscalafix.mode=CHECK -Pspark-4.0 -Psemanticdb -DskipTests
diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml
index bee0c38e8f..df758c2eaa 100644
--- a/.github/workflows/pr_build_linux.yml
+++ b/.github/workflows/pr_build_linux.yml
@@ -307,6 +307,7 @@ jobs:
org.apache.comet.CometFuzzAggregateSuite
org.apache.comet.CometFuzzIcebergSuite
org.apache.comet.CometFuzzMathSuite
+ org.apache.comet.CometCodegenDispatchFuzzSuite
org.apache.comet.DataGeneratorSuite
- name: "shuffle"
value: |
@@ -385,6 +386,9 @@ jobs:
org.apache.comet.expressions.conditional.CometIfSuite
org.apache.comet.expressions.conditional.CometCoalesceSuite
org.apache.comet.expressions.conditional.CometCaseWhenSuite
+ org.apache.comet.CometCodegenDispatchSmokeSuite
+ org.apache.comet.CometCodegenSourceSuite
+ org.apache.comet.CometCodegenHOFSuite
- name: "sql"
value: |
org.apache.spark.sql.CometToPrettyStringSuite
diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml
index 3c6953aade..848a78f4a8 100644
--- a/.github/workflows/pr_build_macos.yml
+++ b/.github/workflows/pr_build_macos.yml
@@ -155,6 +155,7 @@ jobs:
org.apache.comet.CometFuzzAggregateSuite
org.apache.comet.CometFuzzIcebergSuite
org.apache.comet.CometFuzzMathSuite
+ org.apache.comet.CometCodegenDispatchFuzzSuite
org.apache.comet.DataGeneratorSuite
- name: "shuffle"
value: |
@@ -232,6 +233,9 @@ jobs:
org.apache.comet.expressions.conditional.CometIfSuite
org.apache.comet.expressions.conditional.CometCoalesceSuite
org.apache.comet.expressions.conditional.CometCaseWhenSuite
+ org.apache.comet.CometCodegenDispatchSmokeSuite
+ org.apache.comet.CometCodegenSourceSuite
+ org.apache.comet.CometCodegenHOFSuite
- name: "sql"
value: |
org.apache.spark.sql.CometToPrettyStringSuite
diff --git a/common/src/main/java/org/apache/comet/codegen/CometBatchKernel.java b/common/src/main/java/org/apache/comet/codegen/CometBatchKernel.java
new file mode 100644
index 0000000000..f9fbb775a0
--- /dev/null
+++ b/common/src/main/java/org/apache/comet/codegen/CometBatchKernel.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.codegen;
+
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.ValueVector;
+
+/**
+ * Abstract base extended by the Janino-compiled batch kernel emitted by {@code
+ * CometBatchKernelCodegen}. The generated subclass extends {@code CometInternalRow} (so Spark's
+ * {@code BoundReference.genCode} can call {@code this.getUTF8String(ord)} directly) and carries
+ * typed input fields baked at codegen time, one per input column. Expression evaluation plus Arrow
+ * read/write fuse into one method per expression tree.
+ *
+ *
Input scope: any {@code ValueVector[]}; the generated subclass casts each slot to the concrete
+ * Arrow type the compile-time schema specified. Output is a generic {@code FieldVector}; the
+ * generated subclass casts to the concrete type matching the bound expression's {@code dataType}.
+ * Widen input support by adding vector classes to the getter switch in {@code
+ * CometBatchKernelCodegen.emitTypedGetters}; widen output support by adding cases in {@code
+ * CometBatchKernelCodegen.allocateOutput} and {@code emitOutputWriter}.
+ */
+public abstract class CometBatchKernel extends CometInternalRow {
+
+ protected final Object[] references;
+
+ protected CometBatchKernel(Object[] references) {
+ this.references = references;
+ }
+
+ /**
+ * Process one batch.
+ *
+ * @param inputs Arrow input vectors; length and concrete classes must match the schema the kernel
+ * was compiled against
+ * @param output Arrow output vector; caller allocates to the expression's {@code dataType}
+ * @param numRows number of rows in this batch
+ */
+ public abstract void process(ValueVector[] inputs, FieldVector output, int numRows);
+
+ /**
+ * Run partition-dependent initialization. The generated subclass overrides this to execute
+ * statements collected via {@code CodegenContext.addPartitionInitializationStatement}, for
+ * example reseeding {@code Rand}'s {@code XORShiftRandom} from {@code seed + partitionIndex}.
+ * Deterministic expressions leave this as a no-op.
+ *
+ *
The caller must invoke this before the first {@code process} call of each partition. The
+ * generated subclass is not thread-safe across concurrent {@code process} calls, so kernels are
+ * allocated per dispatcher invocation and init is run once on the fresh instance.
+ */
+ public void init(int partitionIndex) {}
+}
diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala
index 9b376837f7..feb4129ac5 100644
--- a/common/src/main/scala/org/apache/comet/CometConf.scala
+++ b/common/src/main/scala/org/apache/comet/CometConf.scala
@@ -380,6 +380,17 @@ object CometConf extends ShimCometConf {
.booleanConf
.createWithDefault(false)
+ val COMET_SCALA_UDF_CODEGEN_ENABLED: ConfigEntry[Boolean] =
+ conf("spark.comet.exec.scalaUDF.codegen.enabled")
+ .category(CATEGORY_EXEC)
+ .doc(
+ "Whether to route Spark `ScalaUDF` expressions through Comet's Arrow-direct codegen " +
+ "dispatcher. When enabled, a supported ScalaUDF is compiled into a per-batch kernel " +
+ "that reads and writes Arrow vectors directly from native execution. When disabled, " +
+ "plans containing a ScalaUDF fall back to Spark for the enclosing operator.")
+ .booleanConf
+ .createWithDefault(true)
+
val COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED: ConfigEntry[Boolean] =
conf("spark.comet.native.shuffle.partitioning.hash.enabled")
.category(CATEGORY_SHUFFLE)
diff --git a/common/src/main/scala/org/apache/comet/codegen/CometArrayData.scala b/common/src/main/scala/org/apache/comet/codegen/CometArrayData.scala
new file mode 100644
index 0000000000..1696c466a3
--- /dev/null
+++ b/common/src/main/scala/org/apache/comet/codegen/CometArrayData.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.codegen
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+
+import org.apache.comet.shims.CometInternalRowShim
+
+/**
+ * Throwing-default base for [[ArrayData]] in the Arrow-direct codegen kernel. Subclasses override
+ * only the getters their element type needs (e.g. `numElements`, `isNullAt`, `getUTF8String` for
+ * an `ArrayType(StringType)` input).
+ *
+ * Consumer: `InputArray_${path}` nested classes the input emitter generates per `ArrayType` input
+ * column. They back `getArray(ord)` plus the recursion for `Array>` and array-typed
+ * map keys / struct fields.
+ *
+ * `ArrayData` and [[CometInternalRow]]'s [[InternalRow]] are sibling abstract classes in Spark
+ * (both extend `SpecializedGetters`, neither inherits the other), so a base aimed at one cannot
+ * serve the other. The dispatch body that '''is''' shared between them lives in
+ * [[CometSpecializedGettersDispatch]]. The third sibling, [[CometMapData]], backs `InputMap_*`
+ * and routes `keyArray()` / `valueArray()` through `CometArrayData` instances.
+ *
+ * Mixes in [[CometInternalRowShim]] for the same reason `CometInternalRow` does: Spark 4.x adds
+ * abstract `SpecializedGetters` methods (`getVariant`, `getGeography`, `getGeometry`) that both
+ * `InternalRow` and `ArrayData` inherit; the per-profile shim provides throwing defaults.
+ */
+abstract class CometArrayData extends ArrayData with CometInternalRowShim {
+
+ override def getInterval(ordinal: Int): CalendarInterval = unsupported("getInterval")
+
+ override def get(ordinal: Int, dataType: DataType): AnyRef =
+ CometSpecializedGettersDispatch.get(this, ordinal, dataType)
+
+ override def isNullAt(ordinal: Int): Boolean = unsupported("isNullAt")
+
+ override def getBoolean(ordinal: Int): Boolean = unsupported("getBoolean")
+
+ override def getByte(ordinal: Int): Byte = unsupported("getByte")
+
+ override def getShort(ordinal: Int): Short = unsupported("getShort")
+
+ override def getInt(ordinal: Int): Int = unsupported("getInt")
+
+ override def getLong(ordinal: Int): Long = unsupported("getLong")
+
+ override def getFloat(ordinal: Int): Float = unsupported("getFloat")
+
+ override def getDouble(ordinal: Int): Double = unsupported("getDouble")
+
+ override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal =
+ unsupported("getDecimal")
+
+ override def getUTF8String(ordinal: Int): UTF8String = unsupported("getUTF8String")
+
+ override def getBinary(ordinal: Int): Array[Byte] = unsupported("getBinary")
+
+ override def getStruct(ordinal: Int, numFields: Int): InternalRow = unsupported("getStruct")
+
+ override def getArray(ordinal: Int): ArrayData = unsupported("getArray")
+
+ override def getMap(ordinal: Int): MapData = unsupported("getMap")
+
+ override def setNullAt(i: Int): Unit = unsupported("setNullAt")
+
+ protected def unsupported(method: String): Nothing =
+ throw new UnsupportedOperationException(
+ s"${getClass.getSimpleName}: $method not implemented for this array shape")
+
+ override def update(i: Int, value: Any): Unit = unsupported("update")
+
+ override def copy(): ArrayData = unsupported("copy")
+
+ override def array: Array[Any] = unsupported("array")
+
+ override def toString(): String = {
+ val n =
+ try numElements().toString
+ catch {
+ case _: Throwable => "?"
+ }
+ s"${getClass.getSimpleName}(numElements=$n)"
+ }
+
+ override def numElements(): Int = unsupported("numElements")
+}
diff --git a/common/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegen.scala b/common/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegen.scala
new file mode 100644
index 0000000000..c29816d470
--- /dev/null
+++ b/common/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegen.scala
@@ -0,0 +1,559 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.codegen
+
+import org.apache.arrow.vector._
+import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector}
+import org.apache.arrow.vector.types.pojo.Field
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, HigherOrderFunction, LambdaFunction, Literal, NamedLambdaVariable, Unevaluable}
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+
+import org.apache.comet.shims.CometExprTraitShim
+
+/**
+ * Compiles a bound [[Expression]] plus an input schema into a [[CometBatchKernel]] that fuses
+ * Arrow input reads, expression evaluation, and Arrow output writes into one Janino-compiled
+ * method per (expression, schema) pair.
+ *
+ * The kernel is generic over Catalyst expressions; it does not know or assume that the bound tree
+ * came from a `ScalaUDF`. Today's only consumer is
+ * [[org.apache.comet.udf.codegen.CometScalaUDFCodegen]], but a future consumer (Spark
+ * `WholeStageCodegenExec` integration, a non-UDF batch evaluator) can drive this class directly.
+ *
+ * Constraints: single output vector per kernel (whole projections need a multi-output extension);
+ * per-row scalar evaluation only (aggregation, window, generator rejected by [[canHandle]]).
+ *
+ * Input- and output-side emission live in [[CometBatchKernelCodegenInput]] and
+ * [[CometBatchKernelCodegenOutput]]. This file owns the [[ArrowColumnSpec]] vocabulary, the
+ * [[canHandle]] / [[allocateOutput]] / [[compile]] / [[generateSource]] entry points, and
+ * cross-cutting kernel-shape decisions (null-intolerant short-circuit, CSE variant).
+ *
+ * The generated kernel '''is''' the `InternalRow` that Spark's `BoundReference.genCode` reads
+ * from: `ctx.INPUT_ROW = "row"` plus `InternalRow row = this;` inside `process` routes
+ * `row.getUTF8String(ord)` to the kernel's own typed getter (final method, constant ordinal; JIT
+ * devirtualizes and folds). `row` rather than `this` because Spark's `splitExpressions` passes
+ * INPUT_ROW as a helper-method parameter name and `this` is a reserved Java keyword.
+ */
+object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
+
+ /**
+ * Resolve an Arrow vector class by its simple name, using the same classloader the codegen uses
+ * internally. Intended for tests: the `common` module shades `org.apache.arrow` to
+ * `org.apache.comet.shaded.arrow`, so `classOf[VarCharVector]` at a call site in an unshaded
+ * module refers to a different [[Class]] object than the one the codegen compares against.
+ * Callers pass a simple name and get back the class the production code actually uses.
+ */
+ def vectorClassBySimpleName(name: String): Class[_ <: ValueVector] = name match {
+ case "BitVector" => classOf[BitVector]
+ case "TinyIntVector" => classOf[TinyIntVector]
+ case "SmallIntVector" => classOf[SmallIntVector]
+ case "IntVector" => classOf[IntVector]
+ case "BigIntVector" => classOf[BigIntVector]
+ case "Float4Vector" => classOf[Float4Vector]
+ case "Float8Vector" => classOf[Float8Vector]
+ case "DecimalVector" => classOf[DecimalVector]
+ case "DateDayVector" => classOf[DateDayVector]
+ case "TimeStampMicroVector" => classOf[TimeStampMicroVector]
+ case "TimeStampMicroTZVector" => classOf[TimeStampMicroTZVector]
+ case "VarCharVector" => classOf[VarCharVector]
+ case "VarBinaryVector" => classOf[VarBinaryVector]
+ case other => throw new IllegalArgumentException(s"unknown Arrow vector class: $other")
+ }
+
+ /**
+ * Type surface the kernel covers, on both the input getter side and the output writer side.
+ * Recursive: `ArrayType` / `StructType` / `MapType` are supported when their children are.
+ * Input and output use a single predicate today; if they ever need to diverge, split this back
+ * into per-direction methods.
+ */
+ def isSupportedDataType(dt: DataType): Boolean = dt match {
+ case BooleanType | ByteType | ShortType | IntegerType | LongType => true
+ case FloatType | DoubleType => true
+ case _: DecimalType => true
+ case _: StringType | _: BinaryType => true
+ case DateType | TimestampType | TimestampNTZType => true
+ case ArrayType(inner, _) => isSupportedDataType(inner)
+ case st: StructType => st.fields.forall(f => isSupportedDataType(f.dataType))
+ case mt: MapType => isSupportedDataType(mt.keyType) && isSupportedDataType(mt.valueType)
+ case _ => false
+ }
+
+ /**
+ * Count the number of leaf fields (including nested) in a [[DataType]]. Mirrors WSCG's
+ * `WholeStageCodegenExec.numOfNestedFields` so the [[canHandle]] threshold check uses the same
+ * unit as `spark.sql.codegen.maxFields`.
+ */
+ private def numOfNestedFields(dataType: DataType): Int = dataType match {
+ case st: StructType => st.fields.map(f => numOfNestedFields(f.dataType)).sum
+ case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType)
+ case a: ArrayType => numOfNestedFields(a.elementType)
+ case _ => 1
+ }
+
+ /**
+ * Plan-time predicate: can the codegen dispatcher handle this bound expression end to end?
+ * `None` greenlights the serde to emit the codegen proto; `Some(reason)` forces a Spark
+ * fallback (typically `withInfo(...) + None`) rather than crashing the Janino compile at
+ * execute time.
+ *
+ * Checks every `BoundReference`'s data type and the root `expr.dataType` against
+ * [[isSupportedDataType]], and rejects aggregates / generators. Intermediate nodes are not
+ * checked: only leaves (row reads) and the root (output write) touch Arrow.
+ */
+ def canHandle(boundExpr: Expression): Option[String] = {
+ if (!isSupportedDataType(boundExpr.dataType)) {
+ return Some(s"codegen dispatch: unsupported output type ${boundExpr.dataType}")
+ }
+ // Mirror WSCG's `spark.sql.codegen.maxFields` gate. Count nested fields in the output type
+ // and in every `BoundReference`'s input type. Wide schemas blow the generated class's typed
+ // input field count, the typed-getter switch, and the constant pool. Refuse here so the
+ // operator falls back to Spark cleanly rather than tripping a Janino compile failure
+ // mid-execution (which Comet has no way to recover from).
+ val maxFields = SQLConf.get.wholeStageMaxNumFields
+ val totalFields = numOfNestedFields(boundExpr.dataType) +
+ boundExpr.collect { case b: BoundReference => numOfNestedFields(b.dataType) }.sum
+ if (totalFields > maxFields) {
+ return Some(
+ s"codegen dispatch: too many nested fields ($totalFields > " +
+ s"spark.sql.codegen.maxFields=$maxFields)")
+ }
+ // Reject expressions that can't be safely compiled or cached:
+ // - AggregateFunction / Generator: non-scalar bridge shape.
+ // - CodegenFallback (other than HOF / lambda nodes admitted below): opts out of
+ // `doGenCode`. The kernel cannot splice the interpreted-eval glue cleanly.
+ // - Unevaluable: unresolved plan markers. `isCodegenInertUnevaluable` lets the shim allow
+ // version-specific leaves that are `Unevaluable` but never touched by codegen (e.g.
+ // Spark 4.0's `ResolvedCollation` in `Collate.collation` as a type marker;
+ // `Collate.genCode` delegates to its child).
+ //
+ // HOFs are `CodegenFallback` but admitted. `CodegenFallback.doGenCode` emits one
+ // `((Expression) references[N]).eval(row)` call site; the kernel dispatches to the HOF's
+ // interpreted `eval`, which mutates `NamedLambdaVariable.value` per element and reads the
+ // input array through the kernel's typed Arrow getters. Correctness depends on per-task
+ // `boundExpr` isolation in `CometScalaUDFCodegen.kernelCache`: concurrent partitions get
+ // their own deserialized expression tree, so they cannot race on the lambda variable's
+ // `AtomicReference`. See `CometCodegenHOFSuite`.
+ //
+ // Nondeterministic / stateful expressions are accepted: per-partition kernel allocation
+ // in `CometScalaUDFCodegen.ensureKernel` plus a single `init(partitionIndex)` call at
+ // partition entry give `Rand` / `MonotonicallyIncreasingID` / etc. correct state across
+ // batches and a clean reset across partitions.
+ //
+ // `ExecSubqueryExpression` (`ScalarSubquery`, `InSubqueryExec`) is accepted via a chain:
+ // the surrounding Comet operator's inherited `SparkPlan.waitForSubqueries` populates the
+ // subquery's mutable `result` field before evaluation; the closure serializer captures
+ // that populated value into the arg-0 bytes; the dispatcher keys its compile cache on
+ // those exact bytes, so distinct subquery results produce distinct cache entries with no
+ // cross-query staleness. Comet operators that bypass `waitForSubqueries` would break this.
+ boundExpr.find {
+ case _: org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction => true
+ case _: org.apache.spark.sql.catalyst.expressions.Generator => true
+ case _: HigherOrderFunction => false
+ case _: LambdaFunction => false
+ case _: NamedLambdaVariable => false
+ case _: CodegenFallback => true
+ case u: Unevaluable if isCodegenInertUnevaluable(u) => false
+ case _: Unevaluable => true
+ case _ => false
+ } match {
+ case Some(bad) =>
+ return Some(
+ s"codegen dispatch: expression ${bad.getClass.getSimpleName} not supported " +
+ "(aggregate, generator, codegen-fallback, or unevaluable)")
+ case None =>
+ }
+ val badRef = boundExpr.collectFirst {
+ case b: BoundReference if !isSupportedDataType(b.dataType) =>
+ b
+ }
+ badRef.map(b =>
+ s"codegen dispatch: unsupported input type ${b.dataType} at ordinal ${b.ordinal}")
+ }
+
+ /**
+ * Allocate an Arrow output vector matching the expression's `dataType`. Thin forwarder to
+ * [[CometBatchKernelCodegenOutput.allocateOutput]]. Kept on this object as part of the public
+ * API so external callers (`CometScalaUDFCodegen`) do not have to know about the internal
+ * split.
+ */
+ def allocateOutput(
+ dataType: DataType,
+ name: String,
+ numRows: Int,
+ estimatedBytes: Int = -1): FieldVector =
+ CometBatchKernelCodegenOutput.allocateOutput(dataType, name, numRows, estimatedBytes)
+
+ /** Variant that takes a pre-computed Arrow `Field`, letting hot-path callers cache it. */
+ def allocateOutput(field: Field, numRows: Int, estimatedBytes: Int): FieldVector =
+ CometBatchKernelCodegenOutput.allocateOutput(field, numRows, estimatedBytes)
+
+ def compile(boundExpr: Expression, inputSchema: Seq[ArrowColumnSpec]): CompiledKernel = {
+ val src = generateSource(boundExpr, inputSchema)
+ val (clazz, _) =
+ try {
+ CodeGenerator.compile(src.code)
+ } catch {
+ case t: Throwable =>
+ logError(
+ s"CometBatchKernelCodegen: compile failed for ${boundExpr.getClass.getSimpleName}. " +
+ s"Generated source follows:\n${CodeFormatter.format(src.code)}",
+ t)
+ throw t
+ }
+ logInfo(
+ s"CometBatchKernelCodegen: compiled ${boundExpr.getClass.getSimpleName} " +
+ s"-> ${boundExpr.dataType} inputs=" +
+ inputSchema
+ .map(s => s"${s.vectorClass.getSimpleName}${if (s.nullable) "?" else ""}")
+ .mkString(","))
+ // Freshen references per kernel allocation. See the `CompiledKernel` scaladoc for why.
+ // `generateSource` is pure with respect to its inputs (no hidden state) and produces a
+ // layout-compatible references array each time because the expression and schema are
+ // fixed.
+ val freshReferences: () => Array[Any] = () =>
+ generateSource(boundExpr, inputSchema).references
+ CompiledKernel(clazz, freshReferences)
+ }
+
+ /**
+ * Generate the Java source for a kernel without compiling it. Factored out of [[compile]] so
+ * tests can assert on the emitted source (null short-circuit present, non-nullable `isNullAt`
+ * returns literal `false`, etc.) without paying for Janino.
+ */
+ def generateSource(
+ boundExpr: Expression,
+ inputSchema: Seq[ArrowColumnSpec]): GeneratedSource = {
+ val ctx = new CodegenContext
+ // `BoundReference.genCode` emits `${ctx.INPUT_ROW}.getUTF8String(ord)`. We alias a local
+ // `row` to `this` at the top of `process` so those reads resolve to the kernel's own typed
+ // getters (virtual dispatch on a concrete final class, JIT devirtualizes + folds the
+ // switch). `row` rather than `this` because Spark's `splitExpressions` uses INPUT_ROW as the
+ // parameter name of any helper method it emits; `this` is a reserved keyword, so using it
+ // as a parameter name produces `private UTF8String helper(InternalRow this)` which Janino
+ // rejects.
+ ctx.INPUT_ROW = "row"
+
+ val baseClass = classOf[CometBatchKernel].getName
+ // Resolve shaded Arrow class names at compile time so generated source
+ // matches the abstract method signature after Maven relocation.
+ val valueVectorClass = classOf[ValueVector].getName
+ val fieldVectorClass = classOf[FieldVector].getName
+
+ // Build the per-row body via Spark's doGenCode.
+ //
+ // `outputSetup` holds once-per-batch declarations (typed child-vector casts for complex
+ // output) that `emitOutputWriter` factors out of the per-row body so they do not repeat on
+ // every row. Scalar outputs return an empty string here.
+ //
+ // TODO(method-size): perRowBody is inlined inside process's for-loop and not split.
+ // Sufficiently deep trees can exceed Janino's 64KB method size; wrap in
+ // ctx.splitExpressionsWithCurrentInputs when hit.
+ val (concreteOutClass, outputSetup, perRowBody) = {
+ // Class-field CSE. `generateExpressions` runs `subexpressionElimination` under the
+ // hood, which populates `ctx.subexprFunctions` with per-row helper calls that write
+ // common subexpression results into `addMutableState`-allocated fields; the returned
+ // `ExprCode` then references those fields. `subexprFunctionsCode` is the concatenated
+ // helper invocation block, spliced into the per-row body by `defaultBody` (inside the
+ // NullIntolerant else-branch when that short-circuit fires, otherwise before
+ // `ev.code`). See the "Subexpression elimination" section of the object-level
+ // Scaladoc for why we use this variant rather than the WSCG one.
+ val ev = if (SQLConf.get.subexpressionEliminationEnabled) {
+ ctx.generateExpressions(Seq(boundExpr), doSubexpressionElimination = true).head
+ } else {
+ boundExpr.genCode(ctx)
+ }
+ val subExprsCode = ctx.subexprFunctionsCode
+ val (cls, setup, snippet) =
+ CometBatchKernelCodegenOutput.emitOutputWriter(boundExpr.dataType, ev.value, ctx)
+ (cls, setup, defaultBody(boundExpr, ev, snippet, subExprsCode))
+ }
+
+ val typedFieldDecls = CometBatchKernelCodegenInput.emitInputFieldDecls(inputSchema)
+ val typedInputCasts = CometBatchKernelCodegenInput.emitInputCasts(inputSchema)
+ val decimalTypeByOrdinal = CometBatchKernelCodegenInput.decimalPrecisionByOrdinal(boundExpr)
+ val getters =
+ CometBatchKernelCodegenInput.emitTypedGetters(inputSchema, decimalTypeByOrdinal)
+ val nested = CometBatchKernelCodegenInput.emitNestedClasses(inputSchema)
+ val getArrayMethod = CometBatchKernelCodegenInput.emitGetArrayMethod(inputSchema)
+ val getStructMethod = CometBatchKernelCodegenInput.emitGetStructMethod(inputSchema)
+ val getMapMethod = CometBatchKernelCodegenInput.emitGetMapMethod(inputSchema)
+
+ val codeBody =
+ s"""
+ |public java.lang.Object generate(Object[] references) {
+ | return new SpecificCometBatchKernel(references);
+ |}
+ |
+ |final class SpecificCometBatchKernel extends $baseClass {
+ |
+ | ${ctx.declareMutableStates()}
+ |
+ | $typedFieldDecls
+ | private int rowIdx;
+ |
+ | public SpecificCometBatchKernel(Object[] references) {
+ | super(references);
+ | ${ctx.initMutableStates()}
+ | }
+ |
+ | @Override
+ | public void init(int partitionIndex) {
+ | ${ctx.initPartition()}
+ | }
+ |
+ | $getters
+ | $getArrayMethod
+ | $getStructMethod
+ | $getMapMethod
+ |
+ | @Override
+ | public void process(
+ | $valueVectorClass[] inputs,
+ | $fieldVectorClass outRaw,
+ | int numRows) {
+ | $concreteOutClass output = ($concreteOutClass) outRaw;
+ | $typedInputCasts
+ | $outputSetup
+ | // Alias the kernel as `row` so Spark-generated `${ctx.INPUT_ROW}.method()` reads
+ | // resolve to the kernel's own typed getters. Helper methods that Spark splits off
+ | // via `splitExpressions` also take `InternalRow row` as a parameter; we pass `this`
+ | // implicitly since callers substitute INPUT_ROW which we've set to `row`.
+ | org.apache.spark.sql.catalyst.InternalRow row = this;
+ | for (int i = 0; i < numRows; i++) {
+ | this.rowIdx = i;
+ | $perRowBody
+ | }
+ | }
+ |
+ | ${ctx.declareAddedFunctions()}
+ |
+ |$nested
+ |}
+ """.stripMargin
+
+ val code = CodeFormatter.stripOverlappingComments(
+ new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))
+ GeneratedSource(code.body, code, ctx.references.toArray)
+ }
+
+ /**
+ * Per-row body for the default path. For `NullIntolerant` expressions (null in any input ->
+ * null output), prepends a short-circuit that skips expression evaluation entirely when any
+ * input column is null this row, saving the full `ev.code` cost. Otherwise the standard shape:
+ * run `ev.code`, then `setNull` or write based on `ev.isNull`.
+ *
+ * `subExprsCode` is the CSE helper-invocation block; it writes common subexpression results
+ * into class fields that `ev.code` reads, so it must run before `ev.code`. Inside the
+ * short-circuit it lives in the else branch, skipping CSE for null rows. Empty when CSE is
+ * disabled or the tree has none.
+ */
+ private def defaultBody(
+ boundExpr: Expression,
+ ev: ExprCode,
+ writeSnippet: String,
+ subExprsCode: String): String = {
+ boundExpr match {
+ case _ if isNullIntolerant(boundExpr) && allNullIntolerant(boundExpr) =>
+ // Every node from root to leaf is either NullIntolerant or a leaf. That transitively
+ // guarantees "any BoundReference null at this row -> whole expression null", so we can
+ // short-circuit on the union of input ordinals. Breaking the chain with a non-null-
+ // propagating node like `coalesce` or `if` produces the wrong result (coalesce(null,x)
+ // is x, not null), so the check above rejects those shapes and falls through to the
+ // default branch which runs Spark's own null-aware ev.code.
+ val inputOrdinals =
+ boundExpr.collect { case b: BoundReference => b.ordinal }.distinct
+ val nullCheck =
+ if (inputOrdinals.isEmpty) "false"
+ else inputOrdinals.map(ord => s"this.col$ord.isNull(i)").mkString(" || ")
+ s"""
+ |if ($nullCheck) {
+ | output.setNull(i);
+ |} else {
+ | $subExprsCode
+ | ${ev.code}
+ | $writeSnippet
+ |}
+ """.stripMargin
+ case _ =>
+ // Optimization: NonNullableOutputShortCircuit.
+ // When the bound expression declares `nullable = false`, the `if (ev.isNull)` branch is
+ // dead and HotSpot may or may not fold it (it depends on whether the expression's
+ // `doGenCode` made `ev.isNull` a `FalseLiteral` or a variable whose value is
+ // false-at-runtime but not a compile-time constant from Spark's side). Drop the guard
+ // at source level so we don't depend on JIT folding and keep the generated body
+ // minimal.
+ if (!boundExpr.nullable) {
+ s"""
+ |$subExprsCode
+ |${ev.code}
+ |$writeSnippet
+ """.stripMargin
+ } else {
+ s"""
+ |$subExprsCode
+ |${ev.code}
+ |if (${ev.isNull}) {
+ | output.setNull(i);
+ |} else {
+ | $writeSnippet
+ |}
+ """.stripMargin
+ }
+ }
+ }
+
+ /**
+ * True iff every node in the expression tree is either `NullIntolerant` or a leaf we can safely
+ * consider null-propagating (`BoundReference` and `Literal`). Used to gate the `NullIntolerant`
+ * short-circuit in [[defaultBody]]: the short-circuit collects `BoundReference` ordinals from
+ * the whole tree and skips `ev.code` when any of them is null, which is only correct when every
+ * path from a leaf to the root propagates nulls. A non- propagating node (`Coalesce`, `If`,
+ * `CaseWhen`, `Concat`, etc.) anywhere in the tree invalidates this assumption: `coalesce(null,
+ * x)` is `x`, not null, so pre-nulling on any input null would produce the wrong result.
+ */
+ private def allNullIntolerant(expr: Expression): Boolean =
+ !expr.exists {
+ case _: BoundReference | _: Literal => false
+ case other => !isNullIntolerant(other)
+ }
+
+ /**
+ * Per-column compile-time invariants. The concrete Arrow vector class and whether the column is
+ * nullable are baked into the generated kernel's typed fields and branches. Part of the cache
+ * key: different vector classes or nullability produce different kernels.
+ *
+ * Sealed hierarchy so that complex types (array/map/struct) can carry their nested element
+ * shape recursively. Today scalar, array, and struct specs exist; map cases will land as an
+ * additional subclass when the emitter covers them. A companion `apply` / `unapply` preserves
+ * the original scalar-only construction and extractor shape so existing callers don't need to
+ * change.
+ */
+ sealed trait ArrowColumnSpec {
+ def vectorClass: Class[_ <: ValueVector]
+
+ def nullable: Boolean
+ }
+
+ /** Scalar column: one Arrow vector class per row slot, no nested structure. */
+ final case class ScalarColumnSpec(vectorClass: Class[_ <: ValueVector], nullable: Boolean)
+ extends ArrowColumnSpec
+
+ /**
+ * Array column: an Arrow `ListVector` wrapping a child spec. `elementSparkType` is the Spark
+ * `DataType` of the element so the nested-class getter emitter can choose the right template
+ * (e.g. `getUTF8String` for `StringType`, `getInt` for `IntegerType`). The child spec carries
+ * the Arrow child vector class. Nested arrays (`Array>`) work by the child being
+ * itself an `ArrayColumnSpec`.
+ */
+ final case class ArrayColumnSpec(
+ nullable: Boolean,
+ elementSparkType: DataType,
+ element: ArrowColumnSpec)
+ extends ArrowColumnSpec {
+ override def vectorClass: Class[_ <: ValueVector] = classOf[ListVector]
+ }
+
+ /**
+ * Struct column: an Arrow `StructVector` wrapping N typed child specs. Each entry carries the
+ * Spark field name (for schema identification in the cache key), the Spark `DataType` of the
+ * field (so per-field emitters pick the right read/write template), the child `ArrowColumnSpec`
+ * (so nested shapes like `Struct>` compose by trait-level recursion), and the
+ * field's `nullable` bit (so non-nullable fields elide their per-row null check at source
+ * level). Nested structs (`Struct>`) work by the child being itself a
+ * `StructColumnSpec`.
+ */
+ final case class StructColumnSpec(nullable: Boolean, fields: Seq[StructFieldSpec])
+ extends ArrowColumnSpec {
+ override def vectorClass: Class[_ <: ValueVector] = classOf[StructVector]
+ }
+
+ /** One field entry on a [[StructColumnSpec]]. */
+ final case class StructFieldSpec(
+ name: String,
+ sparkType: DataType,
+ nullable: Boolean,
+ child: ArrowColumnSpec)
+
+ /**
+ * Map column: an Arrow `MapVector` (subclass of `ListVector`) whose data vector is a
+ * `StructVector` with a key field at ordinal 0 and a value field at ordinal 1. `key` and
+ * `value` are themselves `ArrowColumnSpec` so nested shapes (`Map, Array>`,
+ * `Map