Skip to content

Commit 30a452d

Browse files
committed
[ENH]: Make all functions incremental
1 parent 3646f99 commit 30a452d

File tree

3 files changed

+223
-31
lines changed

3 files changed

+223
-31
lines changed

rust/worker/src/execution/functions/statistics.rs

Lines changed: 154 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
//! The core idea is the following: For each key-value pair associated with a record, aggregate so
44
//! (key, value) -> count. This gives a count of how frequently each key appears.
55
//!
6-
//! For now it's not incremental.
6+
//! The statistics executor is incremental - it loads existing counts from the output_reader
7+
//! and updates them with new records.
78
89
use std::collections::{HashMap, HashSet};
910
use std::hash::{Hash, Hasher};
@@ -27,8 +28,10 @@ pub trait StatisticsFunctionFactory: std::fmt::Debug + Send + Sync {
2728

2829
/// Accumulate statistics. Must be an associative and commutative over a sequence of `observe` calls.
2930
pub trait StatisticsFunction: std::fmt::Debug + Send {
30-
fn observe(&mut self, hydrated_record: &HydratedMaterializedLogRecord<'_, '_>);
31+
fn observe_insert(&mut self, hydrated_record: &HydratedMaterializedLogRecord<'_, '_>);
32+
fn observe_delete(&mut self, hydrated_record: &HydratedMaterializedLogRecord<'_, '_>);
3133
fn output(&self) -> UpdateMetadataValue;
34+
fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
3235
}
3336

3437
#[derive(Debug, Default)]
@@ -45,14 +48,29 @@ pub struct CounterFunction {
4548
acc: i64,
4649
}
4750

51+
impl CounterFunction {
52+
/// Create a CounterFunction with an initial value.
53+
pub fn with_initial_value(value: i64) -> Self {
54+
Self { acc: value }
55+
}
56+
}
57+
4858
impl StatisticsFunction for CounterFunction {
49-
fn observe(&mut self, _: &HydratedMaterializedLogRecord<'_, '_>) {
59+
fn observe_insert(&mut self, _: &HydratedMaterializedLogRecord<'_, '_>) {
5060
self.acc = self.acc.saturating_add(1);
5161
}
5262

63+
fn observe_delete(&mut self, _: &HydratedMaterializedLogRecord<'_, '_>) {
64+
self.acc = self.acc.saturating_sub(1);
65+
}
66+
5367
fn output(&self) -> UpdateMetadataValue {
5468
UpdateMetadataValue::Int(self.acc)
5569
}
70+
71+
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
72+
self
73+
}
5674
}
5775

5876
/// Canonical representation of metadata values tracked by the statistics executor.
@@ -171,31 +189,153 @@ impl Hash for StatisticsValue {
171189
#[derive(Debug)]
172190
pub struct StatisticsFunctionExecutor(pub Box<dyn StatisticsFunctionFactory>);
173191

192+
impl StatisticsFunctionExecutor {
193+
/// Load existing statistics from the output reader.
194+
/// Returns a HashMap with the same structure as the counts HashMap.
195+
async fn load_existing_statistics(
196+
&self,
197+
output_reader: Option<&RecordSegmentReader<'_>>,
198+
) -> Result<
199+
HashMap<String, HashMap<StatisticsValue, Box<dyn StatisticsFunction>>>,
200+
Box<dyn ChromaError>,
201+
> {
202+
let mut counts: HashMap<String, HashMap<StatisticsValue, Box<dyn StatisticsFunction>>> =
203+
HashMap::default();
204+
205+
let Some(reader) = output_reader else {
206+
return Ok(counts);
207+
};
208+
209+
let max_offset_id = reader.get_max_offset_id();
210+
let mut stream = reader.get_data_stream(0..=max_offset_id).await;
211+
212+
while let Some(record_result) = stream.next().await {
213+
let (_, record) = record_result?;
214+
215+
// Parse the record to extract key, value, type, and count
216+
let Some(metadata) = &record.metadata else {
217+
continue;
218+
};
219+
220+
let key = match metadata.get("key") {
221+
Some(MetadataValue::Str(k)) => k.clone(),
222+
_ => continue,
223+
};
224+
225+
let value_type = match metadata.get("type") {
226+
Some(MetadataValue::Str(t)) => t.as_str(),
227+
_ => continue,
228+
};
229+
230+
let value_str = match metadata.get("value") {
231+
Some(MetadataValue::Str(v)) => v.as_str(),
232+
_ => continue,
233+
};
234+
235+
let count = match metadata.get("count") {
236+
Some(MetadataValue::Int(c)) => *c,
237+
_ => continue,
238+
};
239+
240+
// Reconstruct the StatisticsValue from type and value
241+
let stats_value = match value_type {
242+
"bool" => match value_str {
243+
"true" => StatisticsValue::Bool(true),
244+
"false" => StatisticsValue::Bool(false),
245+
_ => continue,
246+
},
247+
"int" => match value_str.parse::<i64>() {
248+
Ok(i) => StatisticsValue::Int(i),
249+
_ => continue,
250+
},
251+
"float" => match value_str.parse::<f64>() {
252+
Ok(f) => StatisticsValue::Float(f),
253+
_ => continue,
254+
},
255+
"str" => StatisticsValue::Str(value_str.to_string()),
256+
"sparse" => match value_str.parse::<u32>() {
257+
Ok(index) => StatisticsValue::SparseVector(index),
258+
_ => continue,
259+
},
260+
_ => continue,
261+
};
262+
263+
// Create a statistics function initialized with the existing count
264+
let stats_function =
265+
Box::new(CounterFunction::with_initial_value(count)) as Box<dyn StatisticsFunction>;
266+
267+
counts
268+
.entry(key)
269+
.or_default()
270+
.insert(stats_value, stats_function);
271+
}
272+
273+
Ok(counts)
274+
}
275+
}
276+
174277
#[async_trait]
175278
impl AttachedFunctionExecutor for StatisticsFunctionExecutor {
176279
async fn execute(
177280
&self,
178281
input_records: Chunk<HydratedMaterializedLogRecord<'_, '_>>,
179282
output_reader: Option<&RecordSegmentReader<'_>>,
180283
) -> Result<Chunk<LogRecord>, Box<dyn ChromaError>> {
181-
let mut counts: HashMap<String, HashMap<StatisticsValue, Box<dyn StatisticsFunction>>> =
182-
HashMap::default();
284+
// Load existing statistics from output_reader if available
285+
let mut counts = self.load_existing_statistics(output_reader).await?;
286+
287+
// Process new input records and update counts
183288
for (hydrated_record, _index) in input_records.iter() {
184-
// This is only applicable for non-incremental statistics.
185-
// TODO(tanujnay112): Change this when we make incremental statistics work.
186-
if hydrated_record.get_operation() == MaterializedLogOperation::DeleteExisting {
187-
continue;
289+
let metadata_delta = hydrated_record.compute_metadata_delta();
290+
291+
// Decrement counts for deleted metadata
292+
for (key, old_value) in metadata_delta.metadata_to_delete {
293+
for stats_value in StatisticsValue::from_metadata_value(old_value) {
294+
let inner_map = counts.entry(key.to_string()).or_default();
295+
inner_map
296+
.entry(stats_value)
297+
.or_insert_with(|| self.0.create())
298+
.observe_delete(hydrated_record);
299+
}
300+
}
301+
302+
// Decrement counts for old values in updates
303+
for (key, (old_value, new_value)) in &metadata_delta.metadata_to_update {
304+
for stats_value in StatisticsValue::from_metadata_value(old_value) {
305+
let inner_map = counts.entry(key.to_string()).or_default();
306+
inner_map
307+
.entry(stats_value)
308+
.or_insert_with(|| self.0.create())
309+
.observe_delete(hydrated_record);
310+
}
311+
312+
for stats_value in StatisticsValue::from_metadata_value(new_value) {
313+
let inner_map = counts.entry(key.to_string()).or_default();
314+
inner_map
315+
.entry(stats_value)
316+
.or_insert_with(|| self.0.create())
317+
.observe_insert(hydrated_record);
318+
}
188319
}
189320

190-
// Use merged_metadata to get the metadata from the hydrated record
191-
let metadata = hydrated_record.merged_metadata();
192-
for (key, value) in metadata.iter() {
193-
let inner_map = counts.entry(key.clone()).or_default();
321+
// Increment counts for new values in both updates and inserts
322+
for (key, value) in metadata_delta
323+
.metadata_to_update
324+
.iter()
325+
.map(|(k, (_old, new))| (*k, *new))
326+
.chain(
327+
metadata_delta
328+
.metadata_to_insert
329+
.iter()
330+
.map(|(k, v)| (*k, *v)),
331+
)
332+
{
194333
for stats_value in StatisticsValue::from_metadata_value(value) {
334+
let inner_map = counts.entry(key.to_string()).or_default();
195335
inner_map
196336
.entry(stats_value)
197337
.or_insert_with(|| self.0.create())
198-
.observe(hydrated_record);
338+
.observe_insert(hydrated_record);
199339
}
200340
}
201341
}

rust/worker/src/execution/operators/execute_task.rs

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ use uuid::Uuid;
1616
use crate::execution::functions::{CounterFunctionFactory, StatisticsFunctionExecutor};
1717
use crate::execution::operators::materialize_logs::MaterializeLogOutput;
1818

19+
// Constants for CountAttachedFunction
20+
const COUNT_FUNCTION_OUTPUT_ID: &str = "function_output";
21+
const COUNT_METADATA_KEY: &str = "total_count";
22+
1923
/// Trait for attached function executors that process input records and produce output records.
2024
/// Implementors can read from the output collection to maintain state across executions.
2125
#[async_trait]
@@ -40,31 +44,72 @@ pub trait AttachedFunctionExecutor: Send + Sync + std::fmt::Debug {
4044
#[derive(Debug)]
4145
pub struct CountAttachedFunction;
4246

47+
impl CountAttachedFunction {
48+
/// Reads the existing count from the output reader.
49+
/// Returns 0 if no existing count is found.
50+
async fn get_existing_count(output_reader: Option<&RecordSegmentReader<'_>>) -> i64 {
51+
let Some(reader) = output_reader else {
52+
return 0;
53+
};
54+
55+
// Try to get the existing record with the function output ID
56+
let offset_id = match reader
57+
.get_offset_id_for_user_id(COUNT_FUNCTION_OUTPUT_ID)
58+
.await
59+
{
60+
Ok(Some(offset_id)) => offset_id,
61+
_ => return 0,
62+
};
63+
64+
// Get the data record for this offset id
65+
let data_record = match reader.get_data_for_offset_id(offset_id).await {
66+
Ok(Some(data_record)) => data_record,
67+
_ => return 0,
68+
};
69+
70+
// Extract total_count from metadata
71+
if let Some(metadata) = &data_record.metadata {
72+
if let Some(chroma_types::MetadataValue::Int(count)) = metadata.get(COUNT_METADATA_KEY)
73+
{
74+
return *count;
75+
}
76+
}
77+
78+
0
79+
}
80+
}
81+
4382
#[async_trait]
4483
impl AttachedFunctionExecutor for CountAttachedFunction {
4584
async fn execute(
4685
&self,
4786
input_records: Chunk<HydratedMaterializedLogRecord<'_, '_>>,
48-
_output_reader: Option<&RecordSegmentReader<'_>>,
87+
output_reader: Option<&RecordSegmentReader<'_>>,
4988
) -> Result<Chunk<LogRecord>, Box<dyn ChromaError>> {
5089
let records_count = input_records.len() as i64;
51-
let new_total_count = records_count;
90+
91+
// Read existing count from output_reader if available
92+
let existing_count = Self::get_existing_count(output_reader).await;
93+
let new_total_count = existing_count + records_count;
5294

5395
// Create output record with updated count
5496
let mut metadata = std::collections::HashMap::new();
5597
metadata.insert(
56-
"total_count".to_string(),
98+
COUNT_METADATA_KEY.to_string(),
5799
UpdateMetadataValue::Int(new_total_count),
58100
);
59101

60102
let output_record = LogRecord {
61103
log_offset: 0,
62104
record: OperationRecord {
63-
id: "function_output".to_string(),
105+
id: COUNT_FUNCTION_OUTPUT_ID.to_string(),
64106
embedding: Some(vec![0.0]),
65107
encoding: None,
66108
metadata: Some(metadata),
67-
document: Some(format!("Processed {} records", records_count)),
109+
document: Some(format!(
110+
"Last processed {} records (total: {})",
111+
records_count, new_total_count
112+
)),
68113
operation: Operation::Upsert,
69114
},
70115
};
@@ -126,6 +171,8 @@ pub struct ExecuteAttachedFunctionInput {
126171
pub output_record_segment: Segment,
127172
/// Blockfile provider for reading segments
128173
pub blockfile_provider: BlockfileProvider,
174+
175+
pub is_rebuild: bool,
129176
}
130177

131178
/// Output from the ExecuteAttachedFunction operator
@@ -190,19 +237,23 @@ impl Operator<ExecuteAttachedFunctionInput, ExecuteAttachedFunctionOutput>
190237
);
191238

192239
// Create record segment reader from the output collection's record segment
193-
let record_segment_reader = match Box::pin(RecordSegmentReader::from_segment(
194-
&input.output_record_segment,
195-
&input.blockfile_provider,
196-
))
197-
.await
198-
{
199-
Ok(reader) => Some(reader),
200-
Err(e) if matches!(*e, RecordSegmentReaderCreationError::UninitializedSegment) => {
201-
// Output collection has no data yet - this is the first run
202-
tracing::info!("[ExecuteAttachedFunction]: Output segment uninitialized - first attached function run");
203-
None
240+
let record_segment_reader = if input.is_rebuild {
241+
None
242+
} else {
243+
match Box::pin(RecordSegmentReader::from_segment(
244+
&input.output_record_segment,
245+
&input.blockfile_provider,
246+
))
247+
.await
248+
{
249+
Ok(reader) => Some(reader),
250+
Err(e) if matches!(*e, RecordSegmentReaderCreationError::UninitializedSegment) => {
251+
// Output collection has no data yet - this is the first run
252+
tracing::info!("[ExecuteAttachedFunction]: Output segment uninitialized - first attached function run");
253+
None
254+
}
255+
Err(e) => return Err((*e).into()),
204256
}
205-
Err(e) => return Err((*e).into()),
206257
};
207258

208259
// Process all materialized logs and hydrate the records

rust/worker/src/execution/orchestration/attached_function_orchestrator.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,7 @@ impl Handler<TaskResult<CollectionAndSegments, GetCollectionAndSegmentsError>>
715715
completion_offset: collection_info.pulled_log_offset as u64, // Use the completion offset from input collection
716716
output_record_segment: message.record_segment.clone(),
717717
blockfile_provider: self.output_context.blockfile_provider.clone(),
718+
is_rebuild: self.output_context.is_rebuild,
718719
};
719720

720721
let task = wrap(

0 commit comments

Comments
 (0)