diff --git a/src/searcher.rs b/src/searcher.rs index 2fa61396..5d3820ea 100644 --- a/src/searcher.rs +++ b/src/searcher.rs @@ -9,6 +9,7 @@ use tantivy as tv; use tantivy::aggregation::AggregationCollector; use tantivy::collector::{Count, MultiCollector, TopDocs}; use tantivy::TantivyDocument; + // Bring the trait into scope. This is required for the `to_named_doc` method. // However, tantivy-py declares its own `Document` class, so we need to avoid // introduce the `Document` trait into the namespace. @@ -151,11 +152,20 @@ impl Searcher { /// to be returned. /// order (Order, optional): The order in which the results /// should be sorted. If not specified, defaults to descending. + /// weight_by_field (Field, optional): A schema field that the results + /// should be weighted by. The field must be declared as a fast + /// field when building the schema. Note, this only works for + /// f64, i64 and u64 fields. The given field value is first + /// transformed using the formula `log2(2.0 + value)` and then + /// multiplied with the original score. This means that a weight field + /// value of 0.0 results in no change to the original score. + /// If the weight value is negative, it is treated as 0.0. /// /// Returns `SearchResult` object. /// /// Raises a ValueError if there was an error with the search. - #[pyo3(signature = (query, limit = 10, count = true, order_by_field = None, offset = 0, order = Order::Desc))] + #[pyo3(signature = (query, limit = 10, count = true, order_by_field = None, offset = 0, order = Order::Desc, + weight_by_field = None))] #[allow(clippy::too_many_arguments)] fn search( &self, @@ -166,6 +176,7 @@ impl Searcher { order_by_field: Option<&str>, offset: usize, order: Order, + weight_by_field: Option<&str>, ) -> PyResult { py.detach(move || { let mut multicollector = MultiCollector::new(); @@ -177,10 +188,103 @@ impl Searcher { }; let (mut multifruit, hits) = { - if let Some(order_by) = order_by_field { - let collector = TopDocs::with_limit(limit) - .and_offset(offset) - .order_by_u64_field(order_by, order.into()); + let collector = TopDocs::with_limit(limit).and_offset(offset); + if let Some(weight_by_field) = weight_by_field { + let weight_by_field = weight_by_field.to_string(); + + // Get field type from schema + let schema = self.inner.schema(); + let field = crate::get_field(&schema, &weight_by_field) + .map_err(|e| PyValueError::new_err(e.to_string()))?; + let field_entry = schema.get_field_entry(field); + let field_type = field_entry.field_type().value_type(); + + if !field_entry.is_fast() { + return Err(PyValueError::new_err(format!( + "Field '{}' is not a fast field. The field must be declared with fast=True in the schema.", + weight_by_field + ))); + } + + // Check if field type is supported + if !matches!(field_type, tv::schema::Type::F64 | tv::schema::Type::I64 | tv::schema::Type::U64) { + return Err(PyValueError::new_err(format!( + "Unsupported field type for weighting: {:?}. Only f64, i64, and u64 fastfields are supported.", + field_type + ))); + } + + let collector = collector.tweak_score( + move |segment_reader: &tv::SegmentReader| { + // Create all three readers upfront. Only one will succeed based on + // the actual field type, but we must create all three because: + // 1. Rust closures have a single concrete type - we can't return + // different closure types from different match arms + // 2. The alternative (Box) adds heap allocation per segment + // and virtual dispatch overhead per document + // 3. This approach enables monomorphization: the inner closure has + // a concrete type, allowing LLVM to inline get_val() calls + let f64_reader = segment_reader + .fast_fields() + .f64(&weight_by_field) + .ok() + .map(|r| r.first_or_default_col(0.0)); + let i64_reader = segment_reader + .fast_fields() + .i64(&weight_by_field) + .ok() + .map(|r| r.first_or_default_col(0)); + let u64_reader = segment_reader + .fast_fields() + .u64(&weight_by_field) + .ok() + .map(|r| r.first_or_default_col(0)); + + move |doc: tv::DocId, original_score: tv::Score| { + let value: f64 = match field_type { + // Runtime type dispatch is required here even though field_type + // was checked earlier because: + // 1. field_type is moved into this closure and can't be matched + // at compile time to select which reader to use + // 2. All three readers must exist at this point for the closure + // to have a single concrete type + // + // Use map_or(0.0, ...) instead of unwrap() because segments + // created before a schema change may lack this fast field. + // Default value 0.0 results in neutral scoring: + // boost = log2(2.0 + 0.0) = 1.0, so score * 1.0 = score + tv::schema::Type::F64 => f64_reader.as_ref().map_or(0.0, |r| r.get_val(doc)), + tv::schema::Type::I64 => i64_reader.as_ref().map_or(0.0, |r| r.get_val(doc) as f64), + tv::schema::Type::U64 => u64_reader.as_ref().map_or(0.0, |r| r.get_val(doc) as f64), + _ => unreachable!(), + }; + let value = value.max(0.0); // Negative values are not allowed + let value_boost_score = ((2f64 + value) as tv::Score).log2(); + value_boost_score * original_score + } + }, + ); + let top_docs_handle = + multicollector.add_collector(collector); + let ret = self.inner.search(query.get(), &multicollector); + match ret { + Ok(mut r) => { + let top_docs = top_docs_handle.extract(&mut r); + let result: Vec<(Fruit, DocAddress)> = top_docs + .iter() + .map(|(f, d)| { + (Fruit::Score(*f), DocAddress::from(d)) + }) + .collect(); + (r, result) + } + Err(e) => { + return Err(PyValueError::new_err(e.to_string())) + } + } + } else if let Some(order_by) = order_by_field { + let collector = + collector.order_by_u64_field(order_by, order.into()); let top_docs_handle = multicollector.add_collector(collector); let ret = self.inner.search(query.get(), &multicollector); @@ -201,8 +305,6 @@ impl Searcher { } } } else { - let collector = - TopDocs::with_limit(limit).and_offset(offset); let top_docs_handle = multicollector.add_collector(collector); let ret = self.inner.search(query.get(), &multicollector); diff --git a/tantivy/tantivy.pyi b/tantivy/tantivy.pyi index 6d4c9b87..eee0bc92 100644 --- a/tantivy/tantivy.pyi +++ b/tantivy/tantivy.pyi @@ -372,6 +372,7 @@ class Searcher: order_by_field: Optional[str] = None, offset: int = 0, order: Order = Order.Desc, + weight_by_field: str | None = None, ) -> SearchResult: pass diff --git a/tests/test_document_scoring.py b/tests/test_document_scoring.py new file mode 100644 index 00000000..274f9c99 --- /dev/null +++ b/tests/test_document_scoring.py @@ -0,0 +1,81 @@ +from tantivy.tantivy import Document, Index, SchemaBuilder +from tests.conftest import build_schema_numeric_fields +import pytest + + +@pytest.mark.parametrize("weight_by_field", [ + "weight_f64", + "weight_i64", + "weight_u64" +]) +def test_document_scoring(weight_by_field: str): + schema = ( + SchemaBuilder() + .add_integer_field("id", stored=True, indexed=True, fast=True) + .add_float_field("weight_f64", stored=True, indexed=True, fast=True) + .add_integer_field("weight_i64", stored=True, indexed=True, fast=True) + .add_unsigned_field("weight_u64", stored=True, indexed=True, fast=True) + .add_text_field("body", stored=True, fast=True) + .build() + ) + index = Index(schema) + writer = index.writer(15_000_000, 1) + + with writer: + doc = Document() + doc.add_integer("id", 1) + doc.add_float("weight_f64", 0.1) + doc.add_integer("weight_i64", 1) + doc.add_unsigned("weight_u64", 1) + doc.add_text("body", "apple banana orange mango") + _ = writer.add_document(doc) + + doc = Document() + doc.add_integer("id", 2) + doc.add_float("weight_f64", 0.9) + doc.add_integer("weight_i64", 10) + doc.add_unsigned("weight_u64", 10) + doc.add_text("body", "pear lemon tomato banana") + _ = writer.add_document(doc) + + index.reload() + + searcher = index.searcher() + + query_text = "body:banana" + query = index.parse_query(query_text) + results = searcher.search(query, limit=1) + assert len(results.hits) == 1 + print(results) + _, doc_address = results.hits[0] + d = index.searcher().doc(doc_address) + assert d["id"] == [1] + + + query_text = "body:banana" + query = index.parse_query(query_text) + results = searcher.search(query, limit=1, weight_by_field=weight_by_field) + assert len(results.hits) == 1 + print(results) + _, doc_address = results.hits[0] + d = index.searcher().doc(doc_address) + assert d["id"] == [2] + + +def test_not_fastfield(): + schema = ( + SchemaBuilder() + .add_integer_field("id", stored=True, indexed=True, fast=True) + .add_float_field("weight_f64", stored=True, indexed=True, fast=False) + .add_text_field("body", stored=True, fast=True) + .build() + ) + index = Index(schema) + index.reload() + + searcher = index.searcher() + + query_text = "body:banana" + query = index.parse_query(query_text) + with pytest.raises(ValueError, match="not a fast field"): + _ = searcher.search(query, limit=1, weight_by_field="weight_f64")