diff --git a/clients/new-js/packages/ai-embeddings/chroma-bm25/src/index.ts b/clients/new-js/packages/ai-embeddings/chroma-bm25/src/index.ts index 6c21bf1f8ef..9ae56eafa67 100644 --- a/clients/new-js/packages/ai-embeddings/chroma-bm25/src/index.ts +++ b/clients/new-js/packages/ai-embeddings/chroma-bm25/src/index.ts @@ -190,18 +190,20 @@ export class ChromaBm25EmbeddingFunction implements SparseEmbeddingFunction { } private encode(text: string): SparseVector { - const tokens = this.tokenizer.tokenize(text); + const tokenList = this.tokenizer.tokenize(text); - if (tokens.length === 0) { + if (tokenList.length === 0) { return { indices: [], values: [] }; } - const docLen = tokens.length; + const docLen = tokenList.length; const counts = new Map(); + const tokenMap = new Map(); - for (const token of tokens) { + for (const token of tokenList) { const tokenId = this.hasher.hash(token); counts.set(tokenId, (counts.get(tokenId) ?? 0) + 1); + tokenMap.set(tokenId, token); } const indices = Array.from(counts.keys()).sort((a, b) => a - b); @@ -213,8 +215,15 @@ export class ChromaBm25EmbeddingFunction implements SparseEmbeddingFunction { (1 - this.b + (this.b * docLen) / this.avgDocLength); return (tf * (this.k + 1)) / denominator; }); + const tokens = indices.map((idx) => { + const token = tokenMap.get(idx); + if (!token) { + throw new Error(`Token not found for index ${idx}`); + } + return token; + }); - return { indices, values }; + return { indices, values, tokens }; } public async generate(texts: string[]): Promise { diff --git a/clients/new-js/packages/chromadb/src/api/types.gen.ts b/clients/new-js/packages/chromadb/src/api/types.gen.ts index 1cf1c789883..961bb7c8d6f 100644 --- a/clients/new-js/packages/chromadb/src/api/types.gen.ts +++ b/clients/new-js/packages/chromadb/src/api/types.gen.ts @@ -452,6 +452,10 @@ export type SparseVector = { * Dimension indices */ indices: Array; + /** + * Tokens corresponding to each index + */ + tokens?: Array | null; /** * Values corresponding to each index */ diff --git a/idl/chromadb/proto/chroma.proto b/idl/chromadb/proto/chroma.proto index 156dd441554..bf963bd4a86 100644 --- a/idl/chromadb/proto/chroma.proto +++ b/idl/chromadb/proto/chroma.proto @@ -28,6 +28,7 @@ message Vector { message SparseVector { repeated uint32 indices = 1; repeated float values = 2; + repeated string tokens = 3; } enum SegmentScope { diff --git a/rust/chroma/src/embed/bm25.rs b/rust/chroma/src/embed/bm25.rs index 4c7a2225f49..1099afdee27 100644 --- a/rust/chroma/src/embed/bm25.rs +++ b/rust/chroma/src/embed/bm25.rs @@ -30,6 +30,8 @@ where T: Tokenizer, H: TokenHasher, { + /// Whether to store tokens in the created sparse vectors. + pub store_tokens: bool, /// Tokenizer for converting text into tokens. pub tokenizer: T, /// Hasher for converting tokens into u32 identifiers. @@ -55,6 +57,7 @@ impl BM25SparseEmbeddingFunction { /// - hasher: Murmur3 with seed 0, abs() behavior pub fn default_murmur3_abs() -> Self { Self { + store_tokens: true, tokenizer: Bm25Tokenizer::default(), hasher: Murmur3AbsHasher::default(), k: 1.2, @@ -75,26 +78,50 @@ where let doc_len = tokens.len() as f32; - let mut token_ids = Vec::with_capacity(tokens.len()); - for token in tokens { - let id = self.hasher.hash(&token); - token_ids.push(id); - } + if self.store_tokens { + let mut token_ids = Vec::with_capacity(tokens.len()); + for token in tokens { + let id = self.hasher.hash(&token); + token_ids.push((id, token)); + } + + token_ids.sort_unstable(); + + let sparse_triples = token_ids.chunk_by(|a, b| a.0 == b.0).map(|chunk| { + let id = chunk[0].0; + let tk = chunk[0].1.clone(); + let tf = chunk.len() as f32; + + // BM25 formula + let score = tf * (self.k + 1.0) + / (tf + self.k * (1.0 - self.b + self.b * doc_len / self.avg_len)); - token_ids.sort_unstable(); + (tk, id, score) + }); - let sparse_pairs = token_ids.chunk_by(|a, b| a == b).map(|chunk| { - let id = chunk[0]; - let tf = chunk.len() as f32; + Ok(SparseVector::from_triples(sparse_triples)) + } else { + let mut token_ids = Vec::with_capacity(tokens.len()); + for token in tokens { + let id = self.hasher.hash(&token); + token_ids.push(id); + } - // BM25 formula - let score = tf * (self.k + 1.0) - / (tf + self.k * (1.0 - self.b + self.b * doc_len / self.avg_len)); + token_ids.sort_unstable(); - (id, score) - }); + let sparse_pairs = token_ids.chunk_by(|a, b| *a == *b).map(|chunk| { + let id = chunk[0]; + let tf = chunk.len() as f32; - Ok(SparseVector::from_pairs(sparse_pairs)) + // BM25 formula + let score = tf * (self.k + 1.0) + / (tf + self.k * (1.0 - self.b + self.b * doc_len / self.avg_len)); + + (id, score) + }); + + Ok(SparseVector::from_pairs(sparse_pairs)) + } } } diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index 6e8a6d9b1f4..0cc47464dd7 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -1683,11 +1683,13 @@ async fn collection_delete( r#where, )?; - server - .frontend - .delete(request) - .meter(metering_context_container) - .await?; + Box::pin( + server + .frontend + .delete(request) + .meter(metering_context_container), + ) + .await?; Ok(Json(DeleteCollectionRecordsResponse {})) } diff --git a/rust/frontend/tests/proptest_helpers/frontend_under_test.rs b/rust/frontend/tests/proptest_helpers/frontend_under_test.rs index d4021eada64..b584312e119 100644 --- a/rust/frontend/tests/proptest_helpers/frontend_under_test.rs +++ b/rust/frontend/tests/proptest_helpers/frontend_under_test.rs @@ -138,7 +138,9 @@ impl StateMachineTest for FrontendUnderTest { } } - state.frontend.delete(request.clone()).await.unwrap(); + Box::pin(state.frontend.delete(request.clone())) + .await + .unwrap(); } CollectionRequest::Get(mut request) => { let expected_result = { diff --git a/rust/python_bindings/src/bindings.rs b/rust/python_bindings/src/bindings.rs index d4304b6a878..75f4c0af706 100644 --- a/rust/python_bindings/src/bindings.rs +++ b/rust/python_bindings/src/bindings.rs @@ -561,7 +561,7 @@ impl Bindings { let mut frontend_clone = self.frontend.clone(); self.runtime - .block_on(async { frontend_clone.delete(request).await })?; + .block_on(async { Box::pin(frontend_clone.delete(request)).await })?; Ok(()) } diff --git a/rust/segment/src/blockfile_metadata.rs b/rust/segment/src/blockfile_metadata.rs index 440b495ee7b..d223e946904 100644 --- a/rust/segment/src/blockfile_metadata.rs +++ b/rust/segment/src/blockfile_metadata.rs @@ -2666,10 +2666,10 @@ mod test { let mut update_metadata1 = HashMap::new(); update_metadata1.insert( String::from("sparse_vec"), - UpdateMetadataValue::SparseVector(chroma_types::SparseVector::new( - vec![0, 5, 10], - vec![0.1, 0.5, 0.9], - )), + UpdateMetadataValue::SparseVector( + chroma_types::SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]) + .expect("valid sparse vector"), + ), ); update_metadata1.insert( String::from("category"), diff --git a/rust/types/src/api_types.rs b/rust/types/src/api_types.rs index 73940968a20..91fefe35980 100644 --- a/rust/types/src/api_types.rs +++ b/rust/types/src/api_types.rs @@ -2419,7 +2419,9 @@ mod test { // Add unsorted sparse vector - should fail validation metadata.insert( "sparse".to_string(), - MetadataValue::SparseVector(SparseVector::new(vec![3, 1, 2], vec![0.3, 0.1, 0.2])), + MetadataValue::SparseVector( + SparseVector::new(vec![3, 1, 2], vec![0.3, 0.1, 0.2]).unwrap(), + ), ); let result = AddCollectionRecordsRequest::try_new( @@ -2443,10 +2445,9 @@ mod test { // Add unsorted sparse vector - should fail validation metadata.insert( "sparse".to_string(), - UpdateMetadataValue::SparseVector(SparseVector::new( - vec![3, 1, 2], - vec![0.3, 0.1, 0.2], - )), + UpdateMetadataValue::SparseVector( + SparseVector::new(vec![3, 1, 2], vec![0.3, 0.1, 0.2]).unwrap(), + ), ); let result = UpdateCollectionRecordsRequest::try_new( @@ -2470,10 +2471,9 @@ mod test { // Add unsorted sparse vector - should fail validation metadata.insert( "sparse".to_string(), - UpdateMetadataValue::SparseVector(SparseVector::new( - vec![3, 1, 2], - vec![0.3, 0.1, 0.2], - )), + UpdateMetadataValue::SparseVector( + SparseVector::new(vec![3, 1, 2], vec![0.3, 0.1, 0.2]).unwrap(), + ), ); let result = UpsertCollectionRecordsRequest::try_new( diff --git a/rust/types/src/collection_schema.rs b/rust/types/src/collection_schema.rs index 50df6fc634f..474b6177c96 100644 --- a/rust/types/src/collection_schema.rs +++ b/rust/types/src/collection_schema.rs @@ -3083,7 +3083,7 @@ mod tests { let schema = Schema::new_default(KnnIndex::Spann); let result = schema.is_knn_key_indexing_enabled( "custom_sparse", - &QueryVector::Sparse(SparseVector::new(vec![0_u32], vec![1.0_f32])), + &QueryVector::Sparse(SparseVector::new(vec![0_u32], vec![1.0_f32]).unwrap()), ); let err = result.expect_err("expected indexing disabled error"); @@ -3118,7 +3118,7 @@ mod tests { let result = schema.is_knn_key_indexing_enabled( "sparse_enabled", - &QueryVector::Sparse(SparseVector::new(vec![0_u32], vec![1.0_f32])), + &QueryVector::Sparse(SparseVector::new(vec![0_u32], vec![1.0_f32]).unwrap()), ); assert!(result.is_ok()); diff --git a/rust/types/src/execution/operator.rs b/rust/types/src/execution/operator.rs index d5eec319bde..8c00709c63c 100644 --- a/rust/types/src/execution/operator.rs +++ b/rust/types/src/execution/operator.rs @@ -802,8 +802,8 @@ impl TryFrom for chroma_proto::KnnBatchResult { /// /// let sparse = QueryVector::Sparse(SparseVector::new( /// vec![0, 5, 10, 50], // indices -/// vec![0.5, 0.3, 0.8, 0.2] // values -/// )); +/// vec![0.5, 0.3, 0.8, 0.2], // values +/// ).unwrap()); /// ``` /// /// # Examples @@ -831,8 +831,8 @@ impl TryFrom for chroma_proto::KnnBatchResult { /// let rank = RankExpr::Knn { /// query: QueryVector::Sparse(SparseVector::new( /// vec![1, 5, 10], -/// vec![0.5, 0.3, 0.8] -/// )), +/// vec![0.5, 0.3, 0.8], +/// ).unwrap()), /// key: Key::field("sparse_embedding"), /// limit: 100, /// default: None, @@ -856,7 +856,9 @@ impl TryFrom for QueryVector { Ok(QueryVector::Dense(dense.try_into().map(|(v, _)| v)?)) } chroma_proto::query_vector::Vector::Sparse(sparse) => { - Ok(QueryVector::Sparse(sparse.into())) + Ok(QueryVector::Sparse(sparse.try_into().map_err(|_| { + QueryConversionError::validation("sparse vector length mismatch") + })?)) } } } @@ -2693,7 +2695,7 @@ mod tests { #[test] fn test_query_vector_sparse_proto_conversion() { - let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]); + let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap(); let query_vector = QueryVector::Sparse(sparse.clone()); // Convert to proto @@ -2979,7 +2981,8 @@ mod tests { assert_eq!(deserialized, dense); // Test sparse vector - let sparse = QueryVector::Sparse(SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9])); + let sparse = + QueryVector::Sparse(SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap()); let json = serde_json::to_string(&sparse).unwrap(); let deserialized: QueryVector = serde_json::from_str(&json).unwrap(); assert_eq!(deserialized, sparse); diff --git a/rust/types/src/metadata.rs b/rust/types/src/metadata.rs index e5800c6de77..3552e0c1dae 100644 --- a/rust/types/src/metadata.rs +++ b/rust/types/src/metadata.rs @@ -14,7 +14,7 @@ use thiserror::Error; use crate::chroma_proto; #[cfg(feature = "pyo3")] -use pyo3::types::PyAnyMethods; +use pyo3::types::{PyAnyMethods, PyDictMethods}; #[cfg(feature = "testing")] use proptest::prelude::*; @@ -25,6 +25,7 @@ struct SparseVectorSerdeHelper { type_tag: Option, indices: Vec, values: Vec, + tokens: Option>, } /// Represents a sparse vector using parallel arrays for indices and values. @@ -40,6 +41,8 @@ pub struct SparseVector { pub indices: Vec, /// Values corresponding to each index pub values: Vec, + /// Tokens corresponding to each index + pub tokens: Option>, } // Custom deserializer: accept both old and new formats @@ -63,6 +66,7 @@ impl<'de> Deserialize<'de> for SparseVector { Ok(SparseVector { indices: helper.indices, values: helper.values, + tokens: helper.tokens, }) } } @@ -77,21 +81,97 @@ impl Serialize for SparseVector { type_tag: Some("sparse_vector".to_string()), indices: self.indices.clone(), values: self.values.clone(), + tokens: self.tokens.clone(), }; helper.serialize(serializer) } } +/// Length mismatch between indices, values, and tokens in a sparse vector. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct SparseVectorLengthMismatch; + +impl std::fmt::Display for SparseVectorLengthMismatch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Sparse vector indices, values, and tokens (when present) must have the same length" + ) + } +} + +impl std::error::Error for SparseVectorLengthMismatch {} + +impl ChromaError for SparseVectorLengthMismatch { + fn code(&self) -> ErrorCodes { + ErrorCodes::InvalidArgument + } +} + impl SparseVector { /// Create a new sparse vector from parallel arrays. - pub fn new(indices: Vec, values: Vec) -> Self { - Self { indices, values } + pub fn new(indices: Vec, values: Vec) -> Result { + if indices.len() != values.len() { + return Err(SparseVectorLengthMismatch); + } + Ok(Self { + indices, + values, + tokens: None, + }) + } + + /// Create a new sparse vector from parallel arrays. + pub fn new_with_tokens( + indices: Vec, + values: Vec, + tokens: Vec, + ) -> Result { + if indices.len() != values.len() { + return Err(SparseVectorLengthMismatch); + } + if tokens.len() != indices.len() { + return Err(SparseVectorLengthMismatch); + } + Ok(Self { + indices, + values, + tokens: Some(tokens), + }) } /// Create a sparse vector from an iterator of (index, value) pairs. pub fn from_pairs(pairs: impl IntoIterator) -> Self { - let (indices, values) = pairs.into_iter().unzip(); - Self { indices, values } + let mut indices = vec![]; + let mut values = vec![]; + for (index, value) in pairs { + indices.push(index); + values.push(value); + } + let tokens = None; + Self { + indices, + values, + tokens, + } + } + + /// Create a sparse vector from an iterator of (string, index, value) pairs. + pub fn from_triples(triples: impl IntoIterator) -> Self { + let mut tokens = vec![]; + let mut indices = vec![]; + let mut values = vec![]; + for (token, index, value) in triples { + tokens.push(token); + indices.push(index); + values.push(value); + } + let tokens = Some(tokens); + Self { + indices, + values, + tokens, + } } /// Iterate over (index, value) pairs. @@ -109,6 +189,13 @@ impl SparseVector { return Err(MetadataValueConversionError::SparseVectorLengthMismatch); } + // Check that tokens (if present) align with indices + if let Some(tokens) = self.tokens.as_ref() { + if tokens.len() != self.indices.len() { + return Err(MetadataValueConversionError::SparseVectorLengthMismatch); + } + } + // Check that indices are sorted in strictly ascending order (no duplicates) for i in 1..self.indices.len() { if self.indices[i] <= self.indices[i - 1] { @@ -142,9 +229,15 @@ impl PartialOrd for SparseVector { } } -impl From for SparseVector { - fn from(proto: chroma_proto::SparseVector) -> Self { - SparseVector::new(proto.indices, proto.values) +impl TryFrom for SparseVector { + type Error = SparseVectorLengthMismatch; + + fn try_from(proto: chroma_proto::SparseVector) -> Result { + if proto.tokens.is_empty() { + SparseVector::new(proto.indices, proto.values) + } else { + SparseVector::new_with_tokens(proto.indices, proto.values, proto.tokens) + } } } @@ -153,6 +246,7 @@ impl From for chroma_proto::SparseVector { chroma_proto::SparseVector { indices: sparse.indices, values: sparse.values, + tokens: sparse.tokens.unwrap_or_default(), } } } @@ -182,9 +276,11 @@ impl<'py> pyo3::IntoPyObject<'py> for SparseVector { fn into_pyobject(self, py: pyo3::Python<'py>) -> Result { use pyo3::types::PyDict; + let dict = PyDict::new(py); dict.set_item("indices", self.indices)?; dict.set_item("values", self.values)?; + dict.set_item("tokens", self.tokens)?; Ok(dict.into_any()) } } @@ -196,12 +292,34 @@ impl<'py> pyo3::FromPyObject<'py> for SparseVector { let dict = ob.downcast::()?; let indices_obj = dict.get_item("indices")?; + if indices_obj.is_none() { + return Err(pyo3::exceptions::PyKeyError::new_err( + "missing 'indices' key", + )); + } + let indices: Vec = indices_obj.unwrap().extract()?; + let values_obj = dict.get_item("values")?; + if values_obj.is_none() { + return Err(pyo3::exceptions::PyKeyError::new_err( + "missing 'values' key", + )); + } + let values: Vec = values_obj.unwrap().extract()?; - let indices: Vec = indices_obj.extract()?; - let values: Vec = values_obj.extract()?; + let tokens_obj = dict.get_item("tokens")?; + let tokens = match tokens_obj { + Some(obj) if obj.is_none() => None, + Some(obj) => Some(obj.extract::>()?), + None => None, + }; - Ok(SparseVector::new(indices, values)) + let result = match tokens { + Some(tokens) => SparseVector::new_with_tokens(indices, values, tokens), + None => SparseVector::new(indices, values), + }; + + result.map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string())) } } @@ -328,7 +446,11 @@ impl TryFrom<&chroma_proto::UpdateMetadataValue> for UpdateMetadataValue { Ok(UpdateMetadataValue::Str(value.clone())) } Some(chroma_proto::update_metadata_value::Value::SparseVectorValue(value)) => { - Ok(UpdateMetadataValue::SparseVector(value.clone().into())) + let sparse = value + .clone() + .try_into() + .map_err(|_| UpdateMetadataValueConversionError::InvalidValue)?; + Ok(UpdateMetadataValue::SparseVector(sparse)) } // Used to communicate that the user wants to delete this key. None => Ok(UpdateMetadataValue::None), @@ -643,7 +765,7 @@ pub enum MetadataValueConversionError { InvalidValue, #[error("Metadata key cannot start with '#' or '$': {0}")] InvalidKey(String), - #[error("Sparse vector indices and values must have the same length")] + #[error("Sparse vector indices, values, and tokens (when present) must have the same length")] SparseVectorLengthMismatch, #[error("Sparse vector indices must be sorted in strictly ascending order (no duplicates)")] SparseVectorIndicesNotSorted, @@ -680,7 +802,11 @@ impl TryFrom<&chroma_proto::UpdateMetadataValue> for MetadataValue { Ok(MetadataValue::Str(value.clone())) } Some(chroma_proto::update_metadata_value::Value::SparseVectorValue(value)) => { - Ok(MetadataValue::SparseVector(value.clone().into())) + let sparse = value + .clone() + .try_into() + .map_err(|_| MetadataValueConversionError::SparseVectorLengthMismatch)?; + Ok(MetadataValue::SparseVector(sparse)) } _ => Err(MetadataValueConversionError::InvalidValue), } @@ -1739,6 +1865,15 @@ mod tests { use super::*; + // This is needed for the tests that round trip to the python world. + #[cfg(feature = "pyo3")] + fn ensure_python_interpreter() { + static PYTHON_INIT: std::sync::Once = std::sync::Once::new(); + PYTHON_INIT.call_once(|| { + pyo3::prepare_freethreaded_python(); + }); + } + #[test] fn test_update_metadata_try_from() { let mut proto_metadata = chroma_proto::UpdateMetadata { @@ -1773,6 +1908,7 @@ mod tests { chroma_proto::SparseVector { indices: vec![0, 5, 10], values: vec![0.1, 0.5, 0.9], + tokens: vec!["foo".to_string(), "bar".to_string(), "baz".to_string()], }, ), ), @@ -1794,10 +1930,14 @@ mod tests { ); assert_eq!( converted_metadata.get("sparse").unwrap(), - &UpdateMetadataValue::SparseVector(SparseVector::new( - vec![0, 5, 10], - vec![0.1, 0.5, 0.9] - )) + &UpdateMetadataValue::SparseVector( + SparseVector::new_with_tokens( + vec![0, 5, 10], + vec![0.1, 0.5, 0.9], + vec!["foo".to_string(), "bar".to_string(), "baz".to_string(),], + ) + .unwrap() + ) ); } @@ -1835,6 +1975,7 @@ mod tests { chroma_proto::SparseVector { indices: vec![1, 10, 100], values: vec![0.2, 0.4, 0.6], + tokens: vec!["foo".to_string(), "bar".to_string(), "baz".to_string()], }, ), ), @@ -1856,7 +1997,14 @@ mod tests { ); assert_eq!( converted_metadata.get("sparse").unwrap(), - &MetadataValue::SparseVector(SparseVector::new(vec![1, 10, 100], vec![0.2, 0.4, 0.6])) + &MetadataValue::SparseVector( + SparseVector::new_with_tokens( + vec![1, 10, 100], + vec![0.2, 0.4, 0.6], + vec!["foo".to_string(), "bar".to_string(), "baz".to_string(),], + ) + .unwrap() + ) ); } @@ -2010,7 +2158,7 @@ mod tests { fn test_sparse_vector_new() { let indices = vec![0, 5, 10]; let values = vec![0.1, 0.5, 0.9]; - let sparse = SparseVector::new(indices.clone(), values.clone()); + let sparse = SparseVector::new(indices.clone(), values.clone()).unwrap(); assert_eq!(sparse.indices, indices); assert_eq!(sparse.values, values); } @@ -2023,19 +2171,31 @@ mod tests { assert_eq!(sparse.values, vec![0.1, 0.5, 0.9]); } + #[test] + fn test_sparse_vector_from_triples() { + let triples = vec![ + ("foo".to_string(), 0, 0.1), + ("bar".to_string(), 5, 0.5), + ("baz".to_string(), 10, 0.9), + ]; + let sparse = SparseVector::from_triples(triples.clone()); + assert_eq!(sparse.indices, vec![0, 5, 10]); + assert_eq!(sparse.values, vec![0.1, 0.5, 0.9]); + } + #[test] fn test_sparse_vector_iter() { - let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]); + let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap(); let collected: Vec<(u32, f32)> = sparse.iter().collect(); assert_eq!(collected, vec![(0, 0.1), (5, 0.5), (10, 0.9)]); } #[test] fn test_sparse_vector_ordering() { - let sparse1 = SparseVector::new(vec![0, 5], vec![0.1, 0.5]); - let sparse2 = SparseVector::new(vec![0, 5], vec![0.1, 0.5]); - let sparse3 = SparseVector::new(vec![0, 6], vec![0.1, 0.5]); - let sparse4 = SparseVector::new(vec![0, 5], vec![0.1, 0.6]); + let sparse1 = SparseVector::new(vec![0, 5], vec![0.1, 0.5]).unwrap(); + let sparse2 = SparseVector::new(vec![0, 5], vec![0.1, 0.5]).unwrap(); + let sparse3 = SparseVector::new(vec![0, 6], vec![0.1, 0.5]).unwrap(); + let sparse4 = SparseVector::new(vec![0, 5], vec![0.1, 0.6]).unwrap(); assert_eq!(sparse1, sparse2); assert!(sparse1 < sparse3); @@ -2044,23 +2204,44 @@ mod tests { #[test] fn test_sparse_vector_proto_conversion() { - let sparse = SparseVector::new(vec![1, 10, 100], vec![0.2, 0.4, 0.6]); + let tokens = vec![ + "token1".to_string(), + "token2".to_string(), + "token3".to_string(), + ]; + let sparse = + SparseVector::new_with_tokens(vec![1, 10, 100], vec![0.2, 0.4, 0.6], tokens.clone()) + .unwrap(); let proto: chroma_proto::SparseVector = sparse.clone().into(); assert_eq!(proto.indices, vec![1, 10, 100]); assert_eq!(proto.values, vec![0.2, 0.4, 0.6]); + assert_eq!(proto.tokens, tokens.clone()); - let converted: SparseVector = proto.into(); + let converted: SparseVector = proto.try_into().unwrap(); assert_eq!(converted, sparse); + assert_eq!(converted.tokens, Some(tokens)); + } + + #[test] + fn test_sparse_vector_proto_conversion_empty_tokens() { + let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap(); + let proto: chroma_proto::SparseVector = sparse.clone().into(); + assert_eq!(proto.indices, vec![0, 5, 10]); + assert_eq!(proto.values, vec![0.1, 0.5, 0.9]); + assert_eq!(proto.tokens, Vec::::new()); + + let converted: SparseVector = proto.try_into().unwrap(); + assert_eq!(converted, sparse); + assert_eq!(converted.tokens, None); } #[test] fn test_sparse_vector_logical_size() { let metadata = Metadata::from([( "sparse".to_string(), - MetadataValue::SparseVector(SparseVector::new( - vec![0, 1, 2, 3, 4], - vec![0.1, 0.2, 0.3, 0.4, 0.5], - )), + MetadataValue::SparseVector( + SparseVector::new(vec![0, 1, 2, 3, 4], vec![0.1, 0.2, 0.3, 0.4, 0.5]).unwrap(), + ), )]); let size = logical_size_of_metadata(&metadata); @@ -2072,20 +2253,27 @@ mod tests { #[test] fn test_sparse_vector_validation() { // Valid sparse vector - let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3]); + let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3]).unwrap(); assert!(sparse.validate().is_ok()); // Length mismatch let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2]); - let result = sparse.validate(); - assert!(result.is_err()); - assert!(matches!( - result.unwrap_err(), - MetadataValueConversionError::SparseVectorLengthMismatch - )); + assert!(sparse.is_err()); + let result = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3]) + .unwrap() + .validate(); + assert!(result.is_ok()); + + // Tokens length mismatch with indices/values + let sparse = SparseVector::new_with_tokens( + vec![1, 2, 3], + vec![0.1, 0.2, 0.3], + vec!["a".to_string(), "b".to_string()], + ); + assert!(sparse.is_err()); // Unsorted indices (descending order) - let sparse = SparseVector::new(vec![3, 1, 2], vec![0.3, 0.1, 0.2]); + let sparse = SparseVector::new(vec![3, 1, 2], vec![0.3, 0.1, 0.2]).unwrap(); let result = sparse.validate(); assert!(result.is_err()); assert!(matches!( @@ -2094,7 +2282,7 @@ mod tests { )); // Duplicate indices (not strictly ascending) - let sparse = SparseVector::new(vec![1, 2, 2, 3], vec![0.1, 0.2, 0.3, 0.4]); + let sparse = SparseVector::new(vec![1, 2, 2, 3], vec![0.1, 0.2, 0.3, 0.4]).unwrap(); let result = sparse.validate(); assert!(result.is_err()); assert!(matches!( @@ -2103,7 +2291,7 @@ mod tests { )); // Descending at one point - let sparse = SparseVector::new(vec![1, 3, 2], vec![0.1, 0.3, 0.2]); + let sparse = SparseVector::new(vec![1, 3, 2], vec![0.1, 0.3, 0.2]).unwrap(); let result = sparse.validate(); assert!(result.is_err()); assert!(matches!( @@ -2153,7 +2341,7 @@ mod tests { #[test] fn test_sparse_vector_serialize_always_has_type() { // Serialization should always include #type field - let sv = SparseVector::new(vec![0, 1, 2], vec![1.0, 2.0, 3.0]); + let sv = SparseVector::new(vec![0, 1, 2], vec![1.0, 2.0, 3.0]).unwrap(); let json = serde_json::to_value(&sv).unwrap(); assert_eq!(json["#type"], "sparse_vector"); @@ -2164,7 +2352,7 @@ mod tests { #[test] fn test_sparse_vector_roundtrip_with_type() { // Test that serialize -> deserialize preserves the data - let original = SparseVector::new(vec![0, 5, 10, 15], vec![0.1, 0.5, 1.0, 1.5]); + let original = SparseVector::new(vec![0, 5, 10, 15], vec![0.1, 0.5, 1.0, 1.5]).unwrap(); let json = serde_json::to_string(&original).unwrap(); // Verify the serialized JSON contains #type @@ -2198,6 +2386,186 @@ mod tests { assert_eq!(sv.values, vec![1.0, 2.0]); } + #[test] + fn test_sparse_vector_tokens_roundtrip_old_to_new() { + // Old format without tokens field should deserialize with tokens=None + let json = r#"{"indices": [0, 1, 2], "values": [1.0, 2.0, 3.0]}"#; + let sv: SparseVector = serde_json::from_str(json).unwrap(); + assert_eq!(sv.indices, vec![0, 1, 2]); + assert_eq!(sv.values, vec![1.0, 2.0, 3.0]); + assert_eq!(sv.tokens, None); + + // Serialize and verify it includes #type but no tokens field when None + let serialized = serde_json::to_value(&sv).unwrap(); + assert_eq!(serialized["#type"], "sparse_vector"); + assert_eq!(serialized["indices"], serde_json::json!([0, 1, 2])); + assert_eq!(serialized["values"], serde_json::json!([1.0, 2.0, 3.0])); + assert_eq!(serialized["tokens"], serde_json::Value::Null); + } + + #[test] + fn test_sparse_vector_tokens_roundtrip_new_to_new() { + // New format with tokens field + let sv_with_tokens = SparseVector::new_with_tokens( + vec![0, 1, 2], + vec![1.0, 2.0, 3.0], + vec!["foo".to_string(), "bar".to_string(), "baz".to_string()], + ) + .unwrap(); + + // Serialize + let serialized = serde_json::to_string(&sv_with_tokens).unwrap(); + assert!(serialized.contains("\"#type\":\"sparse_vector\"")); + assert!(serialized.contains("\"tokens\"")); + + // Deserialize and verify tokens are preserved + let deserialized: SparseVector = serde_json::from_str(&serialized).unwrap(); + assert_eq!(deserialized.indices, vec![0, 1, 2]); + assert_eq!(deserialized.values, vec![1.0, 2.0, 3.0]); + assert_eq!( + deserialized.tokens, + Some(vec![ + "foo".to_string(), + "bar".to_string(), + "baz".to_string() + ]) + ); + } + + #[test] + fn test_sparse_vector_tokens_deserialize_with_tokens_field() { + // Test deserializing JSON that explicitly includes tokens field + let json = r##"{"#type": "sparse_vector", "indices": [5, 10], "values": [0.5, 1.0], "tokens": ["token1", "token2"]}"##; + let sv: SparseVector = serde_json::from_str(json).unwrap(); + assert_eq!(sv.indices, vec![5, 10]); + assert_eq!(sv.values, vec![0.5, 1.0]); + assert_eq!( + sv.tokens, + Some(vec!["token1".to_string(), "token2".to_string()]) + ); + } + + #[test] + fn test_sparse_vector_tokens_backward_compatibility() { + // Verify old format (no tokens, no #type) deserializes correctly + let old_json = r#"{"indices": [1, 2], "values": [0.1, 0.2]}"#; + let old_sv: SparseVector = serde_json::from_str(old_json).unwrap(); + + // Verify new format (with #type, with tokens) deserializes correctly + let new_json = r##"{"#type": "sparse_vector", "indices": [1, 2], "values": [0.1, 0.2], "tokens": ["a", "b"]}"##; + let new_sv: SparseVector = serde_json::from_str(new_json).unwrap(); + + // Both should have same indices and values + assert_eq!(old_sv.indices, new_sv.indices); + assert_eq!(old_sv.values, new_sv.values); + + // Old should have None tokens, new should have Some tokens + assert_eq!(old_sv.tokens, None); + assert_eq!(new_sv.tokens, Some(vec!["a".to_string(), "b".to_string()])); + } + + #[test] + fn test_sparse_vector_from_triples_preserves_tokens() { + let triples = vec![ + ("apple".to_string(), 10, 0.5), + ("banana".to_string(), 20, 0.7), + ("cherry".to_string(), 30, 0.9), + ]; + let sv = SparseVector::from_triples(triples.clone()); + + assert_eq!(sv.indices, vec![10, 20, 30]); + assert_eq!(sv.values, vec![0.5, 0.7, 0.9]); + assert_eq!( + sv.tokens, + Some(vec![ + "apple".to_string(), + "banana".to_string(), + "cherry".to_string() + ]) + ); + + // Roundtrip through serialization + let serialized = serde_json::to_string(&sv).unwrap(); + let deserialized: SparseVector = serde_json::from_str(&serialized).unwrap(); + + assert_eq!(deserialized.indices, sv.indices); + assert_eq!(deserialized.values, sv.values); + assert_eq!(deserialized.tokens, sv.tokens); + } + + #[cfg(feature = "pyo3")] + #[test] + fn test_sparse_vector_pyo3_roundtrip_with_tokens() { + ensure_python_interpreter(); + + pyo3::Python::with_gil(|py| { + use pyo3::types::PyDict; + use pyo3::IntoPyObject; + + let dict_in = PyDict::new(py); + dict_in.set_item("indices", vec![0u32, 1, 2]).unwrap(); + dict_in + .set_item("values", vec![0.1f32, 0.2f32, 0.3f32]) + .unwrap(); + dict_in + .set_item("tokens", vec!["foo", "bar", "baz"]) + .unwrap(); + + let sparse: SparseVector = dict_in.clone().into_any().extract().unwrap(); + assert_eq!(sparse.indices, vec![0, 1, 2]); + assert_eq!(sparse.values, vec![0.1, 0.2, 0.3]); + assert_eq!( + sparse.tokens, + Some(vec![ + "foo".to_string(), + "bar".to_string(), + "baz".to_string() + ]) + ); + + let py_obj = sparse.clone().into_pyobject(py).unwrap(); + let dict_out = py_obj.downcast::().unwrap(); + let tokens_obj = dict_out.get_item("tokens").unwrap(); + let tokens: Vec = tokens_obj + .expect("expected tokens key in Python dict") + .extract() + .unwrap(); + assert_eq!( + tokens, + vec!["foo".to_string(), "bar".to_string(), "baz".to_string()] + ); + }); + } + + #[cfg(feature = "pyo3")] + #[test] + fn test_sparse_vector_pyo3_roundtrip_without_tokens() { + ensure_python_interpreter(); + + pyo3::Python::with_gil(|py| { + use pyo3::types::PyDict; + use pyo3::IntoPyObject; + + let dict_in = PyDict::new(py); + dict_in.set_item("indices", vec![5u32]).unwrap(); + dict_in.set_item("values", vec![1.5f32]).unwrap(); + + let sparse: SparseVector = dict_in.clone().into_any().extract().unwrap(); + assert_eq!(sparse.indices, vec![5]); + assert_eq!(sparse.values, vec![1.5]); + assert!(sparse.tokens.is_none()); + + let py_obj = sparse.into_pyobject(py).unwrap(); + let dict_out = py_obj.downcast::().unwrap(); + let tokens_obj = dict_out.get_item("tokens").unwrap(); + let tokens_value = tokens_obj.expect("expected tokens key in Python dict"); + assert!( + tokens_value.is_none(), + "expected tokens value in Python dict to be None" + ); + }); + } + #[test] fn test_simplifies_identities() { let all: Where = true.into(); diff --git a/rust/types/src/validators.rs b/rust/types/src/validators.rs index 16e93a1dcf0..919b374106e 100644 --- a/rust/types/src/validators.rs +++ b/rust/types/src/validators.rs @@ -328,7 +328,7 @@ mod tests { // Valid metadata let mut metadata = Metadata::new(); metadata.insert("valid_key".to_string(), MetadataValue::Int(42)); - let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3]); + let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3]).unwrap(); metadata.insert("embedding".to_string(), MetadataValue::SparseVector(sparse)); assert!(validate_metadata(&metadata).is_ok()); @@ -354,7 +354,11 @@ mod tests { // Invalid sparse vector (length mismatch) let mut metadata = Metadata::new(); - let invalid_sparse = SparseVector::new(vec![1, 2], vec![0.1, 0.2, 0.3]); + let invalid_sparse = SparseVector { + indices: vec![1, 2], + values: vec![0.1, 0.2, 0.3], + tokens: None, + }; metadata.insert( "embedding".to_string(), MetadataValue::SparseVector(invalid_sparse), diff --git a/rust/worker/src/execution/functions/statistics.rs b/rust/worker/src/execution/functions/statistics.rs index 5d5a2ea4bdb..f1b44ac2dc4 100644 --- a/rust/worker/src/execution/functions/statistics.rs +++ b/rust/worker/src/execution/functions/statistics.rs @@ -436,10 +436,10 @@ mod tests { ), ( "sparse_key".to_string(), - UpdateMetadataValue::SparseVector(SparseVector::new( - vec![1, 3], - vec![0.25, 0.75], - )), + UpdateMetadataValue::SparseVector( + SparseVector::new(vec![1, 3], vec![0.25, 0.75]) + .expect("sparse vector creation should succeed"), + ), ), ]), ); @@ -455,7 +455,10 @@ mod tests { ), ( "sparse_key".to_string(), - UpdateMetadataValue::SparseVector(SparseVector::new(vec![3], vec![0.5])), + UpdateMetadataValue::SparseVector( + SparseVector::new(vec![3], vec![0.5]) + .expect("sparse vector creation should succeed"), + ), ), ]), ); @@ -627,10 +630,10 @@ mod tests { "sparse-empty", HashMap::from([( "sparse_key".to_string(), - UpdateMetadataValue::SparseVector(SparseVector::new( - Vec::::new(), - Vec::::new(), - )), + UpdateMetadataValue::SparseVector( + SparseVector::new(Vec::::new(), Vec::::new()) + .expect("valid sparse vector"), + ), )]), ); diff --git a/rust/worker/src/execution/operators/idf.rs b/rust/worker/src/execution/operators/idf.rs index cff9a5cf3a6..d8998c8665e 100644 --- a/rust/worker/src/execution/operators/idf.rs +++ b/rust/worker/src/execution/operators/idf.rs @@ -56,6 +56,8 @@ pub enum IdfError { RecordReader(#[from] RecordSegmentReaderCreationError), #[error("Error using sparse reader: {0}")] SparseReader(#[from] SparseReaderError), + #[error("Query tokens length ({tokens}) does not match query indices length ({indices})")] + TokenLengthMismatch { tokens: usize, indices: usize }, } impl ChromaError for IdfError { @@ -66,6 +68,7 @@ impl ChromaError for IdfError { IdfError::MetadataReader(err) => err.code(), IdfError::RecordReader(err) => err.code(), IdfError::SparseReader(err) => err.code(), + IdfError::TokenLengthMismatch { .. } => chroma_error::ErrorCodes::InvalidArgument, } } } @@ -164,13 +167,33 @@ impl Operator for Idf { }; } - let scaled_query = SparseVector::from_pairs(self.query.iter().map(|(index, value)| { - let nt = nts.get(&index).cloned().unwrap_or_default() as f32; - let scale = ((n as f32 - nt + 0.5) / (nt + 0.5)).ln_1p(); - (index, scale * value) - })); + fn scale(n: f32, nt: f32) -> f32 { + ((n - nt + 0.5) / (nt + 0.5)).ln_1p() + } - Ok(IdfOutput { scaled_query }) + if let Some(tokens) = self.query.tokens.as_ref() { + if tokens.len() != self.query.indices.len() { + return Err(IdfError::TokenLengthMismatch { + tokens: tokens.len(), + indices: self.query.indices.len(), + }); + } + let scaled_query = SparseVector::from_triples(self.query.iter().enumerate().map( + |(token_position, (index, value))| { + let nt = nts.get(&index).cloned().unwrap_or_default() as f32; + let scale = scale(n as f32, nt); + (tokens[token_position].clone(), index, scale * value) + }, + )); + Ok(IdfOutput { scaled_query }) + } else { + let scaled_query = SparseVector::from_pairs(self.query.iter().map(|(index, value)| { + let nt = nts.get(&index).cloned().unwrap_or_default() as f32; + let scale = scale(n as f32, nt); + (index, scale * value) + })); + Ok(IdfOutput { scaled_query }) + } } } @@ -201,7 +224,11 @@ mod tests { metadata.insert( "sparse_embedding".to_string(), - UpdateMetadataValue::SparseVector(SparseVector { indices, values }), + UpdateMetadataValue::SparseVector(SparseVector { + indices, + values, + tokens: None, + }), ); // Add dummy embedding for materialization (required by TestDistributedSegment) @@ -258,6 +285,7 @@ mod tests { let query_vector = SparseVector { indices: vec![0, 1, 2, 3, 4], values: vec![1.0, 1.0, 1.0, 1.0, 1.0], + tokens: None, }; let idf_operator = Idf { @@ -328,6 +356,7 @@ mod tests { let query_vector = SparseVector { indices: vec![0, 1, 2], values: vec![1.0, 1.0, 1.0], + tokens: None, }; let idf_operator = Idf { @@ -383,6 +412,7 @@ mod tests { UpdateMetadataValue::SparseVector(SparseVector { indices: vec![1, 2], // Now has terms 1 and 2 instead values: vec![2.0, 3.0], + tokens: None, }), )])), document: None, @@ -397,6 +427,7 @@ mod tests { UpdateMetadataValue::SparseVector(SparseVector { indices: vec![0], // Now has term 0 instead values: vec![1.5], + tokens: None, }), )])), document: None, @@ -409,6 +440,7 @@ mod tests { let query_vector = SparseVector { indices: vec![0, 1, 2, 3], values: vec![1.0, 1.0, 1.0, 1.0], + tokens: None, }; let idf_operator = Idf { @@ -456,6 +488,7 @@ mod tests { UpdateMetadataValue::SparseVector(SparseVector { indices: vec![0, 5], // New term 5 values: vec![1.0, 2.0], + tokens: None, }), )])), document: Some("Document 11".to_string()), @@ -470,6 +503,7 @@ mod tests { UpdateMetadataValue::SparseVector(SparseVector { indices: vec![5], // Another doc with term 5 values: vec![3.0], + tokens: None, }), )])), document: Some("Document 12".to_string()), @@ -482,6 +516,7 @@ mod tests { let query_vector = SparseVector { indices: vec![0, 5], values: vec![1.0, 1.0], + tokens: None, }; let idf_operator = Idf { @@ -524,6 +559,7 @@ mod tests { let query_vector = SparseVector { indices: vec![], values: vec![], + tokens: None, }; let idf_operator = Idf { @@ -549,6 +585,7 @@ mod tests { let query_vector = SparseVector { indices: vec![99, 100], values: vec![1.0, 2.0], + tokens: None, }; let idf_operator = Idf { @@ -568,4 +605,29 @@ mod tests { assert!((scaled.values[0] - 3.091).abs() < 0.01); assert!((scaled.values[1] - 6.182).abs() < 0.01); // 2.0 * 3.091 } + + #[tokio::test] + async fn test_idf_tokens_length_mismatch_returns_error() { + let (_test_segment, input) = Box::pin(setup_idf_input(1, vec![])).await; + + let query_vector = SparseVector { + indices: vec![0, 1], + values: vec![1.0, 1.0], + tokens: Some(vec!["only_one_token".to_string()]), + }; + + let idf_operator = Idf { + query: query_vector, + key: "sparse_embedding".to_string(), + }; + + let result = idf_operator.run(&input).await; + assert!(matches!( + result, + Err(IdfError::TokenLengthMismatch { + tokens: 1, + indices: 2 + }) + )); + } } diff --git a/rust/worker/src/execution/operators/rank.rs b/rust/worker/src/execution/operators/rank.rs index beb9d6783b9..e4d209430d6 100644 --- a/rust/worker/src/execution/operators/rank.rs +++ b/rust/worker/src/execution/operators/rank.rs @@ -261,6 +261,7 @@ mod tests { query: chroma_types::operator::QueryVector::Sparse(chroma_types::SparseVector { indices: vec![0], values: vec![1.0], + tokens: None, }), key: Key::field("sparse"), limit: 2, diff --git a/rust/worker/src/execution/operators/sparse_log_knn.rs b/rust/worker/src/execution/operators/sparse_log_knn.rs index 24f0ae26b3a..3d1c5c5aa81 100644 --- a/rust/worker/src/execution/operators/sparse_log_knn.rs +++ b/rust/worker/src/execution/operators/sparse_log_knn.rs @@ -149,6 +149,7 @@ mod tests { 0.2 * offset as f32, 0.3 * offset as f32, ], + tokens: None, }), ); @@ -194,6 +195,7 @@ mod tests { let query_vector = SparseVector { indices: vec![0, 2, 4], values: vec![1.0, 1.0, 1.0], + tokens: None, }; let sparse_knn_operator = SparseLogKnn { @@ -237,6 +239,7 @@ mod tests { let query_vector = SparseVector { indices: vec![0, 2, 4], values: vec![1.0, 1.0, 1.0], + tokens: None, }; let sparse_knn_operator = SparseLogKnn { @@ -268,6 +271,7 @@ mod tests { let query_vector = SparseVector { indices: vec![0, 1, 3], // Only index 0 overlaps with generated vectors values: vec![2.0, 1.0, 1.0], + tokens: None, }; let sparse_knn_operator = SparseLogKnn { @@ -302,6 +306,7 @@ mod tests { let query_vector = SparseVector { indices: vec![0, 2, 4], values: vec![1.0, 1.0, 1.0], + tokens: None, }; let sparse_knn_operator = SparseLogKnn { @@ -333,6 +338,7 @@ mod tests { let query_vector = SparseVector { indices: vec![1, 3, 5], // Generated vectors have indices [0, 2, 4] values: vec![1.0, 1.0, 1.0], + tokens: None, }; let sparse_knn_operator = SparseLogKnn {