Skip to content

Commit 1ee970a

Browse files
authored
feat: collector automatically merge and align multiple collect() called with different schema (#1153)
1 parent b5a6a1b commit 1ee970a

File tree

3 files changed

+185
-27
lines changed

3 files changed

+185
-27
lines changed

src/builder/analyzer.rs

Lines changed: 143 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -255,14 +255,88 @@ fn try_merge_collector_schemas(
255255
schema1: &CollectorSchema,
256256
schema2: &CollectorSchema,
257257
) -> Result<CollectorSchema> {
258-
let fields = try_merge_fields_schemas(&schema1.fields, &schema2.fields)?;
258+
let schema1_fields = &schema1.fields;
259+
let schema2_fields = &schema2.fields;
260+
261+
// Create a map from field name to index in schema1
262+
let field_map: HashMap<FieldName, usize> = schema1_fields
263+
.iter()
264+
.enumerate()
265+
.map(|(i, f)| (f.name.clone(), i))
266+
.collect();
267+
268+
let mut output_fields = Vec::new();
269+
let mut next_field_id_1 = 0;
270+
let mut next_field_id_2 = 0;
271+
272+
for (idx, field) in schema2_fields.iter().enumerate() {
273+
if let Some(&idx1) = field_map.get(&field.name) {
274+
if idx1 < next_field_id_1 {
275+
api_bail!(
276+
"Common fields are expected to have consistent order across different `collect()` calls, but got different orders between fields '{}' and '{}'",
277+
field.name,
278+
schema1_fields[next_field_id_1 - 1].name
279+
);
280+
}
281+
// Add intervening fields from schema1
282+
for i in next_field_id_1..idx1 {
283+
output_fields.push(schema1_fields[i].clone());
284+
}
285+
// Add intervening fields from schema2
286+
for i in next_field_id_2..idx {
287+
output_fields.push(schema2_fields[i].clone());
288+
}
289+
// Merge the field
290+
let merged_type =
291+
try_make_common_value_type(&schema1_fields[idx1].value_type, &field.value_type)?;
292+
output_fields.push(FieldSchema {
293+
name: field.name.clone(),
294+
value_type: merged_type,
295+
description: None,
296+
});
297+
next_field_id_1 = idx1 + 1;
298+
next_field_id_2 = idx + 1;
299+
// Fields not in schema1 and not UUID are added at the end
300+
}
301+
}
302+
303+
// Add remaining fields from schema1
304+
for i in next_field_id_1..schema1_fields.len() {
305+
output_fields.push(schema1_fields[i].clone());
306+
}
307+
308+
// Add remaining fields from schema2
309+
for i in next_field_id_2..schema2_fields.len() {
310+
output_fields.push(schema2_fields[i].clone());
311+
}
312+
313+
// Handle auto_uuid_field_idx
314+
let auto_uuid_field_idx = match (schema1.auto_uuid_field_idx, schema2.auto_uuid_field_idx) {
315+
(Some(idx1), Some(idx2)) => {
316+
let name1 = &schema1_fields[idx1].name;
317+
let name2 = &schema2_fields[idx2].name;
318+
if name1 == name2 {
319+
// Find the position of the auto_uuid field in the merged output
320+
output_fields.iter().position(|f| &f.name == name1)
321+
} else {
322+
api_bail!(
323+
"Generated UUID fields must have the same name across different `collect()` calls, got different names: '{}' vs '{}'",
324+
name1,
325+
name2
326+
);
327+
}
328+
}
329+
(Some(_), None) | (None, Some(_)) => {
330+
api_bail!(
331+
"The generated UUID field, once present for one `collect()`, must be consistently present for other `collect()` calls for the same collector"
332+
);
333+
}
334+
(None, None) => None,
335+
};
336+
259337
Ok(CollectorSchema {
260-
fields,
261-
auto_uuid_field_idx: if schema1.auto_uuid_field_idx == schema2.auto_uuid_field_idx {
262-
schema1.auto_uuid_field_idx
263-
} else {
264-
None
265-
},
338+
fields: output_fields,
339+
auto_uuid_field_idx,
266340
})
267341
}
268342

@@ -704,11 +778,14 @@ impl AnalyzerContext {
704778
op_scope: &Arc<OpScope>,
705779
reactive_op: &NamedSpec<ReactiveOpSpec>,
706780
) -> Result<BoxFuture<'static, Result<AnalyzedReactiveOp>>> {
707-
let result_fut = match &reactive_op.spec {
781+
let op_scope_clone = op_scope.clone();
782+
let reactive_op_clone = reactive_op.clone();
783+
let reactive_op_name = reactive_op.name.clone();
784+
let result_fut = match reactive_op_clone.spec {
708785
ReactiveOpSpec::Transform(op) => {
709786
let input_field_schemas =
710787
analyze_input_fields(&op.inputs, op_scope).with_context(|| {
711-
format!("Preparing inputs for transform op: {}", reactive_op.name)
788+
format!("Preparing inputs for transform op: {}", reactive_op_name)
712789
})?;
713790
let spec = serde_json::Value::Object(op.op.spec.clone());
714791

@@ -725,8 +802,8 @@ impl AnalyzerContext {
725802
.with(&output_enriched_type.without_attrs())?;
726803
let output_type = output_enriched_type.typ.clone();
727804
let output =
728-
op_scope.add_op_output(reactive_op.name.clone(), output_enriched_type)?;
729-
let op_name = reactive_op.name.clone();
805+
op_scope.add_op_output(reactive_op_name.clone(), output_enriched_type)?;
806+
let op_name = reactive_op_name.clone();
730807
async move {
731808
trace!("Start building executor for transform op `{op_name}`");
732809
let executor = executor.await.with_context(|| {
@@ -777,10 +854,10 @@ impl AnalyzerContext {
777854
.lock()
778855
.unwrap()
779856
.sub_scopes
780-
.insert(reactive_op.name.clone(), Arc::new(sub_op_scope_schema));
857+
.insert(reactive_op_name.clone(), Arc::new(sub_op_scope_schema));
781858
analyzed_op_scope_fut
782859
};
783-
let op_name = reactive_op.name.clone();
860+
let op_name = reactive_op_name.clone();
784861

785862
let concur_control_options =
786863
foreach_op.execution_options.get_concur_control_options();
@@ -800,22 +877,61 @@ impl AnalyzerContext {
800877
}
801878

802879
ReactiveOpSpec::Collect(op) => {
803-
let (struct_mapping, fields_schema) = analyze_struct_mapping(&op.input, op_scope)?;
880+
let (struct_mapping, fields_schema) =
881+
analyze_struct_mapping(&op.input, &op_scope_clone)?;
804882
let has_auto_uuid_field = op.auto_uuid_field.is_some();
805883
let fingerprinter = Fingerprinter::default().with(&fields_schema)?;
806-
let collect_op = AnalyzedReactiveOp::Collect(AnalyzedCollectOp {
807-
name: reactive_op.name.clone(),
808-
has_auto_uuid_field,
809-
input: struct_mapping,
810-
collector_ref: add_collector(
811-
&op.scope_name,
812-
op.collector_name.clone(),
813-
CollectorSchema::from_fields(fields_schema, op.auto_uuid_field.clone()),
814-
op_scope,
815-
)?,
816-
fingerprinter,
817-
});
818-
async move { Ok(collect_op) }.boxed()
884+
let input_field_names: Vec<FieldName> =
885+
fields_schema.iter().map(|f| f.name.clone()).collect();
886+
let collector_ref = add_collector(
887+
&op.scope_name,
888+
op.collector_name.clone(),
889+
CollectorSchema::from_fields(fields_schema, op.auto_uuid_field.clone()),
890+
&op_scope_clone,
891+
)?;
892+
async move {
893+
// Get the merged collector schema after adding
894+
let collector_schema: Arc<CollectorSchema> = {
895+
let scope = find_scope(&op.scope_name, &op_scope_clone)?.1;
896+
let states = scope.states.lock().unwrap();
897+
let collector = states.collectors.get(&op.collector_name).unwrap();
898+
collector.schema.clone()
899+
};
900+
901+
// Pre-compute field index mappings for efficient evaluation
902+
let field_name_to_index: HashMap<&FieldName, usize> = collector_schema
903+
.fields
904+
.iter()
905+
.enumerate()
906+
.map(|(i, f)| (&f.name, i))
907+
.collect();
908+
let mut field_index_mapping: HashMap<usize, usize> = HashMap::new();
909+
for (input_idx, field_name) in input_field_names.iter().enumerate() {
910+
let collector_idx = field_name_to_index
911+
.get(field_name)
912+
.copied()
913+
.ok_or_else(|| {
914+
anyhow!(
915+
"field `{}` not found in merged collector schema",
916+
field_name
917+
)
918+
})?;
919+
field_index_mapping.insert(collector_idx, input_idx);
920+
}
921+
922+
let collect_op = AnalyzedReactiveOp::Collect(AnalyzedCollectOp {
923+
name: reactive_op_name,
924+
has_auto_uuid_field,
925+
input: struct_mapping,
926+
input_field_names,
927+
collector_schema,
928+
collector_ref,
929+
field_index_mapping,
930+
fingerprinter,
931+
});
932+
Ok(collect_op)
933+
}
934+
.boxed()
819935
}
820936
};
821937
Ok(result_fut)

src/builder/plan.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use crate::base::schema::FieldSchema;
2+
use crate::base::spec::FieldName;
23
use crate::prelude::*;
34

5+
use std::collections::HashMap;
6+
47
use crate::ops::interface::*;
58
use crate::utils::fingerprint::{Fingerprint, Fingerprinter};
69

@@ -90,7 +93,11 @@ pub struct AnalyzedCollectOp {
9093
pub name: String,
9194
pub has_auto_uuid_field: bool,
9295
pub input: AnalyzedStructMapping,
96+
pub input_field_names: Vec<FieldName>,
97+
pub collector_schema: Arc<schema::CollectorSchema>,
9398
pub collector_ref: AnalyzedCollectorReference,
99+
/// Pre-computed mapping from collector field index to input field index.
100+
pub field_index_mapping: HashMap<usize, usize>,
94101
/// Fingerprinter of the collector's schema. Used to decide when to reuse auto-generated UUIDs.
95102
pub fingerprinter: Fingerprinter,
96103
}

src/execution/evaluator.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,41 @@ async fn evaluate_op_scope(
515515
let collector_entry = scoped_entries
516516
.headn(op.collector_ref.scope_up_level as usize)
517517
.ok_or_else(|| anyhow::anyhow!("Collector level out of bound"))?;
518+
519+
// Assemble input values
520+
let input_values: Vec<value::Value> =
521+
assemble_input_values(&op.input.fields, scoped_entries)
522+
.collect::<Result<Vec<_>>>()?;
523+
524+
// Create field_values vector for all fields in the merged schema
525+
let mut field_values: Vec<value::Value> =
526+
vec![value::Value::Null; op.collector_schema.fields.len()];
527+
528+
// Use pre-computed field index mappings for O(1) field placement
529+
for (&collector_idx, &input_idx) in op.field_index_mapping.iter() {
530+
field_values[collector_idx] = input_values[input_idx].clone();
531+
}
532+
533+
// Handle auto_uuid_field (assumed to be at position 0 for efficiency)
534+
if op.has_auto_uuid_field {
535+
if let Some(uuid_idx) = op.collector_schema.auto_uuid_field_idx {
536+
let uuid = memory.next_uuid(
537+
op.fingerprinter
538+
.clone()
539+
.with(
540+
&field_values
541+
.iter()
542+
.enumerate()
543+
.filter(|(i, _)| *i != uuid_idx)
544+
.map(|(_, v)| v)
545+
.collect::<Vec<_>>(),
546+
)?
547+
.into_fingerprint(),
548+
)?;
549+
field_values[uuid_idx] = value::Value::Basic(value::BasicValue::Uuid(uuid));
550+
}
551+
}
552+
518553
{
519554
let mut collected_records = collector_entry.collected_values
520555
[op.collector_ref.local.collector_idx as usize]

0 commit comments

Comments
 (0)