Skip to content

Commit 25ae60d

Browse files
committed
wip: add scorer and addscores
1 parent 064c83f commit 25ae60d

File tree

2 files changed

+90
-8
lines changed

2 files changed

+90
-8
lines changed

Diff for: search_commands.go

+32-7
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,19 @@ type FTAggregateWithCursor struct {
240240
}
241241

242242
type FTAggregateOptions struct {
243-
Verbatim bool
244-
LoadAll bool
245-
Load []FTAggregateLoad
246-
Timeout int
247-
GroupBy []FTAggregateGroupBy
248-
SortBy []FTAggregateSortBy
249-
SortByMax int
243+
Verbatim bool
244+
LoadAll bool
245+
Load []FTAggregateLoad
246+
Timeout int
247+
GroupBy []FTAggregateGroupBy
248+
SortBy []FTAggregateSortBy
249+
SortByMax int
250+
// Scorer is used to set scoring function, if not set passed, a default will be used.
251+
// The default scorer depends on the Redis version:
252+
// - `BM25` for Redis >= 8
253+
// - `TFIDF` for Redis < 8
254+
Scorer string
255+
AddScores bool
250256
Apply []FTAggregateApply
251257
LimitOffset int
252258
Limit int
@@ -505,15 +511,26 @@ func FTAggregateQuery(query string, options *FTAggregateOptions) AggregateQuery
505511
}
506512
}
507513
}
514+
508515
if options.Timeout > 0 {
509516
queryArgs = append(queryArgs, "TIMEOUT", options.Timeout)
510517
}
518+
511519
for _, apply := range options.Apply {
512520
queryArgs = append(queryArgs, "APPLY", apply.Field)
513521
if apply.As != "" {
514522
queryArgs = append(queryArgs, "AS", apply.As)
515523
}
516524
}
525+
526+
if options.Scorer != "" {
527+
queryArgs = append(queryArgs, "SCORER", options.Scorer)
528+
}
529+
530+
if options.AddScores == true {
531+
queryArgs = append(queryArgs, "ADDSCORES")
532+
}
533+
517534
if options.GroupBy != nil {
518535
for _, groupBy := range options.GroupBy {
519536
queryArgs = append(queryArgs, "GROUPBY", len(groupBy.Fields))
@@ -581,6 +598,7 @@ func FTAggregateQuery(query string, options *FTAggregateOptions) AggregateQuery
581598
queryArgs = append(queryArgs, key, value)
582599
}
583600
}
601+
584602
if options.DialectVersion > 0 {
585603
queryArgs = append(queryArgs, "DIALECT", options.DialectVersion)
586604
}
@@ -706,6 +724,12 @@ func (c cmdable) FTAggregateWithArgs(ctx context.Context, index string, query st
706724
args = append(args, "AS", apply.As)
707725
}
708726
}
727+
if options.Scorer != "" {
728+
args = append(args, "SCORER", options.Scorer)
729+
}
730+
if options.AddScores == true {
731+
args = append(args, "ADDSCORES")
732+
}
709733
if options.GroupBy != nil {
710734
for _, groupBy := range options.GroupBy {
711735
args = append(args, "GROUPBY", len(groupBy.Fields))
@@ -779,6 +803,7 @@ func (c cmdable) FTAggregateWithArgs(ctx context.Context, index string, query st
779803
}
780804

781805
cmd := NewAggregateCmd(ctx, args...)
806+
cmd.err = cmd.err
782807
_ = c(ctx, cmd)
783808
return cmd
784809
}

Diff for: search_test.go

+58-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package redis_test
22

33
import (
44
"context"
5+
"fmt"
6+
"strconv"
57
"time"
68

79
. "github.com/bsm/ginkgo/v2"
@@ -640,6 +642,62 @@ var _ = Describe("RediSearch commands Resp 2", Label("search"), func() {
640642
Expect(res.Rows[0].Fields["t2"]).To(BeEquivalentTo("world"))
641643
})
642644

645+
FIt("should FTAggregate with scorer and addscores", Label("search", "ftaggregate"), func() {
646+
title := &redis.FieldSchema{FieldName: "title", FieldType: redis.SearchFieldTypeText, Sortable: false}
647+
description := &redis.FieldSchema{FieldName: "description", FieldType: redis.SearchFieldTypeText, Sortable: false}
648+
val, err := client.FTCreate(ctx, "idx1", &redis.FTCreateOptions{OnHash: true, Prefix: []interface{}{"product:"}}, title, description).Result()
649+
Expect(err).NotTo(HaveOccurred())
650+
Expect(val).To(BeEquivalentTo("OK"))
651+
WaitForIndexing(client, "idx1")
652+
653+
client.HSet(ctx, "product:1", "title", "New Gaming Laptop", "description", "this is not a desktop")
654+
client.HSet(ctx, "product:2", "title", "Super Old Not Gaming Laptop", "description", "this laptop is not a new laptop but it is a laptop")
655+
client.HSet(ctx, "product:3", "title", "Office PC", "description", "office desktop pc")
656+
657+
options := &redis.FTAggregateOptions{
658+
AddScores: true,
659+
Scorer: "BM25",
660+
SortBy: []redis.FTAggregateSortBy{{
661+
FieldName: "@__score",
662+
Desc: true,
663+
}},
664+
}
665+
666+
res, err := client.FTAggregateWithArgs(ctx, "idx1", "laptop", options).Result()
667+
Expect(err).NotTo(HaveOccurred())
668+
Expect(res).ToNot(BeNil())
669+
Expect(len(res.Rows)).To(BeEquivalentTo(2))
670+
score1, err := strconv.ParseFloat(fmt.Sprintf("%s", res.Rows[0].Fields["__score"]), 64)
671+
Expect(err).NotTo(HaveOccurred())
672+
score2, err := strconv.ParseFloat(fmt.Sprintf("%s", res.Rows[1].Fields["__score"]), 64)
673+
Expect(err).NotTo(HaveOccurred())
674+
Expect(score1).To(BeNumerically(">", score2))
675+
676+
optionsDM := &redis.FTAggregateOptions{
677+
AddScores: true,
678+
Scorer: "DISMAX",
679+
SortBy: []redis.FTAggregateSortBy{{
680+
FieldName: "@__score",
681+
Desc: true,
682+
}},
683+
}
684+
685+
resDM, err := client.FTAggregateWithArgs(ctx, "idx1", "laptop", optionsDM).Result()
686+
Expect(err).NotTo(HaveOccurred())
687+
Expect(resDM).ToNot(BeNil())
688+
Expect(len(resDM.Rows)).To(BeEquivalentTo(2))
689+
score1DM, err := strconv.ParseFloat(fmt.Sprintf("%s", resDM.Rows[0].Fields["__score"]), 64)
690+
Expect(err).NotTo(HaveOccurred())
691+
score2DM, err := strconv.ParseFloat(fmt.Sprintf("%s", resDM.Rows[1].Fields["__score"]), 64)
692+
Expect(err).NotTo(HaveOccurred())
693+
Expect(score1DM).To(BeNumerically(">", score2DM))
694+
695+
Expect(score1DM).To(BeEquivalentTo(float64(4)))
696+
Expect(score2DM).To(BeEquivalentTo(float64(1)))
697+
Expect(score1).NotTo(BeEquivalentTo(score1DM))
698+
Expect(score2).NotTo(BeEquivalentTo(score2DM))
699+
})
700+
643701
It("should FTAggregate apply and groupby", Label("search", "ftaggregate"), func() {
644702
text1 := &redis.FieldSchema{FieldName: "PrimaryKey", FieldType: redis.SearchFieldTypeText, Sortable: true}
645703
num1 := &redis.FieldSchema{FieldName: "CreatedDateTimeUTC", FieldType: redis.SearchFieldTypeNumeric, Sortable: true}
@@ -721,7 +779,6 @@ var _ = Describe("RediSearch commands Resp 2", Label("search"), func() {
721779
Expect(res.Rows[0].Fields["age"]).To(BeEquivalentTo("19"))
722780
Expect(res.Rows[1].Fields["age"]).To(BeEquivalentTo("25"))
723781
}
724-
725782
})
726783

727784
It("should FTSearch SkipInitialScan", Label("search", "ftsearch"), func() {

0 commit comments

Comments
 (0)