diff --git a/search/query/custom_query_test.go b/search/query/custom_query_test.go index 5d963941b..7e3196507 100644 --- a/search/query/custom_query_test.go +++ b/search/query/custom_query_test.go @@ -104,7 +104,7 @@ func TestCustomFilterQueryMarshalJSONPreservesPayloadAndRewritesChild(t *testing } q := NewCustomFilterQueryWithFilter(NewMatchQuery("ipa"), - func(d *search.DocumentMatch) bool { return true }, []string{"abv"}, payload) + func(d *search.DocumentMatch) (bool, error) { return true, nil }, []string{"abv"}, payload) out, err := q.MarshalJSON() if err != nil { @@ -154,7 +154,7 @@ func TestCustomScoreQueryMarshalJSONPreservesPayloadAndRewritesChild(t *testing. } q := NewCustomScoreQueryWithScorer(NewMatchQuery("ipa"), - func(d *search.DocumentMatch) float64 { return d.Score }, []string{"ibu"}, payload) + func(d *search.DocumentMatch) (float64, error) { return d.Score, nil }, []string{"ibu"}, payload) out, err := q.MarshalJSON() if err != nil { diff --git a/search/searcher/search_custom_filter.go b/search/searcher/search_custom_filter.go index dcf1f7c05..2d02e1968 100644 --- a/search/searcher/search_custom_filter.go +++ b/search/searcher/search_custom_filter.go @@ -32,8 +32,10 @@ func init() { // CustomFilterFunc decides whether a hit (with doc-value fields populated) // should be kept. Unlike FilterFunc it does not receive a SearchContext since -// custom-query callbacks only need the DocumentMatch. -type CustomFilterFunc func(d *search.DocumentMatch) bool +// custom-query callbacks only need the DocumentMatch. A non-nil error aborts +// the search so the failure can be surfaced to the caller rather than silently +// dropping the hit. +type CustomFilterFunc func(d *search.DocumentMatch) (bool, error) // CustomFilterSearcher wraps a child searcher, optionally loads doc values // into each DocumentMatch, then applies a CustomFilterFunc to decide whether @@ -71,9 +73,14 @@ func (f *CustomFilterSearcher) Next(ctx *search.SearchContext) (*search.Document if err = loadDocValuesOnHitWithTypes(next, f.dvReader, f.indexReader, f.fieldTypes); err != nil { return nil, err } - if f.accept(next) { + keep, ferr := f.accept(next) + if ferr != nil { + return nil, ferr + } + if keep { return next, nil } + ctx.DocumentMatchPool.Put(next) next, err = f.child.Next(ctx) } return nil, err @@ -90,9 +97,14 @@ func (f *CustomFilterSearcher) Advance(ctx *search.SearchContext, ID index.Index if err = loadDocValuesOnHitWithTypes(adv, f.dvReader, f.indexReader, f.fieldTypes); err != nil { return nil, err } - if f.accept(adv) { + keep, ferr := f.accept(adv) + if ferr != nil { + return nil, ferr + } + if keep { return adv, nil } + ctx.DocumentMatchPool.Put(adv) return f.Next(ctx) } diff --git a/search/searcher/search_custom_score.go b/search/searcher/search_custom_score.go index d2feb6d9b..16685e889 100644 --- a/search/searcher/search_custom_score.go +++ b/search/searcher/search_custom_score.go @@ -30,8 +30,10 @@ func init() { reflectStaticSizeCustomScoreSearcher = int(reflect.TypeOf(sfs).Size()) } -// CustomScoreFunc defines a function which can mutate document scores. -type CustomScoreFunc func(d *search.DocumentMatch) float64 +// CustomScoreFunc defines a function which can mutate document scores. A +// non-nil error aborts the search so the failure can be surfaced to the caller +// rather than silently falling back to the original score. +type CustomScoreFunc func(d *search.DocumentMatch) (float64, error) // CustomScoreSearcher wraps any other searcher, optionally loads doc values // into each DocumentMatch, then mutates the score using the supplied @@ -60,15 +62,21 @@ func NewCustomScoreSearcher(ctx context.Context, s search.Searcher, mutate Custo // applyScore mutates the score on the hit and, when explain is enabled, // replaces the explanation with a single node describing the custom score -// result. -func (f *CustomScoreSearcher) applyScore(d *search.DocumentMatch) { - d.Score = f.mutate(d) +// result. A non-nil error from the score function is returned so the caller +// can abort the search. +func (f *CustomScoreSearcher) applyScore(d *search.DocumentMatch) error { + score, err := f.mutate(d) + if err != nil { + return err + } + d.Score = score if f.explain { d.Expl = &search.Explanation{ Value: d.Score, Message: "custom_score function result", } } + return nil } func (f *CustomScoreSearcher) Size() int { @@ -85,7 +93,9 @@ func (f *CustomScoreSearcher) Next(ctx *search.SearchContext) (*search.DocumentM if err = loadDocValuesOnHitWithTypes(next, f.dvReader, f.indexReader, f.fieldTypes); err != nil { return nil, err } - f.applyScore(next) + if err = f.applyScore(next); err != nil { + return nil, err + } } return next, nil } @@ -99,7 +109,9 @@ func (f *CustomScoreSearcher) Advance(ctx *search.SearchContext, ID index.IndexI if err = loadDocValuesOnHitWithTypes(adv, f.dvReader, f.indexReader, f.fieldTypes); err != nil { return nil, err } - f.applyScore(adv) + if err = f.applyScore(adv); err != nil { + return nil, err + } } return adv, nil } diff --git a/search_test.go b/search_test.go index ca7f8bdd9..813a1133e 100644 --- a/search_test.go +++ b/search_test.go @@ -4837,9 +4837,9 @@ func TestCustomFilterQuery(t *testing.T) { "7": {}, } - q := query.NewCustomFilterQueryWithFilter(fictionQuery, func(d *search.DocumentMatch) bool { + q := query.NewCustomFilterQueryWithFilter(fictionQuery, func(d *search.DocumentMatch) (bool, error) { _, ok := allowedIDs[d.ID] - return ok + return ok, nil }, nil, nil) req := NewSearchRequest(q) @@ -4884,8 +4884,8 @@ func TestCustomScoreQuery(t *testing.T) { "0": 1.0, } - q := query.NewCustomScoreQueryWithScorer(fictionQuery, func(d *search.DocumentMatch) float64 { - return d.Score + boosts[d.ID] + q := query.NewCustomScoreQueryWithScorer(fictionQuery, func(d *search.DocumentMatch) (float64, error) { + return d.Score + boosts[d.ID], nil }, nil, nil) req := NewSearchRequest(q) @@ -4927,9 +4927,9 @@ func TestCustomFilterQueryDocumentMatchIDWithoutFields(t *testing.T) { "7": {}, } - q := query.NewCustomFilterQueryWithFilter(fictionQuery, func(d *search.DocumentMatch) bool { + q := query.NewCustomFilterQueryWithFilter(fictionQuery, func(d *search.DocumentMatch) (bool, error) { _, ok := allowedIDs[d.ID] - return ok + return ok, nil }, nil, nil) req := NewSearchRequest(q) @@ -4967,12 +4967,12 @@ func TestCustomScoreQueryWithDocValues(t *testing.T) { fictionQuery := NewTermQuery("fiction") fictionQuery.SetField("genre") - q := query.NewCustomScoreQueryWithScorer(fictionQuery, func(d *search.DocumentMatch) float64 { + q := query.NewCustomScoreQueryWithScorer(fictionQuery, func(d *search.DocumentMatch) (float64, error) { rating, ok := d.Fields["rating"].(float64) if ok && rating >= 9 { - return d.Score + 100 + return d.Score + 100, nil } - return d.Score + return d.Score, nil }, []string{"rating"}, nil) req := NewSearchRequest(q) @@ -5002,8 +5002,8 @@ func TestCustomScoreQueryExplain(t *testing.T) { titleQuery := NewMatchQuery("habit") titleQuery.SetField("title") - q := query.NewCustomScoreQueryWithScorer(titleQuery, func(d *search.DocumentMatch) float64 { - return d.Score * 2 + q := query.NewCustomScoreQueryWithScorer(titleQuery, func(d *search.DocumentMatch) (float64, error) { + return d.Score * 2, nil }, nil, nil) req := NewSearchRequest(q) @@ -5172,12 +5172,12 @@ func TestCustomFilterQueryDateTimeDocValues(t *testing.T) { // Datetime doc values are decoded as RFC 3339 strings. cutoffStr := cutoff.UTC().Format(time.RFC3339Nano) - q := query.NewCustomFilterQueryWithFilter(fictionQuery, func(d *search.DocumentMatch) bool { + q := query.NewCustomFilterQueryWithFilter(fictionQuery, func(d *search.DocumentMatch) (bool, error) { pubStr, ok := d.Fields["published"].(string) if !ok { - return false + return false, nil } - return pubStr >= cutoffStr + return pubStr >= cutoffStr, nil }, []string{"published"}, nil) req := NewSearchRequest(q)