@@ -19,22 +19,25 @@ package org.apache.spark.sql.execution.datasources.v2
19
19
20
20
import org .roaringbitmap .longlong .Roaring64Bitmap
21
21
22
+ import org .apache .spark .SparkUnsupportedOperationException
22
23
import org .apache .spark .rdd .RDD
23
24
import org .apache .spark .sql .AnalysisException
24
25
import org .apache .spark .sql .catalyst .InternalRow
25
26
import org .apache .spark .sql .catalyst .expressions .Attribute
26
27
import org .apache .spark .sql .catalyst .expressions .AttributeSet
27
28
import org .apache .spark .sql .catalyst .expressions .BasePredicate
29
+ import org .apache .spark .sql .catalyst .expressions .BindReferences
28
30
import org .apache .spark .sql .catalyst .expressions .Expression
29
31
import org .apache .spark .sql .catalyst .expressions .Projection
30
32
import org .apache .spark .sql .catalyst .expressions .UnsafeProjection
31
- import org .apache .spark .sql .catalyst .expressions .codegen .GeneratePredicate
33
+ import org .apache .spark .sql .catalyst .expressions .codegen .{CodegenContext , ExprCode , FalseLiteral , GeneratePredicate , JavaCode }
34
+ import org .apache .spark .sql .catalyst .expressions .codegen .Block .BlockHelper
32
35
import org .apache .spark .sql .catalyst .plans .logical .MergeRows .{Context , Copy , Delete , Discard , Insert , Instruction , Keep , ROW_ID , Split , Update }
33
36
import org .apache .spark .sql .catalyst .util .truncatedString
34
37
import org .apache .spark .sql .errors .QueryExecutionErrors
35
- import org .apache .spark .sql .execution .SparkPlan
36
- import org .apache .spark .sql .execution .UnaryExecNode
38
+ import org .apache .spark .sql .execution .{CodegenSupport , SparkPlan , UnaryExecNode }
37
39
import org .apache .spark .sql .execution .metric .{SQLMetric , SQLMetrics }
40
+ import org .apache .spark .sql .types .BooleanType
38
41
39
42
case class MergeRowsExec (
40
43
isSourceRowPresent : Expression ,
@@ -44,7 +47,7 @@ case class MergeRowsExec(
44
47
notMatchedBySourceInstructions : Seq [Instruction ],
45
48
checkCardinality : Boolean ,
46
49
output : Seq [Attribute ],
47
- child : SparkPlan ) extends UnaryExecNode {
50
+ child : SparkPlan ) extends UnaryExecNode with CodegenSupport {
48
51
49
52
override lazy val metrics : Map [String , SQLMetric ] = Map (
50
53
" numTargetRowsCopied" -> SQLMetrics .createMetric(sparkContext,
@@ -92,6 +95,277 @@ case class MergeRowsExec(
92
95
child.execute().mapPartitions(processPartition)
93
96
}
94
97
98
+ override def inputRDDs (): Seq [RDD [InternalRow ]] = {
99
+ child.asInstanceOf [CodegenSupport ].inputRDDs()
100
+ }
101
+
102
+ protected override def doProduce (ctx : CodegenContext ): String = {
103
+ child.asInstanceOf [CodegenSupport ].produce(ctx, this )
104
+ }
105
+
106
+ override def doConsume (ctx : CodegenContext , input : Seq [ExprCode ], row : ExprCode ): String = {
107
+ // Save the input variables that were passed to doConsume
108
+ val inputCurrentVars = input
109
+
110
+ // code for instruction execution code
111
+ generateInstructionExecutionCode(ctx, inputCurrentVars)
112
+ }
113
+
114
+
115
+ /**
116
+ * code for cardinality validation
117
+ */
118
+ private def generateCardinalityValidationCode (ctx : CodegenContext , rowIdOrdinal : Int ,
119
+ input : Seq [ExprCode ]): ExprCode = {
120
+ val bitmapClass = classOf [Roaring64Bitmap ]
121
+ val rowIdBitmap = ctx.addMutableState(bitmapClass.getName, " matchedRowIds" ,
122
+ v => s " $v = new ${bitmapClass.getName}(); " )
123
+
124
+ val currentRowId = input(rowIdOrdinal)
125
+ val queryExecutionErrorsClass = QueryExecutionErrors .getClass.getName + " .MODULE$"
126
+ val code =
127
+ code """
128
+ | ${currentRowId.code}
129
+ |if ( $rowIdBitmap.contains( ${currentRowId.value})) {
130
+ | throw $queryExecutionErrorsClass.mergeCardinalityViolationError();
131
+ |}
132
+ | $rowIdBitmap.add( ${currentRowId.value});
133
+ """ .stripMargin
134
+ ExprCode (code, FalseLiteral , JavaCode .variable(rowIdBitmap, bitmapClass))
135
+ }
136
+
137
+ /**
138
+ * Generate code for instruction execution based on row presence conditions
139
+ */
140
+ private def generateInstructionExecutionCode (ctx : CodegenContext ,
141
+ inputExprs : Seq [ExprCode ]): String = {
142
+
143
+ // code for evaluating src/tgt presence conditions
144
+ val sourcePresentExpr = generatePredicateCode(ctx, isSourceRowPresent, child.output, inputExprs)
145
+ val targetPresentExpr = generatePredicateCode(ctx, isTargetRowPresent, child.output, inputExprs)
146
+
147
+ // code for each instruction type
148
+ val matchedInstructionsCode = generateInstructionsCode(ctx, matchedInstructions,
149
+ " matched" , inputExprs, sourcePresent = true )
150
+ val notMatchedInstructionsCode = generateInstructionsCode(ctx, notMatchedInstructions,
151
+ " notMatched" , inputExprs, sourcePresent = true )
152
+ val notMatchedBySourceInstructionsCode = generateInstructionsCode(ctx,
153
+ notMatchedBySourceInstructions, " notMatchedBySource" , inputExprs, sourcePresent = false )
154
+
155
+ val cardinalityValidationCode = if (checkCardinality) {
156
+ val rowIdOrdinal = child.output.indexWhere(attr => conf.resolver(attr.name, ROW_ID ))
157
+ assert(rowIdOrdinal != - 1 , " Cannot find row ID attr" )
158
+ generateCardinalityValidationCode(ctx, rowIdOrdinal, inputExprs).code
159
+ } else {
160
+ " "
161
+ }
162
+
163
+ s """
164
+ | ${sourcePresentExpr.code}
165
+ | ${targetPresentExpr.code}
166
+ |
167
+ |if ( ${targetPresentExpr.value} && ${sourcePresentExpr.value}) {
168
+ | $cardinalityValidationCode
169
+ | $matchedInstructionsCode
170
+ |} else if ( ${sourcePresentExpr.value}) {
171
+ | $notMatchedInstructionsCode
172
+ |} else if ( ${targetPresentExpr.value}) {
173
+ | $notMatchedBySourceInstructionsCode
174
+ |}
175
+ """ .stripMargin
176
+ }
177
+
178
+ /**
179
+ * Generate code for executing a sequence of instructions
180
+ */
181
+ private def generateInstructionsCode (ctx : CodegenContext , instructions : Seq [Instruction ],
182
+ instructionType : String ,
183
+ inputExprs : Seq [ExprCode ],
184
+ sourcePresent : Boolean ): String = {
185
+ if (instructions.isEmpty) {
186
+ " "
187
+ } else {
188
+ val instructionCodes = instructions.map(instruction =>
189
+ generateSingleInstructionCode(ctx, instruction, inputExprs, sourcePresent))
190
+
191
+ s """
192
+ | ${instructionCodes.mkString(" \n " )}
193
+ |return;
194
+ """ .stripMargin
195
+ }
196
+ }
197
+
198
+ private def generateSingleInstructionCode (ctx : CodegenContext ,
199
+ instruction : Instruction ,
200
+ inputExprs : Seq [ExprCode ],
201
+ sourcePresent : Boolean ): String = {
202
+ instruction match {
203
+ case Keep (context, condition, outputExprs) =>
204
+ val projectionExpr = generateProjectionCode(ctx, outputExprs, inputExprs)
205
+ val code = generatePredicateCode(ctx, condition, child.output, inputExprs)
206
+
207
+ // Generate metric updates based on context
208
+ val metricUpdateCode = generateMetricUpdateCode(ctx, context, sourcePresent)
209
+
210
+ s """
211
+ | ${code.code}
212
+ |if ( ${code.value}) {
213
+ | $metricUpdateCode
214
+ | ${consume(ctx, projectionExpr)}
215
+ | return;
216
+ |}
217
+ """ .stripMargin
218
+
219
+ case Discard (condition) =>
220
+ val code = generatePredicateCode(ctx, condition, child.output, inputExprs)
221
+ val metricUpdateCode = generateDeleteMetricUpdateCode(ctx, sourcePresent)
222
+
223
+ s """
224
+ | ${code.code}
225
+ |if ( ${code.value}) {
226
+ | $metricUpdateCode
227
+ | return; // Discar row
228
+ |}
229
+ """ .stripMargin
230
+
231
+ case Split (condition, outputExprs, otherOutputExprs) =>
232
+ val projectionExpr = generateProjectionCode(ctx, outputExprs, inputExprs)
233
+ val otherProjectionExpr = generateProjectionCode(ctx, otherOutputExprs, inputExprs)
234
+ val code = generatePredicateCode(ctx, condition, child.output, inputExprs)
235
+ val metricUpdateCode = generateUpdateMetricUpdateCode(ctx, sourcePresent)
236
+
237
+ s """
238
+ | ${code.code}
239
+ |if ( ${code.value}) {
240
+ | $metricUpdateCode
241
+ | ${consume(ctx, projectionExpr)}
242
+ | ${consume(ctx, otherProjectionExpr)}
243
+ | return;
244
+ |}
245
+ """ .stripMargin
246
+ case _ =>
247
+ // Codegen not implemented
248
+ throw new SparkUnsupportedOperationException (
249
+ errorClass = " _LEGACY_ERROR_TEMP_3073" ,
250
+ messageParameters = Map (" instruction" -> instruction.toString))
251
+ }
252
+ }
253
+
254
+ /**
255
+ * metric update code based on Keep's context
256
+ */
257
+ private def generateMetricUpdateCode (ctx : CodegenContext , context : Context ,
258
+ sourcePresent : Boolean ): String = {
259
+ context match {
260
+ case Copy =>
261
+ val copyMetric = metricTerm(ctx, " numTargetRowsCopied" )
262
+ s " $copyMetric.add(1); "
263
+
264
+ case Insert =>
265
+ val insertMetric = metricTerm(ctx, " numTargetRowsInserted" )
266
+ s " $insertMetric.add(1); "
267
+
268
+ case Update =>
269
+ generateUpdateMetricUpdateCode(ctx, sourcePresent)
270
+
271
+ case Delete =>
272
+ generateDeleteMetricUpdateCode(ctx, sourcePresent)
273
+
274
+ case _ =>
275
+ throw new IllegalArgumentException (s " Unexpected context for KeepExec: $context" )
276
+ }
277
+ }
278
+
279
+ private def generateUpdateMetricUpdateCode (ctx : CodegenContext ,
280
+ sourcePresent : Boolean ): String = {
281
+ val updateMetric = metricTerm(ctx, " numTargetRowsUpdated" )
282
+ if (sourcePresent) {
283
+ val matchedUpdateMetric = metricTerm(ctx, " numTargetRowsMatchedUpdated" )
284
+
285
+ s """
286
+ | $updateMetric.add(1);
287
+ | $matchedUpdateMetric.add(1);
288
+ """ .stripMargin
289
+ } else {
290
+ val notMatchedBySourceUpdateMetric = metricTerm(ctx, " numTargetRowsNotMatchedBySourceUpdated" )
291
+
292
+ s """
293
+ | $updateMetric.add(1);
294
+ | $notMatchedBySourceUpdateMetric.add(1);
295
+ """ .stripMargin
296
+ }
297
+ }
298
+
299
+ private def generateDeleteMetricUpdateCode (ctx : CodegenContext ,
300
+ sourcePresent : Boolean ): String = {
301
+ val deleteMetric = metricTerm(ctx, " numTargetRowsDeleted" )
302
+ if (sourcePresent) {
303
+ val matchedDeleteMetric = metricTerm(ctx, " numTargetRowsMatchedDeleted" )
304
+
305
+ s """
306
+ | $deleteMetric.add(1);
307
+ | $matchedDeleteMetric.add(1);
308
+ """ .stripMargin
309
+ } else {
310
+ val notMatchedBySourceDeleteMetric = metricTerm(ctx, " numTargetRowsNotMatchedBySourceDeleted" )
311
+
312
+ s """
313
+ | $deleteMetric.add(1);
314
+ | $notMatchedBySourceDeleteMetric.add(1);
315
+ """ .stripMargin
316
+ }
317
+ }
318
+
319
+ /**
320
+ * Helper method to save and restore CodegenContext state for code generation.
321
+ *
322
+ * This is needed because when generating code for expressions, the CodegenContext
323
+ * state (currentVars and INPUT_ROW) gets modified during expression evaluation.
324
+ * This method temporarily sets the context to the input variables from doConsume
325
+ * and restores the original state after the block completes.
326
+ */
327
+ private def withCodegenContext [T ](
328
+ ctx : CodegenContext ,
329
+ inputCurrentVars : Seq [ExprCode ])(block : => T ): T = {
330
+ val originalCurrentVars = ctx.currentVars
331
+ val originalInputRow = ctx.INPUT_ROW
332
+ try {
333
+ // Set to the input variables saved in doConsume
334
+ ctx.currentVars = inputCurrentVars
335
+ block
336
+ } finally {
337
+ // Restore original context
338
+ ctx.currentVars = originalCurrentVars
339
+ ctx.INPUT_ROW = originalInputRow
340
+ }
341
+ }
342
+
343
+ private def generatePredicateCode (ctx : CodegenContext ,
344
+ predicate : Expression ,
345
+ inputAttrs : Seq [Attribute ],
346
+ inputCurrentVars : Seq [ExprCode ]): ExprCode = {
347
+ withCodegenContext(ctx, inputCurrentVars) {
348
+ val boundPredicate = BindReferences .bindReference(predicate, inputAttrs)
349
+ val ev = boundPredicate.genCode(ctx)
350
+ val predicateVar = ctx.freshName(" predicateResult" )
351
+ val code = code """
352
+ | ${ev.code}
353
+ |boolean $predicateVar = ! ${ev.isNull} && ${ev.value};
354
+ """ .stripMargin
355
+ ExprCode (code, FalseLiteral ,
356
+ JavaCode .variable(predicateVar, BooleanType ))
357
+ }
358
+ }
359
+
360
+ private def generateProjectionCode (ctx : CodegenContext ,
361
+ outputExprs : Seq [Expression ],
362
+ inputCurrentVars : Seq [ExprCode ]): Seq [ExprCode ] = {
363
+ withCodegenContext(ctx, inputCurrentVars) {
364
+ val boundExprs = outputExprs.map(BindReferences .bindReference(_, child.output))
365
+ boundExprs.map(_.genCode(ctx))
366
+ }
367
+ }
368
+
95
369
private def processPartition (rowIterator : Iterator [InternalRow ]): Iterator [InternalRow ] = {
96
370
val isSourceRowPresentPred = createPredicate(isSourceRowPresent)
97
371
val isTargetRowPresentPred = createPredicate(isTargetRowPresent)
0 commit comments