Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions search/query/custom_query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func TestCustomFilterQueryMarshalJSONPreservesPayloadAndRewritesChild(t *testing
}

q := NewCustomFilterQueryWithFilter(NewMatchQuery("ipa"),
func(d *search.DocumentMatch) bool { return true }, []string{"abv"}, payload)
func(d *search.DocumentMatch) (bool, error) { return true, nil }, []string{"abv"}, payload)

out, err := q.MarshalJSON()
if err != nil {
Expand Down Expand Up @@ -154,7 +154,7 @@ func TestCustomScoreQueryMarshalJSONPreservesPayloadAndRewritesChild(t *testing.
}

q := NewCustomScoreQueryWithScorer(NewMatchQuery("ipa"),
func(d *search.DocumentMatch) float64 { return d.Score }, []string{"ibu"}, payload)
func(d *search.DocumentMatch) (float64, error) { return d.Score, nil }, []string{"ibu"}, payload)

out, err := q.MarshalJSON()
if err != nil {
Expand Down
18 changes: 14 additions & 4 deletions search/searcher/search_custom_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ func init() {

// CustomFilterFunc decides whether a hit (with doc-value fields populated)
// should be kept. Unlike FilterFunc it does not receive a SearchContext since
// custom-query callbacks only need the DocumentMatch.
type CustomFilterFunc func(d *search.DocumentMatch) bool
// custom-query callbacks only need the DocumentMatch. A non-nil error aborts
// the search so the failure can be surfaced to the caller rather than silently
// dropping the hit.
type CustomFilterFunc func(d *search.DocumentMatch) (bool, error)

// CustomFilterSearcher wraps a child searcher, optionally loads doc values
// into each DocumentMatch, then applies a CustomFilterFunc to decide whether
Expand Down Expand Up @@ -71,7 +73,11 @@ func (f *CustomFilterSearcher) Next(ctx *search.SearchContext) (*search.Document
if err = loadDocValuesOnHitWithTypes(next, f.dvReader, f.indexReader, f.fieldTypes); err != nil {
return nil, err
}
if f.accept(next) {
keep, ferr := f.accept(next)
if ferr != nil {
return nil, ferr
}
if keep {
return next, nil
}
next, err = f.child.Next(ctx)
Comment thread
CascadingRadium marked this conversation as resolved.
Expand All @@ -90,7 +96,11 @@ func (f *CustomFilterSearcher) Advance(ctx *search.SearchContext, ID index.Index
if err = loadDocValuesOnHitWithTypes(adv, f.dvReader, f.indexReader, f.fieldTypes); err != nil {
return nil, err
}
if f.accept(adv) {
keep, ferr := f.accept(adv)
if ferr != nil {
return nil, ferr
}
if keep {
return adv, nil
}
return f.Next(ctx)
Expand Down
26 changes: 19 additions & 7 deletions search/searcher/search_custom_score.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ func init() {
reflectStaticSizeCustomScoreSearcher = int(reflect.TypeOf(sfs).Size())
}

// CustomScoreFunc defines a function which can mutate document scores.
type CustomScoreFunc func(d *search.DocumentMatch) float64
// CustomScoreFunc defines a function which can mutate document scores. A
// non-nil error aborts the search so the failure can be surfaced to the caller
// rather than silently falling back to the original score.
type CustomScoreFunc func(d *search.DocumentMatch) (float64, error)

// CustomScoreSearcher wraps any other searcher, optionally loads doc values
// into each DocumentMatch, then mutates the score using the supplied
Expand Down Expand Up @@ -60,15 +62,21 @@ func NewCustomScoreSearcher(ctx context.Context, s search.Searcher, mutate Custo

// applyScore mutates the score on the hit and, when explain is enabled,
// replaces the explanation with a single node describing the custom score
// result.
func (f *CustomScoreSearcher) applyScore(d *search.DocumentMatch) {
d.Score = f.mutate(d)
// result. A non-nil error from the score function is returned so the caller
// can abort the search.
func (f *CustomScoreSearcher) applyScore(d *search.DocumentMatch) error {
score, err := f.mutate(d)
if err != nil {
return err
}
d.Score = score
if f.explain {
d.Expl = &search.Explanation{
Value: d.Score,
Message: "custom_score function result",
}
}
return nil
}

func (f *CustomScoreSearcher) Size() int {
Expand All @@ -85,7 +93,9 @@ func (f *CustomScoreSearcher) Next(ctx *search.SearchContext) (*search.DocumentM
if err = loadDocValuesOnHitWithTypes(next, f.dvReader, f.indexReader, f.fieldTypes); err != nil {
return nil, err
}
f.applyScore(next)
if err = f.applyScore(next); err != nil {
return nil, err
}
}
return next, nil
}
Expand All @@ -99,7 +109,9 @@ func (f *CustomScoreSearcher) Advance(ctx *search.SearchContext, ID index.IndexI
if err = loadDocValuesOnHitWithTypes(adv, f.dvReader, f.indexReader, f.fieldTypes); err != nil {
return nil, err
}
f.applyScore(adv)
if err = f.applyScore(adv); err != nil {
return nil, err
}
}
return adv, nil
}
Expand Down
28 changes: 14 additions & 14 deletions search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4837,9 +4837,9 @@ func TestCustomFilterQuery(t *testing.T) {
"7": {},
}

q := query.NewCustomFilterQueryWithFilter(fictionQuery, func(d *search.DocumentMatch) bool {
q := query.NewCustomFilterQueryWithFilter(fictionQuery, func(d *search.DocumentMatch) (bool, error) {
_, ok := allowedIDs[d.ID]
return ok
return ok, nil
}, nil, nil)

req := NewSearchRequest(q)
Expand Down Expand Up @@ -4884,8 +4884,8 @@ func TestCustomScoreQuery(t *testing.T) {
"0": 1.0,
}

q := query.NewCustomScoreQueryWithScorer(fictionQuery, func(d *search.DocumentMatch) float64 {
return d.Score + boosts[d.ID]
q := query.NewCustomScoreQueryWithScorer(fictionQuery, func(d *search.DocumentMatch) (float64, error) {
return d.Score + boosts[d.ID], nil
}, nil, nil)

req := NewSearchRequest(q)
Expand Down Expand Up @@ -4927,9 +4927,9 @@ func TestCustomFilterQueryDocumentMatchIDWithoutFields(t *testing.T) {
"7": {},
}

q := query.NewCustomFilterQueryWithFilter(fictionQuery, func(d *search.DocumentMatch) bool {
q := query.NewCustomFilterQueryWithFilter(fictionQuery, func(d *search.DocumentMatch) (bool, error) {
_, ok := allowedIDs[d.ID]
return ok
return ok, nil
}, nil, nil)

req := NewSearchRequest(q)
Expand Down Expand Up @@ -4967,12 +4967,12 @@ func TestCustomScoreQueryWithDocValues(t *testing.T) {
fictionQuery := NewTermQuery("fiction")
fictionQuery.SetField("genre")

q := query.NewCustomScoreQueryWithScorer(fictionQuery, func(d *search.DocumentMatch) float64 {
q := query.NewCustomScoreQueryWithScorer(fictionQuery, func(d *search.DocumentMatch) (float64, error) {
rating, ok := d.Fields["rating"].(float64)
if ok && rating >= 9 {
return d.Score + 100
return d.Score + 100, nil
}
return d.Score
return d.Score, nil
}, []string{"rating"}, nil)

req := NewSearchRequest(q)
Expand Down Expand Up @@ -5002,8 +5002,8 @@ func TestCustomScoreQueryExplain(t *testing.T) {
titleQuery := NewMatchQuery("habit")
titleQuery.SetField("title")

q := query.NewCustomScoreQueryWithScorer(titleQuery, func(d *search.DocumentMatch) float64 {
return d.Score * 2
q := query.NewCustomScoreQueryWithScorer(titleQuery, func(d *search.DocumentMatch) (float64, error) {
return d.Score * 2, nil
}, nil, nil)

req := NewSearchRequest(q)
Expand Down Expand Up @@ -5172,12 +5172,12 @@ func TestCustomFilterQueryDateTimeDocValues(t *testing.T) {
// Datetime doc values are decoded as RFC 3339 strings.
cutoffStr := cutoff.UTC().Format(time.RFC3339Nano)

q := query.NewCustomFilterQueryWithFilter(fictionQuery, func(d *search.DocumentMatch) bool {
q := query.NewCustomFilterQueryWithFilter(fictionQuery, func(d *search.DocumentMatch) (bool, error) {
pubStr, ok := d.Fields["published"].(string)
if !ok {
return false
return false, nil
}
return pubStr >= cutoffStr
return pubStr >= cutoffStr, nil
}, []string{"published"}, nil)

req := NewSearchRequest(q)
Expand Down
Loading