Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions clients/new-js/packages/ai-embeddings/chroma-bm25/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -190,18 +190,20 @@ export class ChromaBm25EmbeddingFunction implements SparseEmbeddingFunction {
}

private encode(text: string): SparseVector {
const tokens = this.tokenizer.tokenize(text);
const tokenList = this.tokenizer.tokenize(text);

if (tokens.length === 0) {
if (tokenList.length === 0) {
return { indices: [], values: [] };
}

const docLen = tokens.length;
const docLen = tokenList.length;
const counts = new Map<number, number>();
const tokenMap = new Map<number, string>();

for (const token of tokens) {
for (const token of tokenList) {
const tokenId = this.hasher.hash(token);
counts.set(tokenId, (counts.get(tokenId) ?? 0) + 1);
tokenMap.set(tokenId, token);
}

const indices = Array.from(counts.keys()).sort((a, b) => a - b);
Expand All @@ -213,8 +215,15 @@ export class ChromaBm25EmbeddingFunction implements SparseEmbeddingFunction {
(1 - this.b + (this.b * docLen) / this.avgDocLength);
return (tf * (this.k + 1)) / denominator;
});
const tokens = indices.map((idx) => {
const token = tokenMap.get(idx);
if (!token) {
throw new Error(`Token not found for index ${idx}`);
}
return token;
});

return { indices, values };
return { indices, values, tokens };
}

public async generate(texts: string[]): Promise<SparseVector[]> {
Expand Down
4 changes: 4 additions & 0 deletions clients/new-js/packages/chromadb/src/api/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,10 @@ export type SparseVector = {
* Dimension indices
*/
indices: Array<number>;
/**
* Tokens corresponding to each index
*/
tokens?: Array<string> | null;
/**
* Values corresponding to each index
*/
Expand Down
1 change: 1 addition & 0 deletions idl/chromadb/proto/chroma.proto
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ message Vector {
message SparseVector {
repeated uint32 indices = 1;
repeated float values = 2;
repeated string tokens = 3;
}

enum SegmentScope {
Expand Down
57 changes: 42 additions & 15 deletions rust/chroma/src/embed/bm25.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ where
T: Tokenizer,
H: TokenHasher,
{
/// Whether to store tokens in the created sparse vectors.
pub store_tokens: bool,
/// Tokenizer for converting text into tokens.
pub tokenizer: T,
/// Hasher for converting tokens into u32 identifiers.
Expand All @@ -55,6 +57,7 @@ impl BM25SparseEmbeddingFunction<Bm25Tokenizer, Murmur3AbsHasher> {
/// - hasher: Murmur3 with seed 0, abs() behavior
pub fn default_murmur3_abs() -> Self {
Self {
store_tokens: true,
tokenizer: Bm25Tokenizer::default(),
hasher: Murmur3AbsHasher::default(),
k: 1.2,
Expand All @@ -75,26 +78,50 @@ where

let doc_len = tokens.len() as f32;

let mut token_ids = Vec::with_capacity(tokens.len());
for token in tokens {
let id = self.hasher.hash(&token);
token_ids.push(id);
}
if self.store_tokens {
let mut token_ids = Vec::with_capacity(tokens.len());
for token in tokens {
let id = self.hasher.hash(&token);
token_ids.push((id, token));
}

token_ids.sort_unstable();

let sparse_triples = token_ids.chunk_by(|a, b| a.0 == b.0).map(|chunk| {
let id = chunk[0].0;
let tk = chunk[0].1.clone();
Comment on lines +90 to +92
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[CriticalError]

Logic error in token storage: When store_tokens is true, you're storing chunk[0].1.clone() which is the first token, but for duplicate token IDs after hashing, this loses information about which specific token variant was used. If "running" and "run" both hash to ID 42, you'll only store the first one encountered.

Consider if this is the intended behavior or if you need to store a representative token (e.g., most frequent variant).

Context for Agents
**Logic error in token storage**: When `store_tokens` is true, you're storing `chunk[0].1.clone()` which is the first token, but for duplicate token IDs after hashing, this loses information about which specific token variant was used. If "running" and "run" both hash to ID 42, you'll only store the first one encountered.

Consider if this is the intended behavior or if you need to store a representative token (e.g., most frequent variant).

File: rust/chroma/src/embed/bm25.rs
Line: 92

let tf = chunk.len() as f32;

// BM25 formula
let score = tf * (self.k + 1.0)
/ (tf + self.k * (1.0 - self.b + self.b * doc_len / self.avg_len));

token_ids.sort_unstable();
(tk, id, score)
});

let sparse_pairs = token_ids.chunk_by(|a, b| a == b).map(|chunk| {
let id = chunk[0];
let tf = chunk.len() as f32;
Ok(SparseVector::from_triples(sparse_triples))
} else {
let mut token_ids = Vec::with_capacity(tokens.len());
for token in tokens {
let id = self.hasher.hash(&token);
token_ids.push(id);
}

// BM25 formula
let score = tf * (self.k + 1.0)
/ (tf + self.k * (1.0 - self.b + self.b * doc_len / self.avg_len));
token_ids.sort_unstable();

(id, score)
});
let sparse_pairs = token_ids.chunk_by(|a, b| *a == *b).map(|chunk| {
let id = chunk[0];
let tf = chunk.len() as f32;

Ok(SparseVector::from_pairs(sparse_pairs))
// BM25 formula
let score = tf * (self.k + 1.0)
/ (tf + self.k * (1.0 - self.b + self.b * doc_len / self.avg_len));

(id, score)
});

Ok(SparseVector::from_pairs(sparse_pairs))
}
}
}

Expand Down
12 changes: 7 additions & 5 deletions rust/frontend/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1683,11 +1683,13 @@ async fn collection_delete(
r#where,
)?;

server
.frontend
.delete(request)
.meter(metering_context_container)
.await?;
Box::pin(
server
.frontend
.delete(request)
.meter(metering_context_container),
)
.await?;

Ok(Json(DeleteCollectionRecordsResponse {}))
}
Expand Down
4 changes: 3 additions & 1 deletion rust/frontend/tests/proptest_helpers/frontend_under_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ impl StateMachineTest for FrontendUnderTest {
}
}

state.frontend.delete(request.clone()).await.unwrap();
Box::pin(state.frontend.delete(request.clone()))
.await
.unwrap();
}
CollectionRequest::Get(mut request) => {
let expected_result = {
Expand Down
2 changes: 1 addition & 1 deletion rust/python_bindings/src/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ impl Bindings {

let mut frontend_clone = self.frontend.clone();
self.runtime
.block_on(async { frontend_clone.delete(request).await })?;
.block_on(async { Box::pin(frontend_clone.delete(request)).await })?;
Ok(())
}

Expand Down
8 changes: 4 additions & 4 deletions rust/segment/src/blockfile_metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2666,10 +2666,10 @@ mod test {
let mut update_metadata1 = HashMap::new();
update_metadata1.insert(
String::from("sparse_vec"),
UpdateMetadataValue::SparseVector(chroma_types::SparseVector::new(
vec![0, 5, 10],
vec![0.1, 0.5, 0.9],
)),
UpdateMetadataValue::SparseVector(
chroma_types::SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9])
.expect("valid sparse vector"),
),
);
update_metadata1.insert(
String::from("category"),
Expand Down
18 changes: 9 additions & 9 deletions rust/types/src/api_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2419,7 +2419,9 @@ mod test {
// Add unsorted sparse vector - should fail validation
metadata.insert(
"sparse".to_string(),
MetadataValue::SparseVector(SparseVector::new(vec![3, 1, 2], vec![0.3, 0.1, 0.2])),
MetadataValue::SparseVector(
SparseVector::new(vec![3, 1, 2], vec![0.3, 0.1, 0.2]).unwrap(),
),
);

let result = AddCollectionRecordsRequest::try_new(
Expand All @@ -2443,10 +2445,9 @@ mod test {
// Add unsorted sparse vector - should fail validation
metadata.insert(
"sparse".to_string(),
UpdateMetadataValue::SparseVector(SparseVector::new(
vec![3, 1, 2],
vec![0.3, 0.1, 0.2],
)),
UpdateMetadataValue::SparseVector(
SparseVector::new(vec![3, 1, 2], vec![0.3, 0.1, 0.2]).unwrap(),
),
);

let result = UpdateCollectionRecordsRequest::try_new(
Expand All @@ -2470,10 +2471,9 @@ mod test {
// Add unsorted sparse vector - should fail validation
metadata.insert(
"sparse".to_string(),
UpdateMetadataValue::SparseVector(SparseVector::new(
vec![3, 1, 2],
vec![0.3, 0.1, 0.2],
)),
UpdateMetadataValue::SparseVector(
SparseVector::new(vec![3, 1, 2], vec![0.3, 0.1, 0.2]).unwrap(),
),
);

let result = UpsertCollectionRecordsRequest::try_new(
Expand Down
4 changes: 2 additions & 2 deletions rust/types/src/collection_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3083,7 +3083,7 @@ mod tests {
let schema = Schema::new_default(KnnIndex::Spann);
let result = schema.is_knn_key_indexing_enabled(
"custom_sparse",
&QueryVector::Sparse(SparseVector::new(vec![0_u32], vec![1.0_f32])),
&QueryVector::Sparse(SparseVector::new(vec![0_u32], vec![1.0_f32]).unwrap()),
);

let err = result.expect_err("expected indexing disabled error");
Expand Down Expand Up @@ -3118,7 +3118,7 @@ mod tests {

let result = schema.is_knn_key_indexing_enabled(
"sparse_enabled",
&QueryVector::Sparse(SparseVector::new(vec![0_u32], vec![1.0_f32])),
&QueryVector::Sparse(SparseVector::new(vec![0_u32], vec![1.0_f32]).unwrap()),
);

assert!(result.is_ok());
Expand Down
17 changes: 10 additions & 7 deletions rust/types/src/execution/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -802,8 +802,8 @@ impl TryFrom<KnnBatchResult> for chroma_proto::KnnBatchResult {
///
/// let sparse = QueryVector::Sparse(SparseVector::new(
/// vec![0, 5, 10, 50], // indices
/// vec![0.5, 0.3, 0.8, 0.2] // values
/// ));
/// vec![0.5, 0.3, 0.8, 0.2], // values
/// ).unwrap());
/// ```
///
/// # Examples
Expand Down Expand Up @@ -831,8 +831,8 @@ impl TryFrom<KnnBatchResult> for chroma_proto::KnnBatchResult {
/// let rank = RankExpr::Knn {
/// query: QueryVector::Sparse(SparseVector::new(
/// vec![1, 5, 10],
/// vec![0.5, 0.3, 0.8]
/// )),
/// vec![0.5, 0.3, 0.8],
/// ).unwrap()),
/// key: Key::field("sparse_embedding"),
/// limit: 100,
/// default: None,
Expand All @@ -856,7 +856,9 @@ impl TryFrom<chroma_proto::QueryVector> for QueryVector {
Ok(QueryVector::Dense(dense.try_into().map(|(v, _)| v)?))
}
chroma_proto::query_vector::Vector::Sparse(sparse) => {
Ok(QueryVector::Sparse(sparse.into()))
Ok(QueryVector::Sparse(sparse.try_into().map_err(|_| {
QueryConversionError::validation("sparse vector length mismatch")
})?))
}
}
}
Expand Down Expand Up @@ -2693,7 +2695,7 @@ mod tests {

#[test]
fn test_query_vector_sparse_proto_conversion() {
let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]);
let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap();
let query_vector = QueryVector::Sparse(sparse.clone());

// Convert to proto
Expand Down Expand Up @@ -2979,7 +2981,8 @@ mod tests {
assert_eq!(deserialized, dense);

// Test sparse vector
let sparse = QueryVector::Sparse(SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]));
let sparse =
QueryVector::Sparse(SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap());
let json = serde_json::to_string(&sparse).unwrap();
let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized, sparse);
Expand Down
Loading
Loading