Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 109 additions & 7 deletions src/searcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use tantivy as tv;
use tantivy::aggregation::AggregationCollector;
use tantivy::collector::{Count, MultiCollector, TopDocs};
use tantivy::TantivyDocument;

// Bring the trait into scope. This is required for the `to_named_doc` method.
// However, tantivy-py declares its own `Document` class, so we need to avoid
// introduce the `Document` trait into the namespace.
Expand Down Expand Up @@ -151,11 +152,20 @@ impl Searcher {
/// to be returned.
/// order (Order, optional): The order in which the results
/// should be sorted. If not specified, defaults to descending.
/// weight_by_field (Field, optional): A schema field that the results
/// should be weighted by. The field must be declared as a fast
/// field when building the schema. Note, this only works for
/// f64, i64 and u64 fields. The given field value is first
/// transformed using the formula `log2(2.0 + value)` and then
/// multiplied with the original score. This means that a weight field
/// value of 0.0 results in no change to the original score.
/// If the weight value is negative, it is treated as 0.0.
///
/// Returns `SearchResult` object.
///
/// Raises a ValueError if there was an error with the search.
#[pyo3(signature = (query, limit = 10, count = true, order_by_field = None, offset = 0, order = Order::Desc))]
#[pyo3(signature = (query, limit = 10, count = true, order_by_field = None, offset = 0, order = Order::Desc,
weight_by_field = None))]
#[allow(clippy::too_many_arguments)]
fn search(
&self,
Expand All @@ -166,6 +176,7 @@ impl Searcher {
order_by_field: Option<&str>,
offset: usize,
order: Order,
weight_by_field: Option<&str>,
) -> PyResult<SearchResult> {
py.detach(move || {
let mut multicollector = MultiCollector::new();
Expand All @@ -177,10 +188,103 @@ impl Searcher {
};

let (mut multifruit, hits) = {
if let Some(order_by) = order_by_field {
let collector = TopDocs::with_limit(limit)
.and_offset(offset)
.order_by_u64_field(order_by, order.into());
let collector = TopDocs::with_limit(limit).and_offset(offset);
if let Some(weight_by_field) = weight_by_field {
let weight_by_field = weight_by_field.to_string();

// Get field type from schema
let schema = self.inner.schema();
let field = crate::get_field(&schema, &weight_by_field)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
let field_entry = schema.get_field_entry(field);
let field_type = field_entry.field_type().value_type();

if !field_entry.is_fast() {
return Err(PyValueError::new_err(format!(
"Field '{}' is not a fast field. The field must be declared with fast=True in the schema.",
weight_by_field
)));
}

// Check if field type is supported
if !matches!(field_type, tv::schema::Type::F64 | tv::schema::Type::I64 | tv::schema::Type::U64) {
return Err(PyValueError::new_err(format!(
"Unsupported field type for weighting: {:?}. Only f64, i64, and u64 fastfields are supported.",
field_type
)));
}

let collector = collector.tweak_score(
move |segment_reader: &tv::SegmentReader| {
// Create all three readers upfront. Only one will succeed based on
// the actual field type, but we must create all three because:
// 1. Rust closures have a single concrete type - we can't return
// different closure types from different match arms
// 2. The alternative (Box<dyn Fn>) adds heap allocation per segment
// and virtual dispatch overhead per document
// 3. This approach enables monomorphization: the inner closure has
// a concrete type, allowing LLVM to inline get_val() calls
let f64_reader = segment_reader
.fast_fields()
.f64(&weight_by_field)
.ok()
.map(|r| r.first_or_default_col(0.0));
let i64_reader = segment_reader
.fast_fields()
.i64(&weight_by_field)
.ok()
.map(|r| r.first_or_default_col(0));
let u64_reader = segment_reader
.fast_fields()
.u64(&weight_by_field)
.ok()
.map(|r| r.first_or_default_col(0));

move |doc: tv::DocId, original_score: tv::Score| {
let value: f64 = match field_type {
// Runtime type dispatch is required here even though field_type
// was checked earlier because:
// 1. field_type is moved into this closure and can't be matched
// at compile time to select which reader to use
// 2. All three readers must exist at this point for the closure
// to have a single concrete type
//
// Use map_or(0.0, ...) instead of unwrap() because segments
// created before a schema change may lack this fast field.
// Default value 0.0 results in neutral scoring:
// boost = log2(2.0 + 0.0) = 1.0, so score * 1.0 = score
tv::schema::Type::F64 => f64_reader.as_ref().map_or(0.0, |r| r.get_val(doc)),
tv::schema::Type::I64 => i64_reader.as_ref().map_or(0.0, |r| r.get_val(doc) as f64),
tv::schema::Type::U64 => u64_reader.as_ref().map_or(0.0, |r| r.get_val(doc) as f64),
_ => unreachable!(),
};
let value = value.max(0.0); // Negative values are not allowed
let value_boost_score = ((2f64 + value) as tv::Score).log2();
value_boost_score * original_score
}
},
);
let top_docs_handle =
multicollector.add_collector(collector);
let ret = self.inner.search(query.get(), &multicollector);
match ret {
Ok(mut r) => {
let top_docs = top_docs_handle.extract(&mut r);
let result: Vec<(Fruit, DocAddress)> = top_docs
.iter()
.map(|(f, d)| {
(Fruit::Score(*f), DocAddress::from(d))
})
.collect();
(r, result)
}
Err(e) => {
return Err(PyValueError::new_err(e.to_string()))
}
}
} else if let Some(order_by) = order_by_field {
let collector =
collector.order_by_u64_field(order_by, order.into());
let top_docs_handle =
multicollector.add_collector(collector);
let ret = self.inner.search(query.get(), &multicollector);
Expand All @@ -201,8 +305,6 @@ impl Searcher {
}
}
} else {
let collector =
TopDocs::with_limit(limit).and_offset(offset);
let top_docs_handle =
multicollector.add_collector(collector);
let ret = self.inner.search(query.get(), &multicollector);
Expand Down
1 change: 1 addition & 0 deletions tantivy/tantivy.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ class Searcher:
order_by_field: Optional[str] = None,
offset: int = 0,
order: Order = Order.Desc,
weight_by_field: str | None = None,
) -> SearchResult:
pass

Expand Down
81 changes: 81 additions & 0 deletions tests/test_document_scoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from tantivy.tantivy import Document, Index, SchemaBuilder
from tests.conftest import build_schema_numeric_fields
import pytest


@pytest.mark.parametrize("weight_by_field", [
"weight_f64",
"weight_i64",
"weight_u64"
])
def test_document_scoring(weight_by_field: str):
schema = (
SchemaBuilder()
.add_integer_field("id", stored=True, indexed=True, fast=True)
.add_float_field("weight_f64", stored=True, indexed=True, fast=True)
.add_integer_field("weight_i64", stored=True, indexed=True, fast=True)
.add_unsigned_field("weight_u64", stored=True, indexed=True, fast=True)
.add_text_field("body", stored=True, fast=True)
.build()
)
index = Index(schema)
writer = index.writer(15_000_000, 1)

with writer:
doc = Document()
doc.add_integer("id", 1)
doc.add_float("weight_f64", 0.1)
doc.add_integer("weight_i64", 1)
doc.add_unsigned("weight_u64", 1)
doc.add_text("body", "apple banana orange mango")
_ = writer.add_document(doc)

doc = Document()
doc.add_integer("id", 2)
doc.add_float("weight_f64", 0.9)
doc.add_integer("weight_i64", 10)
doc.add_unsigned("weight_u64", 10)
doc.add_text("body", "pear lemon tomato banana")
_ = writer.add_document(doc)

index.reload()

searcher = index.searcher()

query_text = "body:banana"
query = index.parse_query(query_text)
results = searcher.search(query, limit=1)
assert len(results.hits) == 1
print(results)
_, doc_address = results.hits[0]
d = index.searcher().doc(doc_address)
assert d["id"] == [1]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pardon my ignorance, but does that mean the document with the lesser weight comes before the one with higher?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gawd! I see now. This was the query without the weight_by_field keyword argument.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep



query_text = "body:banana"
query = index.parse_query(query_text)
results = searcher.search(query, limit=1, weight_by_field=weight_by_field)
assert len(results.hits) == 1
print(results)
_, doc_address = results.hits[0]
d = index.searcher().doc(doc_address)
assert d["id"] == [2]


def test_not_fastfield():
schema = (
SchemaBuilder()
.add_integer_field("id", stored=True, indexed=True, fast=True)
.add_float_field("weight_f64", stored=True, indexed=True, fast=False)
.add_text_field("body", stored=True, fast=True)
.build()
)
index = Index(schema)
index.reload()

searcher = index.searcher()

query_text = "body:banana"
query = index.parse_query(query_text)
with pytest.raises(ValueError, match="not a fast field"):
_ = searcher.search(query, limit=1, weight_by_field="weight_f64")
Loading