Skip to content

Commit 803617d

Browse files
committed
refactor: Only create one native plan for a query on an executor
1 parent 5d2c909 commit 803617d

File tree

6 files changed

+81
-60
lines changed

6 files changed

+81
-60
lines changed

native/core/src/execution/jni_api.rs

+13-14
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ struct ExecutionContext {
7171
pub root_op: Option<Arc<SparkPlan>>,
7272
/// The input sources for the DataFusion plan
7373
pub scans: Vec<ScanExec>,
74-
/// The global reference of input sources for the DataFusion plan
75-
pub input_sources: Vec<Arc<GlobalRef>>,
7674
/// The record batch stream to pull results from
7775
pub stream: Option<SendableRecordBatchStream>,
7876
/// The Tokio runtime used for async.
@@ -99,7 +97,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
9997
e: JNIEnv,
10098
_class: JClass,
10199
id: jlong,
102-
iterators: jobjectArray,
103100
serialized_query: jbyteArray,
104101
metrics_node: JObject,
105102
comet_task_memory_manager_obj: JObject,
@@ -133,15 +130,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
133130

134131
let metrics = Arc::new(jni_new_global_ref!(env, metrics_node)?);
135132

136-
// Get the global references of input sources
137-
let mut input_sources = vec![];
138-
let iter_array = JObjectArray::from_raw(iterators);
139-
let num_inputs = env.get_array_length(&iter_array)?;
140-
for i in 0..num_inputs {
141-
let input_source = env.get_object_array_element(&iter_array, i)?;
142-
let input_source = Arc::new(jni_new_global_ref!(env, input_source)?);
143-
input_sources.push(input_source);
144-
}
145133
let task_memory_manager =
146134
Arc::new(jni_new_global_ref!(env, comet_task_memory_manager_obj)?);
147135

@@ -163,7 +151,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
163151
spark_plan,
164152
root_op: None,
165153
scans: vec![],
166-
input_sources,
167154
stream: None,
168155
runtime,
169156
metrics,
@@ -302,6 +289,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
302289
stage_id: jint,
303290
partition: jint,
304291
exec_context: jlong,
292+
iterators: jobjectArray,
305293
array_addrs: jlongArray,
306294
schema_addrs: jlongArray,
307295
) -> jlong {
@@ -318,9 +306,20 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
318306
let start = Instant::now();
319307
let planner = PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx))
320308
.with_exec_id(exec_context_id);
309+
310+
// Get the global references of input sources
311+
let mut input_sources = vec![];
312+
let iter_array = JObjectArray::from_raw(iterators);
313+
let num_inputs = env.get_array_length(&iter_array)?;
314+
for i in 0..num_inputs {
315+
let input_source = env.get_object_array_element(&iter_array, i)?;
316+
let input_source = Arc::new(jni_new_global_ref!(env, input_source)?);
317+
input_sources.push(input_source);
318+
}
319+
321320
let (scans, root_op) = planner.create_plan(
322321
&exec_context.spark_plan,
323-
&mut exec_context.input_sources.clone(),
322+
&mut input_sources,
324323
)?;
325324
let physical_plan_time = start.elapsed();
326325

native/core/src/execution/operators/scan.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ use std::{
5757
/// Native.executePlan, it passes in the memory addresses of the input batches.
5858
#[derive(Debug, Clone)]
5959
pub struct ScanExec {
60-
/// The ID of the execution context that owns this subquery. We use this ID to retrieve the JVM
61-
/// environment `JNIEnv` from the execution context.
60+
/// The ID of the execution context that owns this scan.
6261
pub exec_context_id: i64,
6362
/// The input source of scan node. It is a global reference of JVM `CometBatchIterator` object.
6463
pub input_source: Option<Arc<GlobalRef>>,

spark/src/main/scala/org/apache/comet/CometExecIterator.scala

+51-29
Original file line numberDiff line numberDiff line change
@@ -35,50 +35,28 @@ import org.apache.comet.vector.NativeUtil
3535
* `hasNext` can be used to check if it is the end of this iterator (i.e. the native query is
3636
* done).
3737
*
38+
* @param id
39+
* The unique id of the query plan behind this native execution.
3840
* @param inputs
3941
* The input iterators producing sequence of batches of Arrow Arrays.
40-
* @param protobufQueryPlan
41-
* The serialized bytes of Spark execution plan.
4242
* @param numParts
4343
* The number of partitions.
4444
* @param partitionIndex
4545
* The index of the partition.
4646
*/
4747
class CometExecIterator(
4848
val id: Long,
49+
nativePlan: Long,
4950
inputs: Seq[Iterator[ColumnarBatch]],
5051
numOutputCols: Int,
51-
protobufQueryPlan: Array[Byte],
52-
nativeMetrics: CometMetricNode,
5352
numParts: Int,
5453
partitionIndex: Int)
5554
extends Iterator[ColumnarBatch] {
55+
import CometExecIterator._
5656

57-
private val nativeLib = new Native()
58-
private val nativeUtil = new NativeUtil()
5957
private val cometBatchIterators = inputs.map { iterator =>
6058
new CometBatchIterator(iterator, nativeUtil)
6159
}.toArray
62-
private val plan = {
63-
val conf = SparkEnv.get.conf
64-
// Only enable unified memory manager when off-heap mode is enabled. Otherwise,
65-
// we'll use the built-in memory pool from DF, and initializes with `memory_limit`
66-
// and `memory_fraction` below.
67-
nativeLib.createPlan(
68-
id,
69-
cometBatchIterators,
70-
protobufQueryPlan,
71-
nativeMetrics,
72-
new CometTaskMemoryManager(id),
73-
batchSize = COMET_BATCH_SIZE.get(),
74-
use_unified_memory_manager = conf.getBoolean("spark.memory.offHeap.enabled", false),
75-
memory_limit = CometSparkSessionExtensions.getCometMemoryOverhead(conf),
76-
memory_fraction = COMET_EXEC_MEMORY_FRACTION.get(),
77-
debug = COMET_DEBUG_ENABLED.get(),
78-
explain = COMET_EXPLAIN_NATIVE_ENABLED.get(),
79-
workerThreads = COMET_WORKER_THREADS.get(),
80-
blockingThreads = COMET_BLOCKING_THREADS.get())
81-
}
8260

8361
private var nextBatch: Option[ColumnarBatch] = None
8462
private var currentBatch: ColumnarBatch = null
@@ -91,7 +69,13 @@ class CometExecIterator(
9169
numOutputCols,
9270
(arrayAddrs, schemaAddrs) => {
9371
val ctx = TaskContext.get()
94-
nativeLib.executePlan(ctx.stageId(), partitionIndex, plan, arrayAddrs, schemaAddrs)
72+
nativeLib.executePlan(
73+
ctx.stageId(),
74+
partitionIndex,
75+
nativePlan,
76+
cometBatchIterators,
77+
arrayAddrs,
78+
schemaAddrs)
9579
})
9680
}
9781

@@ -134,8 +118,6 @@ class CometExecIterator(
134118
currentBatch.close()
135119
currentBatch = null
136120
}
137-
nativeUtil.close()
138-
nativeLib.releasePlan(plan)
139121

140122
// The allocator thoughts the exported ArrowArray and ArrowSchema structs are not released,
141123
// so it will report:
@@ -160,3 +142,43 @@ class CometExecIterator(
160142
}
161143
}
162144
}
145+
146+
object CometExecIterator {
147+
val nativeLib = new Native()
148+
val nativeUtil = new NativeUtil()
149+
150+
val planMap = new java.util.concurrent.ConcurrentHashMap[Array[Byte], Long]()
151+
152+
def createPlan(id: Long, protobufQueryPlan: Array[Byte], nativeMetrics: CometMetricNode): Long =
153+
synchronized {
154+
if (planMap.containsKey(protobufQueryPlan)) {
155+
planMap.get(protobufQueryPlan)
156+
} else {
157+
val conf = SparkEnv.get.conf
158+
159+
val plan = nativeLib.createPlan(
160+
id,
161+
protobufQueryPlan,
162+
nativeMetrics,
163+
new CometTaskMemoryManager(id),
164+
batchSize = COMET_BATCH_SIZE.get(),
165+
use_unified_memory_manager = conf.getBoolean("spark.memory.offHeap.enabled", false),
166+
memory_limit = CometSparkSessionExtensions.getCometMemoryOverhead(conf),
167+
memory_fraction = COMET_EXEC_MEMORY_FRACTION.get(),
168+
debug = COMET_DEBUG_ENABLED.get(),
169+
explain = COMET_EXPLAIN_NATIVE_ENABLED.get(),
170+
workerThreads = COMET_WORKER_THREADS.get(),
171+
blockingThreads = COMET_BLOCKING_THREADS.get())
172+
planMap.put(protobufQueryPlan, plan)
173+
plan
174+
}
175+
}
176+
177+
def releasePlan(protobufQueryPlan: Array[Byte]): Unit = synchronized {
178+
if (planMap.containsKey(protobufQueryPlan)) {
179+
val plan = planMap.get(protobufQueryPlan)
180+
nativeLib.releasePlan(plan)
181+
planMap.remove(protobufQueryPlan)
182+
}
183+
}
184+
}

spark/src/main/scala/org/apache/comet/Native.scala

+4-4
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@ class Native extends NativeBase {
3030
* The id of the query plan.
3131
* @param configMap
3232
* The Java Map object for the configs of native engine.
33-
* @param iterators
34-
* the input iterators to the native query plan. It should be the same number as the number of
35-
* scan nodes in the SparkPlan.
3633
* @param plan
3734
* the bytes of serialized SparkPlan.
3835
* @param metrics
@@ -46,7 +43,6 @@ class Native extends NativeBase {
4643
// scalastyle:off
4744
@native def createPlan(
4845
id: Long,
49-
iterators: Array[CometBatchIterator],
5046
plan: Array[Byte],
5147
metrics: CometMetricNode,
5248
taskMemoryManager: CometTaskMemoryManager,
@@ -69,6 +65,9 @@ class Native extends NativeBase {
6965
* the partition ID, for informational purposes
7066
* @param plan
7167
* the address to native query plan.
68+
* @param iterators
69+
* the input iterators to the native query plan. It should be the same number as the number of
70+
* scan nodes in the SparkPlan.
7271
* @param arrayAddrs
7372
* the addresses of Arrow Array structures
7473
* @param schemaAddrs
@@ -80,6 +79,7 @@ class Native extends NativeBase {
8079
stage: Int,
8180
partition: Int,
8281
plan: Long,
82+
iterators: Array[CometBatchIterator],
8383
arrayAddrs: Array[Long],
8484
schemaAddrs: Array[Long]): Long
8585

spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala

+1
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,7 @@ class CometShuffleWriteProcessor(
499499
// Getting rid of the fake partitionId
500500
val newInputs = inputs.asInstanceOf[Iterator[_ <: Product2[Any, Any]]].map(_._2)
501501

502+
context.taskAttemptId()
502503
val cometIter = CometExec.getCometIterator(
503504
Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]),
504505
outputAttributes.length,

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

+11-11
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,11 @@ object CometExec {
129129
nativePlan.writeTo(outputStream)
130130
outputStream.close()
131131
val bytes = outputStream.toByteArray
132-
new CometExecIterator(
133-
newIterId,
134-
inputs,
135-
numOutputCols,
136-
bytes,
137-
nativeMetrics,
138-
numParts,
139-
partitionIdx)
132+
133+
val planId = CometExec.newIterId
134+
val nativePlanId = CometExecIterator.createPlan(planId, bytes, nativeMetrics)
135+
136+
new CometExecIterator(newIterId, nativePlanId, inputs, numOutputCols, numParts, partitionIdx)
140137
}
141138

142139
/**
@@ -206,12 +203,14 @@ abstract class CometNativeExec extends CometExec {
206203
inputs: Seq[Iterator[ColumnarBatch]],
207204
numParts: Int,
208205
partitionIndex: Int): CometExecIterator = {
206+
val planId = CometExec.newIterId
207+
val nativePlan = CometExecIterator.createPlan(planId, serializedPlanCopy, nativeMetrics)
208+
209209
val it = new CometExecIterator(
210-
CometExec.newIterId,
210+
planId,
211+
nativePlan,
211212
inputs,
212213
output.length,
213-
serializedPlanCopy,
214-
nativeMetrics,
215214
numParts,
216215
partitionIndex)
217216

@@ -221,6 +220,7 @@ abstract class CometNativeExec extends CometExec {
221220
context.addTaskCompletionListener[Unit] { _ =>
222221
it.close()
223222
cleanSubqueries(it.id, this)
223+
CometExecIterator.releasePlan(serializedPlanCopy)
224224
}
225225
}
226226

0 commit comments

Comments
 (0)