diff --git a/cmd/call_cached_check/check.go b/cmd/call_cached_check/check.go index f343808c5..a07309a47 100644 --- a/cmd/call_cached_check/check.go +++ b/cmd/call_cached_check/check.go @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2025 SAP SE or an SAP affiliate company and Greenhouse contributors +// SPDX-FileCopyrightText: 2026 SAP SE or an SAP affiliate company and Greenhouse contributors // SPDX-License-Identifier: Apache-2.0 // This is a helper command intended for use within Heureka. @@ -29,6 +29,27 @@ type result struct { want, got string } +// extractFuncName unwraps expressions like wrappers and returns the underlying function name. +func extractFuncName(expr ast.Expr) (string, bool) { + switch fn := expr.(type) { + case *ast.Ident: + return fn.Name, true + + case *ast.SelectorExpr: + return fn.Sel.Name, true + + case *ast.CallExpr: + // Assume the actual function is the last argument + if len(fn.Args) == 0 { + return "", false + } + + return extractFuncName(fn.Args[len(fn.Args)-1]) + } + + return "", false +} + func main() { dir := flag.String("dir", ".", "directory to scan (walked recursively)") @@ -79,7 +100,7 @@ func main() { if len(call.Args) < 4 { mismatches = append(mismatches, result{ - pos: fset.Position(call.Lparen), // position of '(' for the call + pos: fset.Position(call.Lparen), want: "at least 4 arguments", got: fmt.Sprintf("%d arguments", len(call.Args)), }) @@ -87,6 +108,7 @@ func main() { return true } + // 3rd argument: function name string cacheFuncNameArg := call.Args[2] strLit, ok := cacheFuncNameArg.(*ast.BasicLit) @@ -104,18 +126,12 @@ func main() { cacheFunc := call.Args[3] - var funcName string - - switch fn := cacheFunc.(type) { - case *ast.Ident: - funcName = fn.Name - case *ast.SelectorExpr: - funcName = fn.Sel.Name - default: + funcName, ok := extractFuncName(cacheFunc) + if !ok { mismatches = append(mismatches, result{ pos: fset.Position(cacheFunc.Pos()), - want: "unknown type", - got: fmt.Sprintf("%T", fn), + want: "callable function", + got: fmt.Sprintf("%T", cacheFunc), }) return true @@ -140,7 +156,7 @@ func main() { } for _, m := range mismatches { - fmt.Fprintf(os.Stderr, "%s: mismatch – want %q, got %q\n", + fmt.Fprintf(os.Stderr, "%s: mismatch - want %q, got %q\n", m.pos, m.want, m.got) } diff --git a/cmd/context_wrapper/wrap.go b/cmd/context_wrapper/wrap.go new file mode 100644 index 000000000..7b827e591 --- /dev/null +++ b/cmd/context_wrapper/wrap.go @@ -0,0 +1,79 @@ +// SPDX-FileCopyrightText: 2026 SAP SE or an SAP affiliate company and Greenhouse contributors +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "fmt" + "os" + "text/template" +) + +const maxArgs = 10 + +const tpl = ` +{{range .}} +func WrapContext{{.N}}[{{.TypeParams}}, R any]( + ctx context.Context, + f func(context.Context, {{.ArgsTypes}}) (R, error), +) func({{.ArgsTypes}}) (R, error) { + return func({{.ArgsWithNames}}) (R, error) { + return f(ctx, {{.ArgsNames}}) + } +} +{{end}} +` + +type Data struct { + N int + TypeParams string + ArgsTypes string + ArgsWithNames string + ArgsNames string +} + +//nolint:errcheck +func main() { + var data []Data + + for n := 1; n <= maxArgs; n++ { + var typeParams, argsTypes, argsWithNames, argsNames string + + for i := 1; i <= n; i++ { + if i > 1 { + typeParams += ", " + argsTypes += ", " + argsWithNames += ", " + argsNames += ", " + } + + typeParams += fmt.Sprintf("A%d any", i) + argsTypes += fmt.Sprintf("A%d", i) + argsWithNames += fmt.Sprintf("a%d A%d", i, i) + argsNames += fmt.Sprintf("a%d", i) + } + + data = append(data, Data{ + N: n, + TypeParams: typeParams, + ArgsTypes: argsTypes, + ArgsWithNames: argsWithNames, + ArgsNames: argsNames, + }) + } + + f, err := os.Create("context_wrapper_gen.go") + if err != nil { + panic(err) + } + defer f.Close() + + fmt.Fprintln(f, "// Code generated by go generate; DO NOT EDIT.") + fmt.Fprintln(f, "package cache") + fmt.Fprintln(f, "import \"context\"") + + t := template.Must(template.New("").Parse(tpl)) + if err := t.Execute(f, data); err != nil { + panic(err) + } +} diff --git a/internal/api/graphql/graph/baseResolver/component.go b/internal/api/graphql/graph/baseResolver/component.go index 7c4ae7346..4647f4660 100644 --- a/internal/api/graphql/graph/baseResolver/component.go +++ b/internal/api/graphql/graph/baseResolver/component.go @@ -140,7 +140,7 @@ func ComponentCcrnBaseResolver( opt := GetListOptions(requestedFields) - names, err := app.ListComponentCcrns(f, opt) + names, err := app.ListComponentCcrns(ctx, f, opt) if err != nil { return nil, NewResolverError("ComponentCcrnBaseReolver", err.Error()) } @@ -211,7 +211,7 @@ func ComponentIssueCountsBaseResolver( var severityCounts model.SeverityCounts - counts, err := app.GetComponentVulnerabilityCounts(f) + counts, err := app.GetComponentVulnerabilityCounts(ctx, f) if err != nil { return nil, ToGraphQLError(err) } diff --git a/internal/api/graphql/graph/baseResolver/component_instance.go b/internal/api/graphql/graph/baseResolver/component_instance.go index b9c1e179b..bc700f909 100644 --- a/internal/api/graphql/graph/baseResolver/component_instance.go +++ b/internal/api/graphql/graph/baseResolver/component_instance.go @@ -337,7 +337,7 @@ func ContextBaseResolver( opt := GetListOptions(requestedFields) - names, err := app.ListContexts(f, opt) + names, err := app.ListContexts(ctx, f, opt) if err != nil { return nil, ToGraphQLError(err) } @@ -357,7 +357,7 @@ func ContextBaseResolver( } func ComponentInstanceFilterBaseResolver( - appCall func(filter *entity.ComponentInstanceFilter, options *entity.ListOptions) ([]string, error), + appCall func(ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions) ([]string, error), ctx context.Context, filter *model.ComponentInstanceFilter, filterDisplay *string, @@ -390,7 +390,7 @@ func ComponentInstanceFilterBaseResolver( opt := GetListOptions(requestedFields) - names, err := appCall(f, opt) + names, err := appCall(ctx, f, opt) if err != nil { return nil, ToGraphQLError(err) } diff --git a/internal/api/graphql/graph/baseResolver/image.go b/internal/api/graphql/graph/baseResolver/image.go index d199248ee..51302881f 100644 --- a/internal/api/graphql/graph/baseResolver/image.go +++ b/internal/api/graphql/graph/baseResolver/image.go @@ -84,7 +84,7 @@ func ImageBaseResolver( Repository: filter.Repository, } - counts, err := app.GetComponentVulnerabilityCounts(icFilter) + counts, err := app.GetComponentVulnerabilityCounts(ctx, icFilter) if err != nil { return nil, NewResolverError("ImageBaseResolver", err.Error()) } diff --git a/internal/api/graphql/graph/baseResolver/image_version.go b/internal/api/graphql/graph/baseResolver/image_version.go index 33848d8d1..0fd14d30b 100644 --- a/internal/api/graphql/graph/baseResolver/image_version.go +++ b/internal/api/graphql/graph/baseResolver/image_version.go @@ -134,7 +134,7 @@ func ImageVersionBaseResolver( ComponentVersionId: cvIds, } - counts, err := app.GetIssueSeverityCounts(icFilter) + counts, err := app.GetIssueSeverityCounts(ctx, icFilter) if err != nil { return nil, NewResolverError("ImageVersionBaseResolver", err.Error()) } diff --git a/internal/api/graphql/graph/baseResolver/issue.go b/internal/api/graphql/graph/baseResolver/issue.go index 440541f0a..939201676 100644 --- a/internal/api/graphql/graph/baseResolver/issue.go +++ b/internal/api/graphql/graph/baseResolver/issue.go @@ -53,7 +53,7 @@ func SingleIssueBaseResolver( opt := &entity.IssueListOptions{} - issues, err := app.ListIssues(f, opt) + issues, err := app.ListIssues(ctx, f, opt) if err != nil { return nil, ToGraphQLError(err) } @@ -175,7 +175,7 @@ func IssueBaseResolver( } } - issues, err := app.ListIssues(f, opt) + issues, err := app.ListIssues(ctx, f, opt) if err != nil { return nil, ToGraphQLError(err) } @@ -267,7 +267,7 @@ func IssueNameBaseResolver( opt := GetListOptions(requestedFields) - names, err := app.ListIssueNames(f, opt) + names, err := app.ListIssueNames(ctx, f, opt) if err != nil { return nil, ToGraphQLError(err) } @@ -370,7 +370,7 @@ func IssueCountsBaseResolver( Unique: unique, } - counts, err := app.GetIssueSeverityCounts(f) + counts, err := app.GetIssueSeverityCounts(ctx, f) if err != nil { return nil, ToGraphQLError(err) } diff --git a/internal/api/graphql/graph/baseResolver/issue_repository.go b/internal/api/graphql/graph/baseResolver/issue_repository.go index 6bb0f5598..27bdc3c1f 100644 --- a/internal/api/graphql/graph/baseResolver/issue_repository.go +++ b/internal/api/graphql/graph/baseResolver/issue_repository.go @@ -37,7 +37,7 @@ func SingleIssueRepositoryBaseResolver( opt := &entity.ListOptions{} - issueRepositories, err := app.ListIssueRepositories(f, opt) + issueRepositories, err := app.ListIssueRepositories(ctx, f, opt) // error while fetching if err != nil { return nil, NewResolverError("SingleIssueRepositoryBaseResolver", err.Error()) @@ -113,7 +113,7 @@ func IssueRepositoryBaseResolver( opt := GetListOptions(requestedFields) - issueRepositories, err := app.ListIssueRepositories(f, opt) + issueRepositories, err := app.ListIssueRepositories(ctx, f, opt) if err != nil { return nil, NewResolverError("IssueRepositoryBaseResolver", err.Error()) } diff --git a/internal/api/graphql/graph/baseResolver/issue_variant.go b/internal/api/graphql/graph/baseResolver/issue_variant.go index 982010cb4..2a6926096 100644 --- a/internal/api/graphql/graph/baseResolver/issue_variant.go +++ b/internal/api/graphql/graph/baseResolver/issue_variant.go @@ -36,7 +36,7 @@ func SingleIssueVariantBaseResolver( opt := &entity.ListOptions{} - variants, err := app.ListIssueVariants(f, opt) + variants, err := app.ListIssueVariants(ctx, f, opt) // error while fetching if err != nil { return nil, NewResolverError("SingleIssueVariantBaseResolver", err.Error()) @@ -117,7 +117,7 @@ func IssueVariantBaseResolver( opt := GetListOptions(requestedFields) - variants, err := app.ListIssueVariants(f, opt) + variants, err := app.ListIssueVariants(ctx, f, opt) if err != nil { return nil, NewResolverError("IssueVariantBaseResolver", err.Error()) } @@ -195,7 +195,7 @@ func EffectiveIssueVariantBaseResolver( opt := GetListOptions(requestedFields) - variants, err := app.ListEffectiveIssueVariants(f, opt) + variants, err := app.ListEffectiveIssueVariants(ctx, f, opt) if err != nil { return nil, NewResolverError("EffectiveIssueVariantBaseResolver", err.Error()) } diff --git a/internal/api/graphql/graph/baseResolver/patch.go b/internal/api/graphql/graph/baseResolver/patch.go index c9be03e59..2d94b751e 100644 --- a/internal/api/graphql/graph/baseResolver/patch.go +++ b/internal/api/graphql/graph/baseResolver/patch.go @@ -65,7 +65,7 @@ func PatchBaseResolver( opt := GetListOptions(requestedFields) - patches, err := app.ListPatches(f, opt) + patches, err := app.ListPatches(ctx, f, opt) if err != nil { return nil, ToGraphQLError(err) } diff --git a/internal/api/graphql/graph/baseResolver/remediation.go b/internal/api/graphql/graph/baseResolver/remediation.go index 7fea75e6b..c0ca5b7e2 100644 --- a/internal/api/graphql/graph/baseResolver/remediation.go +++ b/internal/api/graphql/graph/baseResolver/remediation.go @@ -84,7 +84,7 @@ func RemediationBaseResolver( ) } - remediations, err := app.ListRemediations(f, opt) + remediations, err := app.ListRemediations(ctx, f, opt) if err != nil { return nil, ToGraphQLError(err) } diff --git a/internal/api/graphql/graph/baseResolver/service.go b/internal/api/graphql/graph/baseResolver/service.go index dc0171b55..100cf3468 100644 --- a/internal/api/graphql/graph/baseResolver/service.go +++ b/internal/api/graphql/graph/baseResolver/service.go @@ -248,7 +248,7 @@ func ServiceRegionBaseResolver( } func ServiceFilterBaseResolver( - appCall func(filter *entity.ServiceFilter, opt *entity.ListOptions) ([]string, error), + appCall func(ctx context.Context, filter *entity.ServiceFilter, opt *entity.ListOptions) ([]string, error), ctx context.Context, filter *model.ServiceFilter, filterDisplay *string, @@ -274,7 +274,7 @@ func ServiceFilterBaseResolver( opt := GetListOptions(requestedFields) - names, err := appCall(f, opt) + names, err := appCall(ctx, f, opt) if err != nil { return nil, NewResolverError("ServiceFilterBaseResolver", err.Error()) } diff --git a/internal/api/graphql/graph/baseResolver/support_group.go b/internal/api/graphql/graph/baseResolver/support_group.go index 1db0b0974..f7aabc4f1 100644 --- a/internal/api/graphql/graph/baseResolver/support_group.go +++ b/internal/api/graphql/graph/baseResolver/support_group.go @@ -146,7 +146,7 @@ func SupportGroupCcrnBaseResolver( opt := GetListOptions(requestedFields) - names, err := app.ListSupportGroupCcrns(f, opt) + names, err := app.ListSupportGroupCcrns(ctx, f, opt) if err != nil { return nil, NewResolverError("SupportGroupCcrnBaseResolver", err.Error()) } diff --git a/internal/api/graphql/graph/baseResolver/user.go b/internal/api/graphql/graph/baseResolver/user.go index aafe5538f..bcb677a7a 100644 --- a/internal/api/graphql/graph/baseResolver/user.go +++ b/internal/api/graphql/graph/baseResolver/user.go @@ -169,7 +169,7 @@ func UserNameBaseResolver( opt := GetListOptions(requestedFields) - names, err := app.ListUserNames(f, opt) + names, err := app.ListUserNames(ctx, f, opt) if err != nil { return nil, NewResolverError("UserNameBaseResolver", err.Error()) } @@ -211,7 +211,7 @@ func UniqueUserIDBaseResolver( opt := GetListOptions(requestedFields) - names, err := app.ListUniqueUserIDs(f, opt) + names, err := app.ListUniqueUserIDs(ctx, f, opt) if err != nil { return nil, NewResolverError("UniqueUserIDBaseResolver", err.Error()) } @@ -253,7 +253,7 @@ func UserNameWithIdBaseResolver( opt := GetListOptions(requestedFields) - names, ids, err := app.ListUserNamesAndIds(f, opt) + names, ids, err := app.ListUserNamesAndIds(ctx, f, opt) if err != nil { return nil, NewResolverError("UserNameWithIdBaseResolver", err.Error()) } diff --git a/internal/api/graphql/graph/baseResolver/vulnerability.go b/internal/api/graphql/graph/baseResolver/vulnerability.go index b467d4e25..d67d27249 100644 --- a/internal/api/graphql/graph/baseResolver/vulnerability.go +++ b/internal/api/graphql/graph/baseResolver/vulnerability.go @@ -76,12 +76,13 @@ func VulnerabilityBaseResolver(app app.Heureka, ctx context.Context, By: entity.IssueVariantRating, Direction: entity.OrderDirectionDesc, }) + opt.Order = append(opt.Order, entity.Order{ By: entity.IssuePrimaryName, Direction: entity.OrderDirectionAsc, }) - issues, err := app.ListIssues(f, opt) + issues, err := app.ListIssues(ctx, f, opt) if err != nil { return nil, NewResolverError("VulnerabilityBaseResolver", err.Error()) } diff --git a/internal/api/graphql/graph/resolver/mutation.go b/internal/api/graphql/graph/resolver/mutation.go index 2aa55ea21..aca755e54 100644 --- a/internal/api/graphql/graph/resolver/mutation.go +++ b/internal/api/graphql/graph/resolver/mutation.go @@ -809,7 +809,7 @@ func (r *mutationResolver) AddComponentVersionToIssue(ctx context.Context, issue ) } - issue, err := r.App.AddComponentVersionToIssue(*issueIdInt, *componentVersionIdInt) + issue, err := r.App.AddComponentVersionToIssue(ctx, *issueIdInt, *componentVersionIdInt) if err != nil { return nil, baseResolver.NewResolverError( "AddComponentVersionToIssueMutationResolver", @@ -840,7 +840,7 @@ func (r *mutationResolver) RemoveComponentVersionFromIssue(ctx context.Context, ) } - issue, err := r.App.RemoveComponentVersionFromIssue(*issueIdInt, *componentVersionIdInt) + issue, err := r.App.RemoveComponentVersionFromIssue(ctx, *issueIdInt, *componentVersionIdInt) if err != nil { return nil, baseResolver.NewResolverError( "RemoveComponentVersionFromIssueMutationResolver", @@ -1058,6 +1058,7 @@ func (r *mutationResolver) CreateRemediation(ctx context.Context, input model.Re // fetch issue id for given issue name issueResult, err := r.App.ListIssues( + ctx, &entity.IssueFilter{PrimaryName: []*string{input.Vulnerability}}, nil, ) @@ -1071,7 +1072,7 @@ func (r *mutationResolver) CreateRemediation(ctx context.Context, input model.Re remediation.IssueId = issueResult.Elements[0].Issue.Id if input.RemediatedBy != nil { - userUniqueUserIDs, err := r.App.ListUniqueUserIDs(&entity.UserFilter{ + userUniqueUserIDs, err := r.App.ListUniqueUserIDs(ctx, &entity.UserFilter{ UniqueUserID: []*string{input.RemediatedBy}, }, nil) if err != nil { @@ -1162,6 +1163,7 @@ func (r *mutationResolver) UpdateRemediation(ctx context.Context, id string, inp if input.Vulnerability != nil { // fetch issue id for given issue name issueResult, err := r.App.ListIssues( + ctx, &entity.IssueFilter{PrimaryName: []*string{input.Vulnerability}}, nil, ) diff --git a/internal/api/graphql/graph/resolver/mutation_helpers.go b/internal/api/graphql/graph/resolver/mutation_helpers.go index 9ce698597..a1c1a6d0f 100644 --- a/internal/api/graphql/graph/resolver/mutation_helpers.go +++ b/internal/api/graphql/graph/resolver/mutation_helpers.go @@ -176,6 +176,7 @@ func (r *mutationResolver) getOrCreateIssueAndVariant( if input.URL != nil && *input.URL != "" { if input.Name != nil && *input.Name != "" { ivs, err := r.App.ListIssueVariants( + ctx, &entity.IssueVariantFilter{SecondaryName: []*string{input.Name}}, &entity.ListOptions{}, ) @@ -212,7 +213,7 @@ func (r *mutationResolver) getOrCreateIssueAndVariant( f := &entity.IssueFilter{PrimaryName: []*string{input.Name}} lo := entity.IssueListOptions{ListOptions: *entity.NewListOptions()} - issues, ierr := r.App.ListIssues(f, &lo) + issues, ierr := r.App.ListIssues(ctx, f, &lo) if ierr != nil || len(issues.Elements) == 0 { return nil, nil, baseResolver.NewResolverError( "CreateSIEMAlertMutationResolver", @@ -234,7 +235,7 @@ func (r *mutationResolver) getOrCreateIssueAndVariant( Name: []*string{&siemRepoName}, } - repositories, err := r.App.ListIssueRepositories(&repoFilter, &entity.ListOptions{}) + repositories, err := r.App.ListIssueRepositories(ctx, &repoFilter, &entity.ListOptions{}) var issueRepositoryId int64 if err == nil && len(repositories.Elements) > 0 { @@ -299,7 +300,7 @@ func (r *mutationResolver) getOrCreateIssueAndVariant( issueVariant = newIv } else { - iss, err := r.App.GetIssue(issueVariant.IssueId) + iss, err := r.App.GetIssue(ctx, issueVariant.IssueId) if err != nil { return nil, nil, baseResolver.NewResolverError( "CreateSIEMAlertMutationResolver", diff --git a/internal/api/graphql/server.go b/internal/api/graphql/server.go index a3c789467..8ccd74d13 100644 --- a/internal/api/graphql/server.go +++ b/internal/api/graphql/server.go @@ -55,7 +55,7 @@ func (g *GraphQLAPI) graphqlHandler() gin.HandlerFunc { g.Server.AroundOperations(g.batchLimiter.Middleware()) return func(c *gin.Context) { - g.Server.ServeHTTP(c.Writer, c.Request) + g.Server.ServeHTTP(c.Writer, c.Request.WithContext(c.Request.Context())) } } @@ -63,6 +63,6 @@ func (g *GraphQLAPI) playgroundHandler() gin.HandlerFunc { h := playground.Handler("GraphQL", "/query") return func(c *gin.Context) { - h.ServeHTTP(c.Writer, c.Request) + h.ServeHTTP(c.Writer, c.Request.WithContext(c.Request.Context())) } } diff --git a/internal/app/common/user_id.go b/internal/app/common/user_id.go index e7e2d641e..26c661853 100644 --- a/internal/app/common/user_id.go +++ b/internal/app/common/user_id.go @@ -25,9 +25,9 @@ func GetCurrentUserId(ctx context.Context, db database.Database) (int64, error) return 0, fmt.Errorf("could not get user name from context: %w", err) } - return getUserIdFromDb(db, uniqueUserId) + return getUserIdFromDb(ctx, db, uniqueUserId) } else { - return getUserIdFromDb(db, systemUserUniqueUserId) + return getUserIdFromDb(ctx, db, systemUserUniqueUserId) } } @@ -49,13 +49,13 @@ func GetUserIdByUniqueId( db database.Database, uniqueUserId string, ) (int64, error) { - return getUserIdFromDb(db, uniqueUserId) + return getUserIdFromDb(ctx, db, uniqueUserId) } -func getUserIdFromDb(db database.Database, uniqueUserId string) (int64, error) { +func getUserIdFromDb(ctx context.Context, db database.Database, uniqueUserId string) (int64, error) { filter := &entity.UserFilter{UniqueUserID: []*string{&uniqueUserId}} - ids, err := db.GetAllUserIds(filter) + ids, err := db.GetAllUserIds(ctx, filter) if err != nil { return unknownUser, fmt.Errorf("unable to get user ids %w", err) } else if len(ids) < 1 { diff --git a/internal/app/component/component_handler.go b/internal/app/component/component_handler.go index 2bb8f3096..65598285a 100644 --- a/internal/app/component/component_handler.go +++ b/internal/app/component/component_handler.go @@ -100,7 +100,7 @@ func (cs *componentHandler) ListComponents( // Update the filter.Id based on accessibleComponentIds filter.Id = common.CombineFilterWithAccessibleIds(filter.Id, accessibleComponentIds) - res, err := cs.database.GetComponents(filter, options.Order) + res, err := cs.database.GetComponents(ctx, filter, options.Order) if err != nil { wrappedErr := appErrors.InternalError(string(op), "Components", "", err) applog.LogError(cs.logger, wrappedErr, logrus.Fields{ @@ -116,7 +116,7 @@ func (cs *componentHandler) ListComponents( cs.cache, CacheTtlGetAllComponentCursors, "GetAllComponentCursors", - cs.database.GetAllComponentCursors, + cache.WrapContext2(ctx, cs.database.GetAllComponentCursors), filter, options.Order, ) @@ -137,7 +137,7 @@ func (cs *componentHandler) ListComponents( cs.cache, CacheTtlCountComponents, "CountComponents", - cs.database.CountComponents, + cache.WrapContext1(ctx, cs.database.CountComponents), filter, ) if err != nil { @@ -281,6 +281,7 @@ func (cs *componentHandler) DeleteComponent(ctx context.Context, id int64) error } func (cs *componentHandler) ListComponentCcrns( + ctx context.Context, filter *entity.ComponentFilter, options *entity.ListOptions, ) ([]string, error) { @@ -293,7 +294,7 @@ func (cs *componentHandler) ListComponentCcrns( cs.cache, CacheTtlGetComponentCcrns, "GetComponentCcrns", - cs.database.GetComponentCcrns, + cache.WrapContext1(ctx, cs.database.GetComponentCcrns), filter, ) if err != nil { @@ -309,6 +310,7 @@ func (cs *componentHandler) ListComponentCcrns( } func (cs *componentHandler) GetComponentVulnerabilityCounts( + ctx context.Context, filter *entity.ComponentFilter, ) ([]entity.IssueSeverityCounts, error) { l := logrus.WithFields(logrus.Fields{ @@ -316,7 +318,7 @@ func (cs *componentHandler) GetComponentVulnerabilityCounts( "filter": filter, }) - counts, err := cs.database.CountComponentVulnerabilities(filter) + counts, err := cs.database.CountComponentVulnerabilities(ctx, filter) if err != nil { l.Error(err) diff --git a/internal/app/component/component_handler_interface.go b/internal/app/component/component_handler_interface.go index f34ab7ba5..e06db8ef7 100644 --- a/internal/app/component/component_handler_interface.go +++ b/internal/app/component/component_handler_interface.go @@ -18,6 +18,6 @@ type ComponentHandler interface { CreateComponent(context.Context, *entity.Component) (*entity.Component, error) UpdateComponent(context.Context, *entity.Component) (*entity.Component, error) DeleteComponent(context.Context, int64) error - ListComponentCcrns(*entity.ComponentFilter, *entity.ListOptions) ([]string, error) - GetComponentVulnerabilityCounts(*entity.ComponentFilter) ([]entity.IssueSeverityCounts, error) + ListComponentCcrns(context.Context, *entity.ComponentFilter, *entity.ListOptions) ([]string, error) + GetComponentVulnerabilityCounts(context.Context, *entity.ComponentFilter) ([]entity.IssueSeverityCounts, error) } diff --git a/internal/app/component/component_handler_test.go b/internal/app/component/component_handler_test.go index 4db3d7e5d..2ec75f251 100644 --- a/internal/app/component/component_handler_test.go +++ b/internal/app/component/component_handler_test.go @@ -85,9 +85,9 @@ var _ = Describe("When listing Components", Label("app", "ListComponents"), func When("the list option does include the totalCount", func() { BeforeEach(func() { options.ShowTotalCount = true - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetComponents", filter, []entity.Order{}).Return([]entity.ComponentResult{}, nil) - db.On("CountComponents", filter).Return(int64(1337), nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetComponents", mock.Anything, filter, []entity.Order{}).Return([]entity.ComponentResult{}, nil) + db.On("CountComponents", mock.Anything, filter).Return(int64(1337), nil) }) It("shows the total count in the results", func() { @@ -144,9 +144,9 @@ var _ = Describe("When listing Components", Label("app", "ListComponents"), func ) cursors = append(cursors, c) } - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetComponents", filter, []entity.Order{}).Return(components, nil) - db.On("GetAllComponentCursors", filter, []entity.Order{}).Return(cursors, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetComponents", mock.Anything, filter, []entity.Order{}).Return(components, nil) + db.On("GetAllComponentCursors", mock.Anything, filter, []entity.Order{}).Return(cursors, nil) componentHandler = c.NewComponentHandler(handlerContext) res, err := componentHandler.ListComponents(ctx, filter, options) Expect(err).To(BeNil(), "no error should be thrown") @@ -195,8 +195,8 @@ var _ = Describe("When listing Components", Label("app", "ListComponents"), func BeforeEach(func() { compIds := int64(-1) filter.Id = []*int64{&compIds} - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetComponents", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetComponents", mock.Anything, filter, []entity.Order{}). Return([]entity.ComponentResult{}, nil) }) @@ -216,8 +216,8 @@ var _ = Describe("When listing Components", Label("app", "ListComponents"), func systemUserId := int64(1) component = test.NewFakeComponentEntity() filter.Id = []*int64{&component.Id} - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetComponents", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetComponents", mock.Anything, filter, []entity.Order{}). Return([]entity.ComponentResult{{Component: &component}}, nil) relations := []openfga.RelationInput{ @@ -285,9 +285,9 @@ var _ = Describe("When creating Component", Label("app", "CreateComponent"), fun It("creates component", func() { filter.CCRN = []*string{&component.CCRN} - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("CreateComponent", &component).Return(&component, nil) - db.On("GetComponents", filter, []entity.Order{}).Return([]entity.ComponentResult{}, nil) + db.On("GetComponents", mock.Anything, filter, []entity.Order{}).Return([]entity.ComponentResult{}, nil) componentHandler = c.NewComponentHandler(handlerContext) newComponent, err := componentHandler.CreateComponent(common.NewAdminContext(), &component) Expect(err).To(BeNil(), "no error should be thrown") @@ -372,12 +372,12 @@ var _ = Describe("When updating Component", Label("app", "UpdateComponent"), fun }) It("updates component", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("UpdateComponent", component.Component).Return(nil) componentHandler = c.NewComponentHandler(handlerContext) component.CCRN = "NewComponent" filter.Id = []*int64{&component.Id} - db.On("GetComponents", filter, []entity.Order{}). + db.On("GetComponents", mock.Anything, filter, []entity.Order{}). Return([]entity.ComponentResult{component}, nil) updatedComponent, err := componentHandler.UpdateComponent( common.NewAdminContext(), @@ -421,10 +421,10 @@ var _ = Describe("When deleting Component", Label("app", "DeleteComponent"), fun }) It("deletes component", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("DeleteComponent", id, mock.Anything).Return(nil) componentHandler = c.NewComponentHandler(handlerContext) - db.On("GetComponents", filter, []entity.Order{}).Return([]entity.ComponentResult{}, nil) + db.On("GetComponents", mock.Anything, filter, []entity.Order{}).Return([]entity.ComponentResult{}, nil) err := componentHandler.DeleteComponent(common.NewAdminContext(), id) Expect(err).To(BeNil(), "no error should be thrown") diff --git a/internal/app/component_instance/component_instance_handler.go b/internal/app/component_instance/component_instance_handler.go index d79494cdb..9b6c6d799 100644 --- a/internal/app/component_instance/component_instance_handler.go +++ b/internal/app/component_instance/component_instance_handler.go @@ -103,7 +103,7 @@ func (ci *componentInstanceHandler) ListComponentInstances( ci.cache, CacheTtlGetComponentInstances, "GetComponentInstances", - ci.database.GetComponentInstances, + cache.WrapContext2(ctx, ci.database.GetComponentInstances), filter, options.Order, ) @@ -122,7 +122,7 @@ func (ci *componentInstanceHandler) ListComponentInstances( ci.cache, CacheTtlGetAllComponentInstanceCursors, "GetAllComponentInstanceCursors", - ci.database.GetAllComponentInstanceCursors, + cache.WrapContext2(ctx, ci.database.GetAllComponentInstanceCursors), filter, options.Order, ) @@ -148,7 +148,7 @@ func (ci *componentInstanceHandler) ListComponentInstances( ci.cache, CacheTtlCountComponentInstances, "CountComponentInstances", - ci.database.CountComponentInstances, + cache.WrapContext1(ctx, ci.database.CountComponentInstances), filter, ) if err != nil { @@ -551,12 +551,13 @@ func (ci *componentInstanceHandler) DeleteComponentInstance(ctx context.Context, } func (ci *componentInstanceHandler) ListRegions( + ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions, ) ([]string, error) { op := appErrors.Op("componentInstanceHandler.ListRegions") - regions, err := ci.database.GetRegion(filter) + regions, err := ci.database.GetRegion(ctx, filter) if err != nil { wrappedErr := appErrors.InternalError(string(op), "ComponentInstanceRegions", "", err) applog.LogError(ci.logger, wrappedErr, logrus.Fields{ @@ -575,12 +576,13 @@ func (ci *componentInstanceHandler) ListRegions( } func (ci *componentInstanceHandler) ListCcrns( + ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions, ) ([]string, error) { op := appErrors.Op("componentInstanceHandler.ListCcrns") - ccrns, err := ci.database.GetCcrn(filter) + ccrns, err := ci.database.GetCcrn(ctx, filter) if err != nil { wrappedErr := appErrors.InternalError(string(op), "ComponentInstanceCcrns", "", err) applog.LogError(ci.logger, wrappedErr, logrus.Fields{ @@ -599,12 +601,13 @@ func (ci *componentInstanceHandler) ListCcrns( } func (ci *componentInstanceHandler) ListClusters( + ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions, ) ([]string, error) { op := appErrors.Op("componentInstanceHandler.ListClusters") - clusters, err := ci.database.GetCluster(filter) + clusters, err := ci.database.GetCluster(ctx, filter) if err != nil { wrappedErr := appErrors.InternalError(string(op), "ComponentInstanceClusters", "", err) applog.LogError(ci.logger, wrappedErr, logrus.Fields{ @@ -623,12 +626,13 @@ func (ci *componentInstanceHandler) ListClusters( } func (ci *componentInstanceHandler) ListNamespaces( + ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions, ) ([]string, error) { op := appErrors.Op("componentInstanceHandler.ListNamespaces") - namespaces, err := ci.database.GetNamespace(filter) + namespaces, err := ci.database.GetNamespace(ctx, filter) if err != nil { wrappedErr := appErrors.InternalError(string(op), "ComponentInstanceNamespaces", "", err) applog.LogError(ci.logger, wrappedErr, logrus.Fields{ @@ -647,12 +651,13 @@ func (ci *componentInstanceHandler) ListNamespaces( } func (ci *componentInstanceHandler) ListDomains( + ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions, ) ([]string, error) { op := appErrors.Op("componentInstanceHandler.ListDomains") - domains, err := ci.database.GetDomain(filter) + domains, err := ci.database.GetDomain(ctx, filter) if err != nil { wrappedErr := appErrors.InternalError(string(op), "ComponentInstanceDomains", "", err) applog.LogError(ci.logger, wrappedErr, logrus.Fields{ @@ -671,12 +676,13 @@ func (ci *componentInstanceHandler) ListDomains( } func (ci *componentInstanceHandler) ListProjects( + ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions, ) ([]string, error) { op := appErrors.Op("componentInstanceHandler.ListProjects") - projects, err := ci.database.GetProject(filter) + projects, err := ci.database.GetProject(ctx, filter) if err != nil { wrappedErr := appErrors.InternalError(string(op), "ComponentInstanceProjects", "", err) applog.LogError(ci.logger, wrappedErr, logrus.Fields{ @@ -695,12 +701,13 @@ func (ci *componentInstanceHandler) ListProjects( } func (ci *componentInstanceHandler) ListPods( + ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions, ) ([]string, error) { op := appErrors.Op("componentInstanceHandler.ListPods") - pods, err := ci.database.GetPod(filter) + pods, err := ci.database.GetPod(ctx, filter) if err != nil { wrappedErr := appErrors.InternalError(string(op), "ComponentInstancePods", "", err) applog.LogError(ci.logger, wrappedErr, logrus.Fields{ @@ -719,12 +726,13 @@ func (ci *componentInstanceHandler) ListPods( } func (ci *componentInstanceHandler) ListContainers( + ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions, ) ([]string, error) { op := appErrors.Op("componentInstanceHandler.ListContainers") - containers, err := ci.database.GetContainer(filter) + containers, err := ci.database.GetContainer(ctx, filter) if err != nil { wrappedErr := appErrors.InternalError(string(op), "ComponentInstanceContainers", "", err) applog.LogError(ci.logger, wrappedErr, logrus.Fields{ @@ -743,12 +751,13 @@ func (ci *componentInstanceHandler) ListContainers( } func (ci *componentInstanceHandler) ListTypes( + ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions, ) ([]string, error) { op := appErrors.Op("componentInstanceHandler.ListTypes") - types, err := ci.database.GetType(filter) + types, err := ci.database.GetType(ctx, filter) if err != nil { wrappedErr := appErrors.InternalError(string(op), "ComponentInstanceTypes", "", err) applog.LogError(ci.logger, wrappedErr, logrus.Fields{ @@ -767,12 +776,13 @@ func (ci *componentInstanceHandler) ListTypes( } func (ci *componentInstanceHandler) ListParents( + ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions, ) ([]string, error) { op := appErrors.Op("componentInstanceHandler.ListParents") - parents, err := ci.database.GetComponentInstanceParent(filter) + parents, err := ci.database.GetComponentInstanceParent(ctx, filter) if err != nil { wrappedErr := appErrors.InternalError(string(op), "ComponentInstanceParents", "", err) applog.LogError(ci.logger, wrappedErr, logrus.Fields{ @@ -791,12 +801,13 @@ func (ci *componentInstanceHandler) ListParents( } func (ci *componentInstanceHandler) ListContexts( + ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions, ) ([]string, error) { op := appErrors.Op("componentInstanceHandler.ListContexts") - contexts, err := ci.database.GetContext(filter) + contexts, err := ci.database.GetContext(ctx, filter) if err != nil { wrappedErr := appErrors.InternalError(string(op), "ComponentInstanceContexts", "", err) applog.LogError(ci.logger, wrappedErr, logrus.Fields{ diff --git a/internal/app/component_instance/component_instance_handler_interface.go b/internal/app/component_instance/component_instance_handler_interface.go index 9587ed7f0..49bb1cbfe 100644 --- a/internal/app/component_instance/component_instance_handler_interface.go +++ b/internal/app/component_instance/component_instance_handler_interface.go @@ -26,38 +26,46 @@ type ComponentInstanceHandler interface { *string, ) (*entity.ComponentInstance, error) DeleteComponentInstance(context.Context, int64) error - ListCcrns(filter *entity.ComponentInstanceFilter, options *entity.ListOptions) ([]string, error) + ListCcrns(ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions) ([]string, error) ListRegions( + ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions, ) ([]string, error) ListClusters( + ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions, ) ([]string, error) ListNamespaces( + ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions, ) ([]string, error) ListDomains( + ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions, ) ([]string, error) ListProjects( + ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions, ) ([]string, error) - ListPods(filter *entity.ComponentInstanceFilter, options *entity.ListOptions) ([]string, error) + ListPods(ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions) ([]string, error) ListContainers( + ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions, ) ([]string, error) - ListTypes(filter *entity.ComponentInstanceFilter, options *entity.ListOptions) ([]string, error) + ListTypes(ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions) ([]string, error) ListParents( + ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions, ) ([]string, error) ListContexts( + ctx context.Context, filter *entity.ComponentInstanceFilter, options *entity.ListOptions, ) ([]string, error) diff --git a/internal/app/component_instance/component_instance_handler_test.go b/internal/app/component_instance/component_instance_handler_test.go index d5f6129d8..0dbd755f2 100644 --- a/internal/app/component_instance/component_instance_handler_test.go +++ b/internal/app/component_instance/component_instance_handler_test.go @@ -87,10 +87,10 @@ var _ = Describe( When("the list option does include the totalCount", func() { BeforeEach(func() { options.ShowTotalCount = true - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetComponentInstances", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetComponentInstances", mock.Anything, filter, []entity.Order{}). Return([]entity.ComponentInstanceResult{}, nil) - db.On("CountComponentInstances", filter).Return(int64(1337), nil) + db.On("CountComponentInstances", mock.Anything, filter).Return(int64(1337), nil) }) It("shows the total count in the results", func() { @@ -148,10 +148,10 @@ var _ = Describe( ) cursors = append(cursors, c) } - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetComponentInstances", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetComponentInstances", mock.Anything, filter, []entity.Order{}). Return(componentInstances, nil) - db.On("GetAllComponentInstanceCursors", filter, []entity.Order{}). + db.On("GetAllComponentInstanceCursors", mock.Anything, filter, []entity.Order{}). Return(cursors, nil) componentInstanceHandler = ci.NewComponentInstanceHandler(handlerContext) componentInstanceHandler = ci.NewComponentInstanceHandler(handlerContext) @@ -197,8 +197,8 @@ var _ = Describe( It("should return Internal error", func() { // Mock database error dbError := errors.New("database connection failed") - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetComponentInstances", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetComponentInstances", mock.Anything, filter, []entity.Order{}). Return([]entity.ComponentInstanceResult{}, dbError) componentInstanceHandler = ci.NewComponentInstanceHandler(handlerContext) @@ -243,11 +243,11 @@ var _ = Describe( }) } - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetComponentInstances", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetComponentInstances", mock.Anything, filter, []entity.Order{}). Return(componentInstances, nil) cursorsError := errors.New("cursor database error") - db.On("GetAllComponentInstanceCursors", filter, []entity.Order{}). + db.On("GetAllComponentInstanceCursors", mock.Anything, filter, []entity.Order{}). Return([]string{}, cursorsError) componentInstanceHandler = ci.NewComponentInstanceHandler(handlerContext) @@ -289,8 +289,8 @@ var _ = Describe( BeforeEach(func() { serviceIds := int64(-1) filter.ServiceId = []*int64{&serviceIds} - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetComponentInstances", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetComponentInstances", mock.Anything, filter, []entity.Order{}). Return([]entity.ComponentInstanceResult{}, nil) }) @@ -318,8 +318,8 @@ var _ = Describe( systemUserId := int64(1) filter.ServiceId = []*int64{&serviceId} componentInstance = test.NewFakeComponentInstanceEntity() - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetComponentInstances", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetComponentInstances", mock.Anything, filter, []entity.Order{}). Return([]entity.ComponentInstanceResult{{ComponentInstance: &componentInstance}}, nil) relations := []openfga.RelationInput{ @@ -415,7 +415,7 @@ var _ = Describe( Context("with valid input", func() { It("creates componentInstance", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{123}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{123}, nil) db.On("CreateComponentInstance", mock.AnythingOfType("*entity.ComponentInstance")). Return(&componentInstance, nil) @@ -546,7 +546,7 @@ var _ = Describe( }) Context("with valid input", func() { It("updates componentInstance", func() { - db.On("GetAllUserIds", mock.Anything). + db.On("GetAllUserIds", mock.Anything, mock.Anything). Return([]int64{123}, nil) // Changed: return actual user ID db.On("UpdateComponentInstance", componentInstance.ComponentInstance).Return(nil) @@ -567,7 +567,7 @@ var _ = Describe( componentInstance.Namespace, ) filter.Id = []*int64{&componentInstance.Id} - db.On("GetComponentInstances", filter, []entity.Order{}). + db.On("GetComponentInstances", mock.Anything, filter, []entity.Order{}). Return([]entity.ComponentInstanceResult{componentInstance}, nil) updatedComponentInstance, err := componentInstanceHandler.UpdateComponentInstance( common.NewAdminContext(), @@ -831,7 +831,7 @@ var _ = Describe( Context("with valid input", func() { It("deletes componentInstance", func() { - db.On("GetAllUserIds", mock.Anything). + db.On("GetAllUserIds", mock.Anything, mock.Anything). Return([]int64{123}, nil) // Changed: return actual user ID db.On("DeleteComponentInstance", id, int64(123)). @@ -839,7 +839,7 @@ var _ = Describe( // Changed: specify exact user ID componentInstanceHandler = ci.NewComponentInstanceHandler(handlerContext) componentInstanceHandler = ci.NewComponentInstanceHandler(handlerContext) - db.On("GetComponentInstances", filter, []entity.Order{}). + db.On("GetComponentInstances", mock.Anything, filter, []entity.Order{}). Return([]entity.ComponentInstanceResult{}, nil) err := componentInstanceHandler.DeleteComponentInstance( common.NewAdminContext(), @@ -1003,13 +1003,13 @@ var _ = Describe("When listing CCRN", Label("app", "ListCcrn"), func() { When("no filters are used", func() { BeforeEach(func() { - db.On("GetCcrn", filter).Return([]string{}, nil) + db.On("GetCcrn", mock.Anything, filter).Return([]string{}, nil) }) It("it return the results", func() { componentInstanceHandler = ci.NewComponentInstanceHandler(handlerContext) componentInstanceHandler = ci.NewComponentInstanceHandler(handlerContext) - res, err := componentInstanceHandler.ListCcrns(filter, options) + res, err := componentInstanceHandler.ListCcrns(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(res).Should(BeEmpty(), "return correct result") }) @@ -1021,12 +1021,12 @@ var _ = Describe("When listing CCRN", Label("app", "ListCcrn"), func() { CCRN: []*string{&CCRN}, } - db.On("GetCcrn", filter).Return([]string{CCRN}, nil) + db.On("GetCcrn", mock.Anything, filter).Return([]string{CCRN}, nil) }) It("returns filtered CCRN according to the CCRN type", func() { componentInstanceHandler = ci.NewComponentInstanceHandler(handlerContext) componentInstanceHandler = ci.NewComponentInstanceHandler(handlerContext) - res, err := componentInstanceHandler.ListCcrns(filter, options) + res, err := componentInstanceHandler.ListCcrns(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(res).Should(ConsistOf(CCRN), "should only consist of CCRN") }) @@ -1037,11 +1037,11 @@ var _ = Describe("When listing CCRN", Label("app", "ListCcrn"), func() { It("should return Internal error", func() { // Mock database error dbError := errors.New("database connection failed") - db.On("GetCcrn", filter).Return([]string{}, dbError) + db.On("GetCcrn", mock.Anything, filter).Return([]string{}, dbError) componentInstanceHandler = ci.NewComponentInstanceHandler(handlerContext) componentInstanceHandler = ci.NewComponentInstanceHandler(handlerContext) - result, err := componentInstanceHandler.ListCcrns(filter, options) + result, err := componentInstanceHandler.ListCcrns(context.Background(), filter, options) Expect(result).To(BeNil(), "no result should be returned") Expect(err).ToNot(BeNil(), "error should be returned") diff --git a/internal/app/component_version/component_version_handler.go b/internal/app/component_version/component_version_handler.go index 26df6c17a..fcfcf6e01 100644 --- a/internal/app/component_version/component_version_handler.go +++ b/internal/app/component_version/component_version_handler.go @@ -105,7 +105,7 @@ func (cv *componentVersionHandler) ListComponentVersions( cv.cache, CacheTtlGetComponentVersions, "GetComponentVersions", - cv.database.GetComponentVersions, + cache.WrapContext2(ctx, cv.database.GetComponentVersions), filter, options.Order, ) @@ -124,7 +124,7 @@ func (cv *componentVersionHandler) ListComponentVersions( cv.cache, CacheTtlGetAllComponentVersionCursors, "GetAllComponentVersionCursors", - cv.database.GetAllComponentVersionCursors, + cache.WrapContext2(ctx, cv.database.GetAllComponentVersionCursors), filter, options.Order, ) @@ -145,7 +145,7 @@ func (cv *componentVersionHandler) ListComponentVersions( cv.cache, CacheTtlCountComponentVersions, "CountComponentVersions", - cv.database.CountComponentVersions, + cache.WrapContext1(ctx, cv.database.CountComponentVersions), filter, ) if err != nil { diff --git a/internal/app/component_version/component_version_handler_test.go b/internal/app/component_version/component_version_handler_test.go index f123d49ad..24f4506f2 100644 --- a/internal/app/component_version/component_version_handler_test.go +++ b/internal/app/component_version/component_version_handler_test.go @@ -79,10 +79,10 @@ var _ = Describe("When listing ComponentVersions", Label("app", "ListComponentVe When("the list option does include the totalCount", func() { BeforeEach(func() { options.ShowTotalCount = true - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetComponentVersions", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetComponentVersions", mock.Anything, filter, []entity.Order{}). Return([]entity.ComponentVersionResult{}, nil) - db.On("CountComponentVersions", filter).Return(int64(1337), nil) + db.On("CountComponentVersions", mock.Anything, filter).Return(int64(1337), nil) }) It("shows the total count in the results", func() { @@ -146,10 +146,10 @@ var _ = Describe("When listing ComponentVersions", Label("app", "ListComponentVe ) cursors = append(cursors, c) } - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetComponentVersions", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetComponentVersions", mock.Anything, filter, []entity.Order{}). Return(componentVersions, nil) - db.On("GetAllComponentVersionCursors", filter, []entity.Order{}). + db.On("GetAllComponentVersionCursors", mock.Anything, filter, []entity.Order{}). Return(cursors, nil) cvHandler = cv.NewComponentVersionHandler(handlerContext) res, err := cvHandler.ListComponentVersions(ctx, filter, options) @@ -193,11 +193,11 @@ var _ = Describe("When listing ComponentVersions", Label("app", "ListComponentVe tagFilter.Tag = []*string{&testTag} // Mock database calls - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetComponentVersions", tagFilter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetComponentVersions", mock.Anything, tagFilter, []entity.Order{}). Return(componentVersions, nil) if options.ShowTotalCount { - db.On("CountComponentVersions", tagFilter). + db.On("CountComponentVersions", mock.Anything, tagFilter). Return(int64(len(componentVersions)), nil) } @@ -229,11 +229,11 @@ var _ = Describe("When listing ComponentVersions", Label("app", "ListComponentVe repoFilter.Repository = []*string{&testRepo} // Mock database calls - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetComponentVersions", repoFilter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetComponentVersions", mock.Anything, repoFilter, []entity.Order{}). Return(componentVersions, nil) if options.ShowTotalCount { - db.On("CountComponentVersions", repoFilter). + db.On("CountComponentVersions", mock.Anything, repoFilter). Return(int64(len(componentVersions)), nil) } @@ -265,11 +265,11 @@ var _ = Describe("When listing ComponentVersions", Label("app", "ListComponentVe orgFilter.Organization = []*string{&testOrg} // Mock database calls - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetComponentVersions", orgFilter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetComponentVersions", mock.Anything, orgFilter, []entity.Order{}). Return(componentVersions, nil) if options.ShowTotalCount { - db.On("CountComponentVersions", orgFilter). + db.On("CountComponentVersions", mock.Anything, orgFilter). Return(int64(len(componentVersions)), nil) } @@ -307,8 +307,8 @@ var _ = Describe("When listing ComponentVersions", Label("app", "ListComponentVe BeforeEach(func() { compIds := int64(-1) filter.ComponentId = []*int64{&compIds} - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetComponentVersions", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetComponentVersions", mock.Anything, filter, []entity.Order{}). Return([]entity.ComponentVersionResult{}, nil) }) @@ -331,8 +331,8 @@ var _ = Describe("When listing ComponentVersions", Label("app", "ListComponentVe systemUserId := int64(1) filter.ComponentId = []*int64{&compId} componentVersion = test.NewFakeComponentVersionEntity() - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetComponentVersions", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetComponentVersions", mock.Anything, filter, []entity.Order{}). Return([]entity.ComponentVersionResult{{ComponentVersion: &componentVersion}}, nil) relations := []openfga.RelationInput{ @@ -398,7 +398,7 @@ var _ = Describe("When creating ComponentVersion", Label("app", "CreateComponent }) It("creates componentVersion", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("CreateComponentVersion", &componentVersion).Return(&componentVersion, nil) componenVersionService = cv.NewComponentVersionHandler(handlerContext) newComponentVersion, err := componenVersionService.CreateComponentVersion( @@ -489,13 +489,13 @@ var _ = Describe("When updating ComponentVersion", Label("app", "UpdateComponent }) It("updates componentVersion", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("UpdateComponentVersion", componentVersion.ComponentVersion).Return(nil) componenVersionService = cv.NewComponentVersionHandler(handlerContext) componentVersion.Version = "7.3.3.1" componentVersion.Tag = "updated-tag" filter.Id = []*int64{&componentVersion.Id} - db.On("GetComponentVersions", filter, []entity.Order{}). + db.On("GetComponentVersions", mock.Anything, filter, []entity.Order{}). Return([]entity.ComponentVersionResult{componentVersion}, nil) updatedComponentVersion, err := componenVersionService.UpdateComponentVersion( common.NewAdminContext(), @@ -605,10 +605,10 @@ var _ = Describe("When deleting ComponentVersion", Label("app", "DeleteComponent }) It("deletes componentVersion", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("DeleteComponentVersion", id, mock.Anything).Return(nil) componenVersionService = cv.NewComponentVersionHandler(handlerContext) - db.On("GetComponentVersions", filter, []entity.Order{}). + db.On("GetComponentVersions", mock.Anything, filter, []entity.Order{}). Return([]entity.ComponentVersionResult{}, nil) err := componenVersionService.DeleteComponentVersion(common.NewAdminContext(), id) Expect(err).To(BeNil(), "no error should be thrown") diff --git a/internal/app/heureka.go b/internal/app/heureka.go index 38d2b7c06..27dab8cdd 100644 --- a/internal/app/heureka.go +++ b/internal/app/heureka.go @@ -89,7 +89,7 @@ func NewHeurekaApp( remediationHandler := remediation.NewRemediationHandler(common.HandlerContext{ DB: db, EventReg: er, - Cache: nil, + Cache: cache, Authz: authz, }) @@ -163,7 +163,10 @@ func (h *HeurekaApp) SubscribeHandlers() { component_instance.CreateComponentInstanceEventName, event.EventHandlerFunc(issue_match.OnComponentInstanceCreate), }, - {service.CreateServiceEventName, event.EventHandlerFunc(service.OnServiceCreate)}, + { + service.CreateServiceEventName, + event.EventHandlerFunc(service.OnServiceCreate), + }, { issue_repository.CreateIssueRepositoryEventName, event.EventHandlerFunc(issue_repository.OnIssueRepositoryCreate), @@ -208,8 +211,14 @@ func (h *HeurekaApp) SubscribeAuthzHandlers() { event.EventHandlerFunc(issue_match.OnIssueMatchCreateAuthz), }, // Delete events - {user.DeleteUserEventName, event.EventHandlerFunc(user.OnUserDeleteAuthz)}, - {service.DeleteServiceEventName, event.EventHandlerFunc(service.OnServiceDeleteAuthz)}, + { + user.DeleteUserEventName, + event.EventHandlerFunc(user.OnUserDeleteAuthz), + }, + { + service.DeleteServiceEventName, + event.EventHandlerFunc(service.OnServiceDeleteAuthz), + }, { component_instance.DeleteComponentInstanceEventName, event.EventHandlerFunc(component_instance.OnComponentInstanceDeleteAuthz), diff --git a/internal/app/issue/issue_handler.go b/internal/app/issue/issue_handler.go index 2349ffb1f..ec40878cf 100644 --- a/internal/app/issue/issue_handler.go +++ b/internal/app/issue/issue_handler.go @@ -56,7 +56,7 @@ func ensureIssueListOptions(options *entity.IssueListOptions) *entity.IssueListO return options } -func (is *issueHandler) GetIssue(id int64) (*entity.Issue, error) { +func (is *issueHandler) GetIssue(ctx context.Context, id int64) (*entity.Issue, error) { op := appErrors.Op("issueHandler.GetIssue") // Input validation @@ -77,7 +77,7 @@ func (is *issueHandler) GetIssue(id int64) (*entity.Issue, error) { ListOptions: *entity.NewListOptions(), } - issues, err := is.ListIssues(&entity.IssueFilter{Id: []*int64{&id}}, &lo) + issues, err := is.ListIssues(ctx, &entity.IssueFilter{Id: []*int64{&id}}, &lo) if err != nil { // Wrap the error from ListIssues with operation context wrappedErr := appErrors.E(op, "Issue", strconv.FormatInt(id, 10), appErrors.Internal, err) @@ -115,6 +115,7 @@ func (is *issueHandler) GetIssue(id int64) (*entity.Issue, error) { } func (is *issueHandler) ListIssues( + ctx context.Context, filter *entity.IssueFilter, options *entity.IssueListOptions, ) (*entity.IssueList, error) { @@ -139,7 +140,7 @@ func (is *issueHandler) ListIssues( is.cache, CacheTtlGetIssuesWithAggregations, "GetIssuesWithAggregations", - is.database.GetIssuesWithAggregations, + cache.WrapContext2(ctx, is.database.GetIssuesWithAggregations), filter, options.Order, ) @@ -157,7 +158,7 @@ func (is *issueHandler) ListIssues( is.cache, CacheTtlGetIssues, "GetIssues", - is.database.GetIssues, + cache.WrapContext2(ctx, is.database.GetIssues), filter, options.Order, ) @@ -180,7 +181,7 @@ func (is *issueHandler) ListIssues( is.cache, CacheTtlGetAllIssueCursors, "GetAllIssueCursors", - is.database.GetAllIssueCursors, + cache.WrapContext2(ctx, is.database.GetAllIssueCursors), filter, options.Order, ) @@ -203,7 +204,7 @@ func (is *issueHandler) ListIssues( is.cache, CacheTtlCountIssueTypes, "CountIssueTypes", - is.database.CountIssueTypes, + cache.WrapContext1(ctx, is.database.CountIssueTypes), filter, ) if err != nil { @@ -260,7 +261,7 @@ func (is *issueHandler) CreateIssue( ListOptions: *entity.NewListOptions(), } - issues, err := is.ListIssues(f, &lo) + issues, err := is.ListIssues(ctx, f, &lo) if err != nil { wrappedErr := appErrors.InternalError(string(op), "Issue", "", err) applog.LogError(is.logger, wrappedErr, logrus.Fields{ @@ -341,7 +342,7 @@ func (is *issueHandler) UpdateIssue( ListOptions: *entity.NewListOptions(), } - issueResult, err := is.ListIssues(&entity.IssueFilter{Id: []*int64{&issue.Id}}, &lo) + issueResult, err := is.ListIssues(ctx, &entity.IssueFilter{Id: []*int64{&issue.Id}}, &lo) if err != nil { wrappedErr := appErrors.E( op, @@ -411,6 +412,7 @@ func (is *issueHandler) DeleteIssue(ctx context.Context, id int64) error { } func (is *issueHandler) AddComponentVersionToIssue( + ctx context.Context, issueId, componentVersionId int64, ) (*entity.Issue, error) { op := appErrors.Op("issueHandler.AddComponentVersionToIssue") @@ -444,7 +446,7 @@ func (is *issueHandler) AddComponentVersionToIssue( ComponentVersionID: componentVersionId, }) - issue, err := is.GetIssue(issueId) + issue, err := is.GetIssue(ctx, issueId) if err != nil { wrappedErr := appErrors.E( op, @@ -465,6 +467,7 @@ func (is *issueHandler) AddComponentVersionToIssue( } func (is *issueHandler) RemoveComponentVersionFromIssue( + ctx context.Context, issueId, componentVersionId int64, ) (*entity.Issue, error) { op := appErrors.Op("issueHandler.RemoveComponentVersionFromIssue") @@ -486,7 +489,7 @@ func (is *issueHandler) RemoveComponentVersionFromIssue( ComponentVersionID: componentVersionId, }) - issue, err := is.GetIssue(issueId) + issue, err := is.GetIssue(ctx, issueId) if err != nil { wrappedErr := appErrors.E( op, @@ -507,6 +510,7 @@ func (is *issueHandler) RemoveComponentVersionFromIssue( } func (is *issueHandler) ListIssueNames( + ctx context.Context, filter *entity.IssueFilter, options *entity.ListOptions, ) ([]string, error) { @@ -516,7 +520,7 @@ func (is *issueHandler) ListIssueNames( is.cache, CacheTtlGetIssueNames, "GetIssueNames", - is.database.GetIssueNames, + cache.WrapContext1(ctx, is.database.GetIssueNames), filter, ) if err != nil { @@ -538,6 +542,7 @@ func (is *issueHandler) ListIssueNames( } func (is *issueHandler) GetIssueSeverityCounts( + ctx context.Context, filter *entity.IssueFilter, ) (*entity.IssueSeverityCounts, error) { op := appErrors.Op("issueHandler.GetIssueSeverityCounts") @@ -546,7 +551,7 @@ func (is *issueHandler) GetIssueSeverityCounts( is.cache, CacheTtlCountIssueRatings, "CountIssueRatings", - is.database.CountIssueRatings, + cache.WrapContext1(ctx, is.database.CountIssueRatings), filter, ) if err != nil { diff --git a/internal/app/issue/issue_handler_events.go b/internal/app/issue/issue_handler_events.go index dd5f53fc1..86c426da1 100644 --- a/internal/app/issue/issue_handler_events.go +++ b/internal/app/issue/issue_handler_events.go @@ -120,12 +120,14 @@ func OnComponentVersionAttachmentToIssue( "payload": e, }) + ctx := context.Background() + if attachmentEvent, ok := e.(*AddComponentVersionToIssueEvent); ok { // Get ComponentInstances l.WithField("event-step", "GetComponentInstances"). Debug("Get Component Instances by ComponentVersionId") - componentInstances, err := db.GetComponentInstances(&entity.ComponentInstanceFilter{ + componentInstances, err := db.GetComponentInstances(ctx, &entity.ComponentInstanceFilter{ ComponentVersionId: []*int64{&attachmentEvent.ComponentVersionID}, }, []entity.Order{}) if err != nil { @@ -141,6 +143,7 @@ func OnComponentVersionAttachmentToIssue( for _, compInst := range componentInstances { // Get Service Issue Variants issueVariantMap, err := shared.BuildIssueVariantMap( + ctx, db, &entity.ServiceIssueVariantFilter{ ComponentInstanceId: []*int64{&compInst.Id}, @@ -183,7 +186,7 @@ func createIssueMatches( l.WithField("event-step", "GetIssueMatches"). Debug("Fetching issue matches related to assigned Component Instance") - issue_matches, err := db.GetIssueMatches(&entity.IssueMatchFilter{ + issue_matches, err := db.GetIssueMatches(ctx, &entity.IssueMatchFilter{ IssueId: []*int64{&issueId}, ComponentInstanceId: []*int64{&componentInstanceId}, }, []entity.Order{}) diff --git a/internal/app/issue/issue_handler_events_test.go b/internal/app/issue/issue_handler_events_test.go index 9f1922756..d9b2bc699 100644 --- a/internal/app/issue/issue_handler_events_test.go +++ b/internal/app/issue/issue_handler_events_test.go @@ -83,19 +83,19 @@ var _ = Describe( componentInstance.ComponentVersionId = componentVersion.Id // // Setup mock expectations for happy path - db.On("GetComponentInstances", &entity.ComponentInstanceFilter{ + db.On("GetComponentInstances", mock.Anything, &entity.ComponentInstanceFilter{ ComponentVersionId: []*int64{&componentVersion.Id}, }, []entity.Order{}).Return([]entity.ComponentInstanceResult{componentInstance}, nil) }) It("creates an issue match for the component instance", func() { // Setup expectation for existing match check - db.On("GetIssueMatches", &entity.IssueMatchFilter{ + db.On("GetIssueMatches", mock.Anything, &entity.IssueMatchFilter{ ComponentInstanceId: []*int64{&componentInstance.Id}, IssueId: []*int64{&issueEntity.Id}, }, []entity.Order{}).Return([]entity.IssueMatchResult{}, nil) - db.On("GetServiceIssueVariants", &entity.ServiceIssueVariantFilter{ + db.On("GetServiceIssueVariants", mock.Anything, &entity.ServiceIssueVariantFilter{ ComponentInstanceId: []*int64{&componentInstance.Id}, IssueId: []*int64{&issueEntity.Id}, }, mock.Anything).Return([]entity.ServiceIssueVariantResult{{ @@ -109,7 +109,7 @@ var _ = Describe( ComponentInstanceId: componentInstance.Id, IssueId: issueEntity.Id, } - db.On("GetAllUserIds", mock.Anything).Return([]int64{1}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{1}, nil) db.On("CreateIssueMatch", matchIssueMatch(expectedMatch)).Return(expectedMatch, nil) // Emit event @@ -121,7 +121,7 @@ var _ = Describe( It("skips creation if match already exists", func() { existingMatch := test.NewFakeIssueMatchResult() - db.On("GetServiceIssueVariants", &entity.ServiceIssueVariantFilter{ + db.On("GetServiceIssueVariants", mock.Anything, &entity.ServiceIssueVariantFilter{ ComponentInstanceId: []*int64{&componentInstance.Id}, IssueId: []*int64{&issueEntity.Id}, }, mock.Anything).Return([]entity.ServiceIssueVariantResult{{ @@ -129,7 +129,7 @@ var _ = Describe( }}, nil) // Setup expectation to return existing match - db.On("GetIssueMatches", &entity.IssueMatchFilter{ + db.On("GetIssueMatches", mock.Anything, &entity.IssueMatchFilter{ ComponentInstanceId: []*int64{&componentInstance.Id}, IssueId: []*int64{&issueEntity.Id}, }, []entity.Order{}).Return([]entity.IssueMatchResult{existingMatch}, nil) diff --git a/internal/app/issue/issue_handler_interface.go b/internal/app/issue/issue_handler_interface.go index 54ffb39df..a30b107da 100644 --- a/internal/app/issue/issue_handler_interface.go +++ b/internal/app/issue/issue_handler_interface.go @@ -10,13 +10,13 @@ import ( ) type IssueHandler interface { - ListIssues(*entity.IssueFilter, *entity.IssueListOptions) (*entity.IssueList, error) - GetIssue(int64) (*entity.Issue, error) + ListIssues(context.Context, *entity.IssueFilter, *entity.IssueListOptions) (*entity.IssueList, error) + GetIssue(context.Context, int64) (*entity.Issue, error) CreateIssue(context.Context, *entity.Issue) (*entity.Issue, error) UpdateIssue(context.Context, *entity.Issue) (*entity.Issue, error) DeleteIssue(context.Context, int64) error - AddComponentVersionToIssue(int64, int64) (*entity.Issue, error) - RemoveComponentVersionFromIssue(int64, int64) (*entity.Issue, error) - ListIssueNames(*entity.IssueFilter, *entity.ListOptions) ([]string, error) - GetIssueSeverityCounts(*entity.IssueFilter) (*entity.IssueSeverityCounts, error) + AddComponentVersionToIssue(context.Context, int64, int64) (*entity.Issue, error) + RemoveComponentVersionFromIssue(context.Context, int64, int64) (*entity.Issue, error) + ListIssueNames(context.Context, *entity.IssueFilter, *entity.ListOptions) ([]string, error) + GetIssueSeverityCounts(context.Context, *entity.IssueFilter) (*entity.IssueSeverityCounts, error) } diff --git a/internal/app/issue/issue_handler_test.go b/internal/app/issue/issue_handler_test.go index 798648d3f..e9b6b9957 100644 --- a/internal/app/issue/issue_handler_test.go +++ b/internal/app/issue/issue_handler_test.go @@ -3,6 +3,7 @@ package issue_test import ( + "context" "errors" "math" "strconv" @@ -64,11 +65,11 @@ var _ = Describe("When getting a single Issue", Label("app", "GetIssue", "errors expectedResult := []entity.IssueResult{{ Issue: &issueEntity, }} - db.On("GetIssues", mock.MatchedBy(func(filter *entity.IssueFilter) bool { + db.On("GetIssues", mock.Anything, mock.MatchedBy(func(filter *entity.IssueFilter) bool { return len(filter.Id) == 1 && *filter.Id[0] == issueEntity.Id }), []entity.Order{}).Return(expectedResult, nil) - result, err := issueHandler.GetIssue(issueEntity.Id) + result, err := issueHandler.GetIssue(context.Background(), issueEntity.Id) Expect(err).To(BeNil(), "no error should be thrown") Expect(result).ToNot(BeNil(), "issue should be returned") @@ -79,7 +80,7 @@ var _ = Describe("When getting a single Issue", Label("app", "GetIssue", "errors Context("with invalid input", func() { It("should return InvalidArgument error for negative ID", func() { - result, err := issueHandler.GetIssue(-1) + result, err := issueHandler.GetIssue(context.Background(), -1) Expect(result).To(BeNil(), "no result should be returned") Expect(err).ToNot(BeNil(), "error should be returned") @@ -95,7 +96,7 @@ var _ = Describe("When getting a single Issue", Label("app", "GetIssue", "errors }) It("should return InvalidArgument error for zero ID", func() { - result, err := issueHandler.GetIssue(0) + result, err := issueHandler.GetIssue(context.Background(), 0) Expect(result).To(BeNil(), "no result should be returned") Expect(err).ToNot(BeNil(), "error should be returned") @@ -111,9 +112,9 @@ var _ = Describe("When getting a single Issue", Label("app", "GetIssue", "errors Context("when issue does not exist", func() { It("should return NotFound error", func() { // Setup mock to return empty result (no issues found) - db.On("GetIssues", mock.Anything, []entity.Order{}).Return([]entity.IssueResult{}, nil) + db.On("GetIssues", mock.Anything, mock.Anything, []entity.Order{}).Return([]entity.IssueResult{}, nil) - result, err := issueHandler.GetIssue(999) + result, err := issueHandler.GetIssue(context.Background(), 999) Expect(result).To(BeNil(), "no result should be returned") Expect(err).ToNot(BeNil(), "error should be returned") @@ -130,10 +131,10 @@ var _ = Describe("When getting a single Issue", Label("app", "GetIssue", "errors It("should return Internal error wrapping the database error", func() { // Setup mock to return database error dbError := errors.New("database connection failed") - db.On("GetIssues", mock.Anything, []entity.Order{}). + db.On("GetIssues", mock.Anything, mock.Anything, []entity.Order{}). Return([]entity.IssueResult{}, dbError) - result, err := issueHandler.GetIssue(123) + result, err := issueHandler.GetIssue(context.Background(), 123) Expect(result).To(BeNil(), "no result should be returned") Expect(err).ToNot(BeNil(), "error should be returned") @@ -198,13 +199,13 @@ var _ = Describe("When listing Issues", Label("app", "ListIssues"), func() { When("the list option does include the totalCount", func() { BeforeEach(func() { options.ShowTotalCount = true - db.On("GetIssues", filter, []entity.Order{}).Return([]entity.IssueResult{}, nil) - db.On("CountIssueTypes", filter).Return(issueTypeCounts, nil) + db.On("GetIssues", mock.Anything, filter, []entity.Order{}).Return([]entity.IssueResult{}, nil) + db.On("CountIssueTypes", mock.Anything, filter).Return(issueTypeCounts, nil) }) It("shows the total count in the results", func() { issueHandler = issue.NewIssueHandler(handlerContext) - res, err := issueHandler.ListIssues(filter, options) + res, err := issueHandler.ListIssues(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(*res.TotalCount).Should(BeEquivalentTo(int64(1337)), "return correct Totalcount") }) @@ -244,11 +245,11 @@ var _ = Describe("When listing Issues", Label("app", "ListIssues"), func() { c, _ := mariadb.EncodeCursor(mariadb.WithIssue([]entity.Order{}, issue, 0)) cursors = append(cursors, c) } - db.On("GetIssues", filter, []entity.Order{}).Return(issues, nil) - db.On("GetAllIssueCursors", filter, []entity.Order{}).Return(cursors, nil) - db.On("CountIssueTypes", filter).Return(issueTypeCounts, nil) + db.On("GetIssues", mock.Anything, filter, []entity.Order{}).Return(issues, nil) + db.On("GetAllIssueCursors", mock.Anything, filter, []entity.Order{}).Return(cursors, nil) + db.On("CountIssueTypes", mock.Anything, filter).Return(issueTypeCounts, nil) issueHandler = issue.NewIssueHandler(handlerContext) - res, err := issueHandler.ListIssues(filter, options) + res, err := issueHandler.ListIssues(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect( *res.PageInfo.HasNextPage, @@ -276,38 +277,38 @@ var _ = Describe("When listing Issues", Label("app", "ListIssues"), func() { }) Context("and the given filter does not have any matches in the database", func() { BeforeEach(func() { - db.On("GetIssuesWithAggregations", filter, []entity.Order{}). + db.On("GetIssuesWithAggregations", mock.Anything, filter, []entity.Order{}). Return([]entity.IssueResult{}, nil) }) It("should return an empty result", func() { issueHandler = issue.NewIssueHandler(handlerContext) - res, err := issueHandler.ListIssues(filter, options) + res, err := issueHandler.ListIssues(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(len(res.Elements)).Should(BeEquivalentTo(0), "return no results") }) }) Context("and the filter does have results in the database", func() { BeforeEach(func() { - db.On("GetIssuesWithAggregations", filter, []entity.Order{}). + db.On("GetIssuesWithAggregations", mock.Anything, filter, []entity.Order{}). Return(test.NNewFakeIssueResultsWithAggregations(10), nil) }) It("should return the expected issues in the result", func() { issueHandler = issue.NewIssueHandler(handlerContext) - res, err := issueHandler.ListIssues(filter, options) + res, err := issueHandler.ListIssues(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(len(res.Elements)).Should(BeEquivalentTo(10), "return 10 results") }) }) Context("and the database operations throw an error", func() { BeforeEach(func() { - db.On("GetIssuesWithAggregations", filter, []entity.Order{}). + db.On("GetIssuesWithAggregations", mock.Anything, filter, []entity.Order{}). Return([]entity.IssueResult{}, errors.New("database error")) }) It("should return the expected issues in the result", func() { issueHandler = issue.NewIssueHandler(handlerContext) - _, err := issueHandler.ListIssues(filter, options) + _, err := issueHandler.ListIssues(context.Background(), filter, options) Expect(err).ToNot(BeNil(), "error should be returned") @@ -328,23 +329,23 @@ var _ = Describe("When listing Issues", Label("app", "ListIssues"), func() { Context("and the given filter does not have any matches in the database", func() { BeforeEach(func() { - db.On("GetIssues", filter, []entity.Order{}).Return([]entity.IssueResult{}, nil) + db.On("GetIssues", mock.Anything, filter, []entity.Order{}).Return([]entity.IssueResult{}, nil) }) It("should return an empty result", func() { issueHandler = issue.NewIssueHandler(handlerContext) - res, err := issueHandler.ListIssues(filter, options) + res, err := issueHandler.ListIssues(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(len(res.Elements)).Should(BeEquivalentTo(0), "return no results") }) }) Context("and the filter does have results in the database", func() { BeforeEach(func() { - db.On("GetIssues", filter, []entity.Order{}). + db.On("GetIssues", mock.Anything, filter, []entity.Order{}). Return(test.NNewFakeIssueResults(15), nil) }) It("should return the expected issues in the result", func() { issueHandler = issue.NewIssueHandler(handlerContext) - res, err := issueHandler.ListIssues(filter, options) + res, err := issueHandler.ListIssues(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(len(res.Elements)).Should(BeEquivalentTo(15), "return 15 results") }) @@ -352,13 +353,13 @@ var _ = Describe("When listing Issues", Label("app", "ListIssues"), func() { Context("and the database operations throw an error", func() { BeforeEach(func() { - db.On("GetIssues", filter, []entity.Order{}). + db.On("GetIssues", mock.Anything, filter, []entity.Order{}). Return([]entity.IssueResult{}, errors.New("database error")) }) It("should return the expected issues in the result", func() { issueHandler = issue.NewIssueHandler(handlerContext) - _, err := issueHandler.ListIssues(filter, options) + _, err := issueHandler.ListIssues(context.Background(), filter, options) Expect(err).ToNot(BeNil(), "error should be returned") @@ -380,12 +381,12 @@ var _ = Describe("When listing Issues", Label("app", "ListIssues"), func() { It("should return Internal error", func() { // Mock successful GetIssues but failing GetAllIssueCursors - db.On("GetIssues", filter, []entity.Order{}).Return(test.NNewFakeIssueResults(5), nil) + db.On("GetIssues", mock.Anything, filter, []entity.Order{}).Return(test.NNewFakeIssueResults(5), nil) cursorsError := errors.New("cursor database error") - db.On("GetAllIssueCursors", filter, []entity.Order{}).Return([]string{}, cursorsError) + db.On("GetAllIssueCursors", mock.Anything, filter, []entity.Order{}).Return([]string{}, cursorsError) issueHandler = issue.NewIssueHandler(handlerContext) - _, err := issueHandler.ListIssues(filter, options) + _, err := issueHandler.ListIssues(context.Background(), filter, options) Expect(err).ToNot(BeNil(), "error should be returned") @@ -405,12 +406,12 @@ var _ = Describe("When listing Issues", Label("app", "ListIssues"), func() { It("should return Internal error", func() { // Mock successful GetIssues but failing CountIssueTypes - db.On("GetIssues", filter, []entity.Order{}).Return([]entity.IssueResult{}, nil) + db.On("GetIssues", mock.Anything, filter, []entity.Order{}).Return([]entity.IssueResult{}, nil) countError := errors.New("count database error") - db.On("CountIssueTypes", filter).Return((*entity.IssueTypeCounts)(nil), countError) + db.On("CountIssueTypes", mock.Anything, filter).Return((*entity.IssueTypeCounts)(nil), countError) issueHandler = issue.NewIssueHandler(handlerContext) - _, err := issueHandler.ListIssues(filter, options) + _, err := issueHandler.ListIssues(context.Background(), filter, options) Expect(err).ToNot(BeNil(), "error should be returned") @@ -449,10 +450,10 @@ var _ = Describe("When listing Issue Names", Label("app", "ListIssueNames"), fun Context("with valid input", func() { It("returns issue names successfully", func() { expectedNames := []string{"CVE-2023-1234", "CVE-2023-5678", "POLICY-001"} - db.On("GetIssueNames", filter).Return(expectedNames, nil) + db.On("GetIssueNames", mock.Anything, filter).Return(expectedNames, nil) issueHandler = issue.NewIssueHandler(handlerContext) - result, err := issueHandler.ListIssueNames(filter, options) + result, err := issueHandler.ListIssueNames(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(result).ToNot(BeNil(), "result should be returned") @@ -462,10 +463,10 @@ var _ = Describe("When listing Issue Names", Label("app", "ListIssueNames"), fun It("returns empty list when no issues found", func() { expectedNames := []string{} - db.On("GetIssueNames", filter).Return(expectedNames, nil) + db.On("GetIssueNames", mock.Anything, filter).Return(expectedNames, nil) issueHandler = issue.NewIssueHandler(handlerContext) - result, err := issueHandler.ListIssueNames(filter, options) + result, err := issueHandler.ListIssueNames(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(result).ToNot(BeNil(), "result should be returned") @@ -477,10 +478,10 @@ var _ = Describe("When listing Issue Names", Label("app", "ListIssueNames"), fun It("should return Internal error", func() { // Mock database error dbError := errors.New("database connection failed") - db.On("GetIssueNames", filter).Return([]string{}, dbError) + db.On("GetIssueNames", mock.Anything, filter).Return([]string{}, dbError) issueHandler = issue.NewIssueHandler(handlerContext) - result, err := issueHandler.ListIssueNames(filter, options) + result, err := issueHandler.ListIssueNames(context.Background(), filter, options) Expect(result).To(BeNil(), "no result should be returned") Expect(err).ToNot(BeNil(), "error should be returned") @@ -529,9 +530,9 @@ var _ = Describe("When creating Issue", Label("app", "CreateIssue"), func() { It("creates issue successfully", func() { filter.PrimaryName = []*string{&issueEntity.PrimaryName} // Mock successful user ID retrieval - db.On("GetAllUserIds", mock.Anything).Return([]int64{123}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{123}, nil) // Mock no existing issues with same primary name - db.On("GetIssues", filter, []entity.Order{}).Return([]entity.IssueResult{}, nil) + db.On("GetIssues", mock.Anything, filter, []entity.Order{}).Return([]entity.IssueResult{}, nil) // Mock successful database creation db.On("CreateIssue", mock.AnythingOfType("*entity.Issue")).Return(&issueEntity, nil) @@ -555,7 +556,7 @@ var _ = Describe("When creating Issue", Label("app", "CreateIssue"), func() { It("should return Internal error", func() { // Mock GetCurrentUserId failure dbError := errors.New("user database connection failed") - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, dbError) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, dbError) issueHandler = issue.NewIssueHandler(handlerContext) result, err := issueHandler.CreateIssue(common.NewAdminContext(), &issueEntity) @@ -583,10 +584,10 @@ var _ = Describe("When creating Issue", Label("app", "CreateIssue"), func() { Context("when checking for existing issues fails", func() { It("should return Internal error", func() { // Mock successful user ID retrieval - db.On("GetAllUserIds", mock.Anything).Return([]int64{123}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{123}, nil) // Mock ListIssues failure listError := errors.New("database query failed") - db.On("GetIssues", mock.Anything, []entity.Order{}). + db.On("GetIssues", mock.Anything, mock.Anything, []entity.Order{}). Return([]entity.IssueResult{}, listError) issueHandler = issue.NewIssueHandler(handlerContext) @@ -609,12 +610,12 @@ var _ = Describe("When creating Issue", Label("app", "CreateIssue"), func() { Context("when issue with same primary name already exists", func() { It("should return AlreadyExists error", func() { // Mock successful user ID retrieval - db.On("GetAllUserIds", mock.Anything).Return([]int64{123}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{123}, nil) // Mock existing issue with same primary name existingIssue := test.NewFakeIssueEntity() existingIssue.Id = 999 existingIssue.PrimaryName = issueEntity.PrimaryName - db.On("GetIssues", mock.Anything, []entity.Order{}).Return([]entity.IssueResult{{ + db.On("GetIssues", mock.Anything, mock.Anything, []entity.Order{}).Return([]entity.IssueResult{{ Issue: &existingIssue, }}, nil) @@ -641,9 +642,9 @@ var _ = Describe("When creating Issue", Label("app", "CreateIssue"), func() { Context("when database creation fails", func() { It("should return Internal error", func() { // Mock successful user ID retrieval - db.On("GetAllUserIds", mock.Anything).Return([]int64{123}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{123}, nil) // Mock no existing issues - db.On("GetIssues", mock.Anything, []entity.Order{}).Return([]entity.IssueResult{}, nil) + db.On("GetIssues", mock.Anything, mock.Anything, []entity.Order{}).Return([]entity.IssueResult{}, nil) // Mock database creation failure dbError := errors.New("constraint violation") db.On("CreateIssue", mock.AnythingOfType("*entity.Issue")). @@ -697,10 +698,10 @@ var _ = Describe("When updating Issue", Label("app", "UpdateIssue"), func() { Context("with valid input", func() { It("updates issueEntity successfully", func() { // Setup mocks for successful path - db.On("GetAllUserIds", mock.Anything).Return([]int64{123}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{123}, nil) db.On("UpdateIssue", issueResult.Issue).Return(nil) filter.Id = []*int64{&issueResult.Issue.Id} - db.On("GetIssues", filter, []entity.Order{}). + db.On("GetIssues", mock.Anything, filter, []entity.Order{}). Return([]entity.IssueResult{issueResult}, nil) issueHandler = issue.NewIssueHandler(handlerContext) @@ -721,7 +722,7 @@ var _ = Describe("When updating Issue", Label("app", "UpdateIssue"), func() { It("should return Internal error", func() { // Mock GetCurrentUserId failure dbError := errors.New("user database connection failed") - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, dbError) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, dbError) issueHandler = issue.NewIssueHandler(handlerContext) result, err := issueHandler.UpdateIssue(common.NewAdminContext(), issueResult.Issue) @@ -744,7 +745,7 @@ var _ = Describe("When updating Issue", Label("app", "UpdateIssue"), func() { Context("when database update fails", func() { It("should return Internal error", func() { // Mock successful user ID retrieval - db.On("GetAllUserIds", mock.Anything).Return([]int64{123}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{123}, nil) // Mock database update failure dbError := errors.New("constraint violation") db.On("UpdateIssue", issueResult.Issue).Return(dbError) @@ -770,11 +771,11 @@ var _ = Describe("When updating Issue", Label("app", "UpdateIssue"), func() { Context("when retrieving updated issue fails", func() { It("should return Internal error", func() { // Mock successful user ID and update - db.On("GetAllUserIds", mock.Anything).Return([]int64{123}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{123}, nil) db.On("UpdateIssue", issueResult.Issue).Return(nil) // Mock ListIssues failure listError := errors.New("database query failed") - db.On("GetIssues", mock.Anything, []entity.Order{}). + db.On("GetIssues", mock.Anything, mock.Anything, []entity.Order{}). Return([]entity.IssueResult{}, listError) issueHandler = issue.NewIssueHandler(handlerContext) @@ -795,12 +796,12 @@ var _ = Describe("When updating Issue", Label("app", "UpdateIssue"), func() { }) It("updates issueEntity", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("UpdateIssue", issueResult.Issue).Return(nil) issueHandler = issue.NewIssueHandler(handlerContext) issueResult.Issue.Description = "New Description" filter.Id = []*int64{&issueResult.Issue.Id} - db.On("GetIssues", filter, []entity.Order{}).Return([]entity.IssueResult{issueResult}, nil) + db.On("GetIssues", mock.Anything, filter, []entity.Order{}).Return([]entity.IssueResult{issueResult}, nil) updatedIssue, err := issueHandler.UpdateIssue(common.NewAdminContext(), issueResult.Issue) Expect(err).To(BeNil(), "no error should be thrown") By("setting fields", func() { @@ -839,9 +840,9 @@ var _ = Describe("When deleting Issue", Label("app", "DeleteIssue"), func() { }) Context("with valid input", func() { It("deletes issue successfully", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{123}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{123}, nil) db.On("DeleteIssue", id, int64(123)).Return(nil) - db.On("GetIssues", mock.Anything, []entity.Order{}).Return([]entity.IssueResult{}, nil) + db.On("GetIssues", mock.Anything, mock.Anything, []entity.Order{}).Return([]entity.IssueResult{}, nil) issueHandler = issue.NewIssueHandler(handlerContext) err := issueHandler.DeleteIssue(common.NewAdminContext(), id) @@ -852,7 +853,7 @@ var _ = Describe("When deleting Issue", Label("app", "DeleteIssue"), func() { lo := entity.IssueListOptions{ ListOptions: *entity.NewListOptions(), } - issues, err := issueHandler.ListIssues(filter, &lo) + issues, err := issueHandler.ListIssues(context.Background(), filter, &lo) Expect(err).To(BeNil(), "no error should be thrown") Expect(issues.Elements).To(BeEmpty(), "issue should be deleted") }) @@ -862,7 +863,7 @@ var _ = Describe("When deleting Issue", Label("app", "DeleteIssue"), func() { It("should return Internal error", func() { // Mock GetCurrentUserId failure dbError := errors.New("user database connection failed") - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, dbError) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, dbError) issueHandler = issue.NewIssueHandler(handlerContext) err := issueHandler.DeleteIssue(common.NewAdminContext(), id) @@ -882,7 +883,7 @@ var _ = Describe("When deleting Issue", Label("app", "DeleteIssue"), func() { Context("when database delete fails", func() { It("should return Internal error", func() { // Mock successful user ID retrieval - db.On("GetAllUserIds", mock.Anything).Return([]int64{123}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{123}, nil) // Mock database delete failure dbError := errors.New("foreign key constraint violation") db.On("DeleteIssue", id, int64(123)).Return(dbError) @@ -929,10 +930,11 @@ var _ = Describe( It("adds componentVersion to issueEntity", func() { db.On("AddComponentVersionToIssue", issueResult.Issue.Id, componentVersion.Id). Return(nil) - db.On("GetIssues", mock.Anything, mock.Anything). + db.On("GetIssues", mock.Anything, mock.Anything, mock.Anything). Return([]entity.IssueResult{issueResult}, nil) issueHandler = issue.NewIssueHandler(handlerContext) issue, err := issueHandler.AddComponentVersionToIssue( + context.Background(), issueResult.Issue.Id, componentVersion.Id, ) @@ -943,10 +945,11 @@ var _ = Describe( It("removes componentVersion from issueEntity", func() { db.On("RemoveComponentVersionFromIssue", issueResult.Issue.Id, componentVersion.Id). Return(nil) - db.On("GetIssues", mock.Anything, mock.Anything). + db.On("GetIssues", mock.Anything, mock.Anything, mock.Anything). Return([]entity.IssueResult{issueResult}, nil) issueHandler = issue.NewIssueHandler(handlerContext) issue, err := issueHandler.RemoveComponentVersionFromIssue( + context.Background(), issueResult.Issue.Id, componentVersion.Id, ) @@ -985,10 +988,10 @@ var _ = Describe( Medium: 50, Low: 15, } - db.On("CountIssueRatings", filter).Return(expectedCounts, nil) + db.On("CountIssueRatings", mock.Anything, filter).Return(expectedCounts, nil) issueHandler = issue.NewIssueHandler(handlerContext) - result, err := issueHandler.GetIssueSeverityCounts(filter) + result, err := issueHandler.GetIssueSeverityCounts(context.Background(), filter) Expect(err).To(BeNil(), "no error should be thrown") Expect(result).ToNot(BeNil(), "result should be returned") @@ -1006,10 +1009,10 @@ var _ = Describe( Medium: 0, Low: 0, } - db.On("CountIssueRatings", filter).Return(expectedCounts, nil) + db.On("CountIssueRatings", mock.Anything, filter).Return(expectedCounts, nil) issueHandler = issue.NewIssueHandler(handlerContext) - result, err := issueHandler.GetIssueSeverityCounts(filter) + result, err := issueHandler.GetIssueSeverityCounts(context.Background(), filter) Expect(err).To(BeNil(), "no error should be thrown") Expect(result).ToNot(BeNil(), "result should be returned") @@ -1024,11 +1027,11 @@ var _ = Describe( It("should return Internal error", func() { // Mock database error dbError := errors.New("database aggregation failed") - db.On("CountIssueRatings", filter). + db.On("CountIssueRatings", mock.Anything, filter). Return((*entity.IssueSeverityCounts)(nil), dbError) issueHandler = issue.NewIssueHandler(handlerContext) - result, err := issueHandler.GetIssueSeverityCounts(filter) + result, err := issueHandler.GetIssueSeverityCounts(context.Background(), filter) Expect(result).To(BeNil(), "no result should be returned") Expect(err).ToNot(BeNil(), "error should be returned") diff --git a/internal/app/issue_match/issue_match_handler.go b/internal/app/issue_match/issue_match_handler.go index d4ebe7079..4f298c89a 100644 --- a/internal/app/issue_match/issue_match_handler.go +++ b/internal/app/issue_match/issue_match_handler.go @@ -208,7 +208,7 @@ func (im *issueMatchHandler) ListIssueMatches( im.cache, CacheTtlGetIssueMatches, "GetIssueMatches", - im.database.GetIssueMatches, + cache.WrapContext2(ctx, im.database.GetIssueMatches), filter, options.Order, ) @@ -227,7 +227,7 @@ func (im *issueMatchHandler) ListIssueMatches( im.cache, CacheTtlGetAllIssueMatchCursors, "GetAllIssueMatchCursors", - im.database.GetAllIssueMatchCursors, + cache.WrapContext2(ctx, im.database.GetAllIssueMatchCursors), filter, options.Order, ) @@ -248,7 +248,7 @@ func (im *issueMatchHandler) ListIssueMatches( im.cache, CacheTtlCountIssueMatches, "CountIssueMatches", - im.database.CountIssueMatches, + cache.WrapContext1(ctx, im.database.CountIssueMatches), filter, ) if err != nil { @@ -303,7 +303,7 @@ func (im *issueMatchHandler) CreateIssueMatch( } //@todo discuss: may be moved to somewhere else? - effectiveSeverity, err := im.severityHandler.GetSeverity(severityFilter) + effectiveSeverity, err := im.severityHandler.GetSeverity(ctx, severityFilter) if err != nil { l.Error(err) return nil, NewIssueMatchHandlerError("Internal error while retrieving effective severity.") diff --git a/internal/app/issue_match/issue_match_handler_events.go b/internal/app/issue_match/issue_match_handler_events.go index d5e6661e7..be44114de 100644 --- a/internal/app/issue_match/issue_match_handler_events.go +++ b/internal/app/issue_match/issue_match_handler_events.go @@ -4,6 +4,7 @@ package issue_match import ( + "context" "strconv" "time" @@ -104,6 +105,7 @@ func BuildIssueVariantMap( // Get Issue Variants issueVariants, err := db.GetServiceIssueVariants( + context.Background(), &entity.ServiceIssueVariantFilter{ComponentInstanceId: []*int64{&componentInstanceID}}, []entity.Order{}, ) @@ -165,7 +167,7 @@ func OnComponentVersionAssignmentToComponentInstance( l.WithField("event-step", "BuildIssueVariants"). Debug("Building map of IssueVariants for issues related to assigned Component Version") - issueVariantMap, err := shared.BuildIssueVariantMap(db, &entity.ServiceIssueVariantFilter{ + issueVariantMap, err := shared.BuildIssueVariantMap(context.Background(), db, &entity.ServiceIssueVariantFilter{ ComponentInstanceId: []*int64{&componentInstanceID}, }, componentVersionID) @@ -187,7 +189,7 @@ func OnComponentVersionAssignmentToComponentInstance( l.WithField("event-step", "FetchIssueMatches"). Debug("Fetching issue matches related to assigned Component Instance") - issue_matches, err := db.GetIssueMatches(&entity.IssueMatchFilter{ + issue_matches, err := db.GetIssueMatches(context.Background(), &entity.IssueMatchFilter{ IssueId: []*int64{&issueId}, ComponentInstanceId: []*int64{&componentInstanceID}, }, nil) diff --git a/internal/app/issue_match/issue_match_handler_test.go b/internal/app/issue_match/issue_match_handler_test.go index 26c24f84d..6d1fc2729 100644 --- a/internal/app/issue_match/issue_match_handler_test.go +++ b/internal/app/issue_match/issue_match_handler_test.go @@ -90,10 +90,10 @@ var _ = Describe("When listing IssueMatches", Label("app", "ListIssueMatches"), When("the list option does include the totalCount", func() { BeforeEach(func() { options.ShowTotalCount = true - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetIssueMatches", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetIssueMatches", mock.Anything, filter, []entity.Order{}). Return([]entity.IssueMatchResult{}, nil) - db.On("CountIssueMatches", filter).Return(int64(1337), nil) + db.On("CountIssueMatches", mock.Anything, filter).Return(int64(1337), nil) }) It("shows the total count in the results", func() { @@ -138,9 +138,9 @@ var _ = Describe("When listing IssueMatches", Label("app", "ListIssueMatches"), c, _ := mariadb.EncodeCursor(mariadb.WithIssueMatch([]entity.Order{}, im)) cursors = append(cursors, c) } - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetIssueMatches", filter, []entity.Order{}).Return(matches, nil) - db.On("GetAllIssueMatchCursors", filter, []entity.Order{}).Return(cursors, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetIssueMatches", mock.Anything, filter, []entity.Order{}).Return(matches, nil) + db.On("GetAllIssueMatchCursors", mock.Anything, filter, []entity.Order{}).Return(cursors, nil) issueMatchHandler = im.NewIssueMatchHandler(handlerContext, nil) res, err := issueMatchHandler.ListIssueMatches(ctx, filter, options) Expect(err).To(BeNil(), "no error should be thrown") @@ -171,8 +171,8 @@ var _ = Describe("When listing IssueMatches", Label("app", "ListIssueMatches"), Context("and the given filter does not have any matches in the database", func() { BeforeEach(func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetIssueMatches", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetIssueMatches", mock.Anything, filter, []entity.Order{}). Return([]entity.IssueMatchResult{}, nil) }) It("should return an empty result", func() { @@ -184,7 +184,7 @@ var _ = Describe("When listing IssueMatches", Label("app", "ListIssueMatches"), }) Context("and the filter does have results in the database", func() { BeforeEach(func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) issueMatches := []entity.IssueMatchResult{} for _, im := range test.NNewFakeIssueMatches(15) { issueMatches = append( @@ -192,7 +192,7 @@ var _ = Describe("When listing IssueMatches", Label("app", "ListIssueMatches"), entity.IssueMatchResult{IssueMatch: new(im)}, ) } - db.On("GetIssueMatches", filter, []entity.Order{}).Return(issueMatches, nil) + db.On("GetIssueMatches", mock.Anything, filter, []entity.Order{}).Return(issueMatches, nil) }) It("should return the expected matches in the result", func() { issueMatchHandler = im.NewIssueMatchHandler(handlerContext, nil) @@ -204,8 +204,8 @@ var _ = Describe("When listing IssueMatches", Label("app", "ListIssueMatches"), Context("and the database operations throw an error", func() { BeforeEach(func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetIssueMatches", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetIssueMatches", mock.Anything, filter, []entity.Order{}). Return([]entity.IssueMatchResult{}, errors.New("some error")) }) @@ -239,8 +239,8 @@ var _ = Describe("When listing IssueMatches", Label("app", "ListIssueMatches"), BeforeEach(func() { componentInstanceIds := int64(-1) filter.ComponentInstanceId = []*int64{&componentInstanceIds} - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetIssueMatches", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetIssueMatches", mock.Anything, filter, []entity.Order{}). Return([]entity.IssueMatchResult{}, nil) }) @@ -265,8 +265,8 @@ var _ = Describe("When listing IssueMatches", Label("app", "ListIssueMatches"), systemUserId := int64(1) filter.ServiceId = []*int64{&serviceId} issueMatch = test.NewFakeIssueMatch() - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetIssueMatches", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetIssueMatches", mock.Anything, filter, []entity.Order{}). Return([]entity.IssueMatchResult{{IssueMatch: &issueMatch}}, nil) relations := []openfga.RelationInput{ @@ -392,9 +392,9 @@ var _ = Describe("When creating IssueMatch", Label("app", "CreateIssueMatch"), f irFilter.Id = []*int64{&repositories[0].Id} ivFilter.IssueId = []*int64{&issueMatch.IssueId} issueMatch.Severity = issueVariants[0].Severity - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("CreateIssueMatch", &issueMatch).Return(&issueMatch, nil) - db.On("GetIssueVariants", ivFilter, mock.Anything).Return([]entity.IssueVariantResult{ + db.On("GetIssueVariants", mock.Anything, ivFilter, mock.Anything).Return([]entity.IssueVariantResult{ { IssueVariant: &issueVariants[0], }, @@ -408,7 +408,7 @@ var _ = Describe("When creating IssueMatch", Label("app", "CreateIssueMatch"), f }) } - db.On("GetIssueRepositories", irFilter, mock.Anything).Return(irResults, nil) + db.On("GetIssueRepositories", mock.Anything, irFilter, mock.Anything).Return(irResults, nil) issueMatchHandler = im.NewIssueMatchHandler(handlerContext, ss) newIssueMatch, err := issueMatchHandler.CreateIssueMatch( common.NewAdminContext(), @@ -511,7 +511,7 @@ var _ = Describe("When updating IssueMatch", Label("app", "UpdateIssueMatch"), f }) It("updates issueMatch", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("UpdateIssueMatch", issueMatch.IssueMatch).Return(nil) issueMatchHandler = im.NewIssueMatchHandler(handlerContext, nil) if issueMatch.Status == entity.NewIssueMatchStatusValue("new") { @@ -520,7 +520,7 @@ var _ = Describe("When updating IssueMatch", Label("app", "UpdateIssueMatch"), f issueMatch.Status = entity.NewIssueMatchStatusValue("new") } filter.Id = []*int64{&issueMatch.Id} - db.On("GetIssueMatches", filter, []entity.Order{}). + db.On("GetIssueMatches", mock.Anything, filter, []entity.Order{}). Return([]entity.IssueMatchResult{issueMatch}, nil) updatedIssueMatch, err := issueMatchHandler.UpdateIssueMatch( common.NewAdminContext(), @@ -642,10 +642,10 @@ var _ = Describe("When deleting IssueMatch", Label("app", "DeleteIssueMatch"), f }) It("deletes issueMatch", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("DeleteIssueMatch", id, mock.Anything).Return(nil) issueMatchHandler = im.NewIssueMatchHandler(handlerContext, nil) - db.On("GetIssueMatches", filter, []entity.Order{}).Return([]entity.IssueMatchResult{}, nil) + db.On("GetIssueMatches", mock.Anything, filter, []entity.Order{}).Return([]entity.IssueMatchResult{}, nil) err := issueMatchHandler.DeleteIssueMatch(common.NewAdminContext(), id) Expect(err).To(BeNil(), "no error should be thrown") @@ -781,7 +781,7 @@ var _ = Describe("OnComponentInstanceCreate", Label("app", "OnComponentInstanceC service.Id = 1 // Mocks - db.On("GetServiceIssueVariants", &entity.ServiceIssueVariantFilter{ + db.On("GetServiceIssueVariants", mock.Anything, &entity.ServiceIssueVariantFilter{ ComponentInstanceId: []*int64{new(int64(1))}, }, mock.Anything).Return([]entity.ServiceIssueVariantResult{}, nil) }) @@ -807,7 +807,7 @@ var _ = Describe("OnComponentInstanceCreate", Label("app", "OnComponentInstanceC } // Mocks - db.On("GetServiceIssueVariants", mock.MatchedBy(func(filter *entity.ServiceIssueVariantFilter) bool { + db.On("GetServiceIssueVariants", mock.Anything, mock.MatchedBy(func(filter *entity.ServiceIssueVariantFilter) bool { // Check that IssueId and IssueRepositoryId are not nil, but don't care about // their contents return filter.ComponentInstanceId != nil @@ -839,7 +839,7 @@ var _ = Describe("OnComponentInstanceCreate", Label("app", "OnComponentInstanceC } // Mocks - db.On("GetServiceIssueVariants", mock.MatchedBy(func(filter *entity.ServiceIssueVariantFilter) bool { + db.On("GetServiceIssueVariants", mock.Anything, mock.MatchedBy(func(filter *entity.ServiceIssueVariantFilter) bool { // Check that IssueId and IssueRepositoryId are not nil, but don't care about // their contents return filter.ComponentInstanceId != nil @@ -869,7 +869,7 @@ var _ = Describe("OnComponentInstanceCreate", Label("app", "OnComponentInstanceC } // Mocks - db.On("GetServiceIssueVariants", mock.MatchedBy(func(filter *entity.ServiceIssueVariantFilter) bool { + db.On("GetServiceIssueVariants", mock.Anything, mock.MatchedBy(func(filter *entity.ServiceIssueVariantFilter) bool { // Check that IssueId and IssueRepositoryId are not nil, but don't care about // their contents return filter.ComponentInstanceId != nil @@ -907,7 +907,7 @@ var _ = Describe("OnComponentInstanceCreate", Label("app", "OnComponentInstanceC } // Mocks - db.On("GetServiceIssueVariants", mock.MatchedBy(func(filter *entity.ServiceIssueVariantFilter) bool { + db.On("GetServiceIssueVariants", mock.Anything, mock.MatchedBy(func(filter *entity.ServiceIssueVariantFilter) bool { // Check that IssueId and IssueRepositoryId are not nil, but don't care about // their contents return filter.ComponentInstanceId != nil @@ -916,7 +916,7 @@ var _ = Describe("OnComponentInstanceCreate", Label("app", "OnComponentInstanceC }) It("should create issue matches for each issue", func() { - db.On("GetIssueMatches", mock.Anything, mock.Anything). + db.On("GetIssueMatches", mock.Anything, mock.Anything, mock.Anything). Return([]entity.IssueMatchResult{}, nil) // Mock CreateIssueMatch db.On("CreateIssueMatch", mock.AnythingOfType("*entity.IssueMatch")). @@ -938,7 +938,7 @@ var _ = Describe("OnComponentInstanceCreate", Label("app", "OnComponentInstanceC issueMatch := test.NewFakeIssueMatchResult() issueMatch.IssueId = 2 // issue2.Id // when issueid is 2 return a fake issue match - db.On("GetIssueMatches", mock.Anything, mock.Anything). + db.On("GetIssueMatches", mock.Anything, mock.Anything, mock.Anything). Return([]entity.IssueMatchResult{issueMatch}, nil). Once() }) diff --git a/internal/app/issue_repository/issue_repository_handler.go b/internal/app/issue_repository/issue_repository_handler.go index e23feb7f3..adcc66fee 100644 --- a/internal/app/issue_repository/issue_repository_handler.go +++ b/internal/app/issue_repository/issue_repository_handler.go @@ -48,6 +48,7 @@ func (e *IssueRepositoryHandlerError) Error() string { } func (ir *issueRepositoryHandler) getIssueRepositoryResults( + ctx context.Context, filter *entity.IssueRepositoryFilter, ) ([]entity.IssueRepositoryResult, error) { var irResults []entity.IssueRepositoryResult @@ -56,7 +57,7 @@ func (ir *issueRepositoryHandler) getIssueRepositoryResults( ir.cache, CacheTtlGetIssueRepository, "GetIssueRepositories", - ir.database.GetIssueRepositories, + cache.WrapContext2(ctx, ir.database.GetIssueRepositories), filter, []entity.Order{}, ) @@ -68,6 +69,7 @@ func (ir *issueRepositoryHandler) getIssueRepositoryResults( } func (ir *issueRepositoryHandler) ListIssueRepositories( + ctx context.Context, filter *entity.IssueRepositoryFilter, options *entity.ListOptions, ) (*entity.List[entity.IssueRepositoryResult], error) { @@ -83,7 +85,7 @@ func (ir *issueRepositoryHandler) ListIssueRepositories( "filter": filter, }) - res, err := ir.getIssueRepositoryResults(filter) + res, err := ir.getIssueRepositoryResults(ctx, filter) if err != nil { return nil, NewIssueRepositoryHandlerError("Error while filtering for IssueRepositories") } @@ -94,7 +96,7 @@ func (ir *issueRepositoryHandler) ListIssueRepositories( ir.cache, CacheTtlGetAllIssueRepositoryCursors, "GetAllIssueRepositoryCursors", - ir.database.GetAllIssueRepositoryCursors, + cache.WrapContext2(ctx, ir.database.GetAllIssueRepositoryCursors), filter, options.Order, ) @@ -110,7 +112,7 @@ func (ir *issueRepositoryHandler) ListIssueRepositories( count = int64(len(cursors)) } } else if options.ShowTotalCount { - count, err = ir.database.CountIssueRepositories(filter) + count, err = ir.database.CountIssueRepositories(ctx, filter) if err != nil { l.Error(err) @@ -160,7 +162,7 @@ func (ir *issueRepositoryHandler) CreateIssueRepository( issueRepository.BaseIssueRepository.UpdatedBy = issueRepository.BaseIssueRepository.CreatedBy - issueRepositories, err := ir.ListIssueRepositories(f, &entity.ListOptions{}) + issueRepositories, err := ir.ListIssueRepositories(ctx, f, &entity.ListOptions{}) if err != nil { l.Error(err) return nil, NewIssueRepositoryHandlerError("Internal error while creating issueRepository.") @@ -212,6 +214,7 @@ func (ir *issueRepositoryHandler) UpdateIssueRepository( } issueRepositoryResult, err := ir.ListIssueRepositories( + ctx, &entity.IssueRepositoryFilter{Id: []*int64{&issueRepository.Id}}, &entity.ListOptions{}, ) diff --git a/internal/app/issue_repository/issue_repository_handler_events.go b/internal/app/issue_repository/issue_repository_handler_events.go index 78527eead..4fdbabda7 100644 --- a/internal/app/issue_repository/issue_repository_handler_events.go +++ b/internal/app/issue_repository/issue_repository_handler_events.go @@ -4,6 +4,8 @@ package issue_repository import ( + "context" + "github.com/cloudoperators/heureka/internal/app/event" "github.com/cloudoperators/heureka/internal/database" "github.com/cloudoperators/heureka/internal/entity" @@ -72,7 +74,7 @@ func OnIssueRepositoryCreate(db database.Database, e event.Event, authz openfga. l.WithField("event-step", "GetIssueRepository").Debug("Fetching Issue Repository by name") // Fetch services - services, err := db.GetServices(&entity.ServiceFilter{}, []entity.Order{}) + services, err := db.GetServices(context.Background(), &entity.ServiceFilter{}, []entity.Order{}) if err != nil { l.WithField("event-step", "GetServices"). WithError(err). diff --git a/internal/app/issue_repository/issue_repository_handler_interface.go b/internal/app/issue_repository/issue_repository_handler_interface.go index 938d4847c..dd8c2466e 100644 --- a/internal/app/issue_repository/issue_repository_handler_interface.go +++ b/internal/app/issue_repository/issue_repository_handler_interface.go @@ -11,6 +11,7 @@ import ( type IssueRepositoryHandler interface { ListIssueRepositories( + context.Context, *entity.IssueRepositoryFilter, *entity.ListOptions, ) (*entity.List[entity.IssueRepositoryResult], error) diff --git a/internal/app/issue_repository/issue_repository_handler_test.go b/internal/app/issue_repository/issue_repository_handler_test.go index c89d48bea..23519a5cf 100644 --- a/internal/app/issue_repository/issue_repository_handler_test.go +++ b/internal/app/issue_repository/issue_repository_handler_test.go @@ -4,6 +4,7 @@ package issue_repository_test import ( + "context" "math" "testing" @@ -74,14 +75,14 @@ var _ = Describe("When listing IssueRepositories", Label("app", "ListIssueReposi When("the list option does include the totalCount", func() { BeforeEach(func() { options.ShowTotalCount = true - db.On("GetIssueRepositories", filter, mock.Anything). + db.On("GetIssueRepositories", mock.Anything, filter, mock.Anything). Return([]entity.IssueRepositoryResult{}, nil) - db.On("CountIssueRepositories", filter).Return(int64(1337), nil) + db.On("CountIssueRepositories", mock.Anything, filter).Return(int64(1337), nil) }) It("shows the total count in the results", func() { issueRepositoryHandler = ir.NewIssueRepositoryHandler(handlerContext) - res, err := issueRepositoryHandler.ListIssueRepositories(filter, options) + res, err := issueRepositoryHandler.ListIssueRepositories(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(*res.TotalCount).Should(BeEquivalentTo(int64(1337)), "return correct Totalcount") }) @@ -124,10 +125,10 @@ var _ = Describe("When listing IssueRepositories", Label("app", "ListIssueReposi cursors = append(cursors, c) } - db.On("GetIssueRepositories", filter, mock.Anything).Return(irResults, nil) - db.On("GetAllIssueRepositoryCursors", filter, mock.Anything).Return(cursors, nil) + db.On("GetIssueRepositories", mock.Anything, filter, mock.Anything).Return(irResults, nil) + db.On("GetAllIssueRepositoryCursors", mock.Anything, filter, mock.Anything).Return(cursors, nil) issueRepositoryHandler = ir.NewIssueRepositoryHandler(handlerContext) - res, err := issueRepositoryHandler.ListIssueRepositories(filter, options) + res, err := issueRepositoryHandler.ListIssueRepositories(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect( *res.PageInfo.HasNextPage, @@ -184,9 +185,9 @@ var _ = Describe("When creating IssueRepository", Label("app", "CreateIssueRepos It("creates issueRepository", func() { filter.Name = []*string{&issueRepository.Name} - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("CreateIssueRepository", &issueRepository).Return(&issueRepository, nil) - db.On("GetIssueRepositories", filter, mock.Anything). + db.On("GetIssueRepositories", mock.Anything, filter, mock.Anything). Return([]entity.IssueRepositoryResult{}, nil) issueRepositoryHandler = ir.NewIssueRepositoryHandler(handlerContext) newIssueRepository, err := issueRepositoryHandler.CreateIssueRepository( @@ -210,7 +211,7 @@ var _ = Describe("When creating IssueRepository", Label("app", "CreateIssueRepos issueRepository.Id = int64(1) services := []entity.ServiceResult{service1, service2} - db.On("GetServices", &entity.ServiceFilter{}, []entity.Order{}).Return(services, nil) + db.On("GetServices", mock.Anything, &entity.ServiceFilter{}, []entity.Order{}).Return(services, nil) db.On("AddIssueRepositoryToService", int64(1), int64(1), int64(100)).Return(nil) db.On("AddIssueRepositoryToService", int64(2), int64(1), int64(100)).Return(nil) db.On("GetDefaultIssuePriority").Return(int64(100)) @@ -265,12 +266,12 @@ var _ = Describe("When updating IssueRepository", Label("app", "UpdateIssueRepos }) It("updates issueRepository", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("UpdateIssueRepository", &issueRepository).Return(nil) issueRepositoryHandler = ir.NewIssueRepositoryHandler(handlerContext) issueRepository.Name = "SecretRepository" filter.Id = []*int64{&issueRepository.Id} - db.On("GetIssueRepositories", filter, mock.Anything).Return([]entity.IssueRepositoryResult{{ + db.On("GetIssueRepositories", mock.Anything, filter, mock.Anything, mock.Anything).Return([]entity.IssueRepositoryResult{{ IssueRepository: &issueRepository, }}, nil) updatedIssueRepository, err := issueRepositoryHandler.UpdateIssueRepository( @@ -312,16 +313,17 @@ var _ = Describe("When deleting IssueRepository", Label("app", "DeleteIssueRepos }) It("deletes issueRepository", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("DeleteIssueRepository", id, mock.Anything).Return(nil) issueRepositoryHandler = ir.NewIssueRepositoryHandler(handlerContext) - db.On("GetIssueRepositories", filter, mock.Anything). + db.On("GetIssueRepositories", mock.Anything, filter, mock.Anything). Return([]entity.IssueRepositoryResult{}, nil) err := issueRepositoryHandler.DeleteIssueRepository(common.NewAdminContext(), id) Expect(err).To(BeNil(), "no error should be thrown") filter.Id = []*int64{&id} issueRepositories, err := issueRepositoryHandler.ListIssueRepositories( + context.Background(), filter, &entity.ListOptions{}, ) diff --git a/internal/app/issue_variant/issue_variant_handler.go b/internal/app/issue_variant/issue_variant_handler.go index 88e57e5c6..f55af1a37 100644 --- a/internal/app/issue_variant/issue_variant_handler.go +++ b/internal/app/issue_variant/issue_variant_handler.go @@ -56,6 +56,7 @@ func (e *IssueVariantHandlerError) Error() string { } func (iv *issueVariantHandler) getIssueVariantResults( + ctx context.Context, filter *entity.IssueVariantFilter, ) ([]entity.IssueVariantResult, error) { var ivResults []entity.IssueVariantResult @@ -64,7 +65,7 @@ func (iv *issueVariantHandler) getIssueVariantResults( iv.cache, CacheTtlGetIssueVariants, "GetIssueVariants", - iv.database.GetIssueVariants, + cache.WrapContext2(ctx, iv.database.GetIssueVariants), filter, []entity.Order{}, ) @@ -76,6 +77,7 @@ func (iv *issueVariantHandler) getIssueVariantResults( } func (iv *issueVariantHandler) ListIssueVariants( + ctx context.Context, filter *entity.IssueVariantFilter, options *entity.ListOptions, ) (*entity.List[entity.IssueVariantResult], error) { @@ -91,7 +93,7 @@ func (iv *issueVariantHandler) ListIssueVariants( "filter": filter, }) - res, err := iv.getIssueVariantResults(filter) + res, err := iv.getIssueVariantResults(ctx, filter) if err != nil { l.Error(err) return nil, NewIssueVariantHandlerError("Error while filtering for IssueVariants") @@ -103,7 +105,7 @@ func (iv *issueVariantHandler) ListIssueVariants( iv.cache, CacheTtlGetAllIssueVariantCursors, "GetAllIssueVariantCursors", - iv.database.GetAllIssueVariantCursors, + cache.WrapContext2(ctx, iv.database.GetAllIssueVariantCursors), filter, options.Order, ) @@ -121,7 +123,7 @@ func (iv *issueVariantHandler) ListIssueVariants( iv.cache, CacheTtlCountIssueVariants, "CountIssueVariants", - iv.database.CountIssueVariants, + cache.WrapContext1(ctx, iv.database.CountIssueVariants), filter, ) if err != nil { @@ -146,6 +148,7 @@ func (iv *issueVariantHandler) ListIssueVariants( } func (iv *issueVariantHandler) ListEffectiveIssueVariants( + ctx context.Context, filter *entity.IssueVariantFilter, options *entity.ListOptions, ) (*entity.List[entity.IssueVariantResult], error) { @@ -154,7 +157,7 @@ func (iv *issueVariantHandler) ListEffectiveIssueVariants( "filter": filter, }) - issueVariants, err := iv.ListIssueVariants(filter, options) + issueVariants, err := iv.ListIssueVariants(ctx, filter, options) if err != nil { l.Error(err) return nil, NewIssueVariantHandlerError("Internal error while returning issueVariants.") @@ -173,7 +176,7 @@ func (iv *issueVariantHandler) ListEffectiveIssueVariants( opts := entity.ListOptions{} - repositories, err := iv.repositoryService.ListIssueRepositories(&repositoryFilter, &opts) + repositories, err := iv.repositoryService.ListIssueRepositories(ctx, &repositoryFilter, &opts) if err != nil { l.Error(err) @@ -247,7 +250,7 @@ func (iv *issueVariantHandler) CreateIssueVariant( issueVariant.UpdatedBy = issueVariant.CreatedBy - issueVariants, err := iv.ListIssueVariants(f, &entity.ListOptions{}) + issueVariants, err := iv.ListIssueVariants(ctx, f, &entity.ListOptions{}) if err != nil { l.Error(err) return nil, NewIssueVariantHandlerError("Internal error while creating issueVariant.") @@ -299,6 +302,7 @@ func (iv *issueVariantHandler) UpdateIssueVariant( } ivResult, err := iv.ListIssueVariants( + ctx, &entity.IssueVariantFilter{Id: []*int64{&issueVariant.Id}}, &entity.ListOptions{}, ) diff --git a/internal/app/issue_variant/issue_variant_handler_interface.go b/internal/app/issue_variant/issue_variant_handler_interface.go index 649041d61..a2346f4e5 100644 --- a/internal/app/issue_variant/issue_variant_handler_interface.go +++ b/internal/app/issue_variant/issue_variant_handler_interface.go @@ -11,10 +11,12 @@ import ( type IssueVariantHandler interface { ListIssueVariants( + context.Context, *entity.IssueVariantFilter, *entity.ListOptions, ) (*entity.List[entity.IssueVariantResult], error) ListEffectiveIssueVariants( + context.Context, *entity.IssueVariantFilter, *entity.ListOptions, ) (*entity.List[entity.IssueVariantResult], error) diff --git a/internal/app/issue_variant/issue_variant_handler_test.go b/internal/app/issue_variant/issue_variant_handler_test.go index 9a3e228de..e7b41f93d 100644 --- a/internal/app/issue_variant/issue_variant_handler_test.go +++ b/internal/app/issue_variant/issue_variant_handler_test.go @@ -4,6 +4,7 @@ package issue_variant_test import ( + "context" "math" "testing" @@ -97,14 +98,14 @@ var _ = Describe("When listing IssueVariants", Label("app", "ListIssueVariants") When("the list option does include the totalCount", func() { BeforeEach(func() { options.ShowTotalCount = true - db.On("GetIssueVariants", filter, mock.Anything). + db.On("GetIssueVariants", mock.Anything, filter, mock.Anything). Return([]entity.IssueVariantResult{}, nil) - db.On("CountIssueVariants", filter).Return(int64(1337), nil) + db.On("CountIssueVariants", mock.Anything, filter).Return(int64(1337), nil) }) It("shows the total count in the results", func() { issueVariantHandler = iv.NewIssueVariantHandler(handlerContext, rs) - res, err := issueVariantHandler.ListIssueVariants(filter, options) + res, err := issueVariantHandler.ListIssueVariants(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(*res.TotalCount).Should(BeEquivalentTo(int64(1337)), "return correct Totalcount") }) @@ -146,10 +147,10 @@ var _ = Describe("When listing IssueVariants", Label("app", "ListIssueVariants") cursors = append(cursors, c) } - db.On("GetIssueVariants", filter, mock.Anything).Return(ivResults, nil) - db.On("GetAllIssueVariantCursors", filter, mock.Anything).Return(cursors, nil) + db.On("GetIssueVariants", mock.Anything, filter, mock.Anything).Return(ivResults, nil) + db.On("GetAllIssueVariantCursors", mock.Anything, filter, mock.Anything).Return(cursors, nil) issueVariantHandler = iv.NewIssueVariantHandler(handlerContext, rs) - res, err := issueVariantHandler.ListIssueVariants(filter, options) + res, err := issueVariantHandler.ListIssueVariants(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect( *res.PageInfo.HasNextPage, @@ -237,12 +238,12 @@ var _ = Describe( }) } - db.On("GetIssueVariants", ivFilter, mock.Anything).Return(ivResults, nil) - db.On("GetIssueRepositories", irFilter, mock.Anything).Return(irResults, nil) + db.On("GetIssueVariants", mock.Anything, ivFilter, mock.Anything).Return(ivResults, nil) + db.On("GetIssueRepositories", mock.Anything, irFilter, mock.Anything).Return(irResults, nil) }) It("can list advisories", func() { issueVariantHandler = iv.NewIssueVariantHandler(handlerContext, rs) - res, err := issueVariantHandler.ListEffectiveIssueVariants(ivFilter, options) + res, err := issueVariantHandler.ListEffectiveIssueVariants(context.Background(), ivFilter, options) Expect(err).To(BeNil(), "no error should be thrown") for _, item := range res.Elements { Expect(item.IssueRepositoryId).To(BeEquivalentTo(repositories[1].Id)) @@ -278,12 +279,12 @@ var _ = Describe( }) } - db.On("GetIssueVariants", ivFilter, mock.Anything).Return(ivResults, nil) - db.On("GetIssueRepositories", irFilter, mock.Anything).Return(irResults, nil) + db.On("GetIssueVariants", mock.Anything, ivFilter, mock.Anything).Return(ivResults, nil) + db.On("GetIssueRepositories", mock.Anything, irFilter, mock.Anything).Return(irResults, nil) }) It("can list issueVariants", func() { issueVariantHandler = iv.NewIssueVariantHandler(handlerContext, rs) - res, err := issueVariantHandler.ListEffectiveIssueVariants(ivFilter, options) + res, err := issueVariantHandler.ListEffectiveIssueVariants(context.Background(), ivFilter, options) Expect(err).To(BeNil(), "no error should be thrown") ir_ids := lo.Map(res.Elements, func(item entity.IssueVariantResult, _ int) int64 { return item.IssueRepositoryId @@ -327,9 +328,9 @@ var _ = Describe("When creating IssueVariant", Label("app", "CreateIssueVariant" It("creates issueVariant", func() { filter.SecondaryName = []*string{&issueVariant.SecondaryName} - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("CreateIssueVariant", &issueVariant).Return(&issueVariant, nil) - db.On("GetIssueVariants", filter, mock.Anything).Return([]entity.IssueVariantResult{}, nil) + db.On("GetIssueVariants", mock.Anything, filter, mock.Anything).Return([]entity.IssueVariantResult{}, nil) issueVariantHandler = iv.NewIssueVariantHandler(handlerContext, rs) newIssueVariant, err := issueVariantHandler.CreateIssueVariant( common.NewAdminContext(), @@ -384,12 +385,12 @@ var _ = Describe("When updating IssueVariant", Label("app", "UpdateIssueVariant" }) It("updates issueVariant", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("UpdateIssueVariant", &issueVariant).Return(nil) issueVariantHandler = iv.NewIssueVariantHandler(handlerContext, rs) issueVariant.SecondaryName = "SecretAdvisory" filter.Id = []*int64{&issueVariant.Id} - db.On("GetIssueVariants", filter, mock.Anything).Return([]entity.IssueVariantResult{ + db.On("GetIssueVariants", mock.Anything, filter, mock.Anything).Return([]entity.IssueVariantResult{ { IssueVariant: &issueVariant, }, @@ -450,15 +451,15 @@ var _ = Describe("When deleting IssueVariant", Label("app", "DeleteIssueVariant" }) It("deletes issueVariant", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("DeleteIssueVariant", id, mock.Anything).Return(nil) issueVariantHandler = iv.NewIssueVariantHandler(handlerContext, rs) - db.On("GetIssueVariants", filter, mock.Anything).Return([]entity.IssueVariantResult{}, nil) + db.On("GetIssueVariants", mock.Anything, filter, mock.Anything).Return([]entity.IssueVariantResult{}, nil) err := issueVariantHandler.DeleteIssueVariant(common.NewAdminContext(), id) Expect(err).To(BeNil(), "no error should be thrown") filter.Id = []*int64{&id} - issueVariants, err := issueVariantHandler.ListIssueVariants(filter, &entity.ListOptions{}) + issueVariants, err := issueVariantHandler.ListIssueVariants(context.Background(), filter, &entity.ListOptions{}) Expect(err).To(BeNil(), "no error should be thrown") Expect(issueVariants.Elements).To(BeEmpty(), "no error should be thrown") }) diff --git a/internal/app/patch/patch_handler.go b/internal/app/patch/patch_handler.go index 684ff92de..98342ae21 100644 --- a/internal/app/patch/patch_handler.go +++ b/internal/app/patch/patch_handler.go @@ -4,6 +4,7 @@ package patch import ( + "context" "time" "github.com/cloudoperators/heureka/internal/app/common" @@ -39,6 +40,7 @@ func NewPatchHandler(handlerContext common.HandlerContext) PatchHandler { } func (ph *patchHandler) ListPatches( + ctx context.Context, filter *entity.PatchFilter, options *entity.ListOptions, ) (*entity.List[entity.PatchResult], error) { @@ -55,7 +57,7 @@ func (ph *patchHandler) ListPatches( ph.cache, CacheTtlGetPatches, "GetPatches", - ph.database.GetPatches, + cache.WrapContext2(ctx, ph.database.GetPatches), filter, options.Order, ) @@ -74,7 +76,7 @@ func (ph *patchHandler) ListPatches( ph.cache, CacheTtlGetAllPatchCursors, "GetAllPatchCursors", - ph.database.GetAllPatchCursors, + cache.WrapContext2(ctx, ph.database.GetAllPatchCursors), filter, options.Order, ) @@ -95,7 +97,7 @@ func (ph *patchHandler) ListPatches( ph.cache, CacheTtlCountPatches, "CountPatches", - ph.database.CountPatches, + cache.WrapContext1(ctx, ph.database.CountPatches), filter, ) if err != nil { diff --git a/internal/app/patch/patch_handler_interface.go b/internal/app/patch/patch_handler_interface.go index 405207330..7c4520be3 100644 --- a/internal/app/patch/patch_handler_interface.go +++ b/internal/app/patch/patch_handler_interface.go @@ -4,9 +4,11 @@ package patch import ( + "context" + "github.com/cloudoperators/heureka/internal/entity" ) type PatchHandler interface { - ListPatches(*entity.PatchFilter, *entity.ListOptions) (*entity.List[entity.PatchResult], error) + ListPatches(context.Context, *entity.PatchFilter, *entity.ListOptions) (*entity.List[entity.PatchResult], error) } diff --git a/internal/app/patch/patch_handler_test.go b/internal/app/patch/patch_handler_test.go index cbfc7e49c..9f2f9f98d 100644 --- a/internal/app/patch/patch_handler_test.go +++ b/internal/app/patch/patch_handler_test.go @@ -4,6 +4,7 @@ package patch_test import ( + "context" "errors" "math" "testing" @@ -20,6 +21,7 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/samber/lo" + "github.com/stretchr/testify/mock" ) var ( @@ -64,13 +66,13 @@ var _ = Describe("When listing Patches", Label("app", "ListPatches"), func() { When("the list option does include the totalCount", func() { BeforeEach(func() { options.ShowTotalCount = true - db.On("GetPatches", filter, []entity.Order{}).Return([]entity.PatchResult{}, nil) - db.On("CountPatches", filter).Return(int64(1337), nil) + db.On("GetPatches", mock.Anything, filter, []entity.Order{}).Return([]entity.PatchResult{}, nil) + db.On("CountPatches", mock.Anything, filter).Return(int64(1337), nil) }) It("shows the total count in the results", func() { patchHandler = ph.NewPatchHandler(handlerContext) - res, err := patchHandler.ListPatches(filter, options) + res, err := patchHandler.ListPatches(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(*res.TotalCount).Should(BeEquivalentTo(int64(1337)), "return correct Totalcount") }) @@ -107,10 +109,10 @@ var _ = Describe("When listing Patches", Label("app", "ListPatches"), func() { c, _ := mariadb.EncodeCursor(mariadb.WithPatch([]entity.Order{}, patch)) cursors = append(cursors, c) } - db.On("GetPatches", filter, []entity.Order{}).Return(patches, nil) - db.On("GetAllPatchCursors", filter, []entity.Order{}).Return(cursors, nil) + db.On("GetPatches", mock.Anything, filter, []entity.Order{}).Return(patches, nil) + db.On("GetAllPatchCursors", mock.Anything, filter, []entity.Order{}).Return(cursors, nil) patchHandler = ph.NewPatchHandler(handlerContext) - res, err := patchHandler.ListPatches(filter, options) + res, err := patchHandler.ListPatches(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect( *res.PageInfo.HasNextPage, @@ -141,10 +143,10 @@ var _ = Describe("When listing Patches", Label("app", "ListPatches"), func() { It("should return Internal error", func() { // Mock database error dbError := errors.New("database connection failed") - db.On("GetPatches", filter, []entity.Order{}).Return([]entity.PatchResult{}, dbError) + db.On("GetPatches", mock.Anything, filter, []entity.Order{}).Return([]entity.PatchResult{}, dbError) patchHandler = ph.NewPatchHandler(handlerContext) - result, err := patchHandler.ListPatches(filter, options) + result, err := patchHandler.ListPatches(context.Background(), filter, options) Expect(result).To(BeNil(), "no result should be returned") Expect(err).ToNot(BeNil(), "error should be returned") @@ -177,12 +179,12 @@ var _ = Describe("When listing Patches", Label("app", "ListPatches"), func() { }) } - db.On("GetPatches", filter, []entity.Order{}).Return(patches, nil) + db.On("GetPatches", mock.Anything, filter, []entity.Order{}).Return(patches, nil) cursorsError := errors.New("cursor database error") - db.On("GetAllPatchCursors", filter, []entity.Order{}).Return([]string{}, cursorsError) + db.On("GetAllPatchCursors", mock.Anything, filter, []entity.Order{}).Return([]string{}, cursorsError) patchHandler = ph.NewPatchHandler(handlerContext) - result, err := patchHandler.ListPatches(filter, options) + result, err := patchHandler.ListPatches(context.Background(), filter, options) Expect(result).To(BeNil(), "no result should be returned") Expect(err).ToNot(BeNil(), "error should be returned") diff --git a/internal/app/remediation/remediation_handler.go b/internal/app/remediation/remediation_handler.go index 98c79d4b5..63609c7b9 100644 --- a/internal/app/remediation/remediation_handler.go +++ b/internal/app/remediation/remediation_handler.go @@ -8,6 +8,7 @@ import ( "fmt" "net/url" "strconv" + "strings" "time" "github.com/cloudoperators/heureka/internal/app/common" @@ -43,6 +44,7 @@ func NewRemediationHandler(handlerContext common.HandlerContext) RemediationHand } func (rh *remediationHandler) ListRemediations( + ctx context.Context, filter *entity.RemediationFilter, options *entity.ListOptions, ) (*entity.List[entity.RemediationResult], error) { @@ -59,7 +61,7 @@ func (rh *remediationHandler) ListRemediations( rh.cache, CacheTtlGetRemediations, "GetRemediations", - rh.database.GetRemediations, + cache.WrapContext2(ctx, rh.database.GetRemediations), filter, options.Order, ) @@ -78,7 +80,7 @@ func (rh *remediationHandler) ListRemediations( rh.cache, CacheTtlGetAllRemediationCursors, "GetAllRemediationCursors", - rh.database.GetAllRemediationCursors, + cache.WrapContext2(ctx, rh.database.GetAllRemediationCursors), filter, options.Order, ) @@ -99,7 +101,7 @@ func (rh *remediationHandler) ListRemediations( rh.cache, CacheTtlCountRemediations, "CountRemediations", - rh.database.CountRemediations, + cache.WrapContext1(ctx, rh.database.CountRemediations), filter, ) if err != nil { @@ -223,7 +225,7 @@ func (rh *remediationHandler) CreateRemediation( State: []entity.StateFilterType{entity.Active}, } - existingRemediations, err := rh.database.GetRemediations(filter, nil) + existingRemediations, err := rh.database.GetRemediations(ctx, filter, nil) if err != nil { wrappedErr := appErrors.InternalError(string(op), "Remediation", "", err) applog.LogError(rh.logger, wrappedErr, logrus.Fields{ @@ -266,6 +268,21 @@ func (rh *remediationHandler) CreateRemediation( return nil, wrappedErr } + if rh.cache != nil { + if err := rh.cache.InvalidateByMatch(func(decodedKey string) bool { + return (strings.Contains(decodedKey, fmt.Sprintf("\"issue_id\":[%d]", newRemediation.IssueId)) || + strings.Contains(decodedKey, fmt.Sprintf("\"id\":[%d]", newRemediation.IssueId))) && + (strings.Contains(decodedKey, "GetIssuesWithAggregations") || strings.Contains(decodedKey, "GetIssues") || + strings.Contains(decodedKey, "GetAllIssueCursors") || strings.Contains(decodedKey, "GetIssueVariants") || + strings.Contains(decodedKey, "GetIssueMatches")) + }); err != nil { + wrappedErr := appErrors.InternalError(string(op), "Remediation", "", err) + applog.LogError(rh.logger, wrappedErr, logrus.Fields{ + "remediation": remediation, + }) + } + } + rh.eventRegistry.PushEvent(&CreateRemediationEvent{ Remediation: newRemediation, }) @@ -325,6 +342,7 @@ func (rh *remediationHandler) UpdateRemediation( lo := entity.NewListOptions() existingRemediations, err := rh.ListRemediations( + ctx, &entity.RemediationFilter{Id: []*int64{&remediation.Id}}, lo, ) @@ -414,6 +432,7 @@ func (rh *remediationHandler) UpdateRemediation( } remediationResult, err := rh.ListRemediations( + ctx, &entity.RemediationFilter{Id: []*int64{&remediation.Id}}, lo, ) diff --git a/internal/app/remediation/remediation_handler_interface.go b/internal/app/remediation/remediation_handler_interface.go index d9893bc7d..75b657f83 100644 --- a/internal/app/remediation/remediation_handler_interface.go +++ b/internal/app/remediation/remediation_handler_interface.go @@ -11,6 +11,7 @@ import ( type RemediationHandler interface { ListRemediations( + context.Context, *entity.RemediationFilter, *entity.ListOptions, ) (*entity.List[entity.RemediationResult], error) diff --git a/internal/app/remediation/remediation_handler_test.go b/internal/app/remediation/remediation_handler_test.go index f276805bc..35f7c562c 100644 --- a/internal/app/remediation/remediation_handler_test.go +++ b/internal/app/remediation/remediation_handler_test.go @@ -4,12 +4,15 @@ package remediation_test import ( + "context" "errors" "math" + "sync" "testing" "time" "github.com/cloudoperators/heureka/internal/app/common" + "github.com/cloudoperators/heureka/internal/cache" "github.com/stretchr/testify/mock" "github.com/cloudoperators/heureka/internal/app/event" @@ -68,14 +71,14 @@ var _ = Describe("When listing Remediations", Label("app", "ListRemediations"), When("the list option does include the totalCount", func() { BeforeEach(func() { options.ShowTotalCount = true - db.On("GetRemediations", filter, []entity.Order{}). + db.On("GetRemediations", mock.Anything, filter, []entity.Order{}). Return([]entity.RemediationResult{}, nil) - db.On("CountRemediations", filter).Return(int64(1337), nil) + db.On("CountRemediations", mock.Anything, filter).Return(int64(1337), nil) }) It("shows the total count in the results", func() { remediationHandler = rh.NewRemediationHandler(handlerContext) - res, err := remediationHandler.ListRemediations(filter, options) + res, err := remediationHandler.ListRemediations(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(*res.TotalCount).Should(BeEquivalentTo(int64(1337)), "return correct Totalcount") }) @@ -119,10 +122,10 @@ var _ = Describe("When listing Remediations", Label("app", "ListRemediations"), ) cursors = append(cursors, c) } - db.On("GetRemediations", filter, []entity.Order{}).Return(remediations, nil) - db.On("GetAllRemediationCursors", filter, []entity.Order{}).Return(cursors, nil) + db.On("GetRemediations", mock.Anything, filter, []entity.Order{}).Return(remediations, nil) + db.On("GetAllRemediationCursors", mock.Anything, filter, []entity.Order{}).Return(cursors, nil) remediationHandler = rh.NewRemediationHandler(handlerContext) - res, err := remediationHandler.ListRemediations(filter, options) + res, err := remediationHandler.ListRemediations(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect( *res.PageInfo.HasNextPage, @@ -154,11 +157,11 @@ var _ = Describe("When listing Remediations", Label("app", "ListRemediations"), It("should return Internal error", func() { // Mock database error dbError := errors.New("database connection failed") - db.On("GetRemediations", filter, []entity.Order{}). + db.On("GetRemediations", mock.Anything, filter, []entity.Order{}). Return([]entity.RemediationResult{}, dbError) remediationHandler = rh.NewRemediationHandler(handlerContext) - result, err := remediationHandler.ListRemediations(filter, options) + result, err := remediationHandler.ListRemediations(context.Background(), filter, options) Expect(result).To(BeNil(), "no result should be returned") Expect(err).ToNot(BeNil(), "error should be returned") @@ -196,13 +199,13 @@ var _ = Describe("When listing Remediations", Label("app", "ListRemediations"), }) } - db.On("GetRemediations", filter, []entity.Order{}).Return(remediations, nil) + db.On("GetRemediations", mock.Anything, filter, []entity.Order{}).Return(remediations, nil) cursorsError := errors.New("cursor database error") - db.On("GetAllRemediationCursors", filter, []entity.Order{}). + db.On("GetAllRemediationCursors", mock.Anything, filter, []entity.Order{}). Return([]string{}, cursorsError) remediationHandler = rh.NewRemediationHandler(handlerContext) - result, err := remediationHandler.ListRemediations(filter, options) + result, err := remediationHandler.ListRemediations(context.Background(), filter, options) Expect(result).To(BeNil(), "no result should be returned") Expect(err).ToNot(BeNil(), "error should be returned") @@ -236,13 +239,14 @@ var _ = Describe("When creating Remediation", Label("app", "CreateRemediation"), DB: db, EventReg: er, Authz: authz, + Cache: cache.NewInMemoryCache(context.Background(), &sync.WaitGroup{}, cache.InMemoryCacheConfig{}), } }) Context("with valid input", func() { It("creates remediation", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{123}, nil) - db.On("GetRemediations", mock.Anything, mock.Anything). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{123}, nil) + db.On("GetRemediations", mock.Anything, mock.Anything, mock.Anything). Return([]entity.RemediationResult{}, nil) db.On("CreateRemediation", mock.AnythingOfType("*entity.Remediation")). Return(&remediation, nil) @@ -275,10 +279,10 @@ var _ = Describe("When creating Remediation", Label("app", "CreateRemediation"), Context("when a duplicate remediation exists", func() { It("returns AlreadyExists error if the existing one is not expired", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{123}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{123}, nil) existing := remediation existing.ExpirationDate = time.Now().Add(time.Hour) - db.On("GetRemediations", mock.Anything, mock.Anything). + db.On("GetRemediations", mock.Anything, mock.Anything, mock.Anything). Return([]entity.RemediationResult{{Remediation: &existing}}, nil) remediationHandler = rh.NewRemediationHandler(handlerContext) @@ -295,10 +299,10 @@ var _ = Describe("When creating Remediation", Label("app", "CreateRemediation"), }) It("creates remediation if the existing one is expired", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{123}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{123}, nil) existing := remediation existing.ExpirationDate = time.Now().Add(-time.Hour) - db.On("GetRemediations", mock.Anything, mock.Anything). + db.On("GetRemediations", mock.Anything, mock.Anything, mock.Anything). Return([]entity.RemediationResult{{Remediation: &existing}}, nil) db.On("CreateRemediation", mock.AnythingOfType("*entity.Remediation")). Return(&remediation, nil) @@ -343,7 +347,7 @@ var _ = Describe("When updating Remediation", Label("app", "UpdateRemediation"), }) Context("with valid input", func() { It("updates remediation", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{123}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{123}, nil) db.On("UpdateRemediation", remediation.Remediation).Return(nil) remediationHandler = rh.NewRemediationHandler(handlerContext) remediation.Description = "Updated description" @@ -351,7 +355,7 @@ var _ = Describe("When updating Remediation", Label("app", "UpdateRemediation"), remediation.Component = "Updated Component" remediation.Issue = "Updated Issue" filter.Id = []*int64{&remediation.Id} - db.On("GetRemediations", filter, []entity.Order{}). + db.On("GetRemediations", mock.Anything, filter, []entity.Order{}). Return([]entity.RemediationResult{remediation}, nil) updatedRemediation, err := remediationHandler.UpdateRemediation( common.NewAdminContext(), @@ -397,17 +401,17 @@ var _ = Describe("When deleting Remediation", Label("app", "DeleteRemediation"), Context("with valid input", func() { It("deletes remediation", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{123}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{123}, nil) db.On("DeleteRemediation", id, int64(123)).Return(nil) remediationHandler = rh.NewRemediationHandler(handlerContext) - db.On("GetRemediations", filter, []entity.Order{}). + db.On("GetRemediations", mock.Anything, filter, []entity.Order{}). Return([]entity.RemediationResult{}, nil) err := remediationHandler.DeleteRemediation(common.NewAdminContext(), id) Expect(err).To(BeNil(), "no error should be thrown") filter.Id = []*int64{&id} lo := entity.NewListOptions() - remediations, err := remediationHandler.ListRemediations(filter, lo) + remediations, err := remediationHandler.ListRemediations(context.Background(), filter, lo) Expect(err).To(BeNil(), "no error should be thrown") Expect(remediations.Elements).To(BeEmpty(), "remediation should be deleted") }) diff --git a/internal/app/scanner_run/scanner_run_handler.go b/internal/app/scanner_run/scanner_run_handler.go index a4e17e29f..8b8602447 100644 --- a/internal/app/scanner_run/scanner_run_handler.go +++ b/internal/app/scanner_run/scanner_run_handler.go @@ -4,6 +4,7 @@ package scanner_run import ( + "context" "fmt" "github.com/cloudoperators/heureka/internal/app/common" @@ -46,7 +47,7 @@ func (srh *scannerRunHandler) CompleteScannerRun(uuid string) (bool, error) { } // Trigger autopatch whenever a scanner run has completed successfully - if _, err := srh.database.Autopatch(); err != nil { + if _, err := srh.database.Autopatch(context.Background()); err != nil { return false, &ScannerRunHandlerError{ msg: "Error executing autopatch in CompleteScannerRun", } diff --git a/internal/app/scanner_run/scanner_run_test.go b/internal/app/scanner_run/scanner_run_test.go index 6da1358d5..8e330e062 100644 --- a/internal/app/scanner_run/scanner_run_test.go +++ b/internal/app/scanner_run/scanner_run_test.go @@ -10,6 +10,7 @@ import ( "github.com/cloudoperators/heureka/internal/app/event" "github.com/cloudoperators/heureka/internal/entity/test" "github.com/cloudoperators/heureka/internal/openfga" + "github.com/stretchr/testify/mock" "github.com/cloudoperators/heureka/internal/entity" "github.com/cloudoperators/heureka/internal/mocks" @@ -61,7 +62,7 @@ var _ = Describe("ScannerRun", Label("app", "CreateScannerRun"), func() { It("creates a scannerrun and completes it", func() { db.On("CreateScannerRun", sre).Return(true, nil) db.On("CompleteScannerRun", sre.UUID).Return(true, nil) - db.On("Autopatch").Return(true, nil) + db.On("Autopatch", mock.Anything).Return(true, nil) scannerRunHandler = NewScannerRunHandler(handlerContext) scannerRunHandler.CreateScannerRun(sre) diff --git a/internal/app/service/service_handler.go b/internal/app/service/service_handler.go index cc6cd8521..495650fc8 100644 --- a/internal/app/service/service_handler.go +++ b/internal/app/service/service_handler.go @@ -179,7 +179,7 @@ func (s *serviceHandler) ListServices(ctx context.Context, s.cache, CacheTtlGetServicesWithAggregations, "GetServicesWithAggregations", - s.database.GetServicesWithAggregations, + cache.WrapContext2(ctx, s.database.GetServicesWithAggregations), filter, options.Order, ) @@ -196,7 +196,7 @@ func (s *serviceHandler) ListServices(ctx context.Context, s.cache, CacheTtlGetServices, "GetServices", - s.database.GetServices, + cache.WrapContext2(ctx, s.database.GetServices), filter, options.Order, ) @@ -216,7 +216,7 @@ func (s *serviceHandler) ListServices(ctx context.Context, s.cache, CacheTtlGetAllSericeCursors, "GetAllServiceCursors", - s.database.GetAllServiceCursors, + cache.WrapContext2(ctx, s.database.GetAllServiceCursors), filter, options.Order, ) @@ -237,7 +237,7 @@ func (s *serviceHandler) ListServices(ctx context.Context, s.cache, CacheTtlCountServices, "CountServices", - s.database.CountServices, + cache.WrapContext1(ctx, s.database.CountServices), filter, ) if err != nil { @@ -457,6 +457,7 @@ func (s *serviceHandler) RemoveIssueRepositoryFromService( } func (s *serviceHandler) ListServiceCcrns( + ctx context.Context, filter *entity.ServiceFilter, options *entity.ListOptions, ) ([]string, error) { @@ -469,7 +470,7 @@ func (s *serviceHandler) ListServiceCcrns( s.cache, CacheTtlGetServiceAttrs, "GetServiceCcrns", - s.database.GetServiceCcrns, + cache.WrapContext1(ctx, s.database.GetServiceCcrns), filter, ) if err != nil { @@ -485,6 +486,7 @@ func (s *serviceHandler) ListServiceCcrns( } func (s *serviceHandler) ListServiceDomains( + ctx context.Context, filter *entity.ServiceFilter, options *entity.ListOptions, ) ([]string, error) { @@ -497,7 +499,7 @@ func (s *serviceHandler) ListServiceDomains( s.cache, CacheTtlGetServiceAttrs, "GetServiceDomains", - s.database.GetServiceDomains, + cache.WrapContext1(ctx, s.database.GetServiceDomains), filter, ) if err != nil { @@ -513,6 +515,7 @@ func (s *serviceHandler) ListServiceDomains( } func (s *serviceHandler) ListServiceRegions( + ctx context.Context, filter *entity.ServiceFilter, options *entity.ListOptions, ) ([]string, error) { @@ -521,8 +524,13 @@ func (s *serviceHandler) ListServiceRegions( "filter": filter, }) - serviceRegions, err := cache.CallCached[[]string](s.cache, CacheTtlGetServiceAttrs, - "GetServiceRegions", s.database.GetServiceRegions, filter) + serviceRegions, err := cache.CallCached[[]string]( + s.cache, + CacheTtlGetServiceAttrs, + "GetServiceRegions", + cache.WrapContext1(ctx, s.database.GetServiceRegions), + filter, + ) if err != nil { l.Error(err) return nil, NewServiceHandlerError("Internal error while retrieving serviceRegions.") diff --git a/internal/app/service/service_handler_events.go b/internal/app/service/service_handler_events.go index 48bf927cd..28aae0079 100644 --- a/internal/app/service/service_handler_events.go +++ b/internal/app/service/service_handler_events.go @@ -4,6 +4,8 @@ package service import ( + "context" + "github.com/cloudoperators/heureka/internal/app/event" "github.com/cloudoperators/heureka/internal/database" "github.com/cloudoperators/heureka/internal/entity" @@ -157,7 +159,7 @@ func OnServiceCreate(db database.Database, e event.Event, authz openfga.Authoriz serviceId := createEvent.Service.Id // Fetch IssueRepositories - issueRepositories, err := db.GetIssueRepositories(&entity.IssueRepositoryFilter{ + issueRepositories, err := db.GetIssueRepositories(context.Background(), &entity.IssueRepositoryFilter{ Name: []*string{&defaultRepoName}, }, []entity.Order{}) if err != nil { diff --git a/internal/app/service/service_handler_interface.go b/internal/app/service/service_handler_interface.go index 0249aca64..141cd723f 100644 --- a/internal/app/service/service_handler_interface.go +++ b/internal/app/service/service_handler_interface.go @@ -21,9 +21,9 @@ type ServiceHandler interface { DeleteService(ctx context.Context, id int64) error AddOwnerToService(ctx context.Context, serviceId, ownerId int64) (*entity.Service, error) RemoveOwnerFromService(ctx context.Context, serviceId, ownerId int64) (*entity.Service, error) - ListServiceCcrns(filter *entity.ServiceFilter, options *entity.ListOptions) ([]string, error) - ListServiceDomains(filter *entity.ServiceFilter, options *entity.ListOptions) ([]string, error) - ListServiceRegions(filter *entity.ServiceFilter, options *entity.ListOptions) ([]string, error) + ListServiceCcrns(ctx context.Context, filter *entity.ServiceFilter, options *entity.ListOptions) ([]string, error) + ListServiceDomains(ctx context.Context, filter *entity.ServiceFilter, options *entity.ListOptions) ([]string, error) + ListServiceRegions(ctx context.Context, filter *entity.ServiceFilter, options *entity.ListOptions) ([]string, error) AddIssueRepositoryToService(context.Context, int64, int64, int64) (*entity.Service, error) RemoveIssueRepositoryFromService(context.Context, int64, int64) (*entity.Service, error) } diff --git a/internal/app/service/service_handler_test.go b/internal/app/service/service_handler_test.go index d9fad36a7..fa7f50672 100644 --- a/internal/app/service/service_handler_test.go +++ b/internal/app/service/service_handler_test.go @@ -88,9 +88,9 @@ var _ = Describe("When listing Services", Label("app", "ListServices"), func() { When("the list option does include the totalCount", func() { BeforeEach(func() { options.ShowTotalCount = true - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetServices", filter, []entity.Order{}).Return([]entity.ServiceResult{}, nil) - db.On("CountServices", filter).Return(int64(1337), nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetServices", mock.Anything, filter, []entity.Order{}).Return([]entity.ServiceResult{}, nil) + db.On("CountServices", mock.Anything, filter).Return(int64(1337), nil) }) It("shows the total count in the results", func() { @@ -147,9 +147,9 @@ var _ = Describe("When listing Services", Label("app", "ListServices"), func() { ) cursors = append(cursors, c) } - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetServices", filter, []entity.Order{}).Return(services, nil) - db.On("GetAllServiceCursors", filter, []entity.Order{}).Return(cursors, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetServices", mock.Anything, filter, []entity.Order{}).Return(services, nil) + db.On("GetAllServiceCursors", mock.Anything, filter, []entity.Order{}).Return(cursors, nil) serviceHandler = s.NewServiceHandler(handlerContext) res, err := serviceHandler.ListServices(ctx, filter, options) Expect(err).To(BeNil(), "no error should be thrown") @@ -184,8 +184,8 @@ var _ = Describe("When listing Services", Label("app", "ListServices"), func() { }) Context("and the given filter does not have any matches in the database", func() { BeforeEach(func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetServicesWithAggregations", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetServicesWithAggregations", mock.Anything, filter, []entity.Order{}). Return([]entity.ServiceResult{}, nil) }) @@ -202,8 +202,8 @@ var _ = Describe("When listing Services", Label("app", "ListServices"), func() { for _, s := range test.NNewFakeServiceEntitiesWithAggregations(10) { services = append(services, entity.ServiceResult{Service: &s.Service}) } - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetServicesWithAggregations", filter, []entity.Order{}).Return(services, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetServicesWithAggregations", mock.Anything, filter, []entity.Order{}).Return(services, nil) }) It("should return the expected services in the result", func() { serviceHandler = s.NewServiceHandler(handlerContext) @@ -214,8 +214,8 @@ var _ = Describe("When listing Services", Label("app", "ListServices"), func() { }) Context("and the database operations throw an error", func() { BeforeEach(func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetServicesWithAggregations", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetServicesWithAggregations", mock.Anything, filter, []entity.Order{}). Return([]entity.ServiceResult{}, errors.New("some error")) }) @@ -236,8 +236,8 @@ var _ = Describe("When listing Services", Label("app", "ListServices"), func() { Context("and the given filter does not have any matches in the database", func() { BeforeEach(func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetServices", filter, []entity.Order{}).Return([]entity.ServiceResult{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetServices", mock.Anything, filter, []entity.Order{}).Return([]entity.ServiceResult{}, nil) }) It("should return an empty result", func() { serviceHandler = s.NewServiceHandler(handlerContext) @@ -252,8 +252,8 @@ var _ = Describe("When listing Services", Label("app", "ListServices"), func() { for _, s := range test.NNewFakeServiceEntitiesWithAggregations(15) { services = append(services, entity.ServiceResult{Service: &s.Service}) } - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetServices", filter, []entity.Order{}).Return(services, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetServices", mock.Anything, filter, []entity.Order{}).Return(services, nil) }) It("should return the expected services in the result", func() { serviceHandler = s.NewServiceHandler(handlerContext) @@ -265,8 +265,8 @@ var _ = Describe("When listing Services", Label("app", "ListServices"), func() { Context("and the database operations throw an error", func() { BeforeEach(func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetServices", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetServices", mock.Anything, filter, []entity.Order{}). Return([]entity.ServiceResult{}, errors.New("some error")) }) @@ -300,8 +300,8 @@ var _ = Describe("When listing Services", Label("app", "ListServices"), func() { BeforeEach(func() { sgIds := int64(-1) filter.SupportGroupId = []*int64{&sgIds} - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetServices", filter, []entity.Order{}).Return([]entity.ServiceResult{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetServices", mock.Anything, filter, []entity.Order{}).Return([]entity.ServiceResult{}, nil) }) It("should return no services", func() { @@ -323,8 +323,8 @@ var _ = Describe("When listing Services", Label("app", "ListServices"), func() { systemUserId := int64(1) filter.SupportGroupId = []*int64{&sgId} service = test.NewFakeServiceEntity() - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetServices", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetServices", mock.Anything, filter, []entity.Order{}). Return([]entity.ServiceResult{{Service: &service}}, nil) relations := []openfga.RelationInput{ @@ -407,9 +407,9 @@ var _ = Describe("When creating Service", Label("app", "CreateService"), func() It("creates service", func() { filter.CCRN = []*string{&service.CCRN} - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("CreateService", &service).Return(&service, nil) - db.On("GetServices", filter, []entity.Order{}).Return([]entity.ServiceResult{}, nil) + db.On("GetServices", mock.Anything, filter, []entity.Order{}).Return([]entity.ServiceResult{}, nil) serviceHandler = s.NewServiceHandler(handlerContext) newService, err := serviceHandler.CreateService(common.NewAdminContext(), &service) @@ -442,7 +442,7 @@ var _ = Describe("When creating Service", Label("app", "CreateService"), func() repo.Id = 456 repo.Name = defaultRepoName - db.On("GetIssueRepositories", &entity.IssueRepositoryFilter{ + db.On("GetIssueRepositories", mock.Anything, &entity.IssueRepositoryFilter{ Name: []*string{&defaultRepoName}, }, mock.Anything).Return([]entity.IssueRepositoryResult{ { @@ -492,7 +492,7 @@ var _ = Describe("When creating Service", Label("app", "CreateService"), func() var event event.Event = createEvent defaultRepoName := "nvd" - db.On("GetIssueRepositories", &entity.IssueRepositoryFilter{ + db.On("GetIssueRepositories", mock.Anything, &entity.IssueRepositoryFilter{ Name: []*string{&defaultRepoName}, }, mock.Anything).Return([]entity.IssueRepositoryResult{}, nil) @@ -576,12 +576,12 @@ var _ = Describe("When updating Service", Label("app", "UpdateService"), func() }) It("updates service", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("UpdateService", service.Service).Return(nil) serviceHandler = s.NewServiceHandler(handlerContext) service.CCRN = "SecretService" filter.Id = []*int64{&service.Id} - db.On("GetServices", filter, []entity.Order{}).Return([]entity.ServiceResult{service}, nil) + db.On("GetServices", mock.Anything, filter, []entity.Order{}).Return([]entity.ServiceResult{service}, nil) updatedService, err := serviceHandler.UpdateService( common.NewAdminContext(), service.Service, @@ -621,10 +621,10 @@ var _ = Describe("When deleting Service", Label("app", "DeleteService"), func() }) It("deletes service", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("DeleteService", id, mock.Anything).Return(nil) serviceHandler = s.NewServiceHandler(handlerContext) - db.On("GetServices", filter, []entity.Order{}).Return([]entity.ServiceResult{}, nil) + db.On("GetServices", mock.Anything, filter, []entity.Order{}).Return([]entity.ServiceResult{}, nil) err := serviceHandler.DeleteService(common.NewAdminContext(), id) Expect(err).To(BeNil(), "no error should be thrown") @@ -775,8 +775,8 @@ var _ = Describe("When modifying owner and Service", Label("app", "OwnerService" It("adds owner to service", func() { db.On("AddOwnerToService", service.Id, owner.Id).Return(nil) - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetServices", filter, []entity.Order{}).Return([]entity.ServiceResult{service}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetServices", mock.Anything, filter, []entity.Order{}).Return([]entity.ServiceResult{service}, nil) serviceHandler = s.NewServiceHandler(handlerContext) service, err := serviceHandler.AddOwnerToService(ctx, service.Id, owner.Id) Expect(err).To(BeNil(), "no error should be thrown") @@ -812,8 +812,8 @@ var _ = Describe("When modifying owner and Service", Label("app", "OwnerService" It("removes owner from service", func() { db.On("RemoveOwnerFromService", service.Id, owner.Id).Return(nil) - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetServices", filter, []entity.Order{}).Return([]entity.ServiceResult{service}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetServices", mock.Anything, filter, []entity.Order{}).Return([]entity.ServiceResult{service}, nil) serviceHandler = s.NewServiceHandler(handlerContext) service, err := serviceHandler.RemoveOwnerFromService(ctx, service.Id, owner.Id) Expect(err).To(BeNil(), "no error should be thrown") @@ -933,8 +933,8 @@ var _ = Describe( It("adds issueRepository to service", func() { db.On("AddIssueRepositoryToService", service.Id, issueRepository.Id, priority). Return(nil) - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetServices", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetServices", mock.Anything, filter, []entity.Order{}). Return([]entity.ServiceResult{service}, nil) serviceHandler = s.NewServiceHandler(handlerContext) service, err := serviceHandler.AddIssueRepositoryToService( @@ -949,8 +949,8 @@ var _ = Describe( It("removes issueRepository from service", func() { db.On("RemoveIssueRepositoryFromService", service.Id, issueRepository.Id).Return(nil) - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetServices", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetServices", mock.Anything, filter, []entity.Order{}). Return([]entity.ServiceResult{service}, nil) serviceHandler = s.NewServiceHandler(handlerContext) service, err := serviceHandler.RemoveIssueRepositoryFromService( @@ -986,12 +986,12 @@ var _ = Describe("When listing serviceCcrns", Label("app", "ListServicesCcrns"), When("no filters are used", func() { BeforeEach(func() { - db.On("GetServiceCcrns", filter).Return([]string{}, nil) + db.On("GetServiceCcrns", mock.Anything, filter).Return([]string{}, nil) }) It("it return the results", func() { serviceHandler = s.NewServiceHandler(handlerContext) - res, err := serviceHandler.ListServiceCcrns(filter, options) + res, err := serviceHandler.ListServiceCcrns(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(res).Should(BeEmpty(), "return correct result") }) @@ -1002,11 +1002,11 @@ var _ = Describe("When listing serviceCcrns", Label("app", "ListServicesCcrns"), CCRN: []*string{&name}, } - db.On("GetServiceCcrns", filter).Return([]string{name}, nil) + db.On("GetServiceCcrns", mock.Anything, filter).Return([]string{name}, nil) }) It("returns filtered services according to the service type", func() { serviceHandler = s.NewServiceHandler(handlerContext) - res, err := serviceHandler.ListServiceCcrns(filter, options) + res, err := serviceHandler.ListServiceCcrns(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(res).Should(ConsistOf(name), "should only consist of serviceCcrn") }) @@ -1035,12 +1035,12 @@ var _ = Describe("When listing serviceDomains", Label("app", "ListServicesDomain When("no filters are used", func() { BeforeEach(func() { - db.On("GetServiceDomains", filter).Return([]string{}, nil) + db.On("GetServiceDomains", mock.Anything, filter).Return([]string{}, nil) }) It("it return the results", func() { serviceHandler = s.NewServiceHandler(handlerContext) - res, err := serviceHandler.ListServiceDomains(filter, options) + res, err := serviceHandler.ListServiceDomains(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(res).Should(BeEmpty(), "return correct result") }) @@ -1051,11 +1051,11 @@ var _ = Describe("When listing serviceDomains", Label("app", "ListServicesDomain Domain: []*string{&domain}, } - db.On("GetServiceDomains", filter).Return([]string{domain}, nil) + db.On("GetServiceDomains", mock.Anything, filter).Return([]string{domain}, nil) }) It("returns filtered services according to the service type", func() { serviceHandler = s.NewServiceHandler(handlerContext) - res, err := serviceHandler.ListServiceDomains(filter, options) + res, err := serviceHandler.ListServiceDomains(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(res).Should(ConsistOf(domain), "should only consist of domain") }) @@ -1084,12 +1084,12 @@ var _ = Describe("When listing serviceRegions", Label("app", "ListServiceRegions When("no filters are used", func() { BeforeEach(func() { - db.On("GetServiceRegions", filter).Return([]string{}, nil) + db.On("GetServiceRegions", mock.Anything, filter).Return([]string{}, nil) }) It("it return the results", func() { serviceHandler = s.NewServiceHandler(handlerContext) - res, err := serviceHandler.ListServiceRegions(filter, options) + res, err := serviceHandler.ListServiceRegions(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(res).Should(BeEmpty(), "return correct result") }) @@ -1100,11 +1100,11 @@ var _ = Describe("When listing serviceRegions", Label("app", "ListServiceRegions Region: []*string{®ion}, } - db.On("GetServiceRegions", filter).Return([]string{region}, nil) + db.On("GetServiceRegions", mock.Anything, filter).Return([]string{region}, nil) }) It("returns filtered services according to the service type", func() { serviceHandler = s.NewServiceHandler(handlerContext) - res, err := serviceHandler.ListServiceRegions(filter, options) + res, err := serviceHandler.ListServiceRegions(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(res).Should(ConsistOf(region), "should only consist of region") }) diff --git a/internal/app/severity/severity_handler.go b/internal/app/severity/severity_handler.go index 78bca1270..5d7ceb7a1 100644 --- a/internal/app/severity/severity_handler.go +++ b/internal/app/severity/severity_handler.go @@ -4,6 +4,8 @@ package severity import ( + "context" + "github.com/cloudoperators/heureka/internal/app/common" "github.com/cloudoperators/heureka/internal/app/event" "github.com/cloudoperators/heureka/internal/app/issue_variant" @@ -42,7 +44,7 @@ func (e *SeverityHandlerError) Error() string { return e.message } -func (s *severityHandler) GetSeverity(filter *entity.SeverityFilter) (*entity.Severity, error) { +func (s *severityHandler) GetSeverity(ctx context.Context, filter *entity.SeverityFilter) (*entity.Severity, error) { l := logrus.WithFields(logrus.Fields{ "event": GetSeverityEventName, "filter": filter, @@ -56,6 +58,7 @@ func (s *severityHandler) GetSeverity(filter *entity.SeverityFilter) (*entity.Se opts := entity.ListOptions{} issueVariants, err := s.issueVariantHandler.ListEffectiveIssueVariants( + ctx, &issueVariantFilter, &opts, ) diff --git a/internal/app/severity/severity_handler_interface.go b/internal/app/severity/severity_handler_interface.go index fc1d27819..a919c7aad 100644 --- a/internal/app/severity/severity_handler_interface.go +++ b/internal/app/severity/severity_handler_interface.go @@ -3,8 +3,12 @@ package severity -import "github.com/cloudoperators/heureka/internal/entity" +import ( + "context" + + "github.com/cloudoperators/heureka/internal/entity" +) type SeverityHandler interface { - GetSeverity(*entity.SeverityFilter) (*entity.Severity, error) + GetSeverity(context.Context, *entity.SeverityFilter) (*entity.Severity, error) } diff --git a/internal/app/severity/severity_handler_test.go b/internal/app/severity/severity_handler_test.go index 9654fb82a..9059e92d4 100644 --- a/internal/app/severity/severity_handler_test.go +++ b/internal/app/severity/severity_handler_test.go @@ -6,6 +6,7 @@ package severity_test import ( + "context" "testing" "github.com/cloudoperators/heureka/internal/app/common" @@ -110,8 +111,8 @@ var _ = Describe("When get Severity", Label("app", "GetSeverity"), func() { }) } - db.On("GetIssueVariants", ivFilter, mock.Anything).Return(ivResults, nil) - db.On("GetIssueRepositories", irFilter, mock.Anything).Return(irResults, nil) + db.On("GetIssueVariants", mock.Anything, ivFilter, mock.Anything).Return(ivResults, nil) + db.On("GetIssueRepositories", mock.Anything, irFilter, mock.Anything).Return(irResults, nil) }) When("higher priority issue variant has highest severity score", func() { BeforeEach(func() { @@ -120,7 +121,7 @@ var _ = Describe("When get Severity", Label("app", "GetSeverity"), func() { }) It("returns severity value", func() { severityHandler = ss.NewSeverityHandler(handlerContext, ivs) - severity, err := severityHandler.GetSeverity(sFilter) + severity, err := severityHandler.GetSeverity(context.Background(), sFilter) Expect(err).To(BeNil(), "no error should be thrown") Expect(severity).ToNot((BeNil()), "severity should exist.") Expect( @@ -136,7 +137,7 @@ var _ = Describe("When get Severity", Label("app", "GetSeverity"), func() { }) It("returns severity value", func() { severityHandler = ss.NewSeverityHandler(handlerContext, ivs) - severity, err := severityHandler.GetSeverity(sFilter) + severity, err := severityHandler.GetSeverity(context.Background(), sFilter) Expect(err).To(BeNil(), "no error should be thrown") Expect(severity).ToNot((BeNil()), "severity should exist.") Expect( @@ -174,8 +175,8 @@ var _ = Describe("When get Severity", Label("app", "GetSeverity"), func() { }) } - db.On("GetIssueVariants", ivFilter, mock.Anything).Return(ivResults, nil) - db.On("GetIssueRepositories", irFilter, mock.Anything).Return(irResults, nil) + db.On("GetIssueVariants", mock.Anything, ivFilter, mock.Anything).Return(ivResults, nil) + db.On("GetIssueRepositories", mock.Anything, irFilter, mock.Anything).Return(irResults, nil) }) When("issueVariants have different severity score", func() { BeforeEach(func() { @@ -185,7 +186,7 @@ var _ = Describe("When get Severity", Label("app", "GetSeverity"), func() { }) It("return severity value", func() { severityHandler = ss.NewSeverityHandler(handlerContext, ivs) - severity, err := severityHandler.GetSeverity(sFilter) + severity, err := severityHandler.GetSeverity(context.Background(), sFilter) Expect(err).To(BeNil(), "no error should be thrown") Expect(severity).ToNot((BeNil()), "severity should exist.") Expect( diff --git a/internal/app/shared/issueVariant.go b/internal/app/shared/issueVariant.go index 4cd1e3a7d..fc55c81a5 100644 --- a/internal/app/shared/issueVariant.go +++ b/internal/app/shared/issueVariant.go @@ -4,6 +4,7 @@ package shared import ( + "context" "fmt" "github.com/cloudoperators/heureka/internal/database" @@ -20,6 +21,7 @@ import ( // // Returns a map of issue id to issue variant func BuildIssueVariantMap( + ctx context.Context, db database.Database, filter *entity.ServiceIssueVariantFilter, componentVersionId int64, @@ -32,7 +34,7 @@ func BuildIssueVariantMap( }) // Get Issue Variants based on filter - issueVariants, err := db.GetServiceIssueVariants(filter, []entity.Order{}) + issueVariants, err := db.GetServiceIssueVariants(ctx, filter, []entity.Order{}) if err != nil { l.WithField("event-step", "FetchIssueVariants"). WithError(err). diff --git a/internal/app/support_group/support_group_handler.go b/internal/app/support_group/support_group_handler.go index 4cdac0e6d..131916e4a 100644 --- a/internal/app/support_group/support_group_handler.go +++ b/internal/app/support_group/support_group_handler.go @@ -185,7 +185,7 @@ func (sg *supportGroupHandler) ListSupportGroups( // Update the filter.Id based on accessibleSupportGroupIds filter.Id = common.CombineFilterWithAccessibleIds(filter.Id, accessibleSupportGroupIds) - res, err := sg.database.GetSupportGroups(filter, options.Order) + res, err := sg.database.GetSupportGroups(ctx, filter, options.Order) if err != nil { wrappedErr := appErrors.InternalError(string(op), "SupportGroups", "", err) applog.LogError(sg.logger, wrappedErr, logrus.Fields{ @@ -197,7 +197,7 @@ func (sg *supportGroupHandler) ListSupportGroups( if options.ShowPageInfo { if len(res) > 0 { - cursors, err := sg.database.GetAllSupportGroupCursors(filter, options.Order) + cursors, err := sg.database.GetAllSupportGroupCursors(ctx, filter, options.Order) if err != nil { wrappedErr := appErrors.InternalError(string(op), "SupportGroups", "", err) applog.LogError(sg.logger, wrappedErr, logrus.Fields{ @@ -211,7 +211,7 @@ func (sg *supportGroupHandler) ListSupportGroups( count = int64(len(cursors)) } } else if options.ShowTotalCount { - count, err = sg.database.CountSupportGroups(filter) + count, err = sg.database.CountSupportGroups(ctx, filter) if err != nil { wrappedErr := appErrors.InternalError(string(op), "SupportGroups", "", err) applog.LogError(sg.logger, wrappedErr, logrus.Fields{ @@ -449,6 +449,7 @@ func (sg *supportGroupHandler) RemoveUserFromSupportGroup(ctx context.Context, } func (sg *supportGroupHandler) ListSupportGroupCcrns( + ctx context.Context, filter *entity.SupportGroupFilter, options *entity.ListOptions, ) ([]string, error) { @@ -457,7 +458,7 @@ func (sg *supportGroupHandler) ListSupportGroupCcrns( "filter": filter, }) - supportGroupCcrns, err := sg.database.GetSupportGroupCcrns(filter) + supportGroupCcrns, err := sg.database.GetSupportGroupCcrns(ctx, filter) if err != nil { l.Error(err) diff --git a/internal/app/support_group/support_group_handler_interface.go b/internal/app/support_group/support_group_handler_interface.go index 51f0141d2..dce5a1a21 100644 --- a/internal/app/support_group/support_group_handler_interface.go +++ b/internal/app/support_group/support_group_handler_interface.go @@ -23,5 +23,5 @@ type SupportGroupHandler interface { RemoveServiceFromSupportGroup(context.Context, int64, int64) (*entity.SupportGroup, error) AddUserToSupportGroup(context.Context, int64, int64) (*entity.SupportGroup, error) RemoveUserFromSupportGroup(context.Context, int64, int64) (*entity.SupportGroup, error) - ListSupportGroupCcrns(*entity.SupportGroupFilter, *entity.ListOptions) ([]string, error) + ListSupportGroupCcrns(context.Context, *entity.SupportGroupFilter, *entity.ListOptions) ([]string, error) } diff --git a/internal/app/support_group/support_group_handler_test.go b/internal/app/support_group/support_group_handler_test.go index 9d2120f62..89da55634 100644 --- a/internal/app/support_group/support_group_handler_test.go +++ b/internal/app/support_group/support_group_handler_test.go @@ -79,9 +79,9 @@ var _ = Describe("When listing SupportGroups", Label("app", "ListSupportGroups") When("the list option does include the totalCount", func() { BeforeEach(func() { options.ShowTotalCount = true - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetSupportGroups", filter, order).Return([]entity.SupportGroupResult{}, nil) - db.On("CountSupportGroups", filter).Return(int64(1337), nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetSupportGroups", mock.Anything, filter, order).Return([]entity.SupportGroupResult{}, nil) + db.On("CountSupportGroups", mock.Anything, filter).Return(int64(1337), nil) }) It("shows the total count in the results", func() { @@ -113,7 +113,7 @@ var _ = Describe("When listing SupportGroups", Label("app", "ListSupportGroups") return s.Value }) - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) var i int64 = 0 for len(cursors) < dbElements { i++ @@ -123,8 +123,8 @@ var _ = Describe("When listing SupportGroups", Label("app", "ListSupportGroups") ) cursors = append(cursors, c) } - db.On("GetSupportGroups", filter, order).Return(supportGroups, nil) - db.On("GetAllSupportGroupCursors", filter, order).Return(cursors, nil) + db.On("GetSupportGroups", mock.Anything, filter, order).Return(supportGroups, nil) + db.On("GetAllSupportGroupCursors", mock.Anything, filter, order).Return(cursors, nil) supportGroupHandler = sg.NewSupportGroupHandler(handlerContext) res, err := supportGroupHandler.ListSupportGroups(ctx, filter, options) Expect(err).To(BeNil(), "no error should be thrown") @@ -173,8 +173,8 @@ var _ = Describe("When listing SupportGroups", Label("app", "ListSupportGroups") BeforeEach(func() { sgIds := int64(-1) filter.Id = []*int64{&sgIds} - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetSupportGroups", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetSupportGroups", mock.Anything, filter, []entity.Order{}). Return([]entity.SupportGroupResult{}, nil) }) @@ -194,8 +194,8 @@ var _ = Describe("When listing SupportGroups", Label("app", "ListSupportGroups") systemUserId := int64(1) supportGroup = test.NewFakeSupportGroupEntity() filter.Id = []*int64{&supportGroup.Id} - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetSupportGroups", filter, []entity.Order{}). + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetSupportGroups", mock.Anything, filter, []entity.Order{}). Return([]entity.SupportGroupResult{{SupportGroup: &supportGroup}}, nil) relations := []openfga.RelationInput{ @@ -265,9 +265,9 @@ var _ = Describe("When creating SupportGroup", Label("app", "CreateSupportGroup" It("creates supportGroup", func() { filter.CCRN = []*string{&supportGroup.CCRN} - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("CreateSupportGroup", &supportGroup).Return(&supportGroup, nil) - db.On("GetSupportGroups", filter, order).Return([]entity.SupportGroupResult{}, nil) + db.On("GetSupportGroups", mock.Anything, filter, order).Return([]entity.SupportGroupResult{}, nil) supportGroupHandler = sg.NewSupportGroupHandler(handlerContext) newSupportGroup, err := supportGroupHandler.CreateSupportGroup( common.NewAdminContext(), @@ -356,12 +356,12 @@ var _ = Describe("When updating SupportGroup", Label("app", "UpdateSupportGroup" }) It("updates supportGroup", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("UpdateSupportGroup", supportGroup.SupportGroup).Return(nil) supportGroupHandler = sg.NewSupportGroupHandler(handlerContext) supportGroup.CCRN = "Team Alone" filter.Id = []*int64{&supportGroup.Id} - db.On("GetSupportGroups", filter, order). + db.On("GetSupportGroups", mock.Anything, filter, order). Return([]entity.SupportGroupResult{supportGroup}, nil) updatedSupportGroup, err := supportGroupHandler.UpdateSupportGroup( common.NewAdminContext(), @@ -408,10 +408,10 @@ var _ = Describe("When deleting SupportGroup", Label("app", "DeleteSupportGroup" }) It("deletes supportGroup", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("DeleteSupportGroup", id, mock.Anything).Return(nil) supportGroupHandler = sg.NewSupportGroupHandler(handlerContext) - db.On("GetSupportGroups", filter, order).Return([]entity.SupportGroupResult{}, nil) + db.On("GetSupportGroups", mock.Anything, filter, order).Return([]entity.SupportGroupResult{}, nil) err := supportGroupHandler.DeleteSupportGroup(common.NewAdminContext(), id) Expect(err).To(BeNil(), "no error should be thrown") @@ -565,9 +565,9 @@ var _ = Describe( }) It("adds service to supportGroup", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("AddServiceToSupportGroup", supportGroup.Id, service.Id).Return(nil) - db.On("GetSupportGroups", filter, order). + db.On("GetSupportGroups", mock.Anything, filter, order). Return([]entity.SupportGroupResult{supportGroup}, nil) supportGroupHandler = sg.NewSupportGroupHandler(handlerContext) supportGroup, err := supportGroupHandler.AddServiceToSupportGroup( @@ -580,9 +580,9 @@ var _ = Describe( }) It("removes service from supportGroup", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("RemoveServiceFromSupportGroup", supportGroup.Id, service.Id).Return(nil) - db.On("GetSupportGroups", filter, order). + db.On("GetSupportGroups", mock.Anything, filter, order). Return([]entity.SupportGroupResult{supportGroup}, nil) supportGroupHandler = sg.NewSupportGroupHandler(handlerContext) supportGroup, err := supportGroupHandler.RemoveServiceFromSupportGroup( @@ -707,9 +707,9 @@ var _ = Describe( }) It("adds user to supportGroup", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("AddUserToSupportGroup", supportGroup.Id, user.Id).Return(nil) - db.On("GetSupportGroups", filter, order). + db.On("GetSupportGroups", mock.Anything, filter, order). Return([]entity.SupportGroupResult{supportGroup}, nil) supportGroupHandler = sg.NewSupportGroupHandler(handlerContext) supportGroup, err := supportGroupHandler.AddUserToSupportGroup( @@ -722,9 +722,9 @@ var _ = Describe( }) It("removes user from supportGroup", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("RemoveUserFromSupportGroup", supportGroup.Id, user.Id).Return(nil) - db.On("GetSupportGroups", filter, order). + db.On("GetSupportGroups", mock.Anything, filter, order). Return([]entity.SupportGroupResult{supportGroup}, nil) supportGroupHandler = sg.NewSupportGroupHandler(handlerContext) supportGroup, err := supportGroupHandler.RemoveUserFromSupportGroup( @@ -835,12 +835,12 @@ var _ = Describe("When listing supportGroupCcrns", Label("app", "ListSupportGrou When("no filters are used", func() { BeforeEach(func() { - db.On("GetSupportGroupCcrns", filter).Return([]string{}, nil) + db.On("GetSupportGroupCcrns", mock.Anything, filter).Return([]string{}, nil) }) It("it return the results", func() { supportGroupHandler = sg.NewSupportGroupHandler(handlerContext) - res, err := supportGroupHandler.ListSupportGroupCcrns(filter, options) + res, err := supportGroupHandler.ListSupportGroupCcrns(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(res).Should(BeEmpty(), "return correct result") }) @@ -851,11 +851,11 @@ var _ = Describe("When listing supportGroupCcrns", Label("app", "ListSupportGrou CCRN: []*string{&ccrn}, } - db.On("GetSupportGroupCcrns", filter).Return([]string{ccrn}, nil) + db.On("GetSupportGroupCcrns", mock.Anything, filter).Return([]string{ccrn}, nil) }) It("returns filtered userGroups according to the service type", func() { supportGroupHandler = sg.NewSupportGroupHandler(handlerContext) - res, err := supportGroupHandler.ListSupportGroupCcrns(filter, options) + res, err := supportGroupHandler.ListSupportGroupCcrns(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(res).Should(ConsistOf(ccrn), "should only consist of supportGroup") }) diff --git a/internal/app/user/user_handler.go b/internal/app/user/user_handler.go index e0d06badf..e97419f79 100644 --- a/internal/app/user/user_handler.go +++ b/internal/app/user/user_handler.go @@ -104,7 +104,7 @@ func (u *userHandler) ListUsers( u.cache, CacheTtlGetUsers, "GetUsers", - u.database.GetUsers, + cache.WrapContext1(ctx, u.database.GetUsers), filter, ) if err != nil { @@ -122,7 +122,7 @@ func (u *userHandler) ListUsers( u.cache, CacheTtlGetAllUserCursors, "GetAllUserCursors", - u.database.GetAllUserCursors, + cache.WrapContext2(ctx, u.database.GetAllUserCursors), filter, options.Order, ) @@ -139,7 +139,7 @@ func (u *userHandler) ListUsers( count = int64(len(cursors)) } } else if options.ShowTotalCount { - count, err = u.database.CountUsers(filter) + count, err = u.database.CountUsers(ctx, filter) if err != nil { wrappedErr := appErrors.InternalError(string(op), "Users", "", err) applog.LogError(u.logger, wrappedErr, logrus.Fields{ @@ -268,6 +268,7 @@ func (u *userHandler) DeleteUser(ctx context.Context, id int64) error { } func (u *userHandler) ListUserNames( + ctx context.Context, filter *entity.UserFilter, options *entity.ListOptions, ) ([]string, error) { @@ -276,7 +277,7 @@ func (u *userHandler) ListUserNames( "filter": filter, }) - userNames, err := u.database.GetUserNames(filter) + userNames, err := u.database.GetUserNames(ctx, filter) if err != nil { l.Error(err) return nil, NewUserHandlerError("Internal error while retrieving userNames.") @@ -290,6 +291,7 @@ func (u *userHandler) ListUserNames( } func (u *userHandler) ListUniqueUserIDs( + ctx context.Context, filter *entity.UserFilter, options *entity.ListOptions, ) ([]string, error) { @@ -298,7 +300,7 @@ func (u *userHandler) ListUniqueUserIDs( "filter": filter, }) - uniqueUserID, err := u.database.GetUniqueUserIDs(filter) + uniqueUserID, err := u.database.GetUniqueUserIDs(ctx, filter) if err != nil { l.Error(err) return nil, NewUserHandlerError("Internal error while retrieving uniqueUserID.") @@ -312,6 +314,7 @@ func (u *userHandler) ListUniqueUserIDs( } func (u *userHandler) ListUserNamesAndIds( + ctx context.Context, filter *entity.UserFilter, options *entity.ListOptions, ) ([]string, []string, error) { @@ -320,7 +323,7 @@ func (u *userHandler) ListUserNamesAndIds( "filter": filter, }) - users, err := u.database.GetUsers(filter) + users, err := u.database.GetUsers(ctx, filter) if err != nil { l.Error(err) return nil, nil, NewUserHandlerError("Internal error while retrieving user.") diff --git a/internal/app/user/user_handler_interface.go b/internal/app/user/user_handler_interface.go index 5665a24bd..194fae05e 100644 --- a/internal/app/user/user_handler_interface.go +++ b/internal/app/user/user_handler_interface.go @@ -18,7 +18,7 @@ type UserHandler interface { CreateUser(context.Context, *entity.User) (*entity.User, error) UpdateUser(context.Context, *entity.User) (*entity.User, error) DeleteUser(context.Context, int64) error - ListUserNames(*entity.UserFilter, *entity.ListOptions) ([]string, error) - ListUniqueUserIDs(*entity.UserFilter, *entity.ListOptions) ([]string, error) - ListUserNamesAndIds(*entity.UserFilter, *entity.ListOptions) ([]string, []string, error) + ListUserNames(context.Context, *entity.UserFilter, *entity.ListOptions) ([]string, error) + ListUniqueUserIDs(context.Context, *entity.UserFilter, *entity.ListOptions) ([]string, error) + ListUserNamesAndIds(context.Context, *entity.UserFilter, *entity.ListOptions) ([]string, []string, error) } diff --git a/internal/app/user/user_handler_test.go b/internal/app/user/user_handler_test.go index a518bf8d6..f5c7dfe02 100644 --- a/internal/app/user/user_handler_test.go +++ b/internal/app/user/user_handler_test.go @@ -92,9 +92,9 @@ var _ = Describe("When listing Users", Label("app", "ListUsers"), func() { When("the list option does include the totalCount", func() { BeforeEach(func() { options.ShowTotalCount = true - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetUsers", filter).Return([]entity.UserResult{}, nil) - db.On("CountUsers", filter).Return(int64(1337), nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetUsers", mock.Anything, filter).Return([]entity.UserResult{}, nil) + db.On("CountUsers", mock.Anything, filter).Return(int64(1337), nil) }) It("shows the total count in the results", func() { @@ -140,9 +140,9 @@ var _ = Describe("When listing Users", Label("app", "ListUsers"), func() { cursors = append(cursors, c) } - db.On("GetUsers", filter).Return(users, nil) - db.On("GetAllUserCursors", filter, []entity.Order{}).Return(cursors, nil) - db.On("GetAllUserIds", authFilter).Return([]int64{}, nil) + db.On("GetUsers", mock.Anything, filter).Return(users, nil) + db.On("GetAllUserCursors", mock.Anything, filter, []entity.Order{}).Return(cursors, nil) + db.On("GetAllUserIds", mock.Anything, authFilter).Return([]int64{}, nil) // db.On("GetAllUserIds", filter).Return(lo.Map(users, func(m entity.UserResult, _ // int) int64 { return m.User.Id }), nil) @@ -194,8 +194,8 @@ var _ = Describe("When listing Users", Label("app", "ListUsers"), func() { BeforeEach(func() { sgIds := int64(-1) filter.SupportGroupId = []*int64{&sgIds} - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetUsers", filter).Return([]entity.UserResult{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetUsers", mock.Anything, filter).Return([]entity.UserResult{}, nil) }) It("should return no users", func() { @@ -214,8 +214,8 @@ var _ = Describe("When listing Users", Label("app", "ListUsers"), func() { systemUserId := int64(1) filter.SupportGroupId = []*int64{&sgId} user = test.NewFakeUserEntity() - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) - db.On("GetUsers", filter).Return([]entity.UserResult{{User: &user}}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) + db.On("GetUsers", mock.Anything, filter).Return([]entity.UserResult{{User: &user}}, nil) relations := []openfga.RelationInput{ { // create support group @@ -281,9 +281,9 @@ var _ = Describe("When creating User", Label("app", "CreateUser"), func() { It("creates user", func() { filter.UniqueUserID = []*string{&user.UniqueUserID} - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("CreateUser", &user).Return(&user, nil) - db.On("GetUsers", filter).Return([]entity.UserResult{}, nil) + db.On("GetUsers", mock.Anything, filter).Return([]entity.UserResult{}, nil) userHandler = u.NewUserHandler(handlerContext) newUser, err := userHandler.CreateUser(common.NewAdminContext(), &user) Expect(err).To(BeNil(), "no error should be thrown") @@ -323,12 +323,12 @@ var _ = Describe("When updating User", Label("app", "UpdateUser"), func() { }) It("updates user", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("UpdateUser", &user).Return(nil) userHandler = u.NewUserHandler(handlerContext) user.Name = "Sauron" filter.Id = []*int64{&user.Id} - db.On("GetUsers", filter).Return([]entity.UserResult{ + db.On("GetUsers", mock.Anything, filter).Return([]entity.UserResult{ { User: &user, }, @@ -373,10 +373,10 @@ var _ = Describe("When deleting User", Label("app", "DeleteUser"), func() { }) It("deletes user", func() { - db.On("GetAllUserIds", mock.Anything).Return([]int64{}, nil) + db.On("GetAllUserIds", mock.Anything, mock.Anything).Return([]int64{}, nil) db.On("DeleteUser", id, mock.Anything).Return(nil) userHandler = u.NewUserHandler(handlerContext) - db.On("GetUsers", filter).Return([]entity.UserResult{}, nil) + db.On("GetUsers", mock.Anything, filter).Return([]entity.UserResult{}, nil) err := userHandler.DeleteUser(common.NewAdminContext(), id) Expect(err).To(BeNil(), "no error should be thrown") @@ -539,12 +539,12 @@ var _ = Describe("When listing User", Label("app", "ListUserNames"), func() { When("no filters are used", func() { BeforeEach(func() { - db.On("GetUserNames", filter).Return([]string{}, nil) + db.On("GetUserNames", mock.Anything, filter).Return([]string{}, nil) }) It("it return the results", func() { userHandler = u.NewUserHandler(handlerContext) - res, err := userHandler.ListUserNames(filter, options) + res, err := userHandler.ListUserNames(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(res).Should(BeEmpty(), "return correct result") }) @@ -555,11 +555,11 @@ var _ = Describe("When listing User", Label("app", "ListUserNames"), func() { Name: []*string{&name}, } - db.On("GetUserNames", filter).Return([]string{name}, nil) + db.On("GetUserNames", mock.Anything, filter).Return([]string{name}, nil) }) It("returns filtered users according to the service type", func() { userHandler = u.NewUserHandler(handlerContext) - res, err := userHandler.ListUserNames(filter, options) + res, err := userHandler.ListUserNames(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(res).Should(ConsistOf(name), "should only consist of name") }) @@ -590,12 +590,12 @@ var _ = Describe("When listing UniqueUserID", Label("app", "ListUniqueUserIDs"), When("no filters are used", func() { BeforeEach(func() { - db.On("GetUniqueUserIDs", filter).Return([]string{}, nil) + db.On("GetUniqueUserIDs", mock.Anything, filter).Return([]string{}, nil) }) It("it return the results", func() { userHandler = u.NewUserHandler(handlerContext) - res, err := userHandler.ListUniqueUserIDs(filter, options) + res, err := userHandler.ListUniqueUserIDs(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(res).Should(BeEmpty(), "return correct result") }) @@ -606,11 +606,11 @@ var _ = Describe("When listing UniqueUserID", Label("app", "ListUniqueUserIDs"), UniqueUserID: []*string{&uuid}, } - db.On("GetUniqueUserIDs", filter).Return([]string{uuid}, nil) + db.On("GetUniqueUserIDs", mock.Anything, filter).Return([]string{uuid}, nil) }) It("returns filtered users according to the service type", func() { userHandler = u.NewUserHandler(handlerContext) - res, err := userHandler.ListUniqueUserIDs(filter, options) + res, err := userHandler.ListUniqueUserIDs(context.Background(), filter, options) Expect(err).To(BeNil(), "no error should be thrown") Expect(res).Should(ConsistOf(uuid), "should only consist of UniqueUserID") }) diff --git a/internal/cache/cache.go b/internal/cache/cache.go index e98cf1129..8ab80615c 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2025 SAP SE or an SAP affiliate company and Greenhouse contributors +// SPDX-FileCopyrightText: 2026 SAP SE or an SAP affiliate company and Greenhouse contributors // SPDX-License-Identifier: Apache-2.0 package cache @@ -16,14 +16,17 @@ import ( type Cache interface { CacheKey(fnname string, fn any, args ...any) (string, error) Get(key string) (string, bool, error) + GetAllKeys() ([]string, error) Set(key string, value string, ttl time.Duration) error Invalidate(key string) error + InvalidateByMatch(keyMatcher func(decodedKey string) bool) error IncHit() IncMiss() IncShared() GetStat() Stat LaunchRefresh(fn func()) GetSingleflightWrapper() SingleflightWrapper + GetKeyHashType() KeyHashType } type SingleflightWrapper interface { diff --git a/internal/cache/context_wrapper_gen.go b/internal/cache/context_wrapper_gen.go new file mode 100644 index 000000000..69cb65580 --- /dev/null +++ b/internal/cache/context_wrapper_gen.go @@ -0,0 +1,97 @@ +// SPDX-FileCopyrightText: 2026 SAP SE or an SAP affiliate company and Greenhouse contributors +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by go generate; DO NOT EDIT. +package cache + +import "context" + +func WrapContext1[A1 any, R any]( + ctx context.Context, + f func(context.Context, A1) (R, error), +) func(A1) (R, error) { + return func(a1 A1) (R, error) { + return f(ctx, a1) + } +} + +func WrapContext2[A1 any, A2 any, R any]( + ctx context.Context, + f func(context.Context, A1, A2) (R, error), +) func(A1, A2) (R, error) { + return func(a1 A1, a2 A2) (R, error) { + return f(ctx, a1, a2) + } +} + +func WrapContext3[A1 any, A2 any, A3 any, R any]( + ctx context.Context, + f func(context.Context, A1, A2, A3) (R, error), +) func(A1, A2, A3) (R, error) { + return func(a1 A1, a2 A2, a3 A3) (R, error) { + return f(ctx, a1, a2, a3) + } +} + +func WrapContext4[A1 any, A2 any, A3 any, A4 any, R any]( + ctx context.Context, + f func(context.Context, A1, A2, A3, A4) (R, error), +) func(A1, A2, A3, A4) (R, error) { + return func(a1 A1, a2 A2, a3 A3, a4 A4) (R, error) { + return f(ctx, a1, a2, a3, a4) + } +} + +func WrapContext5[A1 any, A2 any, A3 any, A4 any, A5 any, R any]( + ctx context.Context, + f func(context.Context, A1, A2, A3, A4, A5) (R, error), +) func(A1, A2, A3, A4, A5) (R, error) { + return func(a1 A1, a2 A2, a3 A3, a4 A4, a5 A5) (R, error) { + return f(ctx, a1, a2, a3, a4, a5) + } +} + +func WrapContext6[A1 any, A2 any, A3 any, A4 any, A5 any, A6 any, R any]( + ctx context.Context, + f func(context.Context, A1, A2, A3, A4, A5, A6) (R, error), +) func(A1, A2, A3, A4, A5, A6) (R, error) { + return func(a1 A1, a2 A2, a3 A3, a4 A4, a5 A5, a6 A6) (R, error) { + return f(ctx, a1, a2, a3, a4, a5, a6) + } +} + +func WrapContext7[A1 any, A2 any, A3 any, A4 any, A5 any, A6 any, A7 any, R any]( + ctx context.Context, + f func(context.Context, A1, A2, A3, A4, A5, A6, A7) (R, error), +) func(A1, A2, A3, A4, A5, A6, A7) (R, error) { + return func(a1 A1, a2 A2, a3 A3, a4 A4, a5 A5, a6 A6, a7 A7) (R, error) { + return f(ctx, a1, a2, a3, a4, a5, a6, a7) + } +} + +func WrapContext8[A1 any, A2 any, A3 any, A4 any, A5 any, A6 any, A7 any, A8 any, R any]( + ctx context.Context, + f func(context.Context, A1, A2, A3, A4, A5, A6, A7, A8) (R, error), +) func(A1, A2, A3, A4, A5, A6, A7, A8) (R, error) { + return func(a1 A1, a2 A2, a3 A3, a4 A4, a5 A5, a6 A6, a7 A7, a8 A8) (R, error) { + return f(ctx, a1, a2, a3, a4, a5, a6, a7, a8) + } +} + +func WrapContext9[A1 any, A2 any, A3 any, A4 any, A5 any, A6 any, A7 any, A8 any, A9 any, R any]( + ctx context.Context, + f func(context.Context, A1, A2, A3, A4, A5, A6, A7, A8, A9) (R, error), +) func(A1, A2, A3, A4, A5, A6, A7, A8, A9) (R, error) { + return func(a1 A1, a2 A2, a3 A3, a4 A4, a5 A5, a6 A6, a7 A7, a8 A8, a9 A9) (R, error) { + return f(ctx, a1, a2, a3, a4, a5, a6, a7, a8, a9) + } +} + +func WrapContext10[A1 any, A2 any, A3 any, A4 any, A5 any, A6 any, A7 any, A8 any, A9 any, A10 any, R any]( + ctx context.Context, + f func(context.Context, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10) (R, error), +) func(A1, A2, A3, A4, A5, A6, A7, A8, A9, A10) (R, error) { + return func(a1 A1, a2 A2, a3 A3, a4 A4, a5 A5, a6 A6, a7 A7, a8 A8, a9 A9, a10 A10) (R, error) { + return f(ctx, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10) + } +} diff --git a/internal/cache/generate.go b/internal/cache/generate.go index 80db87329..c312a4280 100644 --- a/internal/cache/generate.go +++ b/internal/cache/generate.go @@ -4,3 +4,4 @@ package cache //go:generate go run ../../cmd/call_cached_check -dir=../ +//go:generate go run ../../cmd/context_wrapper -dir=../ diff --git a/internal/cache/in_memory_cache.go b/internal/cache/in_memory_cache.go index f96cf89e3..45bf56c61 100644 --- a/internal/cache/in_memory_cache.go +++ b/internal/cache/in_memory_cache.go @@ -61,6 +61,16 @@ func (imc *InMemoryCache) Get(key string) (string, bool, error) { return valStr, true, nil } +func (imc *InMemoryCache) GetAllKeys() ([]string, error) { + keys := make([]string, 0, imc.gc.ItemCount()) + + for key := range imc.gc.Items() { + keys = append(keys, key) + } + + return keys, nil +} + func (imc *InMemoryCache) Set(key string, value string, ttl time.Duration) error { if ttl <= 0 { ttl = gocache.NoExpiration @@ -75,3 +85,29 @@ func (imc *InMemoryCache) Invalidate(key string) error { imc.gc.Delete(key) return nil } + +func (imc *InMemoryCache) GetKeyHashType() KeyHashType { + return imc.keyHash +} + +func (imc *InMemoryCache) InvalidateByMatch(keyMatcher func(decodedKey string) bool) error { + keys, err := imc.GetAllKeys() + if err != nil { + return fmt.Errorf("cache: failed to get cache keys: %w", err) + } + + for _, key := range keys { + decodedKey, err := DecodeKey(key, imc.GetKeyHashType()) + if err != nil { + return fmt.Errorf("cache: failed to decode cached key: %w", err) + } + + if keyMatcher(decodedKey) { + if err := imc.Invalidate(key); err != nil { + return fmt.Errorf("cache: failed to invalidate key: [key: %s, error: %w]", decodedKey, err) + } + } + } + + return nil +} diff --git a/internal/cache/valkey_cache.go b/internal/cache/valkey_cache.go index 7ef45d7f7..c4e4afb41 100644 --- a/internal/cache/valkey_cache.go +++ b/internal/cache/valkey_cache.go @@ -5,6 +5,7 @@ package cache import ( "context" + "fmt" "sync" "time" @@ -74,6 +75,16 @@ func (vc *ValkeyCache) Get(key string) (string, bool, error) { return val, true, nil } +func (vc *ValkeyCache) GetAllKeys() ([]string, error) { + res := vc.client.Do(vc.ctx, vc.client.B().Keys().Pattern("*").Build()) + + if err := res.Error(); err != nil { + return nil, err + } + + return res.AsStrSlice() +} + // ttl = 0 <- infinite func (vc *ValkeyCache) Set(key string, value string, ttl time.Duration) error { return vc.client.Do(vc.ctx, vc.client.B().Set().Key(key).Value(value).Px(ttl).Build()).Error() @@ -86,3 +97,29 @@ func (vc *ValkeyCache) Invalidate(key string) error { func (vc *ValkeyCache) invalidateAll() error { return vc.client.Do(vc.ctx, vc.client.B().Flushall().Build()).Error() } + +func (vc *ValkeyCache) GetKeyHashType() KeyHashType { + return vc.keyHash +} + +func (vc *ValkeyCache) InvalidateByMatch(keyMatcher func(decodedKey string) bool) error { + keys, err := vc.GetAllKeys() + if err != nil { + return fmt.Errorf("cache: failed to get cache keys: %w", err) + } + + for _, key := range keys { + decodedKey, err := DecodeKey(key, vc.GetKeyHashType()) + if err != nil { + return fmt.Errorf("cache: failed to decode cached key: %w", err) + } + + if keyMatcher(decodedKey) { + if err := vc.Invalidate(key); err != nil { + return fmt.Errorf("cache: failed to invalidate key: [key: %s, error: %w]", decodedKey, err) + } + } + } + + return nil +} diff --git a/internal/database/interface.go b/internal/database/interface.go index c10f9d957..45d0262f8 100644 --- a/internal/database/interface.go +++ b/internal/database/interface.go @@ -3,63 +3,71 @@ package database -import "github.com/cloudoperators/heureka/internal/entity" +import ( + "context" + + "github.com/cloudoperators/heureka/internal/entity" +) type Database interface { - GetIssues(*entity.IssueFilter, []entity.Order) ([]entity.IssueResult, error) - GetIssuesWithAggregations(*entity.IssueFilter, []entity.Order) ([]entity.IssueResult, error) - CountIssues(*entity.IssueFilter) (int64, error) - CountIssueTypes(*entity.IssueFilter) (*entity.IssueTypeCounts, error) - CountIssueRatings(*entity.IssueFilter) (*entity.IssueSeverityCounts, error) - GetAllIssueCursors(*entity.IssueFilter, []entity.Order) ([]string, error) + GetIssues(context.Context, *entity.IssueFilter, []entity.Order) ([]entity.IssueResult, error) + GetIssuesWithAggregations(context.Context, *entity.IssueFilter, []entity.Order) ([]entity.IssueResult, error) + CountIssues(context.Context, *entity.IssueFilter) (int64, error) + CountIssueTypes(context.Context, *entity.IssueFilter) (*entity.IssueTypeCounts, error) + CountIssueRatings(context.Context, *entity.IssueFilter) (*entity.IssueSeverityCounts, error) + GetAllIssueCursors(context.Context, *entity.IssueFilter, []entity.Order) ([]string, error) CreateIssue(*entity.Issue) (*entity.Issue, error) UpdateIssue(*entity.Issue) error DeleteIssue(int64, int64) error AddComponentVersionToIssue(int64, int64) error RemoveComponentVersionFromIssue(int64, int64) error - GetIssueNames(*entity.IssueFilter) ([]string, error) + GetIssueNames(context.Context, *entity.IssueFilter) ([]string, error) GetServiceIssueVariants( + context.Context, *entity.ServiceIssueVariantFilter, []entity.Order, ) ([]entity.ServiceIssueVariantResult, error) GetIssueVariants( + context.Context, *entity.IssueVariantFilter, []entity.Order, ) ([]entity.IssueVariantResult, error) - GetAllIssueVariantCursors(*entity.IssueVariantFilter, []entity.Order) ([]string, error) - CountIssueVariants(*entity.IssueVariantFilter) (int64, error) + GetAllIssueVariantCursors(context.Context, *entity.IssueVariantFilter, []entity.Order) ([]string, error) + CountIssueVariants(context.Context, *entity.IssueVariantFilter) (int64, error) CreateIssueVariant(*entity.IssueVariant) (*entity.IssueVariant, error) UpdateIssueVariant(*entity.IssueVariant) error DeleteIssueVariant(int64, int64) error GetIssueRepositories( + context.Context, *entity.IssueRepositoryFilter, []entity.Order, ) ([]entity.IssueRepositoryResult, error) - GetAllIssueRepositoryCursors(*entity.IssueRepositoryFilter, []entity.Order) ([]string, error) - CountIssueRepositories(*entity.IssueRepositoryFilter) (int64, error) + GetAllIssueRepositoryCursors(context.Context, *entity.IssueRepositoryFilter, []entity.Order) ([]string, error) + CountIssueRepositories(context.Context, *entity.IssueRepositoryFilter) (int64, error) CreateIssueRepository(*entity.IssueRepository) (*entity.IssueRepository, error) UpdateIssueRepository(*entity.IssueRepository) error DeleteIssueRepository(int64, int64) error GetDefaultIssuePriority() int64 GetDefaultRepositoryName() string - GetIssueMatches(*entity.IssueMatchFilter, []entity.Order) ([]entity.IssueMatchResult, error) - GetAllIssueMatchIds(*entity.IssueMatchFilter) ([]int64, error) - GetAllIssueMatchCursors(*entity.IssueMatchFilter, []entity.Order) ([]string, error) - CountIssueMatches(filter *entity.IssueMatchFilter) (int64, error) + GetIssueMatches(context.Context, *entity.IssueMatchFilter, []entity.Order) ([]entity.IssueMatchResult, error) + GetAllIssueMatchIds(context.Context, *entity.IssueMatchFilter) ([]int64, error) + GetAllIssueMatchCursors(context.Context, *entity.IssueMatchFilter, []entity.Order) ([]string, error) + CountIssueMatches(ctx context.Context, filter *entity.IssueMatchFilter) (int64, error) CreateIssueMatch(*entity.IssueMatch) (*entity.IssueMatch, error) UpdateIssueMatch(*entity.IssueMatch) error DeleteIssueMatch(int64, int64) error - GetServices(*entity.ServiceFilter, []entity.Order) ([]entity.ServiceResult, error) + GetServices(context.Context, *entity.ServiceFilter, []entity.Order) ([]entity.ServiceResult, error) GetServicesWithAggregations( + context.Context, *entity.ServiceFilter, []entity.Order, ) ([]entity.ServiceResult, error) - GetAllServiceCursors(*entity.ServiceFilter, []entity.Order) ([]string, error) - CountServices(*entity.ServiceFilter) (int64, error) + GetAllServiceCursors(context.Context, *entity.ServiceFilter, []entity.Order) ([]string, error) + CountServices(context.Context, *entity.ServiceFilter) (int64, error) CreateService(*entity.Service) (*entity.Service, error) UpdateService(*entity.Service) error DeleteService(int64, int64) error @@ -67,26 +75,23 @@ type Database interface { RemoveOwnerFromService(int64, int64) error AddIssueRepositoryToService(int64, int64, int64) error RemoveIssueRepositoryFromService(int64, int64) error - GetServiceCcrns(*entity.ServiceFilter) ([]string, error) - GetServiceDomains(*entity.ServiceFilter) ([]string, error) - GetServiceRegions(*entity.ServiceFilter) ([]string, error) - - GetUsers(*entity.UserFilter) ([]entity.UserResult, error) - GetAllUserIds(*entity.UserFilter) ([]int64, error) - GetAllUserCursors(*entity.UserFilter, []entity.Order) ([]string, error) - CountUsers(*entity.UserFilter) (int64, error) + GetServiceCcrns(context.Context, *entity.ServiceFilter) ([]string, error) + GetServiceDomains(context.Context, *entity.ServiceFilter) ([]string, error) + GetServiceRegions(context.Context, *entity.ServiceFilter) ([]string, error) + + GetUsers(context.Context, *entity.UserFilter) ([]entity.UserResult, error) + GetAllUserIds(context.Context, *entity.UserFilter) ([]int64, error) + GetAllUserCursors(context.Context, *entity.UserFilter, []entity.Order) ([]string, error) + CountUsers(context.Context, *entity.UserFilter) (int64, error) CreateUser(*entity.User) (*entity.User, error) UpdateUser(*entity.User) error DeleteUser(int64, int64) error - GetUserNames(*entity.UserFilter) ([]string, error) - GetUniqueUserIDs(*entity.UserFilter) ([]string, error) + GetUserNames(context.Context, *entity.UserFilter) ([]string, error) + GetUniqueUserIDs(context.Context, *entity.UserFilter) ([]string, error) - GetSupportGroups( - *entity.SupportGroupFilter, - []entity.Order, - ) ([]entity.SupportGroupResult, error) - GetAllSupportGroupCursors(*entity.SupportGroupFilter, []entity.Order) ([]string, error) - CountSupportGroups(*entity.SupportGroupFilter) (int64, error) + GetSupportGroups(context.Context, *entity.SupportGroupFilter, []entity.Order) ([]entity.SupportGroupResult, error) + GetAllSupportGroupCursors(context.Context, *entity.SupportGroupFilter, []entity.Order) ([]string, error) + CountSupportGroups(context.Context, *entity.SupportGroupFilter) (int64, error) CreateSupportGroup(*entity.SupportGroup) (*entity.SupportGroup, error) UpdateSupportGroup(*entity.SupportGroup) error DeleteSupportGroup(int64, int64) error @@ -94,47 +99,45 @@ type Database interface { RemoveServiceFromSupportGroup(int64, int64) error AddUserToSupportGroup(int64, int64) error RemoveUserFromSupportGroup(int64, int64) error - GetSupportGroupCcrns(*entity.SupportGroupFilter) ([]string, error) + GetSupportGroupCcrns(context.Context, *entity.SupportGroupFilter) ([]string, error) - GetComponentInstances( - *entity.ComponentInstanceFilter, - []entity.Order, - ) ([]entity.ComponentInstanceResult, error) + GetComponentInstances(context.Context, *entity.ComponentInstanceFilter, []entity.Order) ([]entity.ComponentInstanceResult, error) GetAllComponentInstanceCursors( + context.Context, *entity.ComponentInstanceFilter, []entity.Order, ) ([]string, error) - CountComponentInstances(*entity.ComponentInstanceFilter) (int64, error) + CountComponentInstances(context.Context, *entity.ComponentInstanceFilter) (int64, error) CreateComponentInstance(*entity.ComponentInstance) (*entity.ComponentInstance, error) UpdateComponentInstance(*entity.ComponentInstance) error DeleteComponentInstance(int64, int64) error - GetComponentCcrns(filter *entity.ComponentFilter) ([]string, error) - GetCcrn(filter *entity.ComponentInstanceFilter) ([]string, error) - GetRegion(filter *entity.ComponentInstanceFilter) ([]string, error) - GetCluster(filter *entity.ComponentInstanceFilter) ([]string, error) - GetNamespace(filter *entity.ComponentInstanceFilter) ([]string, error) - GetDomain(filter *entity.ComponentInstanceFilter) ([]string, error) - GetProject(filter *entity.ComponentInstanceFilter) ([]string, error) - GetPod(filter *entity.ComponentInstanceFilter) ([]string, error) - GetContainer(filter *entity.ComponentInstanceFilter) ([]string, error) - GetType(filter *entity.ComponentInstanceFilter) ([]string, error) - GetContext(filter *entity.ComponentInstanceFilter) ([]string, error) - GetComponentInstanceParent(filter *entity.ComponentInstanceFilter) ([]string, error) - - GetComponents(*entity.ComponentFilter, []entity.Order) ([]entity.ComponentResult, error) - GetAllComponentCursors(*entity.ComponentFilter, []entity.Order) ([]string, error) - CountComponents(*entity.ComponentFilter) (int64, error) - CountComponentVulnerabilities(*entity.ComponentFilter) ([]entity.IssueSeverityCounts, error) + GetComponentCcrns(ctx context.Context, filter *entity.ComponentFilter) ([]string, error) + GetCcrn(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error) + GetRegion(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error) + GetCluster(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error) + GetNamespace(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error) + GetDomain(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error) + GetProject(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error) + GetPod(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error) + GetContainer(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error) + GetType(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error) + GetContext(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error) + GetComponentInstanceParent(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error) + + GetComponents(context.Context, *entity.ComponentFilter, []entity.Order) ([]entity.ComponentResult, error) + GetAllComponentCursors(context.Context, *entity.ComponentFilter, []entity.Order) ([]string, error) + CountComponents(context.Context, *entity.ComponentFilter) (int64, error) + CountComponentVulnerabilities(context.Context, *entity.ComponentFilter) ([]entity.IssueSeverityCounts, error) CreateComponent(*entity.Component) (*entity.Component, error) UpdateComponent(*entity.Component) error DeleteComponent(int64, int64) error - GetComponentVersions( + GetComponentVersions(context.Context, *entity.ComponentVersionFilter, []entity.Order, ) ([]entity.ComponentVersionResult, error) - GetAllComponentVersionCursors(*entity.ComponentVersionFilter, []entity.Order) ([]string, error) - CountComponentVersions(*entity.ComponentVersionFilter) (int64, error) + GetAllComponentVersionCursors(context.Context, *entity.ComponentVersionFilter, []entity.Order) ([]string, error) + CountComponentVersions(context.Context, *entity.ComponentVersionFilter) (int64, error) CreateComponentVersion(*entity.ComponentVersion) (*entity.ComponentVersion, error) UpdateComponentVersion(*entity.ComponentVersion) error DeleteComponentVersion(int64, int64) error @@ -147,22 +150,22 @@ type Database interface { GetScannerRunTags() ([]string, error) CountScannerRuns(*entity.ScannerRunFilter) (int, error) - GetRemediations(*entity.RemediationFilter, []entity.Order) ([]entity.RemediationResult, error) - GetAllRemediationCursors(*entity.RemediationFilter, []entity.Order) ([]string, error) - CountRemediations(*entity.RemediationFilter) (int64, error) + GetRemediations(context.Context, *entity.RemediationFilter, []entity.Order) ([]entity.RemediationResult, error) + GetAllRemediationCursors(context.Context, *entity.RemediationFilter, []entity.Order) ([]string, error) + CountRemediations(context.Context, *entity.RemediationFilter) (int64, error) CreateRemediation(*entity.Remediation) (*entity.Remediation, error) UpdateRemediation(*entity.Remediation) error DeleteRemediation(int64, int64) error - GetPatches(*entity.PatchFilter, []entity.Order) ([]entity.PatchResult, error) - GetAllPatchCursors(*entity.PatchFilter, []entity.Order) ([]string, error) - CountPatches(*entity.PatchFilter) (int64, error) + GetPatches(context.Context, *entity.PatchFilter, []entity.Order) ([]entity.PatchResult, error) + GetAllPatchCursors(context.Context, *entity.PatchFilter, []entity.Order) ([]string, error) + CountPatches(context.Context, *entity.PatchFilter) (int64, error) CloseConnection() error CreateScannerRunComponentInstanceTracker(componentInstanceId int64, scannerRunUUID string) error - Autopatch() (bool, error) + Autopatch(context.Context) (bool, error) WaitPostMigrations() error } diff --git a/internal/database/mariadb/autopatch.go b/internal/database/mariadb/autopatch.go index bafc1c95d..208b44743 100644 --- a/internal/database/mariadb/autopatch.go +++ b/internal/database/mariadb/autopatch.go @@ -4,6 +4,7 @@ package mariadb import ( + "context" "database/sql" "errors" @@ -13,17 +14,17 @@ import ( "github.com/sirupsen/logrus" ) -func (s *SqlDatabase) Autopatch() (bool, error) { - runs, err := s.fetchCompletedRunsWithNewestFirst() +func (s *SqlDatabase) Autopatch(ctx context.Context) (bool, error) { + runs, err := s.fetchCompletedRunsWithNewestFirst(ctx) if err != nil { return false, err } - return s.processAutopatchOnCompletedRuns(runs) + return s.processAutopatchOnCompletedRuns(ctx, runs) } -func (s *SqlDatabase) fetchCompletedRunsWithNewestFirst() (map[string][]int, error) { - rows, err := s.db.Query(` +func (s *SqlDatabase) fetchCompletedRunsWithNewestFirst(ctx context.Context) (map[string][]int, error) { + rows, err := s.db.QueryContext(ctx, ` SELECT scannerrun_tag, scannerrun_run_id FROM ScannerRun WHERE scannerrun_is_completed = TRUE AND scannerrun_deleted_at IS NULL @@ -62,7 +63,7 @@ func (s *SqlDatabase) fetchCompletedRunsWithNewestFirst() (map[string][]int, err return runs, nil } -func (s *SqlDatabase) processAutopatchOnCompletedRuns(runs map[string][]int) (bool, error) { +func (s *SqlDatabase) processAutopatchOnCompletedRuns(ctx context.Context, runs map[string][]int) (bool, error) { autopatched := false for _, tagRuns := range runs { @@ -71,7 +72,7 @@ func (s *SqlDatabase) processAutopatchOnCompletedRuns(runs map[string][]int) (bo continue } - patchedForTag, err := s.processAutopatchForSingleTag(tagRuns) + patchedForTag, err := s.processAutopatchForSingleTag(ctx, tagRuns) if err != nil { return false, err } @@ -84,7 +85,7 @@ func (s *SqlDatabase) processAutopatchOnCompletedRuns(runs map[string][]int) (bo return autopatched, nil } -func (s *SqlDatabase) processAutopatchForSingleTag(tagRuns []int) (bool, error) { +func (s *SqlDatabase) processAutopatchForSingleTag(ctx context.Context, tagRuns []int) (bool, error) { latest := tagRuns[0] secondLatest := tagRuns[1] @@ -109,12 +110,13 @@ func (s *SqlDatabase) processAutopatchForSingleTag(tagRuns []int) (bool, error) return false, err } - err = s.deleteIssueMatchesOfDisappearedInstances(disappearedInstances) + err = s.deleteIssueMatchesOfDisappearedInstances(ctx, disappearedInstances) if err != nil { return false, err } versionsOfDisappearedInstances, err := s.getVersionIdsOfDisappearedInstances( + ctx, disappearedInstances, ) if err != nil { @@ -132,18 +134,19 @@ func (s *SqlDatabase) processAutopatchForSingleTag(tagRuns []int) (bool, error) } componentsOfDisappearedInstances, err := s.getComponentIdsOfDisappearedInstances( + ctx, versionsOfDisappearedInstances, ) if err != nil { return false, err } - err = s.deleteVersionsOfDisappearedInstances(versionsOfDisappearedInstances) + err = s.deleteVersionsOfDisappearedInstances(ctx, versionsOfDisappearedInstances) if err != nil { return false, err } - err = s.deleteComponentsOfDisappearedInstances(componentsOfDisappearedInstances) + err = s.deleteComponentsOfDisappearedInstances(ctx, componentsOfDisappearedInstances) if err != nil { return false, err } @@ -204,11 +207,11 @@ func (s *SqlDatabase) insertPatches(patches map[patchInfo]struct{}) error { return nil } -func (s *SqlDatabase) deleteIssueMatchesOfDisappearedInstances(disappearedInstances []int) error { +func (s *SqlDatabase) deleteIssueMatchesOfDisappearedInstances(ctx context.Context, disappearedInstances []int) error { for _, di := range disappearedInstances { issueMatchFilter := entity.IssueMatchFilter{ComponentInstanceId: []*int64{new(int64(di))}} - issueMatchIds, err := s.GetAllIssueMatchIds(&issueMatchFilter) + issueMatchIds, err := s.GetAllIssueMatchIds(ctx, &issueMatchFilter) if err != nil { return err } @@ -224,6 +227,7 @@ func (s *SqlDatabase) deleteIssueMatchesOfDisappearedInstances(disappearedInstan } func (s *SqlDatabase) getVersionIdsOfDisappearedInstances( + ctx context.Context, disappearedInstances []int, ) (map[int64]struct{}, error) { idsDisappeared := lo.Map(disappearedInstances, func(v int, _ int) *int64 { @@ -233,7 +237,7 @@ func (s *SqlDatabase) getVersionIdsOfDisappearedInstances( cif := entity.ComponentInstanceFilter{Id: idsDisappeared} - res, err := s.GetComponentInstances(&cif, nil) + res, err := s.GetComponentInstances(ctx, &cif, nil) if err != nil { return nil, err } @@ -247,6 +251,7 @@ func (s *SqlDatabase) getVersionIdsOfDisappearedInstances( } func (s *SqlDatabase) getComponentIdsOfDisappearedInstances( + ctx context.Context, versions map[int64]struct{}, ) (map[int64]struct{}, error) { var versionsToFilter []*int64 @@ -256,7 +261,7 @@ func (s *SqlDatabase) getComponentIdsOfDisappearedInstances( cvf := entity.ComponentVersionFilter{Id: versionsToFilter} - res, err := s.GetComponentVersions(&cvf, nil) + res, err := s.GetComponentVersions(ctx, &cvf, nil) if err != nil { return nil, err } @@ -282,12 +287,13 @@ func (s *SqlDatabase) deleteVersionIssuesOfDisappearedInstances( } func (s *SqlDatabase) deleteVersionsOfDisappearedInstances( + ctx context.Context, versionIdsOfDisappearedInstances map[int64]struct{}, ) error { for vIdDi := range versionIdsOfDisappearedInstances { cif := entity.ComponentInstanceFilter{ComponentVersionId: []*int64{&vIdDi}} - res, err := s.GetComponentInstances(&cif, nil) + res, err := s.GetComponentInstances(ctx, &cif, nil) if err != nil { return err } @@ -303,12 +309,13 @@ func (s *SqlDatabase) deleteVersionsOfDisappearedInstances( } func (s *SqlDatabase) deleteComponentsOfDisappearedInstances( + ctx context.Context, componentIdsOfDisappearedInstances map[int64]struct{}, ) error { for cIdDi := range componentIdsOfDisappearedInstances { cvf := entity.ComponentVersionFilter{ComponentId: []*int64{&cIdDi}} - res, err := s.GetComponentVersions(&cvf, nil) + res, err := s.GetComponentVersions(ctx, &cvf, nil) if err != nil { return err } diff --git a/internal/database/mariadb/autopatch_test.go b/internal/database/mariadb/autopatch_test.go index 752397cfc..d1f75e497 100644 --- a/internal/database/mariadb/autopatch_test.go +++ b/internal/database/mariadb/autopatch_test.go @@ -4,6 +4,7 @@ package mariadb_test import ( + "context" "fmt" "strings" "time" @@ -484,7 +485,7 @@ var _ = Describe("Autopatch", Label("database", "Autopatch"), func() { apt.Run( tag, t, - func(db *mariadb.SqlDatabase) (bool, error) { return db.Autopatch() }, + func(db *mariadb.SqlDatabase) (bool, error) { return db.Autopatch(context.Background()) }, ) } }) diff --git a/internal/database/mariadb/component.go b/internal/database/mariadb/component.go index 3d21347e0..99b7918ca 100644 --- a/internal/database/mariadb/component.go +++ b/internal/database/mariadb/component.go @@ -4,6 +4,7 @@ package mariadb import ( + "context" "fmt" "github.com/cloudoperators/heureka/internal/entity" @@ -179,6 +180,7 @@ func (s *SqlDatabase) getComponentColumns(order []entity.Order) string { } func (s *SqlDatabase) buildComponentStatement( + ctx context.Context, baseQuery string, filter *entity.ComponentFilter, withCursor bool, @@ -206,7 +208,7 @@ func (s *SqlDatabase) buildComponentStatement( } // construct prepared statement and if where clause does exist add parameters - stmt, err := s.db.Preparex(query) + stmt, err := s.db.PreparexContext(ctx, query) if err != nil { msg := ERROR_MSG_PREPARED_STMT l.WithFields( @@ -225,6 +227,7 @@ func (s *SqlDatabase) buildComponentStatement( } func (s *SqlDatabase) GetAllComponentCursors( + ctx context.Context, filter *entity.ComponentFilter, order []entity.Order, ) ([]string, error) { @@ -243,7 +246,7 @@ func (s *SqlDatabase) GetAllComponentCursors( columns := s.getComponentColumns(order) baseQuery = fmt.Sprintf(baseQuery, columns, "%s", "%s", "%s") - stmt, filterParameters, err := s.buildComponentStatement(baseQuery, filter, false, order, l) + stmt, filterParameters, err := s.buildComponentStatement(ctx, baseQuery, filter, false, order, l) if err != nil { return nil, err } @@ -255,6 +258,7 @@ func (s *SqlDatabase) GetAllComponentCursors( }() rows, err := performListScan( + ctx, stmt, filterParameters, l, @@ -281,6 +285,7 @@ func (s *SqlDatabase) GetAllComponentCursors( } func (s *SqlDatabase) GetComponents( + ctx context.Context, filter *entity.ComponentFilter, order []entity.Order, ) ([]entity.ComponentResult, error) { @@ -299,7 +304,7 @@ func (s *SqlDatabase) GetComponents( columns := s.getComponentColumns(order) baseQuery = fmt.Sprintf(baseQuery, columns, "%s", "%s", "%s", "%s") - stmt, filterParameters, err := s.buildComponentStatement(baseQuery, filter, true, order, l) + stmt, filterParameters, err := s.buildComponentStatement(ctx, baseQuery, filter, true, order, l) if err != nil { return nil, err } @@ -311,6 +316,7 @@ func (s *SqlDatabase) GetComponents( }() return performListScan( + ctx, stmt, filterParameters, l, @@ -334,7 +340,7 @@ func (s *SqlDatabase) GetComponents( ) } -func (s *SqlDatabase) CountComponents(filter *entity.ComponentFilter) (int64, error) { +func (s *SqlDatabase) CountComponents(ctx context.Context, filter *entity.ComponentFilter) (int64, error) { l := logrus.WithFields(logrus.Fields{ "event": "database.CountComponents", }) @@ -347,6 +353,7 @@ func (s *SqlDatabase) CountComponents(filter *entity.ComponentFilter) (int64, er ` stmt, filterParameters, err := s.buildComponentStatement( + ctx, baseQuery, filter, false, @@ -363,10 +370,11 @@ func (s *SqlDatabase) CountComponents(filter *entity.ComponentFilter) (int64, er } }() - return performCountScan(stmt, filterParameters, l) + return performCountScan(ctx, stmt, filterParameters, l) } func (s *SqlDatabase) CountComponentVulnerabilities( + ctx context.Context, filter *entity.ComponentFilter, ) ([]entity.IssueSeverityCounts, error) { l := logrus.WithFields(logrus.Fields{ @@ -424,7 +432,7 @@ func (s *SqlDatabase) CountComponentVulnerabilities( query = fmt.Sprintf("%s %s", query, groupBy) - stmt, err := s.db.Preparex(query) + stmt, err := s.db.PreparexContext(ctx, query) if err != nil { msg := ERROR_MSG_PREPARED_STMT l.WithFields( @@ -444,6 +452,7 @@ func (s *SqlDatabase) CountComponentVulnerabilities( }() return performListScan( + ctx, stmt, filterParameters, l, @@ -465,7 +474,7 @@ func (s *SqlDatabase) DeleteComponent(id int64, userId int64) error { return componentObject.Delete(s.db, id, userId) } -func (s *SqlDatabase) GetComponentCcrns(filter *entity.ComponentFilter) ([]string, error) { +func (s *SqlDatabase) GetComponentCcrns(ctx context.Context, filter *entity.ComponentFilter) ([]string, error) { l := logrus.WithFields(logrus.Fields{ "filter": filter, "event": "database.GetComponentCcrns", @@ -488,7 +497,7 @@ func (s *SqlDatabase) GetComponentCcrns(filter *entity.ComponentFilter) ([]strin } // Builds full statement with possible joins and filters - stmt, filterParameters, err := s.buildComponentStatement(baseQuery, filter, false, order, l) + stmt, filterParameters, err := s.buildComponentStatement(ctx, baseQuery, filter, false, order, l) if err != nil { l.Error("Error preparing statement: ", err) return nil, err @@ -501,7 +510,7 @@ func (s *SqlDatabase) GetComponentCcrns(filter *entity.ComponentFilter) ([]strin }() // Execute the query - rows, err := stmt.Queryx(filterParameters...) + rows, err := stmt.QueryxContext(ctx, filterParameters...) if err != nil { l.Error("Error executing query: ", err) return nil, err diff --git a/internal/database/mariadb/component_instance.go b/internal/database/mariadb/component_instance.go index f9ee9f982..ccfb6f40c 100644 --- a/internal/database/mariadb/component_instance.go +++ b/internal/database/mariadb/component_instance.go @@ -4,6 +4,7 @@ package mariadb import ( + "context" "database/sql" "fmt" @@ -277,6 +278,7 @@ func ensureComponentInstanceFilter( } func (s *SqlDatabase) buildComponentInstanceStatement( + ctx context.Context, baseQuery string, filter *entity.ComponentInstanceFilter, withCursor bool, @@ -304,7 +306,7 @@ func (s *SqlDatabase) buildComponentInstanceStatement( } // construct prepared statement and if where clause does exist add parameters - stmt, err := s.db.Preparex(query) + stmt, err := s.db.PreparexContext(ctx, query) if err != nil { msg := ERROR_MSG_PREPARED_STMT l.WithFields( @@ -327,6 +329,7 @@ func (s *SqlDatabase) buildComponentInstanceStatement( } func (s *SqlDatabase) GetComponentInstances( + ctx context.Context, filter *entity.ComponentInstanceFilter, order []entity.Order, ) ([]entity.ComponentInstanceResult, error) { @@ -341,6 +344,7 @@ func (s *SqlDatabase) GetComponentInstances( ` stmt, filterParameters, err := s.buildComponentInstanceStatement( + ctx, baseQuery, filter, true, @@ -358,6 +362,7 @@ func (s *SqlDatabase) GetComponentInstances( }() results, err := performListScan( + ctx, stmt, filterParameters, l, @@ -384,6 +389,7 @@ func (s *SqlDatabase) GetComponentInstances( } func (s *SqlDatabase) GetAllComponentInstanceCursors( + ctx context.Context, filter *entity.ComponentInstanceFilter, order []entity.Order, ) ([]string, error) { @@ -399,6 +405,7 @@ func (s *SqlDatabase) GetAllComponentInstanceCursors( ` stmt, filterParameters, err := s.buildComponentInstanceStatement( + ctx, baseQuery, filter, false, @@ -410,6 +417,7 @@ func (s *SqlDatabase) GetAllComponentInstanceCursors( } rows, err := performListScan( + ctx, stmt, filterParameters, l, @@ -432,6 +440,7 @@ func (s *SqlDatabase) GetAllComponentInstanceCursors( } func (s *SqlDatabase) CountComponentInstances( + ctx context.Context, filter *entity.ComponentInstanceFilter, ) (int64, error) { l := logrus.WithFields(logrus.Fields{ @@ -447,6 +456,7 @@ func (s *SqlDatabase) CountComponentInstances( ` stmt, filterParameters, err := s.buildComponentInstanceStatement( + ctx, baseQuery, filter, false, @@ -463,7 +473,7 @@ func (s *SqlDatabase) CountComponentInstances( } }() - count, err := performCountScan(stmt, filterParameters, l) + count, err := performCountScan(ctx, stmt, filterParameters, l) if err != nil { return -1, fmt.Errorf("failed to count ComponentInstances: %w", err) } @@ -486,6 +496,7 @@ func (s *SqlDatabase) DeleteComponentInstance(id int64, userId int64) error { } func (s *SqlDatabase) getComponentInstanceAttr( + ctx context.Context, attrName string, filter *entity.ComponentInstanceFilter, ) ([]string, error) { @@ -512,6 +523,7 @@ func (s *SqlDatabase) getComponentInstanceAttr( // Builds full statement with possible joins and filters stmt, filterParameters, err := s.buildComponentInstanceStatement( + ctx, baseQuery, filter, false, @@ -533,7 +545,7 @@ func (s *SqlDatabase) getComponentInstanceAttr( }() // Execute the query - rows, err := stmt.Queryx(filterParameters...) + rows, err := stmt.QueryxContext(ctx, filterParameters...) if err != nil { return nil, fmt.Errorf( "failed to execute ComponentInstance attribute query for %s: %w", @@ -573,8 +585,8 @@ func (s *SqlDatabase) getComponentInstanceAttr( return attrVal, nil } -func (s *SqlDatabase) GetCcrn(filter *entity.ComponentInstanceFilter) ([]string, error) { - ccrns, err := s.getComponentInstanceAttr("ccrn", filter) +func (s *SqlDatabase) GetCcrn(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error) { + ccrns, err := s.getComponentInstanceAttr(ctx, "ccrn", filter) if err != nil { return nil, fmt.Errorf("failed to get ComponentInstance CCRNs: %w", err) } @@ -582,8 +594,8 @@ func (s *SqlDatabase) GetCcrn(filter *entity.ComponentInstanceFilter) ([]string, return ccrns, nil } -func (s *SqlDatabase) GetRegion(filter *entity.ComponentInstanceFilter) ([]string, error) { - regions, err := s.getComponentInstanceAttr("region", filter) +func (s *SqlDatabase) GetRegion(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error) { + regions, err := s.getComponentInstanceAttr(ctx, "region", filter) if err != nil { return nil, fmt.Errorf("failed to get ComponentInstance regions: %w", err) } @@ -591,8 +603,8 @@ func (s *SqlDatabase) GetRegion(filter *entity.ComponentInstanceFilter) ([]strin return regions, nil } -func (s *SqlDatabase) GetCluster(filter *entity.ComponentInstanceFilter) ([]string, error) { - clusters, err := s.getComponentInstanceAttr("cluster", filter) +func (s *SqlDatabase) GetCluster(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error) { + clusters, err := s.getComponentInstanceAttr(ctx, "cluster", filter) if err != nil { return nil, fmt.Errorf("failed to get ComponentInstance clusters: %w", err) } @@ -600,8 +612,8 @@ func (s *SqlDatabase) GetCluster(filter *entity.ComponentInstanceFilter) ([]stri return clusters, nil } -func (s *SqlDatabase) GetNamespace(filter *entity.ComponentInstanceFilter) ([]string, error) { - namespaces, err := s.getComponentInstanceAttr("namespace", filter) +func (s *SqlDatabase) GetNamespace(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error) { + namespaces, err := s.getComponentInstanceAttr(ctx, "namespace", filter) if err != nil { return nil, fmt.Errorf("failed to get ComponentInstance namespaces: %w", err) } @@ -609,8 +621,8 @@ func (s *SqlDatabase) GetNamespace(filter *entity.ComponentInstanceFilter) ([]st return namespaces, nil } -func (s *SqlDatabase) GetDomain(filter *entity.ComponentInstanceFilter) ([]string, error) { - domains, err := s.getComponentInstanceAttr("domain", filter) +func (s *SqlDatabase) GetDomain(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error) { + domains, err := s.getComponentInstanceAttr(ctx, "domain", filter) if err != nil { return nil, fmt.Errorf("failed to get ComponentInstance domains: %w", err) } @@ -618,8 +630,8 @@ func (s *SqlDatabase) GetDomain(filter *entity.ComponentInstanceFilter) ([]strin return domains, nil } -func (s *SqlDatabase) GetProject(filter *entity.ComponentInstanceFilter) ([]string, error) { - projects, err := s.getComponentInstanceAttr("project", filter) +func (s *SqlDatabase) GetProject(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error) { + projects, err := s.getComponentInstanceAttr(ctx, "project", filter) if err != nil { return nil, fmt.Errorf("failed to get ComponentInstance projects: %w", err) } @@ -627,8 +639,8 @@ func (s *SqlDatabase) GetProject(filter *entity.ComponentInstanceFilter) ([]stri return projects, nil } -func (s *SqlDatabase) GetPod(filter *entity.ComponentInstanceFilter) ([]string, error) { - pods, err := s.getComponentInstanceAttr("pod", filter) +func (s *SqlDatabase) GetPod(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error) { + pods, err := s.getComponentInstanceAttr(ctx, "pod", filter) if err != nil { return nil, fmt.Errorf("failed to get ComponentInstance pods: %w", err) } @@ -636,8 +648,8 @@ func (s *SqlDatabase) GetPod(filter *entity.ComponentInstanceFilter) ([]string, return pods, nil } -func (s *SqlDatabase) GetContainer(filter *entity.ComponentInstanceFilter) ([]string, error) { - containers, err := s.getComponentInstanceAttr("container", filter) +func (s *SqlDatabase) GetContainer(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error) { + containers, err := s.getComponentInstanceAttr(ctx, "container", filter) if err != nil { return nil, fmt.Errorf("failed to get ComponentInstance containers: %w", err) } @@ -645,8 +657,8 @@ func (s *SqlDatabase) GetContainer(filter *entity.ComponentInstanceFilter) ([]st return containers, nil } -func (s *SqlDatabase) GetType(filter *entity.ComponentInstanceFilter) ([]string, error) { - types, err := s.getComponentInstanceAttr("type", filter) +func (s *SqlDatabase) GetType(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error) { + types, err := s.getComponentInstanceAttr(ctx, "type", filter) if err != nil { return nil, fmt.Errorf("failed to get ComponentInstance types: %w", err) } @@ -655,9 +667,10 @@ func (s *SqlDatabase) GetType(filter *entity.ComponentInstanceFilter) ([]string, } func (s *SqlDatabase) GetComponentInstanceParent( + ctx context.Context, filter *entity.ComponentInstanceFilter, ) ([]string, error) { - parents, err := s.getComponentInstanceAttr("parent_id", filter) + parents, err := s.getComponentInstanceAttr(ctx, "parent_id", filter) if err != nil { return nil, fmt.Errorf("failed to get ComponentInstance parents: %w", err) } @@ -665,8 +678,8 @@ func (s *SqlDatabase) GetComponentInstanceParent( return parents, nil } -func (s *SqlDatabase) GetContext(filter *entity.ComponentInstanceFilter) ([]string, error) { - contexts, err := s.getComponentInstanceAttr("context", filter) +func (s *SqlDatabase) GetContext(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error) { + contexts, err := s.getComponentInstanceAttr(ctx, "context", filter) if err != nil { return nil, fmt.Errorf("failed to get ComponentInstance contexts: %w", err) } diff --git a/internal/database/mariadb/component_instance_test.go b/internal/database/mariadb/component_instance_test.go index b6b9515c4..adabfde0f 100644 --- a/internal/database/mariadb/component_instance_test.go +++ b/internal/database/mariadb/component_instance_test.go @@ -4,6 +4,7 @@ package mariadb_test import ( + "context" "math" "sort" @@ -35,7 +36,7 @@ var _ = Describe("ComponentInstance - ", Label("database", "ComponentInstance"), When("Getting ComponentInstances", Label("GetComponentInstance"), func() { Context("and the database is empty", func() { It("can perform the list query", func() { - res, err := db.GetComponentInstances(nil, nil) + res, err := db.GetComponentInstances(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -51,7 +52,7 @@ var _ = Describe("ComponentInstance - ", Label("database", "ComponentInstance"), }) Context("and using no filter", func() { It("can fetch the items correctly", func() { - res, err := db.GetComponentInstances(nil, nil) + res, err := db.GetComponentInstances(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -129,7 +130,7 @@ var _ = Describe("ComponentInstance - ", Label("database", "ComponentInstance"), Id: []*int64{&ci.Id.Int64}, } - entries, err := db.GetComponentInstances(filter, nil) + entries, err := db.GetComponentInstances(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -152,7 +153,7 @@ var _ = Describe("ComponentInstance - ", Label("database", "ComponentInstance"), IssueMatchId: []*int64{&rnd.Id.Int64}, } - entries, err := db.GetComponentInstances(filter, nil) + entries, err := db.GetComponentInstances(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -173,7 +174,7 @@ var _ = Describe("ComponentInstance - ", Label("database", "ComponentInstance"), ServiceId: []*int64{&cir.ServiceId.Int64}, } - entries, err := db.GetComponentInstances(filter, nil) + entries, err := db.GetComponentInstances(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -200,7 +201,7 @@ var _ = Describe("ComponentInstance - ", Label("database", "ComponentInstance"), ServiceCcrn: []*string{lo.ToPtr(service.CCRN.String)}, } - entries, err := db.GetComponentInstances(filter, nil) + entries, err := db.GetComponentInstances(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -223,7 +224,7 @@ var _ = Describe("ComponentInstance - ", Label("database", "ComponentInstance"), ComponentVersionId: []*int64{&cir.ComponentVersionId.Int64}, } - entries, err := db.GetComponentInstances(filter, nil) + entries, err := db.GetComponentInstances(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -255,7 +256,7 @@ var _ = Describe("ComponentInstance - ", Label("database", "ComponentInstance"), ComponentVersionVersion: []*string{&cvr.Version.String}, } - entries, err := db.GetComponentInstances(filter, nil) + entries, err := db.GetComponentInstances(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -279,7 +280,7 @@ var _ = Describe("ComponentInstance - ", Label("database", "ComponentInstance"), ) filter := &entity.ComponentInstanceFilter{IssueMatchId: ids} - entries, err := db.GetComponentInstances(filter, nil) + entries, err := db.GetComponentInstances(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -325,7 +326,7 @@ var _ = Describe("ComponentInstance - ", Label("database", "ComponentInstance"), When("Counting ComponentInstances", Label("CountComponentInstance"), func() { Context("and the database is empty", func() { It("returns a correct totalCount without an error", func() { - c, err := db.CountComponentInstances(nil) + c, err := db.CountComponentInstances(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -346,7 +347,7 @@ var _ = Describe("ComponentInstance - ", Label("database", "ComponentInstance"), }) Context("and using no filter", func() { It("can count", func() { - c, err := db.CountComponentInstances(nil) + c, err := db.CountComponentInstances(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -367,7 +368,7 @@ var _ = Describe("ComponentInstance - ", Label("database", "ComponentInstance"), }, IssueMatchId: nil, } - c, err := db.CountComponentInstances(filter) + c, err := db.CountComponentInstances(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -393,7 +394,7 @@ var _ = Describe("ComponentInstance - ", Label("database", "ComponentInstance"), }, IssueMatchId: ids, } - entries, err := db.CountComponentInstances(filter) + entries, err := db.CountComponentInstances(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -535,7 +536,7 @@ var _ = Describe("ComponentInstance - ", Label("database", "ComponentInstance"), Id: []*int64{&componentInstance.Id}, } - ci, err := db.GetComponentInstances(componentInstanceFilter, nil) + ci, err := db.GetComponentInstances(context.Background(), componentInstanceFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -579,7 +580,7 @@ var _ = Describe("ComponentInstance - ", Label("database", "ComponentInstance"), Id: []*int64{&componentInstance.Id}, } - ci, err := db.GetComponentInstances(componentInstanceFilter, nil) + ci, err := db.GetComponentInstances(context.Background(), componentInstanceFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -630,7 +631,7 @@ var _ = Describe("ComponentInstance - ", Label("database", "ComponentInstance"), Id: []*int64{&componentInstance.Id}, } - ci, err := db.GetComponentInstances(componentInstanceFilter, nil) + ci, err := db.GetComponentInstances(context.Background(), componentInstanceFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -655,7 +656,7 @@ var _ = Describe("ComponentInstance - ", Label("database", "ComponentInstance"), order []entity.Order, verifyFunc func(res []entity.ComponentInstanceResult), ) { - res, err := db.GetComponentInstances(nil, order) + res, err := db.GetComponentInstances(context.Background(), nil, order) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -1819,9 +1820,9 @@ var _ = Describe("ComponentInstance - ", Label("database", "ComponentInstance"), }) func canPerformComponentInstanceQuery[T any]( - getFunc func(filter *entity.ComponentInstanceFilter) ([]T, error), + getFunc func(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]T, error), ) { - res, err := getFunc(nil) + res, err := getFunc(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -1833,10 +1834,10 @@ func canPerformComponentInstanceQuery[T any]( } func canFetchComponentInstanceQueryItems( - getFunc func(filter *entity.ComponentInstanceFilter) ([]string, error), + getFunc func(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error), expectedItems []string, ) { - res, err := getFunc(nil) + res, err := getFunc(context.Background(), nil) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -1854,11 +1855,11 @@ func canFetchComponentInstanceQueryItems( } func issueComponentInstanceAttrFilterWithExpect( - getAttrFunc func(filter *entity.ComponentInstanceFilter) ([]string, error), + getAttrFunc func(ctx context.Context, filter *entity.ComponentInstanceFilter) ([]string, error), cifilter *entity.ComponentInstanceFilter, expectedAttrVal []string, ) { - res, err := getAttrFunc(cifilter) + res, err := getAttrFunc(context.Background(), cifilter) By("throwing no error", func() { Expect(err).Should(BeNil()) diff --git a/internal/database/mariadb/component_test.go b/internal/database/mariadb/component_test.go index 377aa6bc8..5c94915f8 100644 --- a/internal/database/mariadb/component_test.go +++ b/internal/database/mariadb/component_test.go @@ -4,6 +4,7 @@ package mariadb_test import ( + "context" "sort" "github.com/cloudoperators/heureka/internal/database/mariadb" @@ -36,7 +37,7 @@ var _ = Describe("Component", Label("database", "Component"), func() { // This is tricky because expectedComponents is not check against content, just length, so it should be expectedComponentsLen insted of expectedComponents // TODO: decide to use Len or at least check ids of expected components testGetComponents := func(filter *entity.ComponentFilter, order []entity.Order, expectedComponents []mariadb.ComponentRow, check func(entries []entity.ComponentResult)) { - res, err := db.GetComponents(filter, order) + res, err := db.GetComponents(context.Background(), filter, order) Expect(err).To(BeNil(), "GetComponents should not error") Expect( len(res), @@ -283,7 +284,7 @@ var _ = Describe("Component", Label("database", "Component"), func() { }) When("Counting Components", Label("CountComponents"), func() { testCountComponents := func(filter *entity.ComponentFilter, expectedCount int) { - c, err := db.CountComponents(filter) + c, err := db.CountComponents(context.Background(), filter) Expect(err).To(BeNil(), "CountComponents should not error") Expect( c, @@ -373,7 +374,7 @@ var _ = Describe("Component", Label("database", "Component"), func() { Id: []*int64{&component.Id}, } - c, err := db.GetComponents(componentFilter, []entity.Order{}) + c, err := db.GetComponents(context.Background(), componentFilter, []entity.Order{}) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -423,7 +424,7 @@ var _ = Describe("Component", Label("database", "Component"), func() { Id: []*int64{&component.Id}, } - c, err := db.GetComponents(componentFilter, []entity.Order{}) + c, err := db.GetComponents(context.Background(), componentFilter, []entity.Order{}) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -453,7 +454,7 @@ var _ = Describe("Component", Label("database", "Component"), func() { Id: []*int64{&component.Id}, } - c, err := db.GetComponents(componentFilter, []entity.Order{}) + c, err := db.GetComponents(context.Background(), componentFilter, []entity.Order{}) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -490,7 +491,7 @@ var _ = Describe("Component", Label("database", "Component"), func() { Id: []*int64{&component.Id}, } - c, err := db.GetComponents(componentFilter, []entity.Order{}) + c, err := db.GetComponents(context.Background(), componentFilter, []entity.Order{}) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -525,7 +526,7 @@ var _ = Describe("Ordering Components", Label("ComponentOrdering"), func() { order []entity.Order, verifyFunc func(res []entity.ComponentResult), ) { - res, err := db.GetComponents(nil, order) + res, err := db.GetComponents(context.Background(), nil, order) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -625,7 +626,7 @@ var _ = Describe("Ordering Components", Label("ComponentOrdering"), func() { {By: entity.LowCount, Direction: entity.OrderDirectionDesc}, {By: entity.NoneCount, Direction: entity.OrderDirectionDesc}, } - components, err := db.GetComponents(componentFilter, order) + components, err := db.GetComponents(context.Background(), componentFilter, order) Expect(err).To(BeNil()) Expect(components[0].Id).To(BeEquivalentTo(1)) Expect(components[1].Id).To(BeEquivalentTo(2)) @@ -641,7 +642,7 @@ var _ = Describe("Ordering Components", Label("ComponentOrdering"), func() { {By: entity.LowCount, Direction: entity.OrderDirectionAsc}, {By: entity.NoneCount, Direction: entity.OrderDirectionAsc}, } - components, err := db.GetComponents(componentFilter, order) + components, err := db.GetComponents(context.Background(), componentFilter, order) Expect(err).To(BeNil()) Expect(components[0].Id).To(BeEquivalentTo(5)) Expect(components[1].Id).To(BeEquivalentTo(4)) diff --git a/internal/database/mariadb/component_version.go b/internal/database/mariadb/component_version.go index c52b19212..4f6b77958 100644 --- a/internal/database/mariadb/component_version.go +++ b/internal/database/mariadb/component_version.go @@ -4,6 +4,7 @@ package mariadb import ( + "context" "fmt" "github.com/samber/lo" @@ -241,6 +242,7 @@ func (s *SqlDatabase) getComponentVersionColumns(order []entity.Order) string { } func (s *SqlDatabase) buildComponentVersionStatement( + ctx context.Context, baseQuery string, filter *entity.ComponentVersionFilter, withCursor bool, @@ -270,7 +272,7 @@ func (s *SqlDatabase) buildComponentVersionStatement( } // construct prepared statement and if where clause does exist add parameters - stmt, err := s.db.Preparex(query) + stmt, err := s.db.PreparexContext(ctx, query) if err != nil { msg := ERROR_MSG_PREPARED_STMT l.WithFields( @@ -289,6 +291,7 @@ func (s *SqlDatabase) buildComponentVersionStatement( } func (s *SqlDatabase) GetAllComponentVersionCursors( + ctx context.Context, filter *entity.ComponentVersionFilter, order []entity.Order, ) ([]string, error) { @@ -304,6 +307,7 @@ func (s *SqlDatabase) GetAllComponentVersionCursors( ` stmt, filterParameters, err := s.buildComponentVersionStatement( + ctx, baseQuery, filter, false, @@ -315,6 +319,7 @@ func (s *SqlDatabase) GetAllComponentVersionCursors( } rows, err := performListScan( + ctx, stmt, filterParameters, l, @@ -341,6 +346,7 @@ func (s *SqlDatabase) GetAllComponentVersionCursors( } func (s *SqlDatabase) GetComponentVersions( + ctx context.Context, filter *entity.ComponentVersionFilter, order []entity.Order, ) ([]entity.ComponentVersionResult, error) { @@ -358,6 +364,7 @@ func (s *SqlDatabase) GetComponentVersions( filter = ensureComponentVersionFilter(filter) stmt, filterParameters, err := s.buildComponentVersionStatement( + ctx, baseQuery, filter, true, @@ -375,6 +382,7 @@ func (s *SqlDatabase) GetComponentVersions( }() return performListScan( + ctx, stmt, filterParameters, l, @@ -400,7 +408,7 @@ func (s *SqlDatabase) GetComponentVersions( ) } -func (s *SqlDatabase) CountComponentVersions(filter *entity.ComponentVersionFilter) (int64, error) { +func (s *SqlDatabase) CountComponentVersions(ctx context.Context, filter *entity.ComponentVersionFilter) (int64, error) { l := logrus.WithFields(logrus.Fields{ "event": "database.CountComponentVersions", }) @@ -413,6 +421,7 @@ func (s *SqlDatabase) CountComponentVersions(filter *entity.ComponentVersionFilt ` stmt, filterParameters, err := s.buildComponentVersionStatement( + ctx, baseQuery, filter, false, @@ -429,7 +438,7 @@ func (s *SqlDatabase) CountComponentVersions(filter *entity.ComponentVersionFilt } }() - return performCountScan(stmt, filterParameters, l) + return performCountScan(ctx, stmt, filterParameters, l) } func (s *SqlDatabase) CreateComponentVersion( diff --git a/internal/database/mariadb/component_version_test.go b/internal/database/mariadb/component_version_test.go index a35c0f9f4..45490846e 100644 --- a/internal/database/mariadb/component_version_test.go +++ b/internal/database/mariadb/component_version_test.go @@ -4,6 +4,7 @@ package mariadb_test import ( + "context" "fmt" "math/rand" "sort" @@ -32,7 +33,7 @@ var _ = Describe("ComponentVersion", Label("database", "ComponentVersion"), func When("Getting ComponentVersions", Label("GetComponentVersions"), func() { Context("and the database is empty", func() { It("can perform the list query", func() { - res, err := db.GetComponentVersions(nil, nil) + res, err := db.GetComponentVersions(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -49,7 +50,7 @@ var _ = Describe("ComponentVersion", Label("database", "ComponentVersion"), func Context("and using no filter", func() { It("can fetch the items correctly", func() { - res, err := db.GetComponentVersions(nil, nil) + res, err := db.GetComponentVersions(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -100,7 +101,7 @@ var _ = Describe("ComponentVersion", Label("database", "ComponentVersion"), func Id: []*int64{&cv.Id.Int64}, } - entries, err := db.GetComponentVersions(filter, nil) + entries, err := db.GetComponentVersions(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -130,7 +131,7 @@ var _ = Describe("ComponentVersion", Label("database", "ComponentVersion"), func filter := &entity.ComponentVersionFilter{IssueId: []*int64{&issueRow.Id.Int64}} - entries, err := db.GetComponentVersions(filter, nil) + entries, err := db.GetComponentVersions(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -158,7 +159,7 @@ var _ = Describe("ComponentVersion", Label("database", "ComponentVersion"), func ComponentId: []*int64{&componentRow.Id.Int64}, } - entries, err := db.GetComponentVersions(filter, nil) + entries, err := db.GetComponentVersions(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -175,7 +176,7 @@ var _ = Describe("ComponentVersion", Label("database", "ComponentVersion"), func filter := &entity.ComponentVersionFilter{Version: []*string{&cv.Version.String}} - entries, err := db.GetComponentVersions(filter, nil) + entries, err := db.GetComponentVersions(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -193,7 +194,7 @@ var _ = Describe("ComponentVersion", Label("database", "ComponentVersion"), func filter := &entity.ComponentVersionFilter{IssueRepositoryId: []*int64{&iv.IssueRepositoryId.Int64}} - entries, err := db.GetComponentVersions(filter, nil) + entries, err := db.GetComponentVersions(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -219,7 +220,7 @@ var _ = Describe("ComponentVersion", Label("database", "ComponentVersion"), func ComponentCCRN: []*string{&componentCCRN}, } - entries, err := db.GetComponentVersions(filter, nil) + entries, err := db.GetComponentVersions(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -245,7 +246,7 @@ var _ = Describe("ComponentVersion", Label("database", "ComponentVersion"), func filter := &entity.ComponentVersionFilter{ServiceId: []*int64{&s.Id.Int64}} - entries, err := db.GetComponentVersions(filter, nil) + entries, err := db.GetComponentVersions(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -269,7 +270,7 @@ var _ = Describe("ComponentVersion", Label("database", "ComponentVersion"), func filter := &entity.ComponentVersionFilter{ServiceCCRN: []*string{&s.CCRN.String}} - entries, err := db.GetComponentVersions(filter, nil) + entries, err := db.GetComponentVersions(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -293,7 +294,7 @@ var _ = Describe("ComponentVersion", Label("database", "ComponentVersion"), func filter := &entity.ComponentVersionFilter{Tag: []*string{&tagToFilterBy}} // Execute the query - entries, err := db.GetComponentVersions(filter, nil) + entries, err := db.GetComponentVersions(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -326,7 +327,7 @@ var _ = Describe("ComponentVersion", Label("database", "ComponentVersion"), func filter := &entity.ComponentVersionFilter{ Repository: []*string{&cv.Repository.String}, } - entries, err := db.GetComponentVersions(filter, nil) + entries, err := db.GetComponentVersions(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -344,7 +345,7 @@ var _ = Describe("ComponentVersion", Label("database", "ComponentVersion"), func filter := &entity.ComponentVersionFilter{ Organization: []*string{&cv.Organization.String}, } - entries, err := db.GetComponentVersions(filter, nil) + entries, err := db.GetComponentVersions(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -364,7 +365,7 @@ var _ = Describe("ComponentVersion", Label("database", "ComponentVersion"), func EndOfLife: []*bool{&endOfLifeAsFalse}, } - entries, err := db.GetComponentVersions(filter, nil) + entries, err := db.GetComponentVersions(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -384,7 +385,7 @@ var _ = Describe("ComponentVersion", Label("database", "ComponentVersion"), func EndOfLife: []*bool{&endOfLifeAsTrue}, } - entries, err := db.GetComponentVersions(filter, nil) + entries, err := db.GetComponentVersions(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -434,7 +435,7 @@ var _ = Describe("ComponentVersion", Label("database", "ComponentVersion"), func When("Counting ComponentVersions", Label("CountComponentVersions"), func() { Context("and the database is empty", func() { It("can count correctly", func() { - c, err := db.CountComponentVersions(nil) + c, err := db.CountComponentVersions(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -455,7 +456,7 @@ var _ = Describe("ComponentVersion", Label("database", "ComponentVersion"), func }) Context("and using no filter", func() { It("can count", func() { - c, err := db.CountComponentVersions(nil) + c, err := db.CountComponentVersions(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -474,7 +475,7 @@ var _ = Describe("ComponentVersion", Label("database", "ComponentVersion"), func After: nil, }, } - c, err := db.CountComponentVersions(filter) + c, err := db.CountComponentVersions(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -520,7 +521,7 @@ var _ = Describe("ComponentVersion", Label("database", "ComponentVersion"), func Id: []*int64{&componentVersion.Id}, } - cv, err := db.GetComponentVersions(componentVersionFilter, nil) + cv, err := db.GetComponentVersions(context.Background(), componentVersionFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -555,7 +556,7 @@ var _ = Describe("ComponentVersion", Label("database", "ComponentVersion"), func Id: []*int64{&componentVersion.Id}, } - cv, err := db.GetComponentVersions(componentVersionFilter, nil) + cv, err := db.GetComponentVersions(context.Background(), componentVersionFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -582,7 +583,7 @@ var _ = Describe("ComponentVersion", Label("database", "ComponentVersion"), func Id: []*int64{&componentVersion.Id}, } - cv, err := db.GetComponentVersions(componentVersionFilter, nil) + cv, err := db.GetComponentVersions(context.Background(), componentVersionFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -621,7 +622,7 @@ var _ = Describe("ComponentVersion", Label("database", "ComponentVersion"), func // Retrieve all component versions and find our updated one manually // This avoids relying on the filter functionality - allVersions, err := db.GetComponentVersions(nil, nil) + allVersions, err := db.GetComponentVersions(context.Background(), nil, nil) By("throwing no error during retrieval", func() { Expect(err).To(BeNil()) @@ -684,7 +685,7 @@ var _ = Describe("Ordering ComponentVersions", func() { order []entity.Order, verifyFunc func(res []entity.ComponentVersionResult), ) { - res, err := db.GetComponentVersions(nil, order) + res, err := db.GetComponentVersions(context.Background(), nil, order) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -741,7 +742,7 @@ var _ = Describe("Ordering ComponentVersions", func() { {By: entity.LowCount, Direction: entity.OrderDirectionDesc}, {By: entity.NoneCount, Direction: entity.OrderDirectionDesc}, } - cvs, err := db.GetComponentVersions(nil, order) + cvs, err := db.GetComponentVersions(context.Background(), nil, order) Expect(err).To(BeNil()) Expect(cvs[0].Id).To(BeEquivalentTo(3)) Expect(cvs[1].Id).To(BeEquivalentTo(8)) @@ -762,7 +763,7 @@ var _ = Describe("Ordering ComponentVersions", func() { {By: entity.LowCount, Direction: entity.OrderDirectionAsc}, {By: entity.NoneCount, Direction: entity.OrderDirectionAsc}, } - cvs, err := db.GetComponentVersions(nil, order) + cvs, err := db.GetComponentVersions(context.Background(), nil, order) Expect(err).To(BeNil()) Expect(cvs[0].Id).To(BeEquivalentTo(9)) Expect(cvs[1].Id).To(BeEquivalentTo(10)) diff --git a/internal/database/mariadb/database.go b/internal/database/mariadb/database.go index ed8fcaccb..6cbfd508a 100644 --- a/internal/database/mariadb/database.go +++ b/internal/database/mariadb/database.go @@ -4,6 +4,7 @@ package mariadb import ( + "context" "database/sql" "fmt" @@ -204,12 +205,13 @@ func performExec[T any](s *SqlDatabase, query string, item T, l *logrus.Entry) ( } func performListScan[T DatabaseRow, E entity.HeurekaEntity | DatabaseRow]( + ctx context.Context, stmt Stmt, filterParameters []any, l *logrus.Entry, listBuilder func([]E, T) []E, ) ([]E, error) { - rows, err := stmt.Queryx(filterParameters...) + rows, err := stmt.QueryxContext(ctx, filterParameters...) if err != nil { msg := "Error while performing Query from prepared Statement" l.WithFields( @@ -248,6 +250,17 @@ func performListScan[T DatabaseRow, E entity.HeurekaEntity | DatabaseRow]( listEntries = listBuilder(listEntries, row) } + if err := rows.Err(); err != nil { + msg := "Error while iterating over result rows" + l.WithFields( + logrus.Fields{ + "error": err.Error(), + "parameters": filterParameters, + }).Error(msg) + + return nil, fmt.Errorf("%s", msg) + } + l.WithFields( logrus.Fields{ "count": len(listEntries), @@ -256,8 +269,8 @@ func performListScan[T DatabaseRow, E entity.HeurekaEntity | DatabaseRow]( return listEntries, nil } -func performIdScan(stmt Stmt, filterParameters []any, l *logrus.Entry) ([]int64, error) { - rows, err := stmt.Queryx(filterParameters...) +func performIdScan(ctx context.Context, stmt Stmt, filterParameters []any, l *logrus.Entry) ([]int64, error) { + rows, err := stmt.QueryxContext(ctx, filterParameters...) if err != nil { msg := "Error while performing query with prepared Statement" l.WithFields( @@ -305,8 +318,8 @@ func performIdScan(stmt Stmt, filterParameters []any, l *logrus.Entry) ([]int64, return listEntries, nil } -func performCountScan(stmt Stmt, filterParameters []any, l *logrus.Entry) (int64, error) { - rows, err := stmt.Queryx(filterParameters...) +func performCountScan(ctx context.Context, stmt Stmt, filterParameters []any, l *logrus.Entry) (int64, error) { + rows, err := stmt.QueryxContext(ctx, filterParameters...) if err != nil { msg := "Error while performing query with prepared Statement" l.WithFields( diff --git a/internal/database/mariadb/db.go b/internal/database/mariadb/db.go index 52350986d..03603b1b9 100644 --- a/internal/database/mariadb/db.go +++ b/internal/database/mariadb/db.go @@ -4,6 +4,7 @@ package mariadb import ( + "context" "database/sql" "github.com/cloudoperators/heureka/internal/util" @@ -13,6 +14,7 @@ import ( type Stmt interface { Close() error Queryx(args ...any) (*sqlx.Rows, error) + QueryxContext(ctx context.Context, args ...any) (*sqlx.Rows, error) } type NamedStmt interface { @@ -33,9 +35,11 @@ type Db interface { Get(dest any, query string, args ...any) error GetDbInstance() *sql.DB Preparex(query string) (Stmt, error) + PreparexContext(ctx context.Context, query string) (Stmt, error) PrepareNamed(query string) (NamedStmt, error) Select(dest any, query string, args ...any) error Query(query string, args ...any) (SqlRows, error) + QueryContext(ctx context.Context, query string, args ...any) (SqlRows, error) QueryRow(query string, args ...any) *sql.Row } diff --git a/internal/database/mariadb/issue.go b/internal/database/mariadb/issue.go index 667978654..ba26b4ab6 100644 --- a/internal/database/mariadb/issue.go +++ b/internal/database/mariadb/issue.go @@ -4,6 +4,7 @@ package mariadb import ( + "context" "errors" "fmt" @@ -364,6 +365,7 @@ func getIssueQuery(baseQuery string, order []entity.Order, filter *entity.IssueF } func (s *SqlDatabase) buildIssueStatementWithCursor( + ctx context.Context, baseQuery string, filter *entity.IssueFilter, order []entity.Order, @@ -380,7 +382,7 @@ func (s *SqlDatabase) buildIssueStatementWithCursor( query := getIssueQueryWithCursor(baseQuery, order, ifilter, cursorFields) // construct prepared statement and if where clause does exist add parameters - stmt, err := s.db.Preparex(query) + stmt, err := s.db.PreparexContext(ctx, query) if err != nil { msg := ERROR_MSG_PREPARED_STMT l.WithFields( @@ -400,6 +402,7 @@ func (s *SqlDatabase) buildIssueStatementWithCursor( } func (s *SqlDatabase) buildIssueStatement( + ctx context.Context, baseQuery string, filter *entity.IssueFilter, order []entity.Order, @@ -416,7 +419,7 @@ func (s *SqlDatabase) buildIssueStatement( query := getIssueQuery(baseQuery, order, ifilter) // construct prepared statement and if where clause does exist add parameters - stmt, err := s.db.Preparex(query) + stmt, err := s.db.PreparexContext(ctx, query) if err != nil { msg := ERROR_MSG_PREPARED_STMT l.WithFields( @@ -436,6 +439,7 @@ func (s *SqlDatabase) buildIssueStatement( } func (s *SqlDatabase) GetIssuesWithAggregations( + ctx context.Context, filter *entity.IssueFilter, order []entity.Order, ) ([]entity.IssueResult, error) { @@ -514,7 +518,7 @@ func (s *SqlDatabase) GetIssuesWithAggregations( aggQuery := fmt.Sprintf(baseAggQuery, columns, joins, whereClause, cursorQuery, ord) query := fmt.Sprintf(baseQuery, ciQuery, aggQuery) - stmt, err := s.db.Preparex(query) + stmt, err := s.db.PreparexContext(ctx, query) if err != nil { msg := ERROR_MSG_PREPARED_STMT l.WithFields( @@ -541,6 +545,7 @@ func (s *SqlDatabase) GetIssuesWithAggregations( }() return performListScan( + ctx, stmt, filterParameters, l, @@ -571,7 +576,7 @@ func (s *SqlDatabase) GetIssuesWithAggregations( ) } -func (s *SqlDatabase) CountIssues(filter *entity.IssueFilter) (int64, error) { +func (s *SqlDatabase) CountIssues(ctx context.Context, filter *entity.IssueFilter) (int64, error) { l := logrus.WithFields(logrus.Fields{ "event": "database.CountIssues", }) @@ -583,7 +588,7 @@ func (s *SqlDatabase) CountIssues(filter *entity.IssueFilter) (int64, error) { ORDER BY %s ` - stmt, filterParameters, err := s.buildIssueStatement(baseQuery, filter, []entity.Order{}, l) + stmt, filterParameters, err := s.buildIssueStatement(ctx, baseQuery, filter, []entity.Order{}, l) if err != nil { return -1, err } @@ -594,10 +599,10 @@ func (s *SqlDatabase) CountIssues(filter *entity.IssueFilter) (int64, error) { } }() - return performCountScan(stmt, filterParameters, l) + return performCountScan(ctx, stmt, filterParameters, l) } -func (s *SqlDatabase) CountIssueTypes(filter *entity.IssueFilter) (*entity.IssueTypeCounts, error) { +func (s *SqlDatabase) CountIssueTypes(ctx context.Context, filter *entity.IssueFilter) (*entity.IssueTypeCounts, error) { l := logrus.WithFields(logrus.Fields{ "event": "database.CountIssueTypes", }) @@ -609,7 +614,7 @@ func (s *SqlDatabase) CountIssueTypes(filter *entity.IssueFilter) (*entity.Issue GROUP BY I.issue_type ORDER BY %s ` - stmt, filterParameters, err := s.buildIssueStatement(baseQuery, filter, []entity.Order{}, l) + stmt, filterParameters, err := s.buildIssueStatement(ctx, baseQuery, filter, []entity.Order{}, l) if err != nil { return nil, err } @@ -621,6 +626,7 @@ func (s *SqlDatabase) CountIssueTypes(filter *entity.IssueFilter) (*entity.Issue }() counts, err := performListScan( + ctx, stmt, filterParameters, l, @@ -649,6 +655,7 @@ func (s *SqlDatabase) CountIssueTypes(filter *entity.IssueFilter) (*entity.Issue } func (s *SqlDatabase) GetAllIssueCursors( + ctx context.Context, filter *entity.IssueFilter, order []entity.Order, ) ([]string, error) { @@ -663,7 +670,7 @@ func (s *SqlDatabase) GetAllIssueCursors( %s GROUP BY I.issue_id ORDER BY %s ` - stmt, filterParameters, err := s.buildIssueStatement(baseQuery, filter, order, l) + stmt, filterParameters, err := s.buildIssueStatement(ctx, baseQuery, filter, order, l) if err != nil { return nil, err } @@ -675,6 +682,7 @@ func (s *SqlDatabase) GetAllIssueCursors( }() rows, err := performListScan( + context.Background(), stmt, filterParameters, l, @@ -701,6 +709,7 @@ func (s *SqlDatabase) GetAllIssueCursors( } func (s *SqlDatabase) GetIssues( + ctx context.Context, filter *entity.IssueFilter, order []entity.Order, ) ([]entity.IssueResult, error) { @@ -717,7 +726,7 @@ func (s *SqlDatabase) GetIssues( filter = ensureIssueFilter(filter) - stmt, filterParameters, err := s.buildIssueStatementWithCursor(baseQuery, filter, order, l) + stmt, filterParameters, err := s.buildIssueStatementWithCursor(ctx, baseQuery, filter, order, l) if err != nil { return nil, err } @@ -729,6 +738,7 @@ func (s *SqlDatabase) GetIssues( }() return performListScan( + ctx, stmt, filterParameters, l, @@ -852,7 +862,7 @@ func (s *SqlDatabase) RemoveAllIssuesFromComponentVersion(componentVersionId int return err } -func (s *SqlDatabase) GetIssueNames(filter *entity.IssueFilter) ([]string, error) { +func (s *SqlDatabase) GetIssueNames(ctx context.Context, filter *entity.IssueFilter) ([]string, error) { l := logrus.WithFields(logrus.Fields{ "filter": filter, "event": "database.GetIssueNames", @@ -873,7 +883,7 @@ func (s *SqlDatabase) GetIssueNames(filter *entity.IssueFilter) ([]string, error filter = ensureIssueFilter(filter) // Builds full statement with possible joins and filters - stmt, filterParameters, err := s.buildIssueStatement(baseQuery, filter, order, l) + stmt, filterParameters, err := s.buildIssueStatement(ctx, baseQuery, filter, order, l) if err != nil { l.Error("Error preparing statement: ", err) return nil, err @@ -886,7 +896,7 @@ func (s *SqlDatabase) GetIssueNames(filter *entity.IssueFilter) ([]string, error }() // Execute the query - rows, err := stmt.Queryx(filterParameters...) + rows, err := stmt.QueryxContext(ctx, filterParameters...) if err != nil { l.Error("Error executing query: ", err) return nil, err diff --git a/internal/database/mariadb/issue_match.go b/internal/database/mariadb/issue_match.go index 257b09010..8485aa203 100644 --- a/internal/database/mariadb/issue_match.go +++ b/internal/database/mariadb/issue_match.go @@ -4,6 +4,7 @@ package mariadb import ( + "context" "fmt" "time" @@ -277,6 +278,7 @@ func (s *SqlDatabase) getIssueMatchColumns(order []entity.Order) string { } func (s *SqlDatabase) buildIssueMatchStatement( + ctx context.Context, baseQuery string, filter *entity.IssueMatchFilter, withCursor bool, @@ -305,7 +307,7 @@ func (s *SqlDatabase) buildIssueMatchStatement( } // construct prepared statement and if where clause does exist add parameters - stmt, err := s.db.Preparex(query) + stmt, err := s.db.PreparexContext(ctx, query) if err != nil { msg := ERROR_MSG_PREPARED_STMT l.WithFields( @@ -323,7 +325,7 @@ func (s *SqlDatabase) buildIssueMatchStatement( return stmt, filterParameters, nil } -func (s *SqlDatabase) GetAllIssueMatchIds(filter *entity.IssueMatchFilter) ([]int64, error) { +func (s *SqlDatabase) GetAllIssueMatchIds(ctx context.Context, filter *entity.IssueMatchFilter) ([]int64, error) { l := logrus.WithFields(logrus.Fields{ "filter": filter, "event": "database.GetAllIssueMatchIds", @@ -336,6 +338,7 @@ func (s *SqlDatabase) GetAllIssueMatchIds(filter *entity.IssueMatchFilter) ([]in ` stmt, filterParameters, err := s.buildIssueMatchStatement( + ctx, baseQuery, filter, false, @@ -346,10 +349,11 @@ func (s *SqlDatabase) GetAllIssueMatchIds(filter *entity.IssueMatchFilter) ([]in return nil, err } - return performIdScan(stmt, filterParameters, l) + return performIdScan(ctx, stmt, filterParameters, l) } func (s *SqlDatabase) GetAllIssueMatchCursors( + ctx context.Context, filter *entity.IssueMatchFilter, order []entity.Order, ) ([]string, error) { @@ -364,12 +368,13 @@ func (s *SqlDatabase) GetAllIssueMatchCursors( %s GROUP BY IM.issuematch_id ORDER BY %s ` - stmt, filterParameters, err := s.buildIssueMatchStatement(baseQuery, filter, false, order, l) + stmt, filterParameters, err := s.buildIssueMatchStatement(ctx, baseQuery, filter, false, order, l) if err != nil { return nil, err } rows, err := performListScan( + ctx, stmt, filterParameters, l, @@ -398,6 +403,7 @@ func (s *SqlDatabase) GetAllIssueMatchCursors( } func (s *SqlDatabase) GetIssueMatches( + ctx context.Context, filter *entity.IssueMatchFilter, order []entity.Order, ) ([]entity.IssueMatchResult, error) { @@ -412,7 +418,7 @@ func (s *SqlDatabase) GetIssueMatches( %s %s GROUP BY IM.issuematch_id ORDER BY %s LIMIT ? ` - stmt, filterParameters, err := s.buildIssueMatchStatement(baseQuery, filter, true, order, l) + stmt, filterParameters, err := s.buildIssueMatchStatement(ctx, baseQuery, filter, true, order, l) if err != nil { return nil, err } @@ -424,6 +430,7 @@ func (s *SqlDatabase) GetIssueMatches( }() return performListScan( + ctx, stmt, filterParameters, l, @@ -451,7 +458,7 @@ func (s *SqlDatabase) GetIssueMatches( ) } -func (s *SqlDatabase) CountIssueMatches(filter *entity.IssueMatchFilter) (int64, error) { +func (s *SqlDatabase) CountIssueMatches(ctx context.Context, filter *entity.IssueMatchFilter) (int64, error) { l := logrus.WithFields(logrus.Fields{ "filter": filter, "event": "database.CountIssueMatches", @@ -465,6 +472,7 @@ func (s *SqlDatabase) CountIssueMatches(filter *entity.IssueMatchFilter) (int64, ` stmt, filterParameters, err := s.buildIssueMatchStatement( + ctx, baseQuery, filter, false, @@ -475,7 +483,7 @@ func (s *SqlDatabase) CountIssueMatches(filter *entity.IssueMatchFilter) (int64, return -1, err } - return performCountScan(stmt, filterParameters, l) + return performCountScan(ctx, stmt, filterParameters, l) } func (s *SqlDatabase) CreateIssueMatch(issueMatch *entity.IssueMatch) (*entity.IssueMatch, error) { diff --git a/internal/database/mariadb/issue_match_test.go b/internal/database/mariadb/issue_match_test.go index f3c3dd25b..10f4ba3cf 100644 --- a/internal/database/mariadb/issue_match_test.go +++ b/internal/database/mariadb/issue_match_test.go @@ -4,6 +4,7 @@ package mariadb_test import ( + "context" "database/sql" "sort" "time" @@ -33,7 +34,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { When("Getting All IssueMatch IDs", Label("GetAllIssueMatchIds"), func() { Context("and the database is empty", func() { It("can perform the query", func() { - res, err := db.GetAllIssueMatchIds(nil) + res, err := db.GetAllIssueMatchIds(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -56,7 +57,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { }) Context("and using no filter", func() { It("can fetch the items correctly", func() { - res, err := db.GetAllIssueMatchIds(nil) + res, err := db.GetAllIssueMatchIds(context.Background(), nil) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -90,7 +91,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { Id: []*int64{&vmId}, } - entries, err := db.GetAllIssueMatchIds(filter) + entries, err := db.GetAllIssueMatchIds(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -118,7 +119,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { } } - entries, err := db.GetAllIssueMatchIds(filter) + entries, err := db.GetAllIssueMatchIds(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -147,7 +148,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { ComponentCCRN: []*string{&cRow.CCRN.String}, } - entries, err := db.GetAllIssueMatchIds(filter) + entries, err := db.GetAllIssueMatchIds(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -176,7 +177,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { } } - entries, err := db.GetAllIssueMatchIds(filter) + entries, err := db.GetAllIssueMatchIds(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -200,7 +201,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { searchStr := test.CutString(issueRow.PrimaryName.String, 2, 2, 5) filter := &entity.IssueMatchFilter{Search: []*string{&searchStr}} - entries, err := db.GetAllIssueMatchIds(filter) + entries, err := db.GetAllIssueMatchIds(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -221,7 +222,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { When("Getting IssueMatches", Label("GetIssueMatches"), func() { Context("and the database is empty", func() { It("can perform the query", func() { - res, err := db.GetIssueMatches(nil, nil) + res, err := db.GetIssueMatches(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -242,7 +243,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { }) Context("and using no filter", func() { It("can fetch the items correctly", func() { - res, err := db.GetIssueMatches(nil, nil) + res, err := db.GetIssueMatches(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -288,7 +289,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { Id: []*int64{&im.Id.Int64}, } - entries, err := db.GetIssueMatches(filter, nil) + entries, err := db.GetIssueMatches(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -315,7 +316,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { } } - entries, err := db.GetIssueMatches(filter, nil) + entries, err := db.GetIssueMatches(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -344,7 +345,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { } } - entries, err := db.GetIssueMatches(filter, nil) + entries, err := db.GetIssueMatches(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -381,7 +382,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { } } - entries, err := db.GetIssueMatches(filter, nil) + entries, err := db.GetIssueMatches(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -417,7 +418,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { } } - entries, err := db.GetIssueMatches(filter, nil) + entries, err := db.GetIssueMatches(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -466,7 +467,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { // fixture creation does not guarantee that a support group is always present if sgFound { - entries, err := db.GetIssueMatches(filter, nil) + entries, err := db.GetIssueMatches(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -501,7 +502,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { ServiceOwnerUsername: []*string{&user.Name.String}, } - entries, err := db.GetIssueMatches(filter, nil) + entries, err := db.GetIssueMatches(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -530,7 +531,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { ServiceOwnerUniqueUserId: []*string{&user.UniqueUserID.String}, } - entries, err := db.GetIssueMatches(filter, nil) + entries, err := db.GetIssueMatches(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -583,7 +584,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { Context("and using no filter", func() { DescribeTable("it returns correct count", func(x int) { _ = seeder.SeedDbWithNFakeData(x) - res, err := db.CountIssueMatches(nil) + res, err := db.CountIssueMatches(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -613,7 +614,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { After: &after, }, } - res, err := db.CountIssueMatches(filter) + res, err := db.CountIssueMatches(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -636,7 +637,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { imIds = append(imIds, e.Id.Int64) } } - count, err := db.CountIssueMatches(filter) + count, err := db.CountIssueMatches(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -683,7 +684,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { Id: []*int64{&issueMatch.Id}, } - im, err := db.GetIssueMatches(issueMatchFilter, nil) + im, err := db.GetIssueMatches(context.Background(), issueMatchFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -736,7 +737,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { Id: []*int64{&issueMatch.Id}, } - im, err := db.GetIssueMatches(issueMatchFilter, nil) + im, err := db.GetIssueMatches(context.Background(), issueMatchFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -783,7 +784,7 @@ var _ = Describe("IssueMatch", Label("database", "IssueMatch"), func() { Id: []*int64{&issueMatch.Id}, } - im, err := db.GetIssueMatches(issueMatchFilter, nil) + im, err := db.GetIssueMatches(context.Background(), issueMatchFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -814,7 +815,7 @@ var _ = Describe("Ordering IssueMatches", func() { order []entity.Order, verifyFunc func(res []entity.IssueMatchResult), ) { - res, err := db.GetIssueMatches(nil, order) + res, err := db.GetIssueMatches(context.Background(), nil, order) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -1226,7 +1227,7 @@ var _ = Describe("Using the Cursor on IssueMatches", func() { {By: entity.IssuePrimaryName, Direction: entity.OrderDirectionAsc}, {By: entity.IssueMatchTargetRemediationDate, Direction: entity.OrderDirectionAsc}, } - im, err := db.GetIssueMatches(&filter, order) + im, err := db.GetIssueMatches(context.Background(), &filter, order) Expect(err).To(BeNil()) Expect(im).To(HaveLen(1)) filterWithCursor := entity.IssueMatchFilter{ @@ -1234,7 +1235,7 @@ var _ = Describe("Using the Cursor on IssueMatches", func() { After: im[0].Cursor(), }, } - res, err := db.GetIssueMatches(&filterWithCursor, order) + res, err := db.GetIssueMatches(context.Background(), &filterWithCursor, order) Expect(err).To(BeNil()) Expect(res[0].Id).To(BeEquivalentTo(13)) Expect(res[1].Id).To(BeEquivalentTo(20)) diff --git a/internal/database/mariadb/issue_repository.go b/internal/database/mariadb/issue_repository.go index ee3e6d14a..33c6b3dcd 100644 --- a/internal/database/mariadb/issue_repository.go +++ b/internal/database/mariadb/issue_repository.go @@ -4,6 +4,7 @@ package mariadb import ( + "context" "fmt" "github.com/cloudoperators/heureka/internal/entity" @@ -104,6 +105,7 @@ func ensureIssueRepositoryFilter( } func (s *SqlDatabase) buildIssueRepositoryStatement( + ctx context.Context, baseQuery string, filter *entity.IssueRepositoryFilter, withCursor bool, @@ -130,7 +132,7 @@ func (s *SqlDatabase) buildIssueRepositoryStatement( query = fmt.Sprintf(baseQuery, joins, whereClause, ord) } - stmt, err := s.db.Preparex(query) + stmt, err := s.db.PreparexContext(ctx, query) if err != nil { msg := ERROR_MSG_PREPARED_STMT l.WithFields( @@ -149,6 +151,7 @@ func (s *SqlDatabase) buildIssueRepositoryStatement( } func (s *SqlDatabase) GetAllIssueRepositoryCursors( + ctx context.Context, filter *entity.IssueRepositoryFilter, order []entity.Order, ) ([]string, error) { @@ -166,6 +169,7 @@ func (s *SqlDatabase) GetAllIssueRepositoryCursors( filter = ensureIssueRepositoryFilter(filter) stmt, filterParameters, err := s.buildIssueRepositoryStatement( + ctx, baseQuery, filter, false, @@ -183,6 +187,7 @@ func (s *SqlDatabase) GetAllIssueRepositoryCursors( }() rows, err := performListScan( + ctx, stmt, filterParameters, l, @@ -204,6 +209,7 @@ func (s *SqlDatabase) GetAllIssueRepositoryCursors( } func (s *SqlDatabase) GetIssueRepositories( + ctx context.Context, filter *entity.IssueRepositoryFilter, order []entity.Order, ) ([]entity.IssueRepositoryResult, error) { @@ -221,6 +227,7 @@ func (s *SqlDatabase) GetIssueRepositories( filter = ensureIssueRepositoryFilter(filter) stmt, filterParameters, err := s.buildIssueRepositoryStatement( + ctx, baseQuery, filter, true, @@ -238,6 +245,7 @@ func (s *SqlDatabase) GetIssueRepositories( }() return performListScan( + ctx, stmt, filterParameters, l, @@ -257,7 +265,7 @@ func (s *SqlDatabase) GetIssueRepositories( ) } -func (s *SqlDatabase) CountIssueRepositories(filter *entity.IssueRepositoryFilter) (int64, error) { +func (s *SqlDatabase) CountIssueRepositories(ctx context.Context, filter *entity.IssueRepositoryFilter) (int64, error) { l := logrus.WithFields(logrus.Fields{ "event": "database.CountIssueRepositories", }) @@ -270,6 +278,7 @@ func (s *SqlDatabase) CountIssueRepositories(filter *entity.IssueRepositoryFilte ` stmt, filterParameters, err := s.buildIssueRepositoryStatement( + ctx, baseQuery, filter, false, @@ -286,7 +295,7 @@ func (s *SqlDatabase) CountIssueRepositories(filter *entity.IssueRepositoryFilte } }() - return performCountScan(stmt, filterParameters, l) + return performCountScan(ctx, stmt, filterParameters, l) } func (s *SqlDatabase) CreateIssueRepository( diff --git a/internal/database/mariadb/issue_repository_test.go b/internal/database/mariadb/issue_repository_test.go index e6e6292be..c9f87d793 100644 --- a/internal/database/mariadb/issue_repository_test.go +++ b/internal/database/mariadb/issue_repository_test.go @@ -4,6 +4,8 @@ package mariadb_test import ( + "context" + "github.com/cloudoperators/heureka/internal/database/mariadb" "github.com/cloudoperators/heureka/internal/database/mariadb/test" "github.com/cloudoperators/heureka/internal/entity" @@ -28,7 +30,7 @@ var _ = Describe("IssueRepository", Label("database", "IssueRepository"), func() When("Getting IssueRepositories", Label("GetIssueRepositories"), func() { Context("and the database is empty", func() { It("can perform the list query", func() { - res, err := db.GetIssueRepositories(nil, nil) + res, err := db.GetIssueRepositories(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -45,7 +47,7 @@ var _ = Describe("IssueRepository", Label("database", "IssueRepository"), func() Context("and using no filter", func() { It("can fetch the items correctly", func() { - res, err := db.GetIssueRepositories(nil, nil) + res, err := db.GetIssueRepositories(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -94,7 +96,7 @@ var _ = Describe("IssueRepository", Label("database", "IssueRepository"), func() row := test.PickOne(seedCollection.IssueRepositoryRows) filter := &entity.IssueRepositoryFilter{Name: []*string{&row.Name.String}} - entries, err := db.GetIssueRepositories(filter, nil) + entries, err := db.GetIssueRepositories(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -124,7 +126,7 @@ var _ = Describe("IssueRepository", Label("database", "IssueRepository"), func() ServiceCCRN: []*string{&sRow.CCRN.String}, } - entries, err := db.GetIssueRepositories(filter, nil) + entries, err := db.GetIssueRepositories(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -150,7 +152,7 @@ var _ = Describe("IssueRepository", Label("database", "IssueRepository"), func() filter := &entity.IssueRepositoryFilter{ServiceId: []*int64{&sRow.Id.Int64}} - entries, err := db.GetIssueRepositories(filter, nil) + entries, err := db.GetIssueRepositories(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -198,7 +200,7 @@ var _ = Describe("IssueRepository", Label("database", "IssueRepository"), func() When("Counting IssueRepositories", Label("CountIssueRepositories"), func() { Context("and the database is empty", func() { It("can count correctly", func() { - c, err := db.CountIssueRepositories(nil) + c, err := db.CountIssueRepositories(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -219,7 +221,7 @@ var _ = Describe("IssueRepository", Label("database", "IssueRepository"), func() }) Context("and using no filter", func() { It("can count", func() { - c, err := db.CountIssueRepositories(nil) + c, err := db.CountIssueRepositories(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -238,7 +240,7 @@ var _ = Describe("IssueRepository", Label("database", "IssueRepository"), func() After: nil, }, } - c, err := db.CountIssueRepositories(filter) + c, err := db.CountIssueRepositories(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -269,7 +271,7 @@ var _ = Describe("IssueRepository", Label("database", "IssueRepository"), func() }, ServiceCCRN: []*string{&sRow.CCRN.String}, } - entries, err := db.CountIssueRepositories(filter) + entries, err := db.CountIssueRepositories(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -309,7 +311,7 @@ var _ = Describe("IssueRepository", Label("database", "IssueRepository"), func() Id: []*int64{&issueRepository.Id}, } - ir, err := db.GetIssueRepositories(issueRepositoryFilter, nil) + ir, err := db.GetIssueRepositories(context.Background(), issueRepositoryFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -356,7 +358,7 @@ var _ = Describe("IssueRepository", Label("database", "IssueRepository"), func() Id: []*int64{&issueRepository.Id}, } - ir, err := db.GetIssueRepositories(issueRepositoryFilter, nil) + ir, err := db.GetIssueRepositories(context.Background(), issueRepositoryFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -383,7 +385,7 @@ var _ = Describe("IssueRepository", Label("database", "IssueRepository"), func() Id: []*int64{&issueRepository.Id}, } - ir, err := db.GetIssueRepositories(issueRepositoryFilter, nil) + ir, err := db.GetIssueRepositories(context.Background(), issueRepositoryFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -417,7 +419,7 @@ var _ = Describe("IssueRepository", Label("database", "IssueRepository"), func() Id: []*int64{&issueRepository.Id}, } - ir, err := db.GetIssueRepositories(issueRepositoryFilter, nil) + ir, err := db.GetIssueRepositories(context.Background(), issueRepositoryFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) diff --git a/internal/database/mariadb/issue_test.go b/internal/database/mariadb/issue_test.go index 5d1afdc17..d05fae8b1 100644 --- a/internal/database/mariadb/issue_test.go +++ b/internal/database/mariadb/issue_test.go @@ -4,6 +4,7 @@ package mariadb_test import ( + "context" "database/sql" "fmt" "sort" @@ -36,7 +37,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { When("Getting Issues", Label("GetIssues"), func() { Context("and the database is empty", func() { It("can perform the list query", func() { - res, err := db.GetIssues(nil, nil) + res, err := db.GetIssues(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -53,7 +54,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { Context("and using no filter", func() { It("can fetch the items correctly", func() { - res, err := db.GetIssues(nil, nil) + res, err := db.GetIssues(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -112,7 +113,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { } filter := &entity.IssueFilter{ServiceCCRN: []*string{&row.CCRN.String}} - entries, err := db.GetIssues(filter, nil) + entries, err := db.GetIssues(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -128,7 +129,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { nonExistingName := pkg_util.GenerateRandomString(40, nil) filter := &entity.IssueFilter{ServiceCCRN: []*string{&nonExistingName}} - entries, err := db.GetIssues(filter, nil) + entries, err := db.GetIssues(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -150,7 +151,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { expectedIssues = lo.Uniq(expectedIssues) filter := &entity.IssueFilter{ServiceCCRN: serviceCcrns} - entries, err := db.GetIssues(filter, nil) + entries, err := db.GetIssues(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -163,7 +164,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { row := test.PickOne(seedCollection.IssueRows) filter := &entity.IssueFilter{Id: []*int64{&row.Id.Int64}} - entries, err := db.GetIssues(filter, nil) + entries, err := db.GetIssues(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -188,7 +189,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { filter := &entity.IssueFilter{ServiceId: []*int64{&serviceRow.Id.Int64}} - entries, err := db.GetIssues(filter, nil) + entries, err := db.GetIssues(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -226,7 +227,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { filter := &entity.IssueFilter{SupportGroupCCRN: []*string{&sgRow.CCRN.String}} - entries, err := db.GetIssues(filter, nil) + entries, err := db.GetIssues(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -252,7 +253,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { filter := &entity.IssueFilter{ComponentVersionId: []*int64{&cvRow.Id.Int64}} - entries, err := db.GetIssues(filter, nil) + entries, err := db.GetIssues(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -286,7 +287,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { filter := &entity.IssueFilter{ComponentId: []*int64{&cRow.Id.Int64}} - entries, err := db.GetIssues(filter, nil) + entries, err := db.GetIssues(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -306,7 +307,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { IssueVariantId: []*int64{&issueVariantRow.Id.Int64}, } - entries, err := db.GetIssues(filter, nil) + entries, err := db.GetIssues(context.Background(), filter, nil) issueIds := []int64{} for _, entry := range entries { @@ -326,7 +327,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { filter := &entity.IssueFilter{Type: []*string{&issueType}} - entries, err := db.GetIssues(filter, nil) + entries, err := db.GetIssues(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -339,7 +340,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { It("can filter by hasIssueMatches", func() { filter := &entity.IssueFilter{HasIssueMatches: true} - entries, err := db.GetIssues(filter, nil) + entries, err := db.GetIssues(context.Background(), filter, nil) Expect(err).To(BeNil()) for _, entry := range entries { @@ -367,7 +368,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { IssueMatchSeverity: []*string{new(severity.String())}, } - entries, err := db.GetIssues(filter, nil) + entries, err := db.GetIssues(context.Background(), filter, nil) Expect(err).To(BeNil()) for _, entry := range entries { @@ -383,7 +384,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { searchStr := test.CutString(row.PrimaryName.String, 2, 2, 5) filter := &entity.IssueFilter{Search: []*string{&searchStr}} - entries, err := db.GetIssues(filter, nil) + entries, err := db.GetIssues(context.Background(), filter, nil) issueIds := []int64{} for _, entry := range entries { @@ -409,7 +410,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { searchStr := test.CutString(issueVariantRow.SecondaryName.String, 2, 2, 5) filter := &entity.IssueFilter{Search: []*string{&searchStr}} - entries, err := db.GetIssues(filter, nil) + entries, err := db.GetIssues(context.Background(), filter, nil) issueIds := []int64{} for _, entry := range entries { @@ -451,7 +452,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { It("can filter issue by IssueStatusOpen", func() { filter := &entity.IssueFilter{Status: entity.IssueStatusOpen} - entries, err := db.GetIssues(filter, nil) + entries, err := db.GetIssues(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -464,7 +465,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { It("can filter issue by IssueStatusRemediated", func() { filter := &entity.IssueFilter{Status: entity.IssueStatusRemediated} - entries, err := db.GetIssues(filter, nil) + entries, err := db.GetIssues(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -519,7 +520,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { db.CreateIssue(&newIssue) }) It("returns the issues with aggregations", func() { - entriesWithAggregations, err := db.GetIssuesWithAggregations(nil, nil) + entriesWithAggregations, err := db.GetIssuesWithAggregations(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -553,7 +554,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { _ = seeder.SeedDbWithNFakeData(10) }) It("returns the issues with aggregations", func() { - entriesWithAggregations, err := db.GetIssuesWithAggregations(nil, nil) + entriesWithAggregations, err := db.GetIssuesWithAggregations(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -585,7 +586,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { Context("and using no filter", func() { DescribeTable("it returns correct count", func(x int) { _ = seeder.SeedDbWithNFakeData(x) - res, err := db.CountIssues(nil) + res, err := db.CountIssues(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -621,7 +622,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { } } - issueTypeCounts, err := db.CountIssueTypes(nil) + issueTypeCounts, err := db.CountIssueTypes(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -656,7 +657,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { After: &after, }, } - res, err := db.CountIssues(filter) + res, err := db.CountIssues(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -679,7 +680,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { } filter := &entity.IssueFilter{ServiceCCRN: []*string{&row.CCRN.String}} - count, err := db.CountIssues(filter) + count, err := db.CountIssues(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -702,7 +703,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { } filter := &entity.IssueFilter{ServiceId: []*int64{&row.Id.Int64}} - count, err := db.CountIssues(filter) + count, err := db.CountIssues(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -717,7 +718,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { }) When("IssueCounts by Severity", Label("IssueCounts"), func() { testIssueSeverityCount := func(filter *entity.IssueFilter, counts entity.IssueSeverityCounts) { - issueSeverityCounts, err := db.CountIssueRatings(filter) + issueSeverityCounts, err := db.CountIssueRatings(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -879,7 +880,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { Id: []*int64{&issue.Id}, } - i, err := db.GetIssues(issueFilter, nil) + i, err := db.GetIssues(context.Background(), issueFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -926,7 +927,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { Id: []*int64{&issue.Id}, } - i, err := db.GetIssues(issueFilter, nil) + i, err := db.GetIssues(context.Background(), issueFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -958,7 +959,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { Id: []*int64{&issue.Id}, } - i, err := db.GetIssues(issueFilter, nil) + i, err := db.GetIssues(context.Background(), issueFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -999,7 +1000,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { Id: []*int64{&issue.Id}, } - i, err := db.GetIssues(issueFilter, nil) + i, err := db.GetIssues(context.Background(), issueFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -1046,7 +1047,7 @@ var _ = Describe("Issue", Label("database", "Issue"), func() { }, } - issues, err := db.GetIssues(issueFilter, nil) + issues, err := db.GetIssues(context.Background(), issueFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -1079,7 +1080,7 @@ var _ = Describe("Ordering Issues", Label("IssueOrder"), func() { order []entity.Order, verifyFunc func(res []entity.IssueResult), ) { - res, err := db.GetIssues(nil, order) + res, err := db.GetIssues(context.Background(), nil, order) By("throwing no error", func() { Expect(err).Should(BeNil()) diff --git a/internal/database/mariadb/issue_variant.go b/internal/database/mariadb/issue_variant.go index 0ca1ba41a..98522dd04 100644 --- a/internal/database/mariadb/issue_variant.go +++ b/internal/database/mariadb/issue_variant.go @@ -4,6 +4,7 @@ package mariadb import ( + "context" "fmt" "github.com/cloudoperators/heureka/internal/entity" @@ -161,6 +162,7 @@ func ensureIssueVariantFilter(filter *entity.IssueVariantFilter) *entity.IssueVa } func (s *SqlDatabase) buildIssueVariantStatement( + ctx context.Context, baseQuery string, filter *entity.IssueVariantFilter, withCursor bool, @@ -187,7 +189,7 @@ func (s *SqlDatabase) buildIssueVariantStatement( query = fmt.Sprintf(baseQuery, joins, whereClause, ord) } - stmt, err := s.db.Preparex(query) + stmt, err := s.db.PreparexContext(ctx, query) if err != nil { msg := ERROR_MSG_PREPARED_STMT l.WithFields( @@ -206,6 +208,7 @@ func (s *SqlDatabase) buildIssueVariantStatement( } func (s *SqlDatabase) GetAllIssueVariantCursors( + ctx context.Context, filter *entity.IssueVariantFilter, order []entity.Order, ) ([]string, error) { @@ -222,7 +225,7 @@ func (s *SqlDatabase) GetAllIssueVariantCursors( filter = ensureIssueVariantFilter(filter) - stmt, filterParameters, err := s.buildIssueVariantStatement(baseQuery, filter, false, order, l) + stmt, filterParameters, err := s.buildIssueVariantStatement(ctx, baseQuery, filter, false, order, l) if err != nil { return nil, fmt.Errorf("failed to build IssueVariant cursor query: %w", err) } @@ -234,6 +237,7 @@ func (s *SqlDatabase) GetAllIssueVariantCursors( }() rows, err := performListScan( + ctx, stmt, filterParameters, l, @@ -255,6 +259,7 @@ func (s *SqlDatabase) GetAllIssueVariantCursors( } func (s *SqlDatabase) GetIssueVariants( + ctx context.Context, filter *entity.IssueVariantFilter, order []entity.Order, ) ([]entity.IssueVariantResult, error) { @@ -269,7 +274,7 @@ func (s *SqlDatabase) GetIssueVariants( %s ORDER BY %s LIMIT ? ` - stmt, filterParameters, err := s.buildIssueVariantStatement(baseQuery, filter, true, order, l) + stmt, filterParameters, err := s.buildIssueVariantStatement(ctx, baseQuery, filter, true, order, l) if err != nil { return nil, err } @@ -281,6 +286,7 @@ func (s *SqlDatabase) GetIssueVariants( }() return performListScan( + ctx, stmt, filterParameters, l, @@ -300,7 +306,7 @@ func (s *SqlDatabase) GetIssueVariants( ) } -func (s *SqlDatabase) CountIssueVariants(filter *entity.IssueVariantFilter) (int64, error) { +func (s *SqlDatabase) CountIssueVariants(ctx context.Context, filter *entity.IssueVariantFilter) (int64, error) { l := logrus.WithFields(logrus.Fields{ "event": "database.CountIssueVariants", }) @@ -313,6 +319,7 @@ func (s *SqlDatabase) CountIssueVariants(filter *entity.IssueVariantFilter) (int ` stmt, filterParameters, err := s.buildIssueVariantStatement( + ctx, baseQuery, filter, false, @@ -330,6 +337,7 @@ func (s *SqlDatabase) CountIssueVariants(filter *entity.IssueVariantFilter) (int }() return performCountScan( + ctx, stmt, filterParameters, l, diff --git a/internal/database/mariadb/issue_variant_test.go b/internal/database/mariadb/issue_variant_test.go index 1ba0dafea..6b86ef847 100644 --- a/internal/database/mariadb/issue_variant_test.go +++ b/internal/database/mariadb/issue_variant_test.go @@ -4,6 +4,8 @@ package mariadb_test import ( + "context" + "github.com/cloudoperators/heureka/internal/database/mariadb" "github.com/cloudoperators/heureka/internal/database/mariadb/test" "github.com/cloudoperators/heureka/internal/entity" @@ -29,7 +31,7 @@ var _ = Describe("IssueVariant - ", Label("database", "IssueVariant"), func() { When("Getting IssueVariants", Label("GetIssueVariants"), func() { Context("and the database is empty", func() { It("can perform the query", func() { - res, err := db.GetIssueVariants(nil, []entity.Order{}) + res, err := db.GetIssueVariants(context.Background(), nil, []entity.Order{}) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -46,7 +48,7 @@ var _ = Describe("IssueVariant - ", Label("database", "IssueVariant"), func() { }) Context("and using no filter", func() { It("can fetch the items correctly", func() { - res, err := db.GetIssueVariants(nil, []entity.Order{}) + res, err := db.GetIssueVariants(context.Background(), nil, []entity.Order{}) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -102,7 +104,7 @@ var _ = Describe("IssueVariant - ", Label("database", "IssueVariant"), func() { Id: []*int64{&issueVariant.Id.Int64}, } - entries, err := db.GetIssueVariants(filter, []entity.Order{}) + entries, err := db.GetIssueVariants(context.Background(), filter, []entity.Order{}) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -126,7 +128,7 @@ var _ = Describe("IssueVariant - ", Label("database", "IssueVariant"), func() { IssueId: []*int64{&issueId}, } - entries, err := db.GetIssueVariants(filter, []entity.Order{}) + entries, err := db.GetIssueVariants(context.Background(), filter, []entity.Order{}) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -167,7 +169,7 @@ var _ = Describe("IssueVariant - ", Label("database", "IssueVariant"), func() { IssueId: []*int64{&issueId1, &issueId2}, } - entries, err := db.GetIssueVariants(filter, []entity.Order{}) + entries, err := db.GetIssueVariants(context.Background(), filter, []entity.Order{}) By("throwing no Error", func() { Expect(err).To(BeNil()) @@ -195,7 +197,7 @@ var _ = Describe("IssueVariant - ", Label("database", "IssueVariant"), func() { IssueRepositoryId: []*int64{&ir.IssueRepositoryId.Int64}, } - issueVariants, err := db.GetIssueVariants(filter, []entity.Order{}) + issueVariants, err := db.GetIssueVariants(context.Background(), filter, []entity.Order{}) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -222,7 +224,7 @@ var _ = Describe("IssueVariant - ", Label("database", "IssueVariant"), func() { ServiceId: []*int64{&service.Id.Int64}, } - entries, err := db.GetIssueVariants(filter, []entity.Order{}) + entries, err := db.GetIssueVariants(context.Background(), filter, []entity.Order{}) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -252,7 +254,7 @@ var _ = Describe("IssueVariant - ", Label("database", "IssueVariant"), func() { IssueMatchId: []*int64{&im.Id.Int64}, } - entries, err := db.GetIssueVariants(filter, []entity.Order{}) + entries, err := db.GetIssueVariants(context.Background(), filter, []entity.Order{}) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -280,7 +282,7 @@ var _ = Describe("IssueVariant - ", Label("database", "IssueVariant"), func() { SecondaryName: []*string{&iv.SecondaryName.String}, } - entries, err := db.GetIssueVariants(filter, []entity.Order{}) + entries, err := db.GetIssueVariants(context.Background(), filter, []entity.Order{}) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -330,7 +332,7 @@ var _ = Describe("IssueVariant - ", Label("database", "IssueVariant"), func() { When("Counting IssueVariants", Label("CountIssueVariants"), func() { Context("and the database is empty", func() { It("can count correctly", func() { - c, err := db.CountIssueVariants(nil) + c, err := db.CountIssueVariants(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -351,7 +353,7 @@ var _ = Describe("IssueVariant - ", Label("database", "IssueVariant"), func() { }) Context("and using no filter", func() { It("can count", func() { - c, err := db.CountIssueVariants(nil) + c, err := db.CountIssueVariants(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -370,7 +372,7 @@ var _ = Describe("IssueVariant - ", Label("database", "IssueVariant"), func() { After: nil, }, } - c, err := db.CountIssueVariants(filter) + c, err := db.CountIssueVariants(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -395,7 +397,7 @@ var _ = Describe("IssueVariant - ", Label("database", "IssueVariant"), func() { }, IssueId: []*int64{&issueId}, } - entries, err := db.CountIssueVariants(filter) + entries, err := db.CountIssueVariants(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -436,7 +438,7 @@ var _ = Describe("IssueVariant - ", Label("database", "IssueVariant"), func() { Id: []*int64{&issueVariant.Id}, } - iv, err := db.GetIssueVariants(issueVariantFilter, []entity.Order{}) + iv, err := db.GetIssueVariants(context.Background(), issueVariantFilter, []entity.Order{}) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -493,7 +495,7 @@ var _ = Describe("IssueVariant - ", Label("database", "IssueVariant"), func() { Id: []*int64{&issueVariant.Id}, } - iv, err := db.GetIssueVariants(issueVariantFilter, []entity.Order{}) + iv, err := db.GetIssueVariants(context.Background(), issueVariantFilter, []entity.Order{}) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -537,7 +539,7 @@ var _ = Describe("IssueVariant - ", Label("database", "IssueVariant"), func() { Id: []*int64{&issueVariant.Id}, } - iv, err := db.GetIssueVariants(issueVariantFilter, []entity.Order{}) + iv, err := db.GetIssueVariants(context.Background(), issueVariantFilter, []entity.Order{}) By("throwing no error", func() { Expect(err).To(BeNil()) }) diff --git a/internal/database/mariadb/mv_vulnerabilities.go b/internal/database/mariadb/mv_vulnerabilities.go index bf889c9e4..bcc6cf124 100644 --- a/internal/database/mariadb/mv_vulnerabilities.go +++ b/internal/database/mariadb/mv_vulnerabilities.go @@ -4,6 +4,7 @@ package mariadb import ( + "context" "fmt" "github.com/cloudoperators/heureka/internal/entity" @@ -40,6 +41,7 @@ func getCountTable(filter *entity.IssueFilter) string { } func (s *SqlDatabase) CountIssueRatings( + ctx context.Context, filter *entity.IssueFilter, ) (*entity.IssueSeverityCounts, error) { l := logrus.WithFields(logrus.Fields{ @@ -93,7 +95,7 @@ func (s *SqlDatabase) CountIssueRatings( query = fmt.Sprintf("%s WHERE %s", query, filterStr) } - stmt, err := s.db.Preparex(query) + stmt, err := s.db.PreparexContext(ctx, query) if err != nil { msg := ERROR_MSG_PREPARED_STMT l.WithFields( @@ -113,6 +115,7 @@ func (s *SqlDatabase) CountIssueRatings( }() counts, err := performListScan( + ctx, stmt, filterParameters, l, diff --git a/internal/database/mariadb/mv_vulnerabilities_test.go b/internal/database/mariadb/mv_vulnerabilities_test.go index b81b1072c..eb86423b0 100644 --- a/internal/database/mariadb/mv_vulnerabilities_test.go +++ b/internal/database/mariadb/mv_vulnerabilities_test.go @@ -4,6 +4,7 @@ package mariadb_test import ( + "context" "database/sql" "fmt" "time" @@ -23,7 +24,7 @@ var _ = Describe("Counting Issues by Severity", Label("IssueCounts"), func() { var seedCollection *test.SeedCollection testIssueSeverityCount := func(filter *entity.IssueFilter, counts entity.IssueSeverityCounts) { - issueSeverityCounts, err := db.CountIssueRatings(filter) + issueSeverityCounts, err := db.CountIssueRatings(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -395,10 +396,11 @@ var _ = Describe("Counting Issues by Severity", Label("IssueCounts"), func() { Expect(err).To(BeNil()) Expect(seeder.RefreshCountIssueRatings()).To(BeNil()) - counts, err := db.CountIssueRatings(&entity.IssueFilter{ - ComponentVersionId: []*int64{&cv.Id.Int64}, - ServiceId: []*int64{&remediation.ServiceId}, - }) + counts, err := db.CountIssueRatings(context.Background(), + &entity.IssueFilter{ + ComponentVersionId: []*int64{&cv.Id.Int64}, + ServiceId: []*int64{&remediation.ServiceId}, + }) Expect(err).To(BeNil()) @@ -409,10 +411,11 @@ var _ = Describe("Counting Issues by Severity", Label("IssueCounts"), func() { Expect(counts.None).To(BeEquivalentTo(0)) Expect(counts.Total).To(BeEquivalentTo(0)) - countsEmpty, err := db.CountIssueRatings(&entity.IssueFilter{ - ComponentVersionId: []*int64{&cv.Id.Int64}, - ServiceId: []*int64{&newCi.ServiceId}, - }) + countsEmpty, err := db.CountIssueRatings(context.Background(), + &entity.IssueFilter{ + ComponentVersionId: []*int64{&cv.Id.Int64}, + ServiceId: []*int64{&newCi.ServiceId}, + }) Expect(err).To(BeNil()) cvId := fmt.Sprintf("%d", cv.Id.Int64) diff --git a/internal/database/mariadb/patch.go b/internal/database/mariadb/patch.go index 37fd57e05..bcd015e24 100644 --- a/internal/database/mariadb/patch.go +++ b/internal/database/mariadb/patch.go @@ -4,6 +4,7 @@ package mariadb import ( + "context" "fmt" "github.com/cloudoperators/heureka/internal/entity" @@ -56,6 +57,7 @@ func ensurePatchFilter(filter *entity.PatchFilter) *entity.PatchFilter { } func (s *SqlDatabase) buildPatchStatement( + ctx context.Context, baseQuery string, filter *entity.PatchFilter, withCursor bool, @@ -81,7 +83,7 @@ func (s *SqlDatabase) buildPatchStatement( query = fmt.Sprintf(baseQuery, whereClause, ord) } - stmt, err := s.db.Preparex(query) + stmt, err := s.db.PreparexContext(ctx, query) if err != nil { msg := ERROR_MSG_PREPARED_STMT l.WithFields( @@ -100,6 +102,7 @@ func (s *SqlDatabase) buildPatchStatement( } func (s *SqlDatabase) GetPatches( + ctx context.Context, filter *entity.PatchFilter, order []entity.Order, ) ([]entity.PatchResult, error) { @@ -116,7 +119,7 @@ func (s *SqlDatabase) GetPatches( GROUP BY P.patch_id ORDER BY %s LIMIT ? ` - stmt, filterParameters, err := s.buildPatchStatement(baseQuery, filter, true, order, l) + stmt, filterParameters, err := s.buildPatchStatement(ctx, baseQuery, filter, true, order, l) if err != nil { return nil, fmt.Errorf("failed to build Patch query: %w", err) } @@ -128,6 +131,7 @@ func (s *SqlDatabase) GetPatches( }() results, err := performListScan( + ctx, stmt, filterParameters, l, @@ -152,7 +156,7 @@ func (s *SqlDatabase) GetPatches( return results, nil } -func (s *SqlDatabase) CountPatches(filter *entity.PatchFilter) (int64, error) { +func (s *SqlDatabase) CountPatches(ctx context.Context, filter *entity.PatchFilter) (int64, error) { l := logrus.WithFields(logrus.Fields{ "event": "database.CountPatches", "filter": filter, @@ -165,6 +169,7 @@ func (s *SqlDatabase) CountPatches(filter *entity.PatchFilter) (int64, error) { ` stmt, filterParameters, err := s.buildPatchStatement( + ctx, baseQuery, filter, false, @@ -181,7 +186,7 @@ func (s *SqlDatabase) CountPatches(filter *entity.PatchFilter) (int64, error) { } }() - count, err := performCountScan(stmt, filterParameters, l) + count, err := performCountScan(ctx, stmt, filterParameters, l) if err != nil { return -1, fmt.Errorf("failed to count Patches: %w", err) } @@ -190,6 +195,7 @@ func (s *SqlDatabase) CountPatches(filter *entity.PatchFilter) (int64, error) { } func (s *SqlDatabase) GetAllPatchCursors( + ctx context.Context, filter *entity.PatchFilter, order []entity.Order, ) ([]string, error) { @@ -205,7 +211,7 @@ func (s *SqlDatabase) GetAllPatchCursors( filter = ensurePatchFilter(filter) - stmt, filterParameters, err := s.buildPatchStatement(baseQuery, filter, false, order, l) + stmt, filterParameters, err := s.buildPatchStatement(ctx, baseQuery, filter, false, order, l) if err != nil { return nil, fmt.Errorf("failed to build Patch cursor query: %w", err) } @@ -217,6 +223,7 @@ func (s *SqlDatabase) GetAllPatchCursors( }() rows, err := performListScan( + ctx, stmt, filterParameters, l, diff --git a/internal/database/mariadb/patch_test.go b/internal/database/mariadb/patch_test.go index d03c93d18..4111ad8a6 100644 --- a/internal/database/mariadb/patch_test.go +++ b/internal/database/mariadb/patch_test.go @@ -4,6 +4,8 @@ package mariadb_test import ( + "context" + "github.com/cloudoperators/heureka/internal/database/mariadb" "github.com/cloudoperators/heureka/internal/database/mariadb/test" "github.com/cloudoperators/heureka/internal/entity" @@ -27,7 +29,7 @@ var _ = Describe("Patch", Label("database", "Patch"), func() { When("Getting Patches", Label("GetPatches"), func() { Context("and the database is empty", func() { It("can perform the list query", func() { - res, err := db.GetPatches(nil, nil) + res, err := db.GetPatches(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -44,7 +46,7 @@ var _ = Describe("Patch", Label("database", "Patch"), func() { Context("and using no filter", func() { It("can fetch the items correctly", func() { - res, err := db.GetPatches(nil, nil) + res, err := db.GetPatches(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -97,7 +99,7 @@ var _ = Describe("Patch", Label("database", "Patch"), func() { row := test.PickOne(seedCollection.PatchRows) filter := &entity.PatchFilter{Id: []*int64{&row.Id.Int64}} - entries, err := db.GetPatches(filter, nil) + entries, err := db.GetPatches(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -113,7 +115,7 @@ var _ = Describe("Patch", Label("database", "Patch"), func() { row := test.PickOne(seedCollection.PatchRows) filter := &entity.PatchFilter{ServiceId: []*int64{&row.ServiceId.Int64}} - entries, err := db.GetPatches(filter, nil) + entries, err := db.GetPatches(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -131,7 +133,7 @@ var _ = Describe("Patch", Label("database", "Patch"), func() { row := test.PickOne(seedCollection.PatchRows) filter := &entity.PatchFilter{ServiceName: []*string{&row.ServiceName.String}} - entries, err := db.GetPatches(filter, nil) + entries, err := db.GetPatches(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -151,7 +153,7 @@ var _ = Describe("Patch", Label("database", "Patch"), func() { ComponentVersionId: []*int64{&row.ComponentVersionId.Int64}, } - entries, err := db.GetPatches(filter, nil) + entries, err := db.GetPatches(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -173,7 +175,7 @@ var _ = Describe("Patch", Label("database", "Patch"), func() { ComponentVersionName: []*string{&row.ComponentVersionName.String}, } - entries, err := db.GetPatches(filter, nil) + entries, err := db.GetPatches(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -222,7 +224,7 @@ var _ = Describe("Patch", Label("database", "Patch"), func() { When("Counting Patches", Label("CountPatches"), func() { Context("and the database is empty", func() { It("can count correctly", func() { - c, err := db.CountPatches(nil) + c, err := db.CountPatches(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -243,7 +245,7 @@ var _ = Describe("Patch", Label("database", "Patch"), func() { }) Context("and using no filter", func() { It("can count", func() { - c, err := db.CountPatches(nil) + c, err := db.CountPatches(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -263,7 +265,7 @@ var _ = Describe("Patch", Label("database", "Patch"), func() { After: &after, }, } - c, err := db.CountPatches(filter) + c, err := db.CountPatches(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) diff --git a/internal/database/mariadb/quiet_db.go b/internal/database/mariadb/quiet_db.go index 2f069c916..a26626b9e 100644 --- a/internal/database/mariadb/quiet_db.go +++ b/internal/database/mariadb/quiet_db.go @@ -4,6 +4,7 @@ package mariadb import ( + "context" "database/sql" "github.com/jmoiron/sqlx" @@ -37,6 +38,10 @@ func (qdb *QuietDb) Preparex(query string) (Stmt, error) { return qdb.db.Preparex(query) } +func (qdb *QuietDb) PreparexContext(ctx context.Context, query string) (Stmt, error) { + return qdb.db.PreparexContext(ctx, query) +} + func (qdb *QuietDb) Select(dest any, query string, args ...any) error { return qdb.db.Select(dest, query, args...) } @@ -45,6 +50,10 @@ func (qdb *QuietDb) Query(query string, args ...any) (SqlRows, error) { return qdb.db.Query(query, args...) } +func (qdb *QuietDb) QueryContext(ctx context.Context, query string, args ...any) (SqlRows, error) { + return qdb.db.QueryContext(ctx, query, args...) +} + func (qdb *QuietDb) QueryRow(query string, args ...any) *sql.Row { return qdb.db.QueryRow(query, args...) } diff --git a/internal/database/mariadb/remediation.go b/internal/database/mariadb/remediation.go index 19519e4eb..242ff0af4 100644 --- a/internal/database/mariadb/remediation.go +++ b/internal/database/mariadb/remediation.go @@ -4,6 +4,7 @@ package mariadb import ( + "context" "database/sql" "fmt" "time" @@ -191,6 +192,7 @@ func ensureRemediationFilter(filter *entity.RemediationFilter) *entity.Remediati } func (s *SqlDatabase) buildRemediationStatement( + ctx context.Context, baseQuery string, filter *entity.RemediationFilter, withCursor bool, @@ -216,7 +218,7 @@ func (s *SqlDatabase) buildRemediationStatement( query = fmt.Sprintf(baseQuery, whereClause, ord) } - stmt, err := s.db.Preparex(query) + stmt, err := s.db.PreparexContext(ctx, query) if err != nil { msg := ERROR_MSG_PREPARED_STMT l.WithFields( @@ -235,6 +237,7 @@ func (s *SqlDatabase) buildRemediationStatement( } func (s *SqlDatabase) GetRemediations( + ctx context.Context, filter *entity.RemediationFilter, order []entity.Order, ) ([]entity.RemediationResult, error) { @@ -251,7 +254,7 @@ func (s *SqlDatabase) GetRemediations( GROUP BY R.remediation_id ORDER BY %s LIMIT ? ` - stmt, filterParameters, err := s.buildRemediationStatement(baseQuery, filter, true, order, l) + stmt, filterParameters, err := s.buildRemediationStatement(ctx, baseQuery, filter, true, order, l) if err != nil { return nil, fmt.Errorf("failed to build Remediation query: %w", err) } @@ -263,6 +266,7 @@ func (s *SqlDatabase) GetRemediations( }() results, err := performListScan( + ctx, stmt, filterParameters, l, @@ -287,7 +291,7 @@ func (s *SqlDatabase) GetRemediations( return results, nil } -func (s *SqlDatabase) CountRemediations(filter *entity.RemediationFilter) (int64, error) { +func (s *SqlDatabase) CountRemediations(ctx context.Context, filter *entity.RemediationFilter) (int64, error) { l := logrus.WithFields(logrus.Fields{ "event": "database.CountRemediations", "filter": filter, @@ -300,6 +304,7 @@ func (s *SqlDatabase) CountRemediations(filter *entity.RemediationFilter) (int64 ` stmt, filterParameters, err := s.buildRemediationStatement( + ctx, baseQuery, filter, false, @@ -316,7 +321,7 @@ func (s *SqlDatabase) CountRemediations(filter *entity.RemediationFilter) (int64 } }() - count, err := performCountScan(stmt, filterParameters, l) + count, err := performCountScan(ctx, stmt, filterParameters, l) if err != nil { return -1, fmt.Errorf("failed to count Remediations: %w", err) } @@ -325,6 +330,7 @@ func (s *SqlDatabase) CountRemediations(filter *entity.RemediationFilter) (int64 } func (s *SqlDatabase) GetAllRemediationCursors( + ctx context.Context, filter *entity.RemediationFilter, order []entity.Order, ) ([]string, error) { @@ -340,7 +346,7 @@ func (s *SqlDatabase) GetAllRemediationCursors( filter = ensureRemediationFilter(filter) - stmt, filterParameters, err := s.buildRemediationStatement(baseQuery, filter, false, order, l) + stmt, filterParameters, err := s.buildRemediationStatement(ctx, baseQuery, filter, false, order, l) if err != nil { return nil, fmt.Errorf("failed to build Remediation cursor query: %w", err) } @@ -352,6 +358,7 @@ func (s *SqlDatabase) GetAllRemediationCursors( }() rows, err := performListScan( + ctx, stmt, filterParameters, l, diff --git a/internal/database/mariadb/remediation_test.go b/internal/database/mariadb/remediation_test.go index c04538b14..b64637541 100644 --- a/internal/database/mariadb/remediation_test.go +++ b/internal/database/mariadb/remediation_test.go @@ -4,6 +4,7 @@ package mariadb_test import ( + "context" "database/sql" "time" @@ -32,7 +33,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { When("Getting Remediations", Label("GetRemediations"), func() { Context("and the database is empty", func() { It("can perform the list query", func() { - res, err := db.GetRemediations(nil, nil) + res, err := db.GetRemediations(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -49,7 +50,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { Context("and using no filter", func() { It("can fetch the items correctly", func() { - res, err := db.GetRemediations(nil, nil) + res, err := db.GetRemediations(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -123,7 +124,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { row := test.PickOne(seedCollection.RemediationRows) filter := &entity.RemediationFilter{Id: []*int64{&row.Id.Int64}} - entries, err := db.GetRemediations(filter, nil) + entries, err := db.GetRemediations(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -139,7 +140,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { severity := gofakeit.RandomString(entity.AllSeverityValuesString) filter := &entity.RemediationFilter{Severity: []*string{&severity}} - entries, err := db.GetRemediations(filter, nil) + entries, err := db.GetRemediations(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -155,7 +156,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { row := test.PickOne(seedCollection.RemediationRows) filter := &entity.RemediationFilter{Service: []*string{&row.Service.String}} - entries, err := db.GetRemediations(filter, nil) + entries, err := db.GetRemediations(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -173,7 +174,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { row := test.PickOne(seedCollection.RemediationRows) filter := &entity.RemediationFilter{ServiceId: []*int64{&row.ServiceId.Int64}} - entries, err := db.GetRemediations(filter, nil) + entries, err := db.GetRemediations(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -191,7 +192,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { row := test.PickOne(seedCollection.RemediationRows) filter := &entity.RemediationFilter{Component: []*string{&row.Component.String}} - entries, err := db.GetRemediations(filter, nil) + entries, err := db.GetRemediations(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -211,7 +212,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { ComponentId: []*int64{&row.ComponentId.Int64}, } - entries, err := db.GetRemediations(filter, nil) + entries, err := db.GetRemediations(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -229,7 +230,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { row := test.PickOne(seedCollection.RemediationRows) filter := &entity.RemediationFilter{Issue: []*string{&row.Issue.String}} - entries, err := db.GetRemediations(filter, nil) + entries, err := db.GetRemediations(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -247,7 +248,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { row := test.PickOne(seedCollection.RemediationRows) filter := &entity.RemediationFilter{IssueId: []*int64{&row.IssueId.Int64}} - entries, err := db.GetRemediations(filter, nil) + entries, err := db.GetRemediations(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -265,7 +266,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { row := test.PickOne(seedCollection.RemediationRows) filter := &entity.RemediationFilter{Type: []*string{&row.Type.String}} - entries, err := db.GetRemediations(filter, nil) + entries, err := db.GetRemediations(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -283,7 +284,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { remediationType := entity.RemediationTypeFalsePositive.String() filter := &entity.RemediationFilter{Type: []*string{&remediationType}} - entries, err := db.GetRemediations(filter, nil) + entries, err := db.GetRemediations(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -305,7 +306,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { remediationType := entity.RemediationTypeRiskAccepted.String() filter := &entity.RemediationFilter{Type: []*string{&remediationType}} - entries, err := db.GetRemediations(filter, nil) + entries, err := db.GetRemediations(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -327,7 +328,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { remediationType := entity.RemediationTypeMitigation.String() filter := &entity.RemediationFilter{Type: []*string{&remediationType}} - entries, err := db.GetRemediations(filter, nil) + entries, err := db.GetRemediations(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -349,7 +350,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { remediationType := entity.RemediationTypeRescore.String() filter := &entity.RemediationFilter{Type: []*string{&remediationType}} - entries, err := db.GetRemediations(filter, nil) + entries, err := db.GetRemediations(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -373,7 +374,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { searchStr := test.CutString(row.Issue.String, 2, 2, 5) filter := &entity.RemediationFilter{Search: []*string{&searchStr}} - entries, err := db.GetRemediations(filter, nil) + entries, err := db.GetRemediations(context.Background(), filter, nil) ids := []int64{} for _, entry := range entries { @@ -440,7 +441,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { order := []entity.Order{ {By: entity.RemediationIssue, Direction: entity.OrderDirectionAsc}, } - entries, err := db.GetRemediations(nil, order) + entries, err := db.GetRemediations(context.Background(), nil, order) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -456,7 +457,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { order := []entity.Order{ {By: entity.RemediationSeverity, Direction: entity.OrderDirectionDesc}, } - entries, err := db.GetRemediations(nil, order) + entries, err := db.GetRemediations(context.Background(), nil, order) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -472,7 +473,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { order := []entity.Order{ {By: entity.RemediationExpirationDate, Direction: entity.OrderDirectionAsc}, } - entries, err := db.GetRemediations(nil, order) + entries, err := db.GetRemediations(context.Background(), nil, order) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -490,7 +491,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { When("Counting Remediations", Label("CountRemediations"), func() { Context("and the database is empty", func() { It("can count correctly", func() { - c, err := db.CountRemediations(nil) + c, err := db.CountRemediations(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -511,7 +512,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { }) Context("and using no filter", func() { It("can count", func() { - c, err := db.CountRemediations(nil) + c, err := db.CountRemediations(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -531,7 +532,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { After: &after, }, } - c, err := db.CountRemediations(filter) + c, err := db.CountRemediations(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -585,7 +586,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { Id: []*int64{&remediation.Id}, } - r, err := db.GetRemediations(remediationFilter, nil) + r, err := db.GetRemediations(context.Background(), remediationFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -615,7 +616,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { Id: []*int64{&remediation.Id}, } - r, err := db.GetRemediations(remediationFilter, nil) + r, err := db.GetRemediations(context.Background(), remediationFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -647,7 +648,7 @@ var _ = Describe("Remediation", Label("database", "Remediation"), func() { Id: []*int64{&remediation.Id}, } - r, err := db.GetRemediations(remediationFilter, nil) + r, err := db.GetRemediations(context.Background(), remediationFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) diff --git a/internal/database/mariadb/service.go b/internal/database/mariadb/service.go index 496b726a5..b35257bc9 100644 --- a/internal/database/mariadb/service.go +++ b/internal/database/mariadb/service.go @@ -4,6 +4,7 @@ package mariadb import ( + "context" "errors" "fmt" @@ -219,6 +220,7 @@ func (s *SqlDatabase) getServiceColumns(filter *entity.ServiceFilter, order []en } func (s *SqlDatabase) buildServiceStatement( + ctx context.Context, baseQuery string, filter *entity.ServiceFilter, withCursor bool, @@ -248,7 +250,7 @@ func (s *SqlDatabase) buildServiceStatement( } // construct prepared statement and if where clause does exist add parameters - stmt, err := s.db.Preparex(query) + stmt, err := s.db.PreparexContext(ctx, query) if err != nil { msg := ERROR_MSG_PREPARED_STMT l.WithFields( @@ -267,7 +269,7 @@ func (s *SqlDatabase) buildServiceStatement( return stmt, filterParameters, nil } -func (s *SqlDatabase) CountServices(filter *entity.ServiceFilter) (int64, error) { +func (s *SqlDatabase) CountServices(ctx context.Context, filter *entity.ServiceFilter) (int64, error) { l := logrus.WithFields(logrus.Fields{ "event": "database.CountServices", }) @@ -280,6 +282,7 @@ func (s *SqlDatabase) CountServices(filter *entity.ServiceFilter) (int64, error) ` stmt, filterParameters, err := s.buildServiceStatement( + ctx, baseQuery, filter, false, @@ -296,10 +299,11 @@ func (s *SqlDatabase) CountServices(filter *entity.ServiceFilter) (int64, error) } }() - return performCountScan(stmt, filterParameters, l) + return performCountScan(ctx, stmt, filterParameters, l) } func (s *SqlDatabase) GetServices( + ctx context.Context, filter *entity.ServiceFilter, order []entity.Order, ) ([]entity.ServiceResult, error) { @@ -318,7 +322,7 @@ func (s *SqlDatabase) GetServices( columns := s.getServiceColumns(filter, order) baseQuery = fmt.Sprintf(baseQuery, columns, "%s", "%s", "%s", "%s") - stmt, filterParameters, err := s.buildServiceStatement(baseQuery, filter, true, order, l) + stmt, filterParameters, err := s.buildServiceStatement(ctx, baseQuery, filter, true, order, l) if err != nil { return nil, err } @@ -330,6 +334,7 @@ func (s *SqlDatabase) GetServices( }() return performListScan( + ctx, stmt, filterParameters, l, @@ -358,6 +363,7 @@ func (s *SqlDatabase) GetServices( } func (s *SqlDatabase) GetServicesWithAggregations( + ctx context.Context, filter *entity.ServiceFilter, order []entity.Order, ) ([]entity.ServiceResult, error) { @@ -444,7 +450,7 @@ func (s *SqlDatabase) GetServicesWithAggregations( ciQuery := fmt.Sprintf(baseCiQuery, columns, joins, whereClause, cursorQuery, ord) query := fmt.Sprintf(baseQuery, imQuery, ciQuery) - stmt, err := s.db.Preparex(query) + stmt, err := s.db.PreparexContext(ctx, query) if err != nil { msg := ERROR_MSG_PREPARED_STMT l.WithFields( @@ -471,6 +477,7 @@ func (s *SqlDatabase) GetServicesWithAggregations( }() return performListScan( + ctx, stmt, filterParameters, l, @@ -501,6 +508,7 @@ func (s *SqlDatabase) GetServicesWithAggregations( } func (s *SqlDatabase) GetAllServiceCursors( + ctx context.Context, filter *entity.ServiceFilter, order []entity.Order, ) ([]string, error) { @@ -519,7 +527,7 @@ func (s *SqlDatabase) GetAllServiceCursors( columns := s.getServiceColumns(filter, order) baseQuery = fmt.Sprintf(baseQuery, columns, "%s", "%s", "%s") - stmt, filterParameters, err := s.buildServiceStatement(baseQuery, filter, false, order, l) + stmt, filterParameters, err := s.buildServiceStatement(ctx, baseQuery, filter, false, order, l) if err != nil { return nil, err } @@ -531,6 +539,7 @@ func (s *SqlDatabase) GetAllServiceCursors( }() rows, err := performListScan( + ctx, stmt, filterParameters, l, @@ -703,6 +712,7 @@ func (s *SqlDatabase) RemoveIssueRepositoryFromService( } func (s *SqlDatabase) getServiceAttr( + ctx context.Context, attrName string, filter *entity.ServiceFilter, ) ([]string, error) { @@ -727,7 +737,7 @@ func (s *SqlDatabase) getServiceAttr( } // Builds full statement with possible joins and filters - stmt, filterParameters, err := s.buildServiceStatement(baseQuery, filter, false, order, l) + stmt, filterParameters, err := s.buildServiceStatement(context.Background(), baseQuery, filter, false, order, l) if err != nil { l.Error("Error preparing statement: ", err) return nil, err @@ -740,7 +750,7 @@ func (s *SqlDatabase) getServiceAttr( }() // Execute the query - rows, err := stmt.Queryx(filterParameters...) + rows, err := stmt.QueryxContext(context.Background(), filterParameters...) if err != nil { l.Error("Error executing query: ", err) return nil, err @@ -773,8 +783,8 @@ func (s *SqlDatabase) getServiceAttr( return serviceAttrs, nil } -func (s *SqlDatabase) GetServiceCcrns(filter *entity.ServiceFilter) ([]string, error) { - ccrns, err := s.getServiceAttr("ccrn", filter) +func (s *SqlDatabase) GetServiceCcrns(ctx context.Context, filter *entity.ServiceFilter) ([]string, error) { + ccrns, err := s.getServiceAttr(ctx, "ccrn", filter) if err != nil { return nil, fmt.Errorf("failed to get Service ccrns: %w", err) } @@ -782,8 +792,8 @@ func (s *SqlDatabase) GetServiceCcrns(filter *entity.ServiceFilter) ([]string, e return ccrns, nil } -func (s *SqlDatabase) GetServiceDomains(filter *entity.ServiceFilter) ([]string, error) { - domains, err := s.getServiceAttr("domain", filter) +func (s *SqlDatabase) GetServiceDomains(ctx context.Context, filter *entity.ServiceFilter) ([]string, error) { + domains, err := s.getServiceAttr(ctx, "domain", filter) if err != nil { return nil, fmt.Errorf("failed to get Service domains: %w", err) } @@ -791,8 +801,8 @@ func (s *SqlDatabase) GetServiceDomains(filter *entity.ServiceFilter) ([]string, return domains, nil } -func (s *SqlDatabase) GetServiceRegions(filter *entity.ServiceFilter) ([]string, error) { - regions, err := s.getServiceAttr("region", filter) +func (s *SqlDatabase) GetServiceRegions(ctx context.Context, filter *entity.ServiceFilter) ([]string, error) { + regions, err := s.getServiceAttr(ctx, "region", filter) if err != nil { return nil, fmt.Errorf("failed to get Service regions: %w", err) } diff --git a/internal/database/mariadb/service_issue_variant.go b/internal/database/mariadb/service_issue_variant.go index ea79c8e80..50148cff3 100644 --- a/internal/database/mariadb/service_issue_variant.go +++ b/internal/database/mariadb/service_issue_variant.go @@ -4,6 +4,7 @@ package mariadb import ( + "context" "fmt" "github.com/cloudoperators/heureka/internal/entity" @@ -45,6 +46,7 @@ func ensureServiceIssueVariantFilter( } func (s *SqlDatabase) buildServiceIssueVariantStatement( + ctx context.Context, baseQuery string, filter *entity.ServiceIssueVariantFilter, withCursor bool, @@ -71,7 +73,7 @@ func (s *SqlDatabase) buildServiceIssueVariantStatement( } // construct prepared statement and if where clause does exist add parameters - stmt, err := s.db.Preparex(query) + stmt, err := s.db.PreparexContext(ctx, query) if err != nil { msg := ERROR_MSG_PREPARED_STMT l.WithFields( @@ -95,6 +97,7 @@ func (s *SqlDatabase) buildServiceIssueVariantStatement( // TODO: adjust this function to fit dbObject func (s *SqlDatabase) GetServiceIssueVariants( + ctx context.Context, filter *entity.ServiceIssueVariantFilter, order []entity.Order, ) ([]entity.ServiceIssueVariantResult, error) { @@ -121,6 +124,7 @@ func (s *SqlDatabase) GetServiceIssueVariants( ` stmt, filterParameters, err := s.buildServiceIssueVariantStatement( + ctx, baseQuery, filter, true, @@ -138,6 +142,7 @@ func (s *SqlDatabase) GetServiceIssueVariants( }() return performListScan( + ctx, stmt, filterParameters, l, diff --git a/internal/database/mariadb/service_issue_variant_test.go b/internal/database/mariadb/service_issue_variant_test.go index 9630510b8..8e29fab3c 100644 --- a/internal/database/mariadb/service_issue_variant_test.go +++ b/internal/database/mariadb/service_issue_variant_test.go @@ -4,6 +4,7 @@ package mariadb_test import ( + "context" "database/sql" "fmt" @@ -33,7 +34,7 @@ var _ = Describe("ServiceIssueVariant - ", Label("database", "IssueVariant"), fu When("Getting ServiceIssueVariants", Label("GetServiceIssueVariants"), func() { Context("and the database is empty", func() { It("can perform the query", func() { - res, err := db.GetServiceIssueVariants(nil, nil) + res, err := db.GetServiceIssueVariants(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -50,7 +51,7 @@ var _ = Describe("ServiceIssueVariant - ", Label("database", "IssueVariant"), fu // this should work and give me all combinations back Context("and using no filter", func() { It("Should work", func() { - _, err := db.GetServiceIssueVariants(nil, nil) + _, err := db.GetServiceIssueVariants(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).Should(BeNil()) }) @@ -192,7 +193,7 @@ var _ = Describe("ServiceIssueVariant - ", Label("database", "IssueVariant"), fu Paginated: entity.Paginated{}, ComponentInstanceId: cids, } - res, err := db.GetServiceIssueVariants(filter, []entity.Order{}) + res, err := db.GetServiceIssueVariants(context.Background(), filter, []entity.Order{}) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -228,7 +229,7 @@ var _ = Describe("ServiceIssueVariant - ", Label("database", "IssueVariant"), fu IssueId: []*int64{someId}, } - res, err := db.GetServiceIssueVariants(filter, []entity.Order{}) + res, err := db.GetServiceIssueVariants(context.Background(), filter, []entity.Order{}) Expect(err).To(BeNil()) Expect(res).To(BeEmpty()) @@ -318,7 +319,7 @@ var _ = Describe("ServiceIssueVariant - ", Label("database", "IssueVariant"), fu IssueId: []*int64{new(issue.Id.Int64)}, } - res, err := db.GetServiceIssueVariants(filter, []entity.Order{}) + res, err := db.GetServiceIssueVariants(context.Background(), filter, []entity.Order{}) Expect(err).To(BeNil()) Expect(res).To(HaveLen(5)) // One variant per repository diff --git a/internal/database/mariadb/service_test.go b/internal/database/mariadb/service_test.go index aaaa259a7..d82169010 100644 --- a/internal/database/mariadb/service_test.go +++ b/internal/database/mariadb/service_test.go @@ -4,6 +4,7 @@ package mariadb_test import ( + "context" "sort" "github.com/cloudoperators/heureka/internal/database/mariadb" @@ -35,7 +36,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { When("Getting Services", Label("GetServices"), func() { Context("and the database is empty", func() { It("can perform the list query", func() { - res, err := db.GetServices(nil, nil) + res, err := db.GetServices(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -52,7 +53,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { Context("and using no filter", func() { It("can fetch the items correctly", func() { - res, err := db.GetServices(nil, nil) + res, err := db.GetServices(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -102,7 +103,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { row := test.PickOne(seedCollection.ServiceRows) filter := &entity.ServiceFilter{CCRN: []*string{&row.CCRN.String}} - entries, err := db.GetServices(filter, nil) + entries, err := db.GetServices(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -120,7 +121,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { nonExistingName := pkg_util.GenerateRandomString(40, nil) filter := &entity.ServiceFilter{CCRN: []*string{&nonExistingName}} - entries, err := db.GetServices(filter, nil) + entries, err := db.GetServices(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -137,7 +138,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { } filter := &entity.ServiceFilter{CCRN: serviceCcrns} - entries, err := db.GetServices(filter, nil) + entries, err := db.GetServices(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -151,7 +152,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { row := test.PickOne(seedCollection.ServiceRows) filter := &entity.ServiceFilter{Domain: []*string{&row.Domain.String}} - entries, err := db.GetServices(filter, nil) + entries, err := db.GetServices(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -169,7 +170,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { row := test.PickOne(seedCollection.ServiceRows) filter := &entity.ServiceFilter{Region: []*string{&row.Region.String}} - entries, err := db.GetServices(filter, nil) + entries, err := db.GetServices(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -187,7 +188,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { row := test.PickOne(seedCollection.ServiceRows) filter := &entity.ServiceFilter{Id: []*int64{&row.Id.Int64}} - entries, err := db.GetServices(filter, nil) + entries, err := db.GetServices(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -211,7 +212,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { filter := &entity.ServiceFilter{SupportGroupCCRN: []*string{&sgRow.CCRN.String}} - entries, err := db.GetServices(filter, nil) + entries, err := db.GetServices(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -237,7 +238,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { filter := &entity.ServiceFilter{OwnerName: []*string{&userRow.Name.String}} - entries, err := db.GetServices(filter, nil) + entries, err := db.GetServices(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -263,7 +264,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { filter := &entity.ServiceFilter{OwnerId: []*int64{&ownerRow.UserId.Int64}} - entries, err := db.GetServices(filter, nil) + entries, err := db.GetServices(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -289,7 +290,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { filter := &entity.ServiceFilter{IssueRepositoryId: []*int64{&irRow.Id.Int64}} - entries, err := db.GetServices(filter, nil) + entries, err := db.GetServices(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -331,7 +332,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { filter := &entity.ServiceFilter{IssueId: []*int64{&imRow.IssueId.Int64}} - entries, err := db.GetServices(filter, nil) + entries, err := db.GetServices(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -357,7 +358,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { filter := &entity.ServiceFilter{SupportGroupId: []*int64{&sgRow.Id.Int64}} - entries, err := db.GetServices(filter, nil) + entries, err := db.GetServices(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -375,7 +376,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { Expect(ok).To(BeTrue()) filter := &entity.ServiceFilter{ComponentInstanceId: []*int64{&ciRow.Id.Int64}} - entries, err := db.GetServices(filter, nil) + entries, err := db.GetServices(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -392,7 +393,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { searchStr := test.CutString(row.CCRN.String, 2, 2, 2) filter := &entity.ServiceFilter{Search: []*string{&searchStr}} - entries, err := db.GetServices(filter, nil) + entries, err := db.GetServices(context.Background(), filter, nil) names := []string{} for _, entry := range entries { @@ -453,7 +454,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { db.CreateService(&newService) }) It("returns the services with aggregations", func() { - entriesWithAggregations, err := db.GetServicesWithAggregations(nil, nil) + entriesWithAggregations, err := db.GetServicesWithAggregations(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -481,7 +482,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { _ = seeder.SeedDbWithNFakeData(10) }) It("returns the services with aggs", func() { - entriesWithAggregations, err := db.GetServicesWithAggregations(nil, nil) + entriesWithAggregations, err := db.GetServicesWithAggregations(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -512,7 +513,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { When("Counting Services", Label("CountServices"), func() { Context("and the database is empty", func() { It("can count correctly", func() { - c, err := db.CountServices(nil) + c, err := db.CountServices(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -533,7 +534,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { }) Context("and using no filter", func() { It("can count", func() { - c, err := db.CountServices(nil) + c, err := db.CountServices(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -553,7 +554,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { After: &after, }, } - c, err := db.CountServices(filter) + c, err := db.CountServices(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -585,7 +586,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { }, SupportGroupCCRN: []*string{&sgRow.CCRN.String}, } - entries, err := db.CountServices(filter) + entries, err := db.CountServices(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -625,7 +626,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { Id: []*int64{&service.Id}, } - s, err := db.GetServices(serviceFilter, nil) + s, err := db.GetServices(context.Background(), serviceFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -672,7 +673,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { Id: []*int64{&service.Id}, } - s, err := db.GetServices(serviceFilter, nil) + s, err := db.GetServices(context.Background(), serviceFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -704,7 +705,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { Id: []*int64{&service.Id}, } - s, err := db.GetServices(serviceFilter, nil) + s, err := db.GetServices(context.Background(), serviceFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -743,7 +744,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { OwnerId: []*int64{&owner.Id}, } - s, err := db.GetServices(serviceFilter, nil) + s, err := db.GetServices(context.Background(), serviceFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -785,7 +786,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { OwnerId: []*int64{&ownerRow.UserId.Int64}, } - services, err := db.GetServices(serviceFilter, nil) + services, err := db.GetServices(context.Background(), serviceFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -829,7 +830,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { IssueRepositoryId: []*int64{&issueRepository.Id}, } - s, err := db.GetServices(serviceFilter, nil) + s, err := db.GetServices(context.Background(), serviceFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -874,7 +875,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { IssueRepositoryId: []*int64{&issueRepositoryServiceRow.IssueRepositoryId.Int64}, } - services, err := db.GetServices(serviceFilter, nil) + services, err := db.GetServices(context.Background(), serviceFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -888,7 +889,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { When("Getting ServiceCcrns", Label("GetServiceCcrns"), func() { Context("and the database is empty", func() { It("can perform the list query", func() { - res, err := db.GetServiceCcrns(nil) + res, err := db.GetServiceCcrns(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -905,7 +906,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { Context("and using no filter", func() { It("can fetch the items correctly", func() { - res, err := db.GetServiceCcrns(nil) + res, err := db.GetServiceCcrns(context.Background(), nil) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -943,7 +944,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { } It("can fetch the filtered items correctly", func() { - res, err := db.GetServiceCcrns(filter) + res, err := db.GetServiceCcrns(context.Background(), filter) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -971,7 +972,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { } It("returns an empty list when no services match the filter", func() { - res, err := db.GetServiceCcrns(anotherFilter) + res, err := db.GetServiceCcrns(context.Background(), anotherFilter) Expect(err).Should(BeNil()) Expect(res).Should(BeEmpty()) @@ -992,7 +993,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { When("Getting ServiceDomains", Label("GetServiceDomains"), func() { Context("and the database is empty", func() { It("can perform the list query", func() { - res, err := db.GetServiceDomains(nil) + res, err := db.GetServiceDomains(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -1009,7 +1010,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { Context("and using no filter", func() { It("can fetch the items correctly", func() { - res, err := db.GetServiceDomains(nil) + res, err := db.GetServiceDomains(context.Background(), nil) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -1038,7 +1039,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { When("Getting ServiceRegions", Label("GetServiceRegions"), func() { Context("and the database is empty", func() { It("can perform the list query", func() { - res, err := db.GetServiceRegions(nil) + res, err := db.GetServiceRegions(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -1055,7 +1056,7 @@ var _ = Describe("Service", Label("database", "Service"), func() { Context("and using no filter", func() { It("can fetch the items correctly", func() { - res, err := db.GetServiceRegions(nil) + res, err := db.GetServiceRegions(context.Background(), nil) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -1105,7 +1106,7 @@ var _ = Describe("Ordering Services", Label("ServiceOrdering"), func() { order []entity.Order, verifyFunc func(res []entity.ServiceResult), ) { - res, err := db.GetServices(nil, order) + res, err := db.GetServices(context.Background(), nil, order) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -1176,7 +1177,7 @@ var _ = Describe("Ordering Services", Label("ServiceOrdering"), func() { {By: entity.LowCount, Direction: entity.OrderDirectionDesc}, {By: entity.NoneCount, Direction: entity.OrderDirectionDesc}, } - services, err := db.GetServices(nil, order) + services, err := db.GetServices(context.Background(), nil, order) Expect(err).To(BeNil()) Expect(services[0].Id).To(BeEquivalentTo(1)) Expect(services[1].Id).To(BeEquivalentTo(3)) @@ -1192,7 +1193,7 @@ var _ = Describe("Ordering Services", Label("ServiceOrdering"), func() { {By: entity.LowCount, Direction: entity.OrderDirectionAsc}, {By: entity.NoneCount, Direction: entity.OrderDirectionAsc}, } - services, err := db.GetServices(nil, order) + services, err := db.GetServices(context.Background(), nil, order) Expect(err).To(BeNil()) Expect(services[0].Id).To(BeEquivalentTo(2)) Expect(services[1].Id).To(BeEquivalentTo(5)) diff --git a/internal/database/mariadb/support_group.go b/internal/database/mariadb/support_group.go index e03157e49..75a4022b5 100644 --- a/internal/database/mariadb/support_group.go +++ b/internal/database/mariadb/support_group.go @@ -4,6 +4,7 @@ package mariadb import ( + "context" "errors" "fmt" @@ -118,6 +119,7 @@ func ensureSupportGroupFilter(filter *entity.SupportGroupFilter) *entity.Support } func (s *SqlDatabase) buildSupportGroupStatement( + ctx context.Context, baseQuery string, filter *entity.SupportGroupFilter, withCursor bool, @@ -146,7 +148,7 @@ func (s *SqlDatabase) buildSupportGroupStatement( } // construct prepared statement and if where clause does exist add parameters - stmt, err := s.db.Preparex(query) + stmt, err := s.db.PreparexContext(ctx, query) if err != nil { msg := ERROR_MSG_PREPARED_STMT l.WithFields( @@ -165,6 +167,7 @@ func (s *SqlDatabase) buildSupportGroupStatement( } func (s *SqlDatabase) GetAllSupportGroupCursors( + ctx context.Context, filter *entity.SupportGroupFilter, order []entity.Order, ) ([]string, error) { @@ -181,12 +184,13 @@ func (s *SqlDatabase) GetAllSupportGroupCursors( filter = ensureSupportGroupFilter(filter) - stmt, filterParameters, err := s.buildSupportGroupStatement(baseQuery, filter, false, order, l) + stmt, filterParameters, err := s.buildSupportGroupStatement(ctx, baseQuery, filter, false, order, l) if err != nil { return nil, err } rows, err := performListScan( + ctx, stmt, filterParameters, l, @@ -208,6 +212,7 @@ func (s *SqlDatabase) GetAllSupportGroupCursors( } func (s *SqlDatabase) GetSupportGroups( + ctx context.Context, filter *entity.SupportGroupFilter, order []entity.Order, ) ([]entity.SupportGroupResult, error) { @@ -223,7 +228,7 @@ func (s *SqlDatabase) GetSupportGroups( GROUP BY SG.supportgroup_id ORDER BY %s LIMIT ? ` - stmt, filterParameters, err := s.buildSupportGroupStatement(baseQuery, filter, true, order, l) + stmt, filterParameters, err := s.buildSupportGroupStatement(ctx, baseQuery, filter, true, order, l) if err != nil { return nil, err } @@ -235,6 +240,7 @@ func (s *SqlDatabase) GetSupportGroups( }() return performListScan( + ctx, stmt, filterParameters, l, @@ -254,7 +260,7 @@ func (s *SqlDatabase) GetSupportGroups( ) } -func (s *SqlDatabase) CountSupportGroups(filter *entity.SupportGroupFilter) (int64, error) { +func (s *SqlDatabase) CountSupportGroups(ctx context.Context, filter *entity.SupportGroupFilter) (int64, error) { l := logrus.WithFields(logrus.Fields{ "event": "database.CountSupportGroups", }) @@ -267,6 +273,7 @@ func (s *SqlDatabase) CountSupportGroups(filter *entity.SupportGroupFilter) (int ` stmt, filterParameters, err := s.buildSupportGroupStatement( + ctx, baseQuery, filter, false, @@ -283,7 +290,7 @@ func (s *SqlDatabase) CountSupportGroups(filter *entity.SupportGroupFilter) (int } }() - return performCountScan(stmt, filterParameters, l) + return performCountScan(ctx, stmt, filterParameters, l) } func (s *SqlDatabase) CreateSupportGroup( @@ -422,7 +429,7 @@ func (s *SqlDatabase) RemoveUserFromSupportGroup(supportGroupId int64, userId in return err } -func (s *SqlDatabase) GetSupportGroupCcrns(filter *entity.SupportGroupFilter) ([]string, error) { +func (s *SqlDatabase) GetSupportGroupCcrns(ctx context.Context, filter *entity.SupportGroupFilter) ([]string, error) { l := logrus.WithFields(logrus.Fields{ "filter": filter, "event": "database.GetSupportGroupCcrns", @@ -446,7 +453,7 @@ func (s *SqlDatabase) GetSupportGroupCcrns(filter *entity.SupportGroupFilter) ([ filter = ensureSupportGroupFilter(filter) // Builds full statement with possible joins and filters - stmt, filterParameters, err := s.buildSupportGroupStatement(baseQuery, filter, false, order, l) + stmt, filterParameters, err := s.buildSupportGroupStatement(ctx, baseQuery, filter, false, order, l) if err != nil { l.Error("Error preparing statement: ", err) return nil, err @@ -459,7 +466,7 @@ func (s *SqlDatabase) GetSupportGroupCcrns(filter *entity.SupportGroupFilter) ([ }() // Execute the query - rows, err := stmt.Queryx(filterParameters...) + rows, err := stmt.QueryxContext(ctx, filterParameters...) if err != nil { l.Error("Error executing query: ", err) return nil, err diff --git a/internal/database/mariadb/support_group_test.go b/internal/database/mariadb/support_group_test.go index 87fa27e0a..e18a6bc3b 100644 --- a/internal/database/mariadb/support_group_test.go +++ b/internal/database/mariadb/support_group_test.go @@ -4,6 +4,8 @@ package mariadb_test import ( + "context" + "github.com/samber/lo" "golang.org/x/text/collate" "golang.org/x/text/language" @@ -32,7 +34,7 @@ var _ = Describe("SupportGroup", Label("database", "SupportGroup"), func() { When("Getting SupportGroups", Label("GetSupportGroups"), func() { Context("and the database is empty", func() { It("can perform the query", func() { - res, err := db.GetSupportGroups(nil, nil) + res, err := db.GetSupportGroups(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -49,7 +51,7 @@ var _ = Describe("SupportGroup", Label("database", "SupportGroup"), func() { }) Context("and using no filter", func() { It("can fetch the items correctly", func() { - res, err := db.GetSupportGroups(nil, nil) + res, err := db.GetSupportGroups(context.Background(), nil, nil) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -102,7 +104,7 @@ var _ = Describe("SupportGroup", Label("database", "SupportGroup"), func() { } } - entries, err := db.GetSupportGroups(filter, nil) + entries, err := db.GetSupportGroups(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -122,7 +124,7 @@ var _ = Describe("SupportGroup", Label("database", "SupportGroup"), func() { row := test.PickOne(seedCollection.SupportGroupRows) filter := &entity.SupportGroupFilter{Id: []*int64{&row.Id.Int64}} - entries, err := db.GetSupportGroups(filter, nil) + entries, err := db.GetSupportGroups(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -145,7 +147,7 @@ var _ = Describe("SupportGroup", Label("database", "SupportGroup"), func() { } } - entries, err := db.GetSupportGroups(filter, nil) + entries, err := db.GetSupportGroups(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -166,7 +168,7 @@ var _ = Describe("SupportGroup", Label("database", "SupportGroup"), func() { filter := &entity.SupportGroupFilter{CCRN: []*string{&row.CCRN.String}} - entries, err := db.GetSupportGroups(filter, nil) + entries, err := db.GetSupportGroups(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -187,7 +189,7 @@ var _ = Describe("SupportGroup", Label("database", "SupportGroup"), func() { filter := &entity.SupportGroupFilter{IssueId: []*int64{&issueMatchRow.IssueId.Int64}} - entries, err := db.GetSupportGroups(filter, nil) + entries, err := db.GetSupportGroups(context.Background(), filter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -208,7 +210,7 @@ var _ = Describe("SupportGroup", Label("database", "SupportGroup"), func() { order []entity.Order, verifyFunc func(res []entity.SupportGroupResult), ) { - res, err := db.GetSupportGroups(nil, order) + res, err := db.GetSupportGroups(context.Background(), nil, order) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -267,7 +269,7 @@ var _ = Describe("SupportGroup", Label("database", "SupportGroup"), func() { When("Counting SupportGroups", Label("CountSupportGroups"), func() { Context("and the database is empty", func() { It("can count correctly", func() { - c, err := db.CountSupportGroups(nil) + c, err := db.CountSupportGroups(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -288,7 +290,7 @@ var _ = Describe("SupportGroup", Label("database", "SupportGroup"), func() { }) Context("and using no filter", func() { It("can count", func() { - c, err := db.CountSupportGroups(nil) + c, err := db.CountSupportGroups(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -308,7 +310,7 @@ var _ = Describe("SupportGroup", Label("database", "SupportGroup"), func() { After: &after, }, } - c, err := db.CountSupportGroups(filter) + c, err := db.CountSupportGroups(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -343,7 +345,7 @@ var _ = Describe("SupportGroup", Label("database", "SupportGroup"), func() { Id: []*int64{&supportGroup.Id}, } - sg, err := db.GetSupportGroups(sgFilter, nil) + sg, err := db.GetSupportGroups(context.Background(), sgFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -377,7 +379,7 @@ var _ = Describe("SupportGroup", Label("database", "SupportGroup"), func() { Id: []*int64{&supportGroup.Id}, } - sg, err := db.GetSupportGroups(supportGroupFilter, nil) + sg, err := db.GetSupportGroups(context.Background(), supportGroupFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -410,7 +412,7 @@ var _ = Describe("SupportGroup", Label("database", "SupportGroup"), func() { Id: []*int64{&supportGroup.Id}, } - sg, err := db.GetSupportGroups(sgFilter, nil) + sg, err := db.GetSupportGroups(context.Background(), sgFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -449,7 +451,7 @@ var _ = Describe("SupportGroup", Label("database", "SupportGroup"), func() { ServiceId: []*int64{&service.Id}, } - sg, err := db.GetSupportGroups(supportGroupFilter, nil) + sg, err := db.GetSupportGroups(context.Background(), supportGroupFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -494,7 +496,7 @@ var _ = Describe("SupportGroup", Label("database", "SupportGroup"), func() { ServiceId: []*int64{&supportGroupServiceRow.ServiceId.Int64}, } - supportGroups, err := db.GetSupportGroups(supportGroupFilter, nil) + supportGroups, err := db.GetSupportGroups(context.Background(), supportGroupFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -534,7 +536,7 @@ var _ = Describe("SupportGroup", Label("database", "SupportGroup"), func() { UserId: []*int64{&user.Id}, } - sg, err := db.GetSupportGroups(supportGroupFilter, nil) + sg, err := db.GetSupportGroups(context.Background(), supportGroupFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -579,7 +581,7 @@ var _ = Describe("SupportGroup", Label("database", "SupportGroup"), func() { UserId: []*int64{&supportGroupUserRow.UserId.Int64}, } - supportGroups, err := db.GetSupportGroups(supportGroupFilter, nil) + supportGroups, err := db.GetSupportGroups(context.Background(), supportGroupFilter, nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -593,7 +595,7 @@ var _ = Describe("SupportGroup", Label("database", "SupportGroup"), func() { When("Getting SupportGroupCcrns", Label("GetSupportGroupCcrns"), func() { Context("and the database is empty", func() { It("can perform the list query", func() { - res, err := db.GetSupportGroupCcrns(nil) + res, err := db.GetSupportGroupCcrns(context.Background(), nil) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -610,7 +612,7 @@ var _ = Describe("SupportGroup", Label("database", "SupportGroup"), func() { Context("and using no filter", func() { It("can fetch the items correctly", func() { - res, err := db.GetSupportGroupCcrns(nil) + res, err := db.GetSupportGroupCcrns(context.Background(), nil) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -648,7 +650,7 @@ var _ = Describe("SupportGroup", Label("database", "SupportGroup"), func() { } It("can fetch the filtered items correctly", func() { - res, err := db.GetSupportGroupCcrns(filter) + res, err := db.GetSupportGroupCcrns(context.Background(), filter) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -678,7 +680,7 @@ var _ = Describe("SupportGroup", Label("database", "SupportGroup"), func() { It( "returns an empty list when no supportGroup match the filter", func() { - res, err := db.GetSupportGroupCcrns(anotherFilter) + res, err := db.GetSupportGroupCcrns(context.Background(), anotherFilter) Expect(err).Should(BeNil()) Expect(res).Should(BeEmpty()) diff --git a/internal/database/mariadb/test/common.go b/internal/database/mariadb/test/common.go index b5a0dc681..d7b34a980 100644 --- a/internal/database/mariadb/test/common.go +++ b/internal/database/mariadb/test/common.go @@ -4,6 +4,7 @@ package test import ( + "context" "database/sql" "encoding/json" "fmt" @@ -20,7 +21,7 @@ import ( // Temporary used until order is used in all entities func TestPaginationOfListWithOrder[F entity.HeurekaFilter, E entity.HeurekaEntity]( - listFunction func(*F, []entity.Order) ([]E, error), + listFunction func(context.Context, *F, []entity.Order) ([]E, error), filterFunction func(*int, *string) *F, order []entity.Order, getAfterFunction func([]E) string, @@ -36,7 +37,7 @@ func TestPaginationOfListWithOrder[F entity.HeurekaFilter, E entity.HeurekaEntit var afterS string for i := expectedPages; i > 0; i-- { - entries, err := listFunction(filterFunction(&pageSize, &afterS), order) + entries, err := listFunction(context.Background(), filterFunction(&pageSize, &afterS), order) Expect(err).To(BeNil()) diff --git a/internal/database/mariadb/trace_db.go b/internal/database/mariadb/trace_db.go index 680ef576c..3fb6ef4c0 100644 --- a/internal/database/mariadb/trace_db.go +++ b/internal/database/mariadb/trace_db.go @@ -4,6 +4,7 @@ package mariadb import ( + "context" "database/sql" "time" @@ -72,6 +73,10 @@ func (ts *TraceStmt) Queryx(args ...any) (*sqlx.Rows, error) { return ts.stmt.Queryx(args...) } +func (ts *TraceStmt) QueryxContext(ctx context.Context, args ...any) (*sqlx.Rows, error) { + return ts.stmt.QueryxContext(ctx, args...) +} + // TraceNamedStmt type TraceNamedStmt struct { trace *Trace @@ -134,6 +139,18 @@ func (tdb *TraceDb) Preparex(query string) (Stmt, error) { return &TraceStmt{stmt: stmt, trace: trace}, nil } +func (tdb *TraceDb) PreparexContext(ctx context.Context, query string) (Stmt, error) { + trace := NewTrace("PreparexContext", query) + + stmt, err := tdb.db.PreparexContext(ctx, query) + if err != nil { + trace.errorTrace() + return stmt, err + } + + return &TraceStmt{stmt: stmt, trace: trace}, nil +} + func (tdb *TraceDb) Select(dest any, query string, args ...any) error { defer NewTrace("Select", query).exitTrace() return tdb.db.Select(dest, query, args...) @@ -155,3 +172,8 @@ func (tdb *TraceDb) QueryRow(query string, args ...any) *sql.Row { defer NewTrace("QueryRow", query).exitTrace() return tdb.db.QueryRow(query, args...) } + +func (tdb *TraceDb) QueryContext(ctx context.Context, query string, args ...any) (SqlRows, error) { + defer NewTrace("QueryContext", query).exitTrace() + return tdb.db.QueryContext(ctx, query, args...) +} diff --git a/internal/database/mariadb/uniqueness_test.go b/internal/database/mariadb/uniqueness_test.go index ca8385e6a..296a88f9d 100644 --- a/internal/database/mariadb/uniqueness_test.go +++ b/internal/database/mariadb/uniqueness_test.go @@ -4,6 +4,8 @@ package mariadb_test import ( + "context" + "github.com/cloudoperators/heureka/internal/database/mariadb" "github.com/cloudoperators/heureka/internal/database/mariadb/test" "github.com/cloudoperators/heureka/internal/entity" @@ -243,7 +245,7 @@ func (uut *uniquenessUserTemplate) expectDeletedUserCountForUniqueUserId( } func (uut *uniquenessUserTemplate) expectUserCount(uf *entity.UserFilter, cnt int64) { - userCount, err := uut.db.CountUsers(uf) + userCount, err := uut.db.CountUsers(context.Background(), uf) Expect(err).To(BeNil()) Expect(userCount).To(BeEquivalentTo(cnt)) } @@ -305,7 +307,7 @@ func (uct *uniquenessComponentTemplate) expectComponentCount( cf *entity.ComponentFilter, cnt int64, ) { - componentCount, err := uct.db.CountComponents(cf) + componentCount, err := uct.db.CountComponents(context.Background(), cf) Expect(err).To(BeNil()) Expect(componentCount).To(BeEquivalentTo(cnt)) } @@ -392,7 +394,7 @@ func (ucvt *uniquenessComponentVersionTemplate) expectComponentVersionCount( cvf *entity.ComponentVersionFilter, cnt int64, ) { - componentVersionCount, err := ucvt.db.CountComponentVersions(cvf) + componentVersionCount, err := ucvt.db.CountComponentVersions(context.Background(), cvf) Expect(err).To(BeNil()) Expect(componentVersionCount).To(BeEquivalentTo(cnt)) } @@ -451,7 +453,7 @@ func (ust *uniquenessServiceTemplate) expectDeletedServiceCountForCCRN(ccrn stri } func (ust *uniquenessServiceTemplate) expectServiceCount(sf *entity.ServiceFilter, cnt int64) { - serviceCount, err := ust.db.CountServices(sf) + serviceCount, err := ust.db.CountServices(context.Background(), sf) Expect(err).To(BeNil()) Expect(serviceCount).To(BeEquivalentTo(cnt)) } @@ -538,7 +540,7 @@ func (ucit *uniquenessComponentInstanceTemplate) expectComponentInstanceCount( cif *entity.ComponentInstanceFilter, cnt int64, ) { - componentInstanceCount, err := ucit.db.CountComponentInstances(cif) + componentInstanceCount, err := ucit.db.CountComponentInstances(context.Background(), cif) Expect(err).To(BeNil()) Expect(componentInstanceCount).To(BeEquivalentTo(cnt)) } @@ -606,7 +608,7 @@ func (uirt *uniquenessIssueRepositoryTemplate) expectIssueRepositoryCount( irf *entity.IssueRepositoryFilter, cnt int64, ) { - issueRepositoryCount, err := uirt.db.CountIssueRepositories(irf) + issueRepositoryCount, err := uirt.db.CountIssueRepositories(context.Background(), irf) Expect(err).To(BeNil()) Expect(issueRepositoryCount).To(BeEquivalentTo(cnt)) } @@ -668,7 +670,7 @@ func (uit *uniquenessIssueTemplate) expectDeletedIssueCountForPrimaryName( } func (uit *uniquenessIssueTemplate) expectIssueCount(isf *entity.IssueFilter, cnt int64) { - issueCount, err := uit.db.CountIssues(isf) + issueCount, err := uit.db.CountIssues(context.Background(), isf) Expect(err).To(BeNil()) Expect(issueCount).To(BeEquivalentTo(cnt)) } @@ -742,7 +744,7 @@ func (uivt *uniquenessIssueVariantTemplate) expectIssueVariantCount( ivf *entity.IssueVariantFilter, cnt int64, ) { - issueVariantCount, err := uivt.db.CountIssueVariants(ivf) + issueVariantCount, err := uivt.db.CountIssueVariants(context.Background(), ivf) Expect(err).To(BeNil()) Expect(issueVariantCount).To(BeEquivalentTo(cnt)) } diff --git a/internal/database/mariadb/user.go b/internal/database/mariadb/user.go index 4b68a01ae..a52471d99 100644 --- a/internal/database/mariadb/user.go +++ b/internal/database/mariadb/user.go @@ -4,6 +4,7 @@ package mariadb import ( + "context" "fmt" "github.com/cloudoperators/heureka/internal/entity" @@ -107,6 +108,7 @@ func ensureUserFilter(filter *entity.UserFilter) *entity.UserFilter { } func (s *SqlDatabase) buildUserStatement( + ctx context.Context, baseQuery string, filter *entity.UserFilter, withCursor bool, @@ -133,7 +135,7 @@ func (s *SqlDatabase) buildUserStatement( query = fmt.Sprintf(baseQuery, joins, whereClause, ord) } - stmt, err := s.db.Preparex(query) + stmt, err := s.db.PreparexContext(ctx, query) if err != nil { msg := ERROR_MSG_PREPARED_STMT l.WithFields( @@ -151,7 +153,7 @@ func (s *SqlDatabase) buildUserStatement( return stmt, filterParameters, nil } -func (s *SqlDatabase) GetAllUserIds(filter *entity.UserFilter) ([]int64, error) { +func (s *SqlDatabase) GetAllUserIds(ctx context.Context, filter *entity.UserFilter) ([]int64, error) { l := logrus.WithFields(logrus.Fields{ "event": "database.GetUserIds", }) @@ -164,6 +166,7 @@ func (s *SqlDatabase) GetAllUserIds(filter *entity.UserFilter) ([]int64, error) ` stmt, filterParameters, err := s.buildUserStatement( + ctx, baseQuery, filter, false, @@ -180,10 +183,11 @@ func (s *SqlDatabase) GetAllUserIds(filter *entity.UserFilter) ([]int64, error) } }() - return performIdScan(stmt, filterParameters, l) + return performIdScan(ctx, stmt, filterParameters, l) } func (s *SqlDatabase) GetAllUserCursors( + ctx context.Context, filter *entity.UserFilter, order []entity.Order, ) ([]string, error) { @@ -200,7 +204,7 @@ func (s *SqlDatabase) GetAllUserCursors( filter = ensureUserFilter(filter) - stmt, filterParameters, err := s.buildUserStatement(baseQuery, filter, false, order, l) + stmt, filterParameters, err := s.buildUserStatement(ctx, baseQuery, filter, false, order, l) if err != nil { return nil, fmt.Errorf("failed to build User cursor query: %w", err) } @@ -212,6 +216,7 @@ func (s *SqlDatabase) GetAllUserCursors( }() rows, err := performListScan( + ctx, stmt, filterParameters, l, @@ -232,7 +237,7 @@ func (s *SqlDatabase) GetAllUserCursors( }), nil } -func (s *SqlDatabase) GetUsers(filter *entity.UserFilter) ([]entity.UserResult, error) { +func (s *SqlDatabase) GetUsers(ctx context.Context, filter *entity.UserFilter) ([]entity.UserResult, error) { l := logrus.WithFields(logrus.Fields{ "event": "database.GetUsers", }) @@ -247,6 +252,7 @@ func (s *SqlDatabase) GetUsers(filter *entity.UserFilter) ([]entity.UserResult, filter = ensureUserFilter(filter) stmt, filterParameters, err := s.buildUserStatement( + ctx, baseQuery, filter, true, @@ -264,6 +270,7 @@ func (s *SqlDatabase) GetUsers(filter *entity.UserFilter) ([]entity.UserResult, }() return performListScan( + ctx, stmt, filterParameters, l, @@ -283,7 +290,7 @@ func (s *SqlDatabase) GetUsers(filter *entity.UserFilter) ([]entity.UserResult, ) } -func (s *SqlDatabase) CountUsers(filter *entity.UserFilter) (int64, error) { +func (s *SqlDatabase) CountUsers(ctx context.Context, filter *entity.UserFilter) (int64, error) { l := logrus.WithFields(logrus.Fields{ "event": "database.CountUsers", }) @@ -296,6 +303,7 @@ func (s *SqlDatabase) CountUsers(filter *entity.UserFilter) (int64, error) { ` stmt, filterParameters, err := s.buildUserStatement( + ctx, baseQuery, filter, false, @@ -312,7 +320,7 @@ func (s *SqlDatabase) CountUsers(filter *entity.UserFilter) (int64, error) { } }() - return performCountScan(stmt, filterParameters, l) + return performCountScan(ctx, stmt, filterParameters, l) } func (s *SqlDatabase) CreateUser(user *entity.User) (*entity.User, error) { @@ -327,7 +335,7 @@ func (s *SqlDatabase) DeleteUser(id int64, userId int64) error { return userObject.Delete(s.db, id, userId) } -func (s *SqlDatabase) GetUserNames(filter *entity.UserFilter) ([]string, error) { +func (s *SqlDatabase) GetUserNames(ctx context.Context, filter *entity.UserFilter) ([]string, error) { l := logrus.WithFields(logrus.Fields{ "filter": filter, "event": "database.GetUserNames", @@ -344,7 +352,7 @@ func (s *SqlDatabase) GetUserNames(filter *entity.UserFilter) ([]string, error) filter = ensureUserFilter(filter) // Builds full statement with possible joins and filters - stmt, filterParameters, err := s.buildUserStatement(baseQuery, filter, false, []entity.Order{ + stmt, filterParameters, err := s.buildUserStatement(ctx, baseQuery, filter, false, []entity.Order{ { By: entity.UserName, }, @@ -361,7 +369,7 @@ func (s *SqlDatabase) GetUserNames(filter *entity.UserFilter) ([]string, error) }() // Execute the query - rows, err := stmt.Queryx(filterParameters...) + rows, err := stmt.QueryxContext(ctx, filterParameters...) if err != nil { l.Error("Error executing query: ", err) return nil, err @@ -394,7 +402,7 @@ func (s *SqlDatabase) GetUserNames(filter *entity.UserFilter) ([]string, error) return userNames, nil } -func (s *SqlDatabase) GetUniqueUserIDs(filter *entity.UserFilter) ([]string, error) { +func (s *SqlDatabase) GetUniqueUserIDs(ctx context.Context, filter *entity.UserFilter) ([]string, error) { l := logrus.WithFields(logrus.Fields{ "filter": filter, "event": "database.GetUniqueUserIDs", @@ -411,7 +419,7 @@ func (s *SqlDatabase) GetUniqueUserIDs(filter *entity.UserFilter) ([]string, err filter = ensureUserFilter(filter) // Builds full statement with possible joins and filters - stmt, filterParameters, err := s.buildUserStatement(baseQuery, filter, false, []entity.Order{ + stmt, filterParameters, err := s.buildUserStatement(ctx, baseQuery, filter, false, []entity.Order{ { By: entity.UserUniqueUserID, }, @@ -428,7 +436,7 @@ func (s *SqlDatabase) GetUniqueUserIDs(filter *entity.UserFilter) ([]string, err }() // Execute the query - rows, err := stmt.Queryx(filterParameters...) + rows, err := stmt.QueryxContext(ctx, filterParameters...) if err != nil { l.Error("Error executing query: ", err) return nil, err diff --git a/internal/database/mariadb/user_test.go b/internal/database/mariadb/user_test.go index 23a58077b..b096a7f72 100644 --- a/internal/database/mariadb/user_test.go +++ b/internal/database/mariadb/user_test.go @@ -4,6 +4,8 @@ package mariadb_test import ( + "context" + "github.com/cloudoperators/heureka/internal/database/mariadb" "github.com/cloudoperators/heureka/internal/database/mariadb/test" e2e_common "github.com/cloudoperators/heureka/internal/e2e/common" @@ -30,7 +32,7 @@ var _ = Describe("User", Label("database", "User"), func() { When("Getting All User IDs", Label("GetAllUserIds"), func() { Context("and the database is empty", func() { It("can perform the query", func() { - res, err := db.GetAllUserIds(nil) + res, err := db.GetAllUserIds(context.Background(), nil) res = e2e_common.SubtractSystemUserId(res) By("throwing no error", func() { @@ -53,7 +55,7 @@ var _ = Describe("User", Label("database", "User"), func() { }) Context("and using no filter", func() { It("can fetch the items correctly", func() { - res, err := db.GetAllUserIds(nil) + res, err := db.GetAllUserIds(context.Background(), nil) res = e2e_common.SubtractSystemUserId(res) By("throwing no error", func() { @@ -88,7 +90,7 @@ var _ = Describe("User", Label("database", "User"), func() { Id: []*int64{&uId}, } - entries, err := db.GetAllUserIds(filter) + entries, err := db.GetAllUserIds(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -109,7 +111,7 @@ var _ = Describe("User", Label("database", "User"), func() { When("Getting Users", Label("GetUsers"), func() { Context("and the database is empty", func() { It("can perform the query", func() { - res, err := db.GetUsers(nil) + res, err := db.GetUsers(context.Background(), nil) res = e2e_common.SubtractSystemUsersEntity(res) By("throwing no error", func() { @@ -127,7 +129,7 @@ var _ = Describe("User", Label("database", "User"), func() { }) Context("and using no filter", func() { It("can fetch the items correctly", func() { - res, err := db.GetUsers(nil) + res, err := db.GetUsers(context.Background(), nil) res = e2e_common.SubtractSystemUsersEntity(res) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -179,7 +181,7 @@ var _ = Describe("User", Label("database", "User"), func() { Id: []*int64{&user.Id.Int64}, } - entries, err := db.GetUsers(filter) + entries, err := db.GetUsers(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -206,7 +208,7 @@ var _ = Describe("User", Label("database", "User"), func() { } } - entries, err := db.GetUsers(filter) + entries, err := db.GetUsers(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -235,7 +237,7 @@ var _ = Describe("User", Label("database", "User"), func() { } } - entries, err := db.GetUsers(filter) + entries, err := db.GetUsers(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -256,7 +258,7 @@ var _ = Describe("User", Label("database", "User"), func() { filter := &entity.UserFilter{Name: []*string{&row.Name.String}} - entries, err := db.GetUsers(filter) + entries, err := db.GetUsers(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -275,7 +277,7 @@ var _ = Describe("User", Label("database", "User"), func() { filter := &entity.UserFilter{UniqueUserID: []*string{&row.UniqueUserID.String}} - entries, err := db.GetUsers(filter) + entries, err := db.GetUsers(context.Background(), filter) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -293,7 +295,7 @@ var _ = Describe("User", Label("database", "User"), func() { humanUserTypeFilter := &entity.UserFilter{ Type: []entity.UserType{entity.HumanUserType}, } - humanUserEntries, cErr := db.GetUsers(humanUserTypeFilter) + humanUserEntries, cErr := db.GetUsers(context.Background(), humanUserTypeFilter) By("throwing no error when filtering human user type", func() { Expect(cErr).To(BeNil()) }) @@ -309,7 +311,7 @@ var _ = Describe("User", Label("database", "User"), func() { technicalUserTypeFilter := &entity.UserFilter{ Type: []entity.UserType{entity.TechnicalUserType}, } - technicalUserEntries, tErr := db.GetUsers(technicalUserTypeFilter) + technicalUserEntries, tErr := db.GetUsers(context.Background(), technicalUserTypeFilter) By("throwing no error when filtering technical user type", func() { Expect(tErr).To(BeNil()) }) @@ -339,7 +341,7 @@ var _ = Describe("User", Label("database", "User"), func() { When("Counting Users", Label("CountUsers"), func() { Context("and the database is empty", func() { It("can count correctly", func() { - c, err := db.CountUsers(nil) + c, err := db.CountUsers(context.Background(), nil) c = e2e_common.SubtractSystemUsers(c) By("throwing no error", func() { @@ -361,7 +363,7 @@ var _ = Describe("User", Label("database", "User"), func() { }) Context("and using no filter", func() { It("can count", func() { - c, err := db.CountUsers(nil) + c, err := db.CountUsers(context.Background(), nil) c = e2e_common.SubtractSystemUsers(c) By("throwing no error", func() { @@ -381,7 +383,7 @@ var _ = Describe("User", Label("database", "User"), func() { After: nil, }, } - c, err := db.CountUsers(filter) + c, err := db.CountUsers(context.Background(), filter) c = e2e_common.SubtractSystemUsers(c) By("throwing no error", func() { @@ -418,7 +420,7 @@ var _ = Describe("User", Label("database", "User"), func() { Id: []*int64{&user.Id}, } - u, err := db.GetUsers(userFilter) + u, err := db.GetUsers(context.Background(), userFilter) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -466,7 +468,7 @@ var _ = Describe("User", Label("database", "User"), func() { Id: []*int64{&user.Id}, } - u, err := db.GetUsers(userFilter) + u, err := db.GetUsers(context.Background(), userFilter) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -494,7 +496,7 @@ var _ = Describe("User", Label("database", "User"), func() { Id: []*int64{&user.Id}, } - u, err := db.GetUsers(userFilter) + u, err := db.GetUsers(context.Background(), userFilter) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -522,7 +524,7 @@ var _ = Describe("User", Label("database", "User"), func() { Id: []*int64{&user.Id}, } - u, err := db.GetUsers(userFilter) + u, err := db.GetUsers(context.Background(), userFilter) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -557,7 +559,7 @@ var _ = Describe("User", Label("database", "User"), func() { Id: []*int64{&user.Id}, } - u, err := db.GetUsers(userFilter) + u, err := db.GetUsers(context.Background(), userFilter) By("throwing no error", func() { Expect(err).To(BeNil()) }) @@ -570,7 +572,7 @@ var _ = Describe("User", Label("database", "User"), func() { When("Getting UserNames", Label("GetUserNames"), func() { Context("and the database is empty", func() { It("can perform the list query", func() { - res, err := db.GetUserNames(nil) + res, err := db.GetUserNames(context.Background(), nil) res = e2e_common.SubtractSystemUserNameVL(res) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -588,7 +590,7 @@ var _ = Describe("User", Label("database", "User"), func() { Context("and using no filter", func() { It("can fetch the items correctly", func() { - res, err := db.GetUserNames(nil) + res, err := db.GetUserNames(context.Background(), nil) res = e2e_common.SubtractSystemUserNameVL(res) By("throwing no error", func() { @@ -627,7 +629,7 @@ var _ = Describe("User", Label("database", "User"), func() { } It("can fetch the filtered items correctly", func() { - res, err := db.GetUserNames(filter) + res, err := db.GetUserNames(context.Background(), filter) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -655,7 +657,7 @@ var _ = Describe("User", Label("database", "User"), func() { } It("returns an empty list when no users match the filter", func() { - res, err := db.GetUserNames(anotherFilter) + res, err := db.GetUserNames(context.Background(), anotherFilter) Expect(err).Should(BeNil()) Expect(res).Should(BeEmpty()) @@ -676,7 +678,7 @@ var _ = Describe("User", Label("database", "User"), func() { When("Getting UniqueUserID", Label("GetUniqueUserID"), func() { Context("and the database is empty", func() { It("can perform the list query", func() { - res, err := db.GetUniqueUserIDs(nil) + res, err := db.GetUniqueUserIDs(context.Background(), nil) res = e2e_common.SubtractSystemUserUniqueUserIdVL(res) By("throwing no error", func() { Expect(err).To(BeNil()) @@ -694,7 +696,7 @@ var _ = Describe("User", Label("database", "User"), func() { Context("and using no filter", func() { It("can fetch the items correctly", func() { - res, err := db.GetUniqueUserIDs(nil) + res, err := db.GetUniqueUserIDs(context.Background(), nil) res = e2e_common.SubtractSystemUserUniqueUserIdVL(res) By("throwing no error", func() { @@ -733,7 +735,7 @@ var _ = Describe("User", Label("database", "User"), func() { } It("can fetch the filtered items correctly", func() { - res, err := db.GetUniqueUserIDs(filter) + res, err := db.GetUniqueUserIDs(context.Background(), filter) By("throwing no error", func() { Expect(err).Should(BeNil()) @@ -761,7 +763,7 @@ var _ = Describe("User", Label("database", "User"), func() { } It("returns an empty list when no users match the filter", func() { - res, err := db.GetUniqueUserIDs(anotherFilter) + res, err := db.GetUniqueUserIDs(context.Background(), anotherFilter) Expect(err).Should(BeNil()) Expect(res).Should(BeEmpty()) diff --git a/internal/e2e/siem_alert_test.go b/internal/e2e/siem_alert_test.go index c673b0944..20feab743 100644 --- a/internal/e2e/siem_alert_test.go +++ b/internal/e2e/siem_alert_test.go @@ -4,6 +4,7 @@ package e2e_test import ( + "context" "fmt" e2e_common "github.com/cloudoperators/heureka/internal/e2e/common" @@ -92,6 +93,7 @@ var _ = Describe("Creating SIEMAlert via API", Label("e2e", "SIEMAlert"), func() Expect(*respData.SIEM.URL).To(Equal(alertURL)) issues, err := db.GetIssues( + context.Background(), &entity.IssueFilter{PrimaryName: []*string{&alertName}}, nil, ) @@ -100,6 +102,7 @@ var _ = Describe("Creating SIEMAlert via API", Label("e2e", "SIEMAlert"), func() issueId := issues[0].Issue.Id ivs, err := db.GetIssueVariants( + context.Background(), &entity.IssueVariantFilter{IssueId: []*int64{&issueId}}, []entity.Order{}, ) @@ -120,12 +123,12 @@ var _ = Describe("Creating SIEMAlert via API", Label("e2e", "SIEMAlert"), func() Expect(issueVariantWithSeverity).To(BeTrue()) serviceFilter := &entity.ServiceFilter{CCRN: []*string{&service}} - services, err := db.GetServices(serviceFilter, nil) + services, err := db.GetServices(context.Background(), serviceFilter, nil) Expect(err).To(BeNil()) Expect(len(services)).To(BeNumerically(">=", 1)) sgFilter := &entity.SupportGroupFilter{CCRN: []*string{&supportGroup}} - sgs, err := db.GetSupportGroups(sgFilter, nil) + sgs, err := db.GetSupportGroups(context.Background(), sgFilter, nil) Expect(err).To(BeNil()) Expect(len(sgs)).To(BeNumerically(">=", 1)) @@ -140,6 +143,7 @@ var _ = Describe("Creating SIEMAlert via API", Label("e2e", "SIEMAlert"), func() ) cis, err := db.GetComponentInstances( + context.Background(), &entity.ComponentInstanceFilter{CCRN: []*string{&ccrn}}, nil, ) @@ -148,6 +152,7 @@ var _ = Describe("Creating SIEMAlert via API", Label("e2e", "SIEMAlert"), func() ciId := cis[0].Id ims, err := db.GetIssueMatches( + context.Background(), &entity.IssueMatchFilter{IssueId: []*int64{&issueId}}, nil, ) @@ -205,23 +210,26 @@ var _ = Describe("Creating SIEMAlert via API", Label("e2e", "SIEMAlert"), func() // Verify the IssueRepository in the database repoFilter := &entity.IssueRepositoryFilter{Name: []*string{&source}} - repos, err := db.GetIssueRepositories(repoFilter, []entity.Order{}) + repos, err := db.GetIssueRepositories(context.Background(), repoFilter, []entity.Order{}) Expect(err).To(BeNil()) Expect(len(repos)).To(Equal(1)) Expect(repos[0].Name).To(Equal(source)) // Verify IssueVariant is linked to this repository issues, err := db.GetIssues( + context.Background(), &entity.IssueFilter{PrimaryName: []*string{&alertName}}, nil, ) Expect(err).To(BeNil()) issueId := issues[0].Issue.Id - ivs, err := db.GetIssueVariants(&entity.IssueVariantFilter{ - IssueId: []*int64{&issueId}, - IssueRepositoryId: []*int64{&repos[0].Id}, - }, []entity.Order{}) + ivs, err := db.GetIssueVariants( + context.Background(), + &entity.IssueVariantFilter{ + IssueId: []*int64{&issueId}, + IssueRepositoryId: []*int64{&repos[0].Id}, + }, []entity.Order{}) Expect(err).To(BeNil()) Expect(len(ivs)).To(Equal(1)) }) @@ -269,6 +277,7 @@ var _ = Describe("Creating SIEMAlert via API", Label("e2e", "SIEMAlert"), func() // Verify the alert was not created in the database issues, err := db.GetIssues( + context.Background(), &entity.IssueFilter{PrimaryName: []*string{&alertName}}, nil, ) diff --git a/internal/server/server.go b/internal/server/server.go index eb5502d42..1e271ea9a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -178,8 +178,11 @@ func (s *Server) Start() { func (s *Server) NonBlockingStart() { s.nonBlockingSrv = &http.Server{ - Addr: fmt.Sprintf(":%s", s.config.Port), - Handler: s.router.Handler(), + Addr: fmt.Sprintf(":%s", s.config.Port), + Handler: s.router.Handler(), + ReadTimeout: time.Minute * 20, + WriteTimeout: time.Minute * 20, + IdleTimeout: time.Minute * 20, } util2.FirstListenThenServe(s.nonBlockingSrv, logrus.New())