Skip to content

Commit 1780f2e

Browse files
committed
[ENH]: Make all functions incremental
1 parent 2bfcc0b commit 1780f2e

File tree

3 files changed

+224
-33
lines changed

3 files changed

+224
-33
lines changed

rust/worker/src/execution/functions/statistics.rs

Lines changed: 155 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
//! The core idea is the following: For each key-value pair associated with a record, aggregate so
44
//! (key, value) -> count. This gives a count of how frequently each key appears.
55
//!
6-
//! For now it's not incremental.
6+
//! The statistics executor is incremental - it loads existing counts from the output_reader
7+
//! and updates them with new records.
78
89
use std::collections::{HashMap, HashSet};
910
use std::hash::{Hash, Hasher};
@@ -13,8 +14,7 @@ use chroma_error::ChromaError;
1314
use chroma_segment::blockfile_record::RecordSegmentReader;
1415
use chroma_segment::types::HydratedMaterializedLogRecord;
1516
use chroma_types::{
16-
Chunk, LogRecord, MaterializedLogOperation, MetadataValue, Operation, OperationRecord,
17-
UpdateMetadataValue,
17+
Chunk, LogRecord, MetadataValue, Operation, OperationRecord, UpdateMetadataValue,
1818
};
1919
use futures::StreamExt;
2020

@@ -27,8 +27,10 @@ pub trait StatisticsFunctionFactory: std::fmt::Debug + Send + Sync {
2727

2828
/// Accumulate statistics. Must be an associative and commutative over a sequence of `observe` calls.
2929
pub trait StatisticsFunction: std::fmt::Debug + Send {
30-
fn observe(&mut self, hydrated_record: &HydratedMaterializedLogRecord<'_, '_>);
30+
fn observe_insert(&mut self, hydrated_record: &HydratedMaterializedLogRecord<'_, '_>);
31+
fn observe_delete(&mut self, hydrated_record: &HydratedMaterializedLogRecord<'_, '_>);
3132
fn output(&self) -> UpdateMetadataValue;
33+
fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
3234
}
3335

3436
#[derive(Debug, Default)]
@@ -45,14 +47,29 @@ pub struct CounterFunction {
4547
acc: i64,
4648
}
4749

50+
impl CounterFunction {
51+
/// Create a CounterFunction with an initial value.
52+
pub fn with_initial_value(value: i64) -> Self {
53+
Self { acc: value }
54+
}
55+
}
56+
4857
impl StatisticsFunction for CounterFunction {
49-
fn observe(&mut self, _: &HydratedMaterializedLogRecord<'_, '_>) {
58+
fn observe_insert(&mut self, _: &HydratedMaterializedLogRecord<'_, '_>) {
5059
self.acc = self.acc.saturating_add(1);
5160
}
5261

62+
fn observe_delete(&mut self, _: &HydratedMaterializedLogRecord<'_, '_>) {
63+
self.acc = self.acc.saturating_sub(1);
64+
}
65+
5366
fn output(&self) -> UpdateMetadataValue {
5467
UpdateMetadataValue::Int(self.acc)
5568
}
69+
70+
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
71+
self
72+
}
5673
}
5774

5875
/// Canonical representation of metadata values tracked by the statistics executor.
@@ -171,31 +188,153 @@ impl Hash for StatisticsValue {
171188
#[derive(Debug)]
172189
pub struct StatisticsFunctionExecutor(pub Box<dyn StatisticsFunctionFactory>);
173190

191+
impl StatisticsFunctionExecutor {
192+
/// Load existing statistics from the output reader.
193+
/// Returns a HashMap with the same structure as the counts HashMap.
194+
async fn load_existing_statistics(
195+
&self,
196+
output_reader: Option<&RecordSegmentReader<'_>>,
197+
) -> Result<
198+
HashMap<String, HashMap<StatisticsValue, Box<dyn StatisticsFunction>>>,
199+
Box<dyn ChromaError>,
200+
> {
201+
let mut counts: HashMap<String, HashMap<StatisticsValue, Box<dyn StatisticsFunction>>> =
202+
HashMap::default();
203+
204+
let Some(reader) = output_reader else {
205+
return Ok(counts);
206+
};
207+
208+
let max_offset_id = reader.get_max_offset_id();
209+
let mut stream = reader.get_data_stream(0..=max_offset_id).await;
210+
211+
while let Some(record_result) = stream.next().await {
212+
let (_, record) = record_result?;
213+
214+
// Parse the record to extract key, value, type, and count
215+
let Some(metadata) = &record.metadata else {
216+
continue;
217+
};
218+
219+
let key = match metadata.get("key") {
220+
Some(MetadataValue::Str(k)) => k.clone(),
221+
_ => continue,
222+
};
223+
224+
let value_type = match metadata.get("type") {
225+
Some(MetadataValue::Str(t)) => t.as_str(),
226+
_ => continue,
227+
};
228+
229+
let value_str = match metadata.get("value") {
230+
Some(MetadataValue::Str(v)) => v.as_str(),
231+
_ => continue,
232+
};
233+
234+
let count = match metadata.get("count") {
235+
Some(MetadataValue::Int(c)) => *c,
236+
_ => continue,
237+
};
238+
239+
// Reconstruct the StatisticsValue from type and value
240+
let stats_value = match value_type {
241+
"bool" => match value_str {
242+
"true" => StatisticsValue::Bool(true),
243+
"false" => StatisticsValue::Bool(false),
244+
_ => continue,
245+
},
246+
"int" => match value_str.parse::<i64>() {
247+
Ok(i) => StatisticsValue::Int(i),
248+
_ => continue,
249+
},
250+
"float" => match value_str.parse::<f64>() {
251+
Ok(f) => StatisticsValue::Float(f),
252+
_ => continue,
253+
},
254+
"str" => StatisticsValue::Str(value_str.to_string()),
255+
"sparse" => match value_str.parse::<u32>() {
256+
Ok(index) => StatisticsValue::SparseVector(index),
257+
_ => continue,
258+
},
259+
_ => continue,
260+
};
261+
262+
// Create a statistics function initialized with the existing count
263+
let stats_function =
264+
Box::new(CounterFunction::with_initial_value(count)) as Box<dyn StatisticsFunction>;
265+
266+
counts
267+
.entry(key)
268+
.or_default()
269+
.insert(stats_value, stats_function);
270+
}
271+
272+
Ok(counts)
273+
}
274+
}
275+
174276
#[async_trait]
175277
impl AttachedFunctionExecutor for StatisticsFunctionExecutor {
176278
async fn execute(
177279
&self,
178280
input_records: Chunk<HydratedMaterializedLogRecord<'_, '_>>,
179281
output_reader: Option<&RecordSegmentReader<'_>>,
180282
) -> Result<Chunk<LogRecord>, Box<dyn ChromaError>> {
181-
let mut counts: HashMap<String, HashMap<StatisticsValue, Box<dyn StatisticsFunction>>> =
182-
HashMap::default();
283+
// Load existing statistics from output_reader if available
284+
let mut counts = self.load_existing_statistics(output_reader).await?;
285+
286+
// Process new input records and update counts
183287
for (hydrated_record, _index) in input_records.iter() {
184-
// This is only applicable for non-incremental statistics.
185-
// TODO(tanujnay112): Change this when we make incremental statistics work.
186-
if hydrated_record.get_operation() == MaterializedLogOperation::DeleteExisting {
187-
continue;
288+
let metadata_delta = hydrated_record.compute_metadata_delta();
289+
290+
// Decrement counts for deleted metadata
291+
for (key, old_value) in metadata_delta.metadata_to_delete {
292+
for stats_value in StatisticsValue::from_metadata_value(old_value) {
293+
let inner_map = counts.entry(key.to_string()).or_default();
294+
inner_map
295+
.entry(stats_value)
296+
.or_insert_with(|| self.0.create())
297+
.observe_delete(hydrated_record);
298+
}
299+
}
300+
301+
// Decrement counts for old values in updates
302+
for (key, (old_value, new_value)) in &metadata_delta.metadata_to_update {
303+
for stats_value in StatisticsValue::from_metadata_value(old_value) {
304+
let inner_map = counts.entry(key.to_string()).or_default();
305+
inner_map
306+
.entry(stats_value)
307+
.or_insert_with(|| self.0.create())
308+
.observe_delete(hydrated_record);
309+
}
310+
311+
for stats_value in StatisticsValue::from_metadata_value(new_value) {
312+
let inner_map = counts.entry(key.to_string()).or_default();
313+
inner_map
314+
.entry(stats_value)
315+
.or_insert_with(|| self.0.create())
316+
.observe_insert(hydrated_record);
317+
}
188318
}
189319

190-
// Use merged_metadata to get the metadata from the hydrated record
191-
let metadata = hydrated_record.merged_metadata();
192-
for (key, value) in metadata.iter() {
193-
let inner_map = counts.entry(key.clone()).or_default();
320+
// Increment counts for new values in both updates and inserts
321+
for (key, value) in metadata_delta
322+
.metadata_to_update
323+
.iter()
324+
.map(|(k, (_old, new))| (*k, *new))
325+
.chain(
326+
metadata_delta
327+
.metadata_to_insert
328+
.iter()
329+
.map(|(k, v)| (*k, *v)),
330+
)
331+
{
194332
for stats_value in StatisticsValue::from_metadata_value(value) {
333+
let inner_map = counts.entry(key.to_string()).or_default();
195334
inner_map
196335
.entry(stats_value)
197336
.or_insert_with(|| self.0.create())
198-
.observe(hydrated_record);
337+
.observe_insert(hydrated_record);
199338
}
200339
}
201340
}

rust/worker/src/execution/operators/execute_task.rs

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ use uuid::Uuid;
1616
use crate::execution::functions::{CounterFunctionFactory, StatisticsFunctionExecutor};
1717
use crate::execution::operators::materialize_logs::MaterializeLogOutput;
1818

19+
// Constants for CountAttachedFunction
20+
const COUNT_FUNCTION_OUTPUT_ID: &str = "function_output";
21+
const COUNT_METADATA_KEY: &str = "total_count";
22+
1923
/// Trait for attached function executors that process input records and produce output records.
2024
/// Implementors can read from the output collection to maintain state across executions.
2125
#[async_trait]
@@ -40,31 +44,72 @@ pub trait AttachedFunctionExecutor: Send + Sync + std::fmt::Debug {
4044
#[derive(Debug)]
4145
pub struct CountAttachedFunction;
4246

47+
impl CountAttachedFunction {
48+
/// Reads the existing count from the output reader.
49+
/// Returns 0 if no existing count is found.
50+
async fn get_existing_count(output_reader: Option<&RecordSegmentReader<'_>>) -> i64 {
51+
let Some(reader) = output_reader else {
52+
return 0;
53+
};
54+
55+
// Try to get the existing record with the function output ID
56+
let offset_id = match reader
57+
.get_offset_id_for_user_id(COUNT_FUNCTION_OUTPUT_ID)
58+
.await
59+
{
60+
Ok(Some(offset_id)) => offset_id,
61+
_ => return 0,
62+
};
63+
64+
// Get the data record for this offset id
65+
let data_record = match reader.get_data_for_offset_id(offset_id).await {
66+
Ok(Some(data_record)) => data_record,
67+
_ => return 0,
68+
};
69+
70+
// Extract total_count from metadata
71+
if let Some(metadata) = &data_record.metadata {
72+
if let Some(chroma_types::MetadataValue::Int(count)) = metadata.get(COUNT_METADATA_KEY)
73+
{
74+
return *count;
75+
}
76+
}
77+
78+
0
79+
}
80+
}
81+
4382
#[async_trait]
4483
impl AttachedFunctionExecutor for CountAttachedFunction {
4584
async fn execute(
4685
&self,
4786
input_records: Chunk<HydratedMaterializedLogRecord<'_, '_>>,
48-
_output_reader: Option<&RecordSegmentReader<'_>>,
87+
output_reader: Option<&RecordSegmentReader<'_>>,
4988
) -> Result<Chunk<LogRecord>, Box<dyn ChromaError>> {
5089
let records_count = input_records.len() as i64;
51-
let new_total_count = records_count;
90+
91+
// Read existing count from output_reader if available
92+
let existing_count = Self::get_existing_count(output_reader).await;
93+
let new_total_count = existing_count + records_count;
5294

5395
// Create output record with updated count
5496
let mut metadata = std::collections::HashMap::new();
5597
metadata.insert(
56-
"total_count".to_string(),
98+
COUNT_METADATA_KEY.to_string(),
5799
UpdateMetadataValue::Int(new_total_count),
58100
);
59101

60102
let output_record = LogRecord {
61103
log_offset: 0,
62104
record: OperationRecord {
63-
id: "function_output".to_string(),
105+
id: COUNT_FUNCTION_OUTPUT_ID.to_string(),
64106
embedding: Some(vec![0.0]),
65107
encoding: None,
66108
metadata: Some(metadata),
67-
document: Some(format!("Processed {} records", records_count)),
109+
document: Some(format!(
110+
"Last processed {} records (total: {})",
111+
records_count, new_total_count
112+
)),
68113
operation: Operation::Upsert,
69114
},
70115
};
@@ -126,6 +171,8 @@ pub struct ExecuteAttachedFunctionInput {
126171
pub output_record_segment: Segment,
127172
/// Blockfile provider for reading segments
128173
pub blockfile_provider: BlockfileProvider,
174+
175+
pub is_rebuild: bool,
129176
}
130177

131178
/// Output from the ExecuteAttachedFunction operator
@@ -190,19 +237,23 @@ impl Operator<ExecuteAttachedFunctionInput, ExecuteAttachedFunctionOutput>
190237
);
191238

192239
// Create record segment reader from the output collection's record segment
193-
let record_segment_reader = match Box::pin(RecordSegmentReader::from_segment(
194-
&input.output_record_segment,
195-
&input.blockfile_provider,
196-
))
197-
.await
198-
{
199-
Ok(reader) => Some(reader),
200-
Err(e) if matches!(*e, RecordSegmentReaderCreationError::UninitializedSegment) => {
201-
// Output collection has no data yet - this is the first run
202-
tracing::info!("[ExecuteAttachedFunction]: Output segment uninitialized - first attached function run");
203-
None
240+
let record_segment_reader = if input.is_rebuild {
241+
None
242+
} else {
243+
match Box::pin(RecordSegmentReader::from_segment(
244+
&input.output_record_segment,
245+
&input.blockfile_provider,
246+
))
247+
.await
248+
{
249+
Ok(reader) => Some(reader),
250+
Err(e) if matches!(*e, RecordSegmentReaderCreationError::UninitializedSegment) => {
251+
// Output collection has no data yet - this is the first run
252+
tracing::info!("[ExecuteAttachedFunction]: Output segment uninitialized - first attached function run");
253+
None
254+
}
255+
Err(e) => return Err((*e).into()),
204256
}
205-
Err(e) => return Err((*e).into()),
206257
};
207258

208259
// Process all materialized logs and hydrate the records

rust/worker/src/execution/orchestration/attached_function_orchestrator.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,7 @@ impl Handler<TaskResult<CollectionAndSegments, GetCollectionAndSegmentsError>>
715715
completion_offset: collection_info.pulled_log_offset as u64, // Use the completion offset from input collection
716716
output_record_segment: message.record_segment.clone(),
717717
blockfile_provider: self.output_context.blockfile_provider.clone(),
718+
is_rebuild: self.output_context.is_rebuild,
718719
};
719720

720721
let task = wrap(

0 commit comments

Comments
 (0)