Skip to content

Commit 8285941

Browse files
author
Karuppayya Rajendran
committed
Codegen for MergeRowExec
1 parent 686d844 commit 8285941

File tree

1 file changed

+278
-4
lines changed

1 file changed

+278
-4
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala

Lines changed: 278 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,25 @@ package org.apache.spark.sql.execution.datasources.v2
1919

2020
import org.roaringbitmap.longlong.Roaring64Bitmap
2121

22+
import org.apache.spark.SparkUnsupportedOperationException
2223
import org.apache.spark.rdd.RDD
2324
import org.apache.spark.sql.AnalysisException
2425
import org.apache.spark.sql.catalyst.InternalRow
2526
import org.apache.spark.sql.catalyst.expressions.Attribute
2627
import org.apache.spark.sql.catalyst.expressions.AttributeSet
2728
import org.apache.spark.sql.catalyst.expressions.BasePredicate
29+
import org.apache.spark.sql.catalyst.expressions.BindReferences
2830
import org.apache.spark.sql.catalyst.expressions.Expression
2931
import org.apache.spark.sql.catalyst.expressions.Projection
3032
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
31-
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
33+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral, GeneratePredicate, JavaCode}
34+
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
3235
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Context, Copy, Delete, Discard, Insert, Instruction, Keep, ROW_ID, Split, Update}
3336
import org.apache.spark.sql.catalyst.util.truncatedString
3437
import org.apache.spark.sql.errors.QueryExecutionErrors
35-
import org.apache.spark.sql.execution.SparkPlan
36-
import org.apache.spark.sql.execution.UnaryExecNode
38+
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryExecNode}
3739
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
40+
import org.apache.spark.sql.types.BooleanType
3841

3942
case class MergeRowsExec(
4043
isSourceRowPresent: Expression,
@@ -44,7 +47,7 @@ case class MergeRowsExec(
4447
notMatchedBySourceInstructions: Seq[Instruction],
4548
checkCardinality: Boolean,
4649
output: Seq[Attribute],
47-
child: SparkPlan) extends UnaryExecNode {
50+
child: SparkPlan) extends UnaryExecNode with CodegenSupport {
4851

4952
override lazy val metrics: Map[String, SQLMetric] = Map(
5053
"numTargetRowsCopied" -> SQLMetrics.createMetric(sparkContext,
@@ -92,6 +95,277 @@ case class MergeRowsExec(
9295
child.execute().mapPartitions(processPartition)
9396
}
9497

98+
override def inputRDDs(): Seq[RDD[InternalRow]] = {
99+
child.asInstanceOf[CodegenSupport].inputRDDs()
100+
}
101+
102+
protected override def doProduce(ctx: CodegenContext): String = {
103+
child.asInstanceOf[CodegenSupport].produce(ctx, this)
104+
}
105+
106+
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
107+
// Save the input variables that were passed to doConsume
108+
val inputCurrentVars = input
109+
110+
// code for instruction execution code
111+
generateInstructionExecutionCode(ctx, inputCurrentVars)
112+
}
113+
114+
115+
/**
116+
* code for cardinality validation
117+
*/
118+
private def generateCardinalityValidationCode(ctx: CodegenContext, rowIdOrdinal: Int,
119+
input: Seq[ExprCode]): ExprCode = {
120+
val bitmapClass = classOf[Roaring64Bitmap]
121+
val rowIdBitmap = ctx.addMutableState(bitmapClass.getName, "matchedRowIds",
122+
v => s"$v = new ${bitmapClass.getName}();")
123+
124+
val currentRowId = input(rowIdOrdinal)
125+
val queryExecutionErrorsClass = QueryExecutionErrors.getClass.getName + ".MODULE$"
126+
val code =
127+
code"""
128+
|${currentRowId.code}
129+
|if ($rowIdBitmap.contains(${currentRowId.value})) {
130+
| throw $queryExecutionErrorsClass.mergeCardinalityViolationError();
131+
|}
132+
|$rowIdBitmap.add(${currentRowId.value});
133+
""".stripMargin
134+
ExprCode(code, FalseLiteral, JavaCode.variable(rowIdBitmap, bitmapClass))
135+
}
136+
137+
/**
138+
* Generate code for instruction execution based on row presence conditions
139+
*/
140+
private def generateInstructionExecutionCode(ctx: CodegenContext,
141+
inputExprs: Seq[ExprCode]): String = {
142+
143+
// code for evaluating src/tgt presence conditions
144+
val sourcePresentExpr = generatePredicateCode(ctx, isSourceRowPresent, child.output, inputExprs)
145+
val targetPresentExpr = generatePredicateCode(ctx, isTargetRowPresent, child.output, inputExprs)
146+
147+
// code for each instruction type
148+
val matchedInstructionsCode = generateInstructionsCode(ctx, matchedInstructions,
149+
"matched", inputExprs, sourcePresent = true)
150+
val notMatchedInstructionsCode = generateInstructionsCode(ctx, notMatchedInstructions,
151+
"notMatched", inputExprs, sourcePresent = true)
152+
val notMatchedBySourceInstructionsCode = generateInstructionsCode(ctx,
153+
notMatchedBySourceInstructions, "notMatchedBySource", inputExprs, sourcePresent = false)
154+
155+
val cardinalityValidationCode = if (checkCardinality) {
156+
val rowIdOrdinal = child.output.indexWhere(attr => conf.resolver(attr.name, ROW_ID))
157+
assert(rowIdOrdinal != -1, "Cannot find row ID attr")
158+
generateCardinalityValidationCode(ctx, rowIdOrdinal, inputExprs).code
159+
} else {
160+
""
161+
}
162+
163+
s"""
164+
|${sourcePresentExpr.code}
165+
|${targetPresentExpr.code}
166+
|
167+
|if (${targetPresentExpr.value} && ${sourcePresentExpr.value}) {
168+
| $cardinalityValidationCode
169+
| $matchedInstructionsCode
170+
|} else if (${sourcePresentExpr.value}) {
171+
| $notMatchedInstructionsCode
172+
|} else if (${targetPresentExpr.value}) {
173+
| $notMatchedBySourceInstructionsCode
174+
|}
175+
""".stripMargin
176+
}
177+
178+
/**
179+
* Generate code for executing a sequence of instructions
180+
*/
181+
private def generateInstructionsCode(ctx: CodegenContext, instructions: Seq[Instruction],
182+
instructionType: String,
183+
inputExprs: Seq[ExprCode],
184+
sourcePresent: Boolean): String = {
185+
if (instructions.isEmpty) {
186+
""
187+
} else {
188+
val instructionCodes = instructions.map(instruction =>
189+
generateSingleInstructionCode(ctx, instruction, inputExprs, sourcePresent))
190+
191+
s"""
192+
|${instructionCodes.mkString("\n")}
193+
|return;
194+
""".stripMargin
195+
}
196+
}
197+
198+
private def generateSingleInstructionCode(ctx: CodegenContext,
199+
instruction: Instruction,
200+
inputExprs: Seq[ExprCode],
201+
sourcePresent: Boolean): String = {
202+
instruction match {
203+
case Keep(context, condition, outputExprs) =>
204+
val projectionExpr = generateProjectionCode(ctx, outputExprs, inputExprs)
205+
val code = generatePredicateCode(ctx, condition, child.output, inputExprs)
206+
207+
// Generate metric updates based on context
208+
val metricUpdateCode = generateMetricUpdateCode(ctx, context, sourcePresent)
209+
210+
s"""
211+
|${code.code}
212+
|if (${code.value}) {
213+
| $metricUpdateCode
214+
| ${consume(ctx, projectionExpr)}
215+
| return;
216+
|}
217+
""".stripMargin
218+
219+
case Discard(condition) =>
220+
val code = generatePredicateCode(ctx, condition, child.output, inputExprs)
221+
val metricUpdateCode = generateDeleteMetricUpdateCode(ctx, sourcePresent)
222+
223+
s"""
224+
|${code.code}
225+
|if (${code.value}) {
226+
| $metricUpdateCode
227+
| return; // Discar row
228+
|}
229+
""".stripMargin
230+
231+
case Split(condition, outputExprs, otherOutputExprs) =>
232+
val projectionExpr = generateProjectionCode(ctx, outputExprs, inputExprs)
233+
val otherProjectionExpr = generateProjectionCode(ctx, otherOutputExprs, inputExprs)
234+
val code = generatePredicateCode(ctx, condition, child.output, inputExprs)
235+
val metricUpdateCode = generateUpdateMetricUpdateCode(ctx, sourcePresent)
236+
237+
s"""
238+
|${code.code}
239+
|if (${code.value}) {
240+
| $metricUpdateCode
241+
| ${consume(ctx, projectionExpr)}
242+
| ${consume(ctx, otherProjectionExpr)}
243+
| return;
244+
|}
245+
""".stripMargin
246+
case _ =>
247+
// Codegen not implemented
248+
throw new SparkUnsupportedOperationException(
249+
errorClass = "_LEGACY_ERROR_TEMP_3073",
250+
messageParameters = Map("instruction" -> instruction.toString))
251+
}
252+
}
253+
254+
/**
255+
* metric update code based on Keep's context
256+
*/
257+
private def generateMetricUpdateCode(ctx: CodegenContext, context: Context,
258+
sourcePresent: Boolean): String = {
259+
context match {
260+
case Copy =>
261+
val copyMetric = metricTerm(ctx, "numTargetRowsCopied")
262+
s"$copyMetric.add(1);"
263+
264+
case Insert =>
265+
val insertMetric = metricTerm(ctx, "numTargetRowsInserted")
266+
s"$insertMetric.add(1);"
267+
268+
case Update =>
269+
generateUpdateMetricUpdateCode(ctx, sourcePresent)
270+
271+
case Delete =>
272+
generateDeleteMetricUpdateCode(ctx, sourcePresent)
273+
274+
case _ =>
275+
throw new IllegalArgumentException(s"Unexpected context for KeepExec: $context")
276+
}
277+
}
278+
279+
private def generateUpdateMetricUpdateCode(ctx: CodegenContext,
280+
sourcePresent: Boolean): String = {
281+
val updateMetric = metricTerm(ctx, "numTargetRowsUpdated")
282+
if (sourcePresent) {
283+
val matchedUpdateMetric = metricTerm(ctx, "numTargetRowsMatchedUpdated")
284+
285+
s"""
286+
|$updateMetric.add(1);
287+
|$matchedUpdateMetric.add(1);
288+
""".stripMargin
289+
} else {
290+
val notMatchedBySourceUpdateMetric = metricTerm(ctx, "numTargetRowsNotMatchedBySourceUpdated")
291+
292+
s"""
293+
|$updateMetric.add(1);
294+
|$notMatchedBySourceUpdateMetric.add(1);
295+
""".stripMargin
296+
}
297+
}
298+
299+
private def generateDeleteMetricUpdateCode(ctx: CodegenContext,
300+
sourcePresent: Boolean): String = {
301+
val deleteMetric = metricTerm(ctx, "numTargetRowsDeleted")
302+
if (sourcePresent) {
303+
val matchedDeleteMetric = metricTerm(ctx, "numTargetRowsMatchedDeleted")
304+
305+
s"""
306+
|$deleteMetric.add(1);
307+
|$matchedDeleteMetric.add(1);
308+
""".stripMargin
309+
} else {
310+
val notMatchedBySourceDeleteMetric = metricTerm(ctx, "numTargetRowsNotMatchedBySourceDeleted")
311+
312+
s"""
313+
|$deleteMetric.add(1);
314+
|$notMatchedBySourceDeleteMetric.add(1);
315+
""".stripMargin
316+
}
317+
}
318+
319+
/**
320+
* Helper method to save and restore CodegenContext state for code generation.
321+
*
322+
* This is needed because when generating code for expressions, the CodegenContext
323+
* state (currentVars and INPUT_ROW) gets modified during expression evaluation.
324+
* This method temporarily sets the context to the input variables from doConsume
325+
* and restores the original state after the block completes.
326+
*/
327+
private def withCodegenContext[T](
328+
ctx: CodegenContext,
329+
inputCurrentVars: Seq[ExprCode])(block: => T): T = {
330+
val originalCurrentVars = ctx.currentVars
331+
val originalInputRow = ctx.INPUT_ROW
332+
try {
333+
// Set to the input variables saved in doConsume
334+
ctx.currentVars = inputCurrentVars
335+
block
336+
} finally {
337+
// Restore original context
338+
ctx.currentVars = originalCurrentVars
339+
ctx.INPUT_ROW = originalInputRow
340+
}
341+
}
342+
343+
private def generatePredicateCode(ctx: CodegenContext,
344+
predicate: Expression,
345+
inputAttrs: Seq[Attribute],
346+
inputCurrentVars: Seq[ExprCode]): ExprCode = {
347+
withCodegenContext(ctx, inputCurrentVars) {
348+
val boundPredicate = BindReferences.bindReference(predicate, inputAttrs)
349+
val ev = boundPredicate.genCode(ctx)
350+
val predicateVar = ctx.freshName("predicateResult")
351+
val code = code"""
352+
|${ev.code}
353+
|boolean $predicateVar = !${ev.isNull} && ${ev.value};
354+
""".stripMargin
355+
ExprCode(code, FalseLiteral,
356+
JavaCode.variable(predicateVar, BooleanType))
357+
}
358+
}
359+
360+
private def generateProjectionCode(ctx: CodegenContext,
361+
outputExprs: Seq[Expression],
362+
inputCurrentVars: Seq[ExprCode]): Seq[ExprCode] = {
363+
withCodegenContext(ctx, inputCurrentVars) {
364+
val boundExprs = outputExprs.map(BindReferences.bindReference(_, child.output))
365+
boundExprs.map(_.genCode(ctx))
366+
}
367+
}
368+
95369
private def processPartition(rowIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
96370
val isSourceRowPresentPred = createPredicate(isSourceRowPresent)
97371
val isTargetRowPresentPred = createPredicate(isTargetRowPresent)

0 commit comments

Comments
 (0)