Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions search/query/custom_query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func TestCustomFilterQueryMarshalJSONPreservesPayloadAndRewritesChild(t *testing
}

q := NewCustomFilterQueryWithFilter(NewMatchQuery("ipa"),
func(d *search.DocumentMatch) bool { return true }, []string{"abv"}, payload)
func(d *search.DocumentMatch) (bool, error) { return true, nil }, []string{"abv"}, payload)

out, err := q.MarshalJSON()
if err != nil {
Expand Down Expand Up @@ -154,7 +154,7 @@ func TestCustomScoreQueryMarshalJSONPreservesPayloadAndRewritesChild(t *testing.
}

q := NewCustomScoreQueryWithScorer(NewMatchQuery("ipa"),
func(d *search.DocumentMatch) float64 { return d.Score }, []string{"ibu"}, payload)
func(d *search.DocumentMatch) (float64, error) { return d.Score, nil }, []string{"ibu"}, payload)

out, err := q.MarshalJSON()
if err != nil {
Expand Down
20 changes: 16 additions & 4 deletions search/searcher/search_custom_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ func init() {

// CustomFilterFunc decides whether a hit (with doc-value fields populated)
// should be kept. Unlike FilterFunc it does not receive a SearchContext since
// custom-query callbacks only need the DocumentMatch.
type CustomFilterFunc func(d *search.DocumentMatch) bool
// custom-query callbacks only need the DocumentMatch. A non-nil error aborts
// the search so the failure can be surfaced to the caller rather than silently
// dropping the hit.
type CustomFilterFunc func(d *search.DocumentMatch) (bool, error)

// CustomFilterSearcher wraps a child searcher, optionally loads doc values
// into each DocumentMatch, then applies a CustomFilterFunc to decide whether
Expand Down Expand Up @@ -71,9 +73,14 @@ func (f *CustomFilterSearcher) Next(ctx *search.SearchContext) (*search.Document
if err = loadDocValuesOnHitWithTypes(next, f.dvReader, f.indexReader, f.fieldTypes); err != nil {
return nil, err
}
if f.accept(next) {
keep, ferr := f.accept(next)
if ferr != nil {
return nil, ferr
}
if keep {
return next, nil
}
ctx.DocumentMatchPool.Put(next)
next, err = f.child.Next(ctx)
Comment thread
CascadingRadium marked this conversation as resolved.
}
return nil, err
Expand All @@ -90,9 +97,14 @@ func (f *CustomFilterSearcher) Advance(ctx *search.SearchContext, ID index.Index
if err = loadDocValuesOnHitWithTypes(adv, f.dvReader, f.indexReader, f.fieldTypes); err != nil {
return nil, err
}
if f.accept(adv) {
keep, ferr := f.accept(adv)
if ferr != nil {
return nil, ferr
}
if keep {
return adv, nil
}
ctx.DocumentMatchPool.Put(adv)
return f.Next(ctx)
}

Expand Down
26 changes: 19 additions & 7 deletions search/searcher/search_custom_score.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ func init() {
reflectStaticSizeCustomScoreSearcher = int(reflect.TypeOf(sfs).Size())
}

// CustomScoreFunc defines a function which can mutate document scores.
type CustomScoreFunc func(d *search.DocumentMatch) float64
// CustomScoreFunc defines a function which can mutate document scores. A
// non-nil error aborts the search so the failure can be surfaced to the caller
// rather than silently falling back to the original score.
type CustomScoreFunc func(d *search.DocumentMatch) (float64, error)

// CustomScoreSearcher wraps any other searcher, optionally loads doc values
// into each DocumentMatch, then mutates the score using the supplied
Expand Down Expand Up @@ -60,15 +62,21 @@ func NewCustomScoreSearcher(ctx context.Context, s search.Searcher, mutate Custo

// applyScore mutates the score on the hit and, when explain is enabled,
// replaces the explanation with a single node describing the custom score
// result.
func (f *CustomScoreSearcher) applyScore(d *search.DocumentMatch) {
d.Score = f.mutate(d)
// result. A non-nil error from the score function is returned so the caller
// can abort the search.
func (f *CustomScoreSearcher) applyScore(d *search.DocumentMatch) error {
score, err := f.mutate(d)
if err != nil {
return err
}
d.Score = score
if f.explain {
d.Expl = &search.Explanation{
Value: d.Score,
Message: "custom_score function result",
}
}
return nil
}

func (f *CustomScoreSearcher) Size() int {
Expand All @@ -85,7 +93,9 @@ func (f *CustomScoreSearcher) Next(ctx *search.SearchContext) (*search.DocumentM
if err = loadDocValuesOnHitWithTypes(next, f.dvReader, f.indexReader, f.fieldTypes); err != nil {
return nil, err
}
f.applyScore(next)
if err = f.applyScore(next); err != nil {
return nil, err
}
}
return next, nil
}
Expand All @@ -99,7 +109,9 @@ func (f *CustomScoreSearcher) Advance(ctx *search.SearchContext, ID index.IndexI
if err = loadDocValuesOnHitWithTypes(adv, f.dvReader, f.indexReader, f.fieldTypes); err != nil {
return nil, err
}
f.applyScore(adv)
if err = f.applyScore(adv); err != nil {
return nil, err
}
}
return adv, nil
}
Expand Down
28 changes: 14 additions & 14 deletions search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4837,9 +4837,9 @@ func TestCustomFilterQuery(t *testing.T) {
"7": {},
}

q := query.NewCustomFilterQueryWithFilter(fictionQuery, func(d *search.DocumentMatch) bool {
q := query.NewCustomFilterQueryWithFilter(fictionQuery, func(d *search.DocumentMatch) (bool, error) {
_, ok := allowedIDs[d.ID]
return ok
return ok, nil
}, nil, nil)

req := NewSearchRequest(q)
Expand Down Expand Up @@ -4884,8 +4884,8 @@ func TestCustomScoreQuery(t *testing.T) {
"0": 1.0,
}

q := query.NewCustomScoreQueryWithScorer(fictionQuery, func(d *search.DocumentMatch) float64 {
return d.Score + boosts[d.ID]
q := query.NewCustomScoreQueryWithScorer(fictionQuery, func(d *search.DocumentMatch) (float64, error) {
return d.Score + boosts[d.ID], nil
}, nil, nil)

req := NewSearchRequest(q)
Expand Down Expand Up @@ -4927,9 +4927,9 @@ func TestCustomFilterQueryDocumentMatchIDWithoutFields(t *testing.T) {
"7": {},
}

q := query.NewCustomFilterQueryWithFilter(fictionQuery, func(d *search.DocumentMatch) bool {
q := query.NewCustomFilterQueryWithFilter(fictionQuery, func(d *search.DocumentMatch) (bool, error) {
_, ok := allowedIDs[d.ID]
return ok
return ok, nil
}, nil, nil)

req := NewSearchRequest(q)
Expand Down Expand Up @@ -4967,12 +4967,12 @@ func TestCustomScoreQueryWithDocValues(t *testing.T) {
fictionQuery := NewTermQuery("fiction")
fictionQuery.SetField("genre")

q := query.NewCustomScoreQueryWithScorer(fictionQuery, func(d *search.DocumentMatch) float64 {
q := query.NewCustomScoreQueryWithScorer(fictionQuery, func(d *search.DocumentMatch) (float64, error) {
rating, ok := d.Fields["rating"].(float64)
if ok && rating >= 9 {
return d.Score + 100
return d.Score + 100, nil
}
return d.Score
return d.Score, nil
}, []string{"rating"}, nil)

req := NewSearchRequest(q)
Expand Down Expand Up @@ -5002,8 +5002,8 @@ func TestCustomScoreQueryExplain(t *testing.T) {
titleQuery := NewMatchQuery("habit")
titleQuery.SetField("title")

q := query.NewCustomScoreQueryWithScorer(titleQuery, func(d *search.DocumentMatch) float64 {
return d.Score * 2
q := query.NewCustomScoreQueryWithScorer(titleQuery, func(d *search.DocumentMatch) (float64, error) {
return d.Score * 2, nil
}, nil, nil)

req := NewSearchRequest(q)
Expand Down Expand Up @@ -5172,12 +5172,12 @@ func TestCustomFilterQueryDateTimeDocValues(t *testing.T) {
// Datetime doc values are decoded as RFC 3339 strings.
cutoffStr := cutoff.UTC().Format(time.RFC3339Nano)

q := query.NewCustomFilterQueryWithFilter(fictionQuery, func(d *search.DocumentMatch) bool {
q := query.NewCustomFilterQueryWithFilter(fictionQuery, func(d *search.DocumentMatch) (bool, error) {
pubStr, ok := d.Fields["published"].(string)
if !ok {
return false
return false, nil
}
return pubStr >= cutoffStr
return pubStr >= cutoffStr, nil
}, []string{"published"}, nil)

req := NewSearchRequest(q)
Expand Down
Loading