diff --git a/internal/database/mariadb/component.go b/internal/database/mariadb/component.go index 2ce60de0..ac988944 100644 --- a/internal/database/mariadb/component.go +++ b/internal/database/mariadb/component.go @@ -142,14 +142,6 @@ var componentObject = DbObject[*entity.Component]{ }, } -func ensureComponentFilter(filter *entity.ComponentFilter) *entity.ComponentFilter { - if filter == nil { - filter = &entity.ComponentFilter{} - } - - return EnsurePagination(filter) -} - func needSingleComponentByServiceVulnerabilityCounts(filter *entity.ComponentFilter, order *Order) bool { return order.ByCount() && (len(filter.Id) > 0 && (len(filter.ServiceCCRN) > 0)) } @@ -187,44 +179,20 @@ func (s *SqlDatabase) buildComponentStatement( order []entity.Order, l *logrus.Entry, ) (Stmt, []any, error) { - filter = ensureComponentFilter(filter) - l.WithFields(logrus.Fields{"filter": filter}) - - cursorFields, err := DecodeCursor(filter.After) - if err != nil { - return nil, nil, fmt.Errorf("failed to decode Remediation cursor: %w", err) - } - - ord := NewOrder(order, entity.Order{By: entity.ComponentId, Direction: entity.OrderDirectionAsc}) - joins := componentObject.GetJoins(filter, ord) - whereClause, hasFilter := componentObject.GetFilterWhereClause(filter, withCursor) - cursorQuery := componentObject.GetCursorQuery(&hasFilter, cursorFields, &withCursor, false) - - var query string - if withCursor { - query = fmt.Sprintf(baseQuery, joins, whereClause, cursorQuery, ord) - } else { - query = fmt.Sprintf(baseQuery, joins, whereClause, ord) - } - - // construct prepared statement and if where clause does exist add parameters - stmt, err := s.db.PreparexContext(ctx, query) - if err != nil { - msg := ERROR_MSG_PREPARED_STMT - l.WithFields( - logrus.Fields{ - "error": err, - "query": query, - "stmt": stmt, - }, - ).Error(msg) - - return nil, nil, fmt.Errorf("%s", msg) + statement := Statement{ + Db: s.db, + L: l, + Obj: &componentObject, + BaseQuery: baseQuery, + Order: NewOrder(order, entity.Order{By: entity.ComponentId, Direction: entity.OrderDirectionAsc}), + WithCursor: withCursor, + CheckCursorInWhere: true, + CheckCursor: true, + CheckFilter: true, + Aggregated: false, } - filterParameters := componentObject.GetFilterParameters(filter, withCursor, cursorFields) - - return stmt, filterParameters, nil + return BuildStatement(ctx, statement, filter) } func (s *SqlDatabase) GetAllComponentCursors( @@ -243,7 +211,7 @@ func (s *SqlDatabase) GetAllComponentCursors( %s GROUP BY C.component_id ORDER BY %s ` - filter = ensureComponentFilter(filter) + filter = EnsureFilter(filter) columns := s.getComponentColumns(order) baseQuery = fmt.Sprintf(baseQuery, columns, "%s", "%s", "%s") @@ -301,7 +269,7 @@ func (s *SqlDatabase) GetComponents( %s GROUP BY C.component_id ORDER BY %s LIMIT ? ` - filter = ensureComponentFilter(filter) + filter = EnsureFilter(filter) columns := s.getComponentColumns(order) baseQuery = fmt.Sprintf(baseQuery, columns, "%s", "%s", "%s", "%s") @@ -387,7 +355,7 @@ func (s *SqlDatabase) CountComponentVulnerabilities( filterParameters []any ) - filter = ensureComponentFilter(filter) + filter = EnsureFilter(filter) query := ` SELECT CVR.critical_count, CVR.high_count, CVR.medium_count, CVR.low_count, CVR.none_count FROM %s AS CVR @@ -441,8 +409,7 @@ func (s *SqlDatabase) CountComponentVulnerabilities( "error": err, "query": query, "stmt": stmt, - }, - ).Error(msg) + }).Error(msg) return nil, fmt.Errorf("%s", msg) } @@ -490,7 +457,7 @@ func (s *SqlDatabase) GetComponentCcrns(ctx context.Context, filter *entity.Comp ` // Ensure the filter is initialized - filter = ensureComponentFilter(filter) + filter = EnsureFilter(filter) order := []entity.Order{ { By: entity.ComponentCcrn, diff --git a/internal/database/mariadb/component_instance.go b/internal/database/mariadb/component_instance.go index f1db9347..50d52d9f 100644 --- a/internal/database/mariadb/component_instance.go +++ b/internal/database/mariadb/component_instance.go @@ -267,16 +267,6 @@ var componentInstanceObject = DbObject[*entity.ComponentInstance]{ }, } -func ensureComponentInstanceFilter( - filter *entity.ComponentInstanceFilter, -) *entity.ComponentInstanceFilter { - if filter == nil { - filter = &entity.ComponentInstanceFilter{} - } - - return EnsurePagination(filter) -} - func (s *SqlDatabase) buildComponentInstanceStatement( ctx context.Context, baseQuery string, @@ -285,48 +275,20 @@ func (s *SqlDatabase) buildComponentInstanceStatement( order []entity.Order, l *logrus.Entry, ) (Stmt, []any, error) { - filter = ensureComponentInstanceFilter(filter) - l.WithFields(logrus.Fields{"filter": filter}) - - cursorFields, err := DecodeCursor(filter.After) - if err != nil { - return nil, nil, fmt.Errorf("failed to decode cursor: %w", err) - } - - ord := NewOrder(order, entity.Order{By: entity.ComponentInstanceId, Direction: entity.OrderDirectionAsc}) - joins := componentInstanceObject.GetJoins(filter, ord) - whereClause, hasFilter := componentInstanceObject.GetFilterWhereClause(filter, withCursor) - cursorQuery := componentInstanceObject.GetCursorQuery(&hasFilter, cursorFields, &withCursor, false) - - var query string - if withCursor { - query = fmt.Sprintf(baseQuery, joins, whereClause, cursorQuery, ord) - } else { - query = fmt.Sprintf(baseQuery, joins, whereClause, ord) - } - - // construct prepared statement and if where clause does exist add parameters - stmt, err := s.db.PreparexContext(ctx, query) - if err != nil { - msg := ERROR_MSG_PREPARED_STMT - l.WithFields( - logrus.Fields{ - "error": err, - "query": query, - "stmt": stmt, - }, - ).Error(msg) - - return nil, nil, fmt.Errorf("failed to prepare ComponentInstance statement: %w", err) + statement := Statement{ + Db: s.db, + L: l, + Obj: &componentInstanceObject, + BaseQuery: baseQuery, + Order: NewOrder(order, entity.Order{By: entity.ComponentInstanceId, Direction: entity.OrderDirectionAsc}), + WithCursor: withCursor, + CheckCursorInWhere: true, + CheckCursor: true, + CheckFilter: true, + Aggregated: false, } - filterParameters := componentInstanceObject.GetFilterParameters( - filter, - withCursor, - cursorFields, - ) - - return stmt, filterParameters, nil + return BuildStatement(ctx, statement, filter) } func (s *SqlDatabase) GetComponentInstances( @@ -516,7 +478,7 @@ func (s *SqlDatabase) getComponentInstanceAttr( baseQuery = fmt.Sprintf(baseQuery, attrName, "%s", "%s", "%s") // Ensure the filter is initialized - filter = ensureComponentInstanceFilter(filter) + filter = EnsureFilter(filter) order := []entity.Order{ {By: entity.ComponentInstanceCcrn, Direction: entity.OrderDirectionAsc}, diff --git a/internal/database/mariadb/component_version.go b/internal/database/mariadb/component_version.go index 68e2ba0e..85c030f0 100644 --- a/internal/database/mariadb/component_version.go +++ b/internal/database/mariadb/component_version.go @@ -195,16 +195,6 @@ var componentVersionObject = DbObject[*entity.ComponentVersion]{ }, } -func ensureComponentVersionFilter( - filter *entity.ComponentVersionFilter, -) *entity.ComponentVersionFilter { - if filter == nil { - filter = &entity.ComponentVersionFilter{} - } - - return EnsurePagination(filter) -} - func (s *SqlDatabase) getComponentVersionColumns(order []entity.Order) string { columns := "" @@ -249,46 +239,20 @@ func (s *SqlDatabase) buildComponentVersionStatement( order []entity.Order, l *logrus.Entry, ) (Stmt, []any, error) { - filter = ensureComponentVersionFilter(filter) - l.WithFields(logrus.Fields{"filter": filter}) - - cursorFields, err := DecodeCursor(filter.After) - if err != nil { - return nil, nil, err + statement := Statement{ + Db: s.db, + L: l, + Obj: &componentVersionObject, + BaseQuery: baseQuery, + Order: NewOrder(order, entity.Order{By: entity.ComponentVersionId, Direction: entity.OrderDirectionAsc}), + WithCursor: withCursor, + CheckCursorInWhere: false, + CheckCursor: true, + CheckFilter: false, + Aggregated: true, } - ord := NewOrder(order, entity.Order{By: entity.ComponentVersionId, Direction: entity.OrderDirectionAsc}) - joins := componentVersionObject.GetJoins(filter, ord) - whereClause, _ := componentVersionObject.GetFilterWhereClause(filter, false) - cursorQuery := componentVersionObject.GetCursorQuery(nil, cursorFields, &withCursor, true) - - var query string - - columns := s.getComponentVersionColumns(order) - if withCursor { - query = fmt.Sprintf(baseQuery, columns, joins, whereClause, cursorQuery, ord) - } else { - query = fmt.Sprintf(baseQuery, columns, joins, whereClause, ord) - } - - // construct prepared statement and if where clause does exist add parameters - stmt, err := s.db.PreparexContext(ctx, query) - if err != nil { - msg := ERROR_MSG_PREPARED_STMT - l.WithFields( - logrus.Fields{ - "error": err, - "query": query, - "stmt": stmt, - }, - ).Error(msg) - - return nil, nil, fmt.Errorf("%s", msg) - } - - filterParameters := componentVersionObject.GetFilterParameters(filter, withCursor, cursorFields) - - return stmt, filterParameters, nil + return BuildStatement(ctx, statement, filter) } func (s *SqlDatabase) GetAllComponentVersionCursors( @@ -301,11 +265,11 @@ func (s *SqlDatabase) GetAllComponentVersionCursors( "event": "database.GetAllComponentVersionCursors", }) - baseQuery := ` - SELECT CV.* %s FROM ComponentVersion CV + baseQuery := fmt.Sprintf(` + SELECT CV.* %s FROM ComponentVersion CV %s %s GROUP BY CV.componentversion_id ORDER BY %s - ` + `, s.getComponentVersionColumns(order), "%s", "%s", "%s") stmt, filterParameters, err := s.buildComponentVersionStatement( ctx, @@ -355,14 +319,14 @@ func (s *SqlDatabase) GetComponentVersions( "event": "database.GetComponentVersions", }) - baseQuery := ` - SELECT CV.* %s FROM ComponentVersion CV + baseQuery := fmt.Sprintf(` + SELECT CV.* %s FROM ComponentVersion CV %s %s GROUP BY CV.componentversion_id %s ORDER BY %s LIMIT ? - ` + `, s.getComponentVersionColumns(order), "%s", "%s", "%s", "%s") - filter = ensureComponentVersionFilter(filter) + filter = EnsureFilter(filter) stmt, filterParameters, err := s.buildComponentVersionStatement( ctx, @@ -415,7 +379,7 @@ func (s *SqlDatabase) CountComponentVersions(ctx context.Context, filter *entity }) baseQuery := ` - SELECT count(distinct CV.componentversion_id) %s FROM ComponentVersion CV + SELECT count(distinct CV.componentversion_id) FROM ComponentVersion CV %s %s ORDER BY %s diff --git a/internal/database/mariadb/db_object.go b/internal/database/mariadb/db_object.go index d354960c..1940ad08 100644 --- a/internal/database/mariadb/db_object.go +++ b/internal/database/mariadb/db_object.go @@ -4,8 +4,10 @@ package mariadb import ( + "context" "database/sql" "fmt" + "reflect" "strings" "github.com/cloudoperators/heureka/internal/database" @@ -65,8 +67,7 @@ func (do *DbObject[ET]) GetFilterParameters( paginated := filter.GetPaginated() filterParameters = append( filterParameters, - GetCursorQueryParameters(paginated.First, cursorFields)..., - ) + GetCursorQueryParameters(paginated.First, cursorFields)...) } return filterParameters @@ -166,27 +167,24 @@ func (do *DbObject[ET]) GetJoins(filter any, order *Order) string { return NewJoinResolver(do.JoinDefs).Build(filter, order) } -func (do *DbObject[ET]) GetFilterWhereClause(filter any, withCursor bool) (string, bool) { - filterStr := do.GetFilterQuery(filter) - - hasFilter := filterStr != "" - if hasFilter || withCursor { - return fmt.Sprintf("WHERE %s", filterStr), hasFilter +func (do *DbObject[ET]) GetFilterWhereClause(filter any, withCursor bool) string { + if filterStr := do.GetFilterQuery(filter); filterStr != "" || withCursor { + return fmt.Sprintf("WHERE %s", filterStr) } - return "", false + return "" } -func (do *DbObject[ET]) GetCursorQuery(hasFilter *bool, cursorFields []Field, withCursor *bool, aggregated bool) string { +func (do *DbObject[ET]) GetCursorQuery(filter any, cursorFields []Field, withCursor bool, checkCursor bool, aggregated bool) string { cursorQuery := CreateCursorQuery("", cursorFields) if aggregated { - if (withCursor == nil || *withCursor) && (hasFilter == nil || *hasFilter) && cursorQuery != "" { + if (!checkCursor || withCursor) && (IsNil(filter) || do.GetFilterQuery(filter) != "") && cursorQuery != "" { cursorQuery = fmt.Sprintf("HAVING (%s)", cursorQuery) } } else { - if hasFilter != nil { - if *hasFilter && *withCursor && cursorQuery != "" { + if !IsNil(filter) { + if do.GetFilterQuery(filter) != "" && withCursor && cursorQuery != "" { cursorQuery = fmt.Sprintf(" AND (%s)", cursorQuery) } } else { @@ -197,6 +195,77 @@ func (do *DbObject[ET]) GetCursorQuery(hasFilter *bool, cursorFields []Field, wi return cursorQuery } +type Object interface { + GetJoins(any, *Order) string + GetFilterWhereClause(any, bool) string + GetCursorQuery(any, []Field, bool, bool, bool) string + GetFilterParameters(entity.HasPagination, bool, []Field) []any +} + +type Statement struct { + Db Db + L *logrus.Entry + Obj Object + BaseQuery string + Order *Order + WithCursor bool + CheckCursorInWhere bool + CheckCursor bool + CheckFilter bool + Aggregated bool +} + +func BuildStatement[ + T any, + PT interface { + *T + entity.HasPagination + }, +](ctx context.Context, s Statement, filter PT) (Stmt, []any, error) { + filter = EnsureFilter(filter) + s.L.WithFields(logrus.Fields{"filter": filter}) + + joins := s.Obj.GetJoins(filter, s.Order) + whereClause := s.Obj.GetFilterWhereClause(filter, s.CheckCursorInWhere && s.WithCursor) + + var f PT + if s.CheckFilter { + f = filter + } + + cursorFields, err := DecodeCursor(filter.GetPaginated().After) + if err != nil { + return nil, nil, fmt.Errorf("failed to decode cursor: %w", err) + } + + cursorQuery := s.Obj.GetCursorQuery(f, cursorFields, s.WithCursor, s.CheckCursor, s.Aggregated) + + var query string + if s.WithCursor { + query = fmt.Sprintf(s.BaseQuery, joins, whereClause, cursorQuery, s.Order) + } else { + query = fmt.Sprintf(s.BaseQuery, joins, whereClause, s.Order) + } + + // construct prepared statement and if where clause does exist add parameters + stmt, err := s.Db.PreparexContext(ctx, query) + if err != nil { + msg := ERROR_MSG_PREPARED_STMT + s.L.WithFields( + logrus.Fields{ + "error": err, + "query": query, + "stmt": stmt, + }).Error(msg) + + return nil, nil, fmt.Errorf("%s", msg) + } + + filterParameters := s.Obj.GetFilterParameters(filter, s.WithCursor, cursorFields) + + return stmt, filterParameters, nil +} + // Property const NoUpdate = false @@ -370,8 +439,7 @@ func (jr *JoinResolver) Build(filter any, order *Order) string { uniqTableName[j.Table] = struct{}{} - joinSQL := fmt.Sprintf( - "%s %s ON %s", + joinSQL := fmt.Sprintf("%s %s ON %s", j.Type, j.Table, j.On, @@ -384,7 +452,41 @@ func (jr *JoinResolver) Build(filter any, order *Order) string { } // DB helpers -func EnsurePagination[T entity.HasPagination](filter T) T { +func EnsureFilter[ + T any, + PT interface { + *T + entity.HasPagination + }, +](filter PT) PT { + if IsNil(filter) { + filter = PT(new(T)) + } + + return ensurePagination(filter) +} + +func IsNil(v any) bool { + if v == nil { + return true + } + + rv := reflect.ValueOf(v) + + switch rv.Kind() { + case reflect.Ptr, + reflect.Interface, + reflect.Map, + reflect.Slice, + reflect.Func, + reflect.Chan: + return rv.IsNil() + } + + return false +} + +func ensurePagination[T entity.HasPagination](filter T) T { first := 1000 after := "" diff --git a/internal/database/mariadb/issue.go b/internal/database/mariadb/issue.go index 8db13698..46de610f 100644 --- a/internal/database/mariadb/issue.go +++ b/internal/database/mariadb/issue.go @@ -316,14 +316,6 @@ var issueObject = DbObject[*entity.Issue]{ }, } -func ensureIssueFilter(filter *entity.IssueFilter) *entity.IssueFilter { - if filter == nil { - filter = &entity.IssueFilter{} - } - - return EnsurePagination(filter) -} - func getIssueColumns(order []entity.Order) string { columns := "" @@ -340,112 +332,25 @@ func getIssueColumns(order []entity.Order) string { return columns } -func getIssueQueryWithCursor( - baseQuery string, - order []entity.Order, - filter *entity.IssueFilter, - cursorFields []Field, -) string { - issueColumns := getIssueColumns(order) - ord := NewOrder(order, entity.Order{By: entity.IssueId, Direction: entity.OrderDirectionAsc}) - joins := issueObject.GetJoins(filter, ord) - whereClause, hasFilter := issueObject.GetFilterWhereClause(filter, false) - issueCursor := issueObject.GetCursorQuery(&hasFilter, cursorFields, nil, true) - - return fmt.Sprintf(baseQuery, issueColumns, joins, whereClause, issueCursor, ord) -} - -func getIssueQuery(baseQuery string, order []entity.Order, filter *entity.IssueFilter) string { - issueColumns := getIssueColumns(order) - ord := NewOrder(order, entity.Order{By: entity.IssueId, Direction: entity.OrderDirectionAsc}) - joins := issueObject.GetJoins(filter, ord) - whereClause, _ := issueObject.GetFilterWhereClause(filter, false) - - return fmt.Sprintf(baseQuery, issueColumns, joins, whereClause, ord) -} - -func (s *SqlDatabase) buildIssueStatementWithCursor( - ctx context.Context, - baseQuery string, - filter *entity.IssueFilter, - order []entity.Order, - l *logrus.Entry, -) (Stmt, []any, error) { - ifilter := ensureIssueFilter(filter) - l.WithFields(logrus.Fields{"filter": ifilter}) - - cursorFields, err := DecodeCursor(ifilter.After) - if err != nil { - return nil, nil, err - } - - query := getIssueQueryWithCursor(baseQuery, order, ifilter, cursorFields) - - // construct prepared statement and if where clause does exist add parameters - stmt, err := s.db.PreparexContext(ctx, query) - if err != nil { - msg := ERROR_MSG_PREPARED_STMT - l.WithFields( - logrus.Fields{ - "error": err, - "query": query, - "stmt": stmt, - }, - ).Error(msg) - - return nil, nil, fmt.Errorf("%s", msg) - } - - // adding parameters - filterParameters := issueObject.GetFilterParameters(ifilter, true, cursorFields) - - return stmt, filterParameters, nil -} - -func (s *SqlDatabase) buildIssueStatement( - ctx context.Context, - baseQuery string, - filter *entity.IssueFilter, - order []entity.Order, - l *logrus.Entry, -) (Stmt, []any, error) { - ifilter := ensureIssueFilter(filter) - l.WithFields(logrus.Fields{"filter": ifilter}) - - cursorFields, err := DecodeCursor(ifilter.After) - if err != nil { - return nil, nil, err +func (s *SqlDatabase) buildIssueStatement(ctx context.Context, baseQuery string, filter *entity.IssueFilter, withCursor bool, order []entity.Order, l *logrus.Entry) (Stmt, []any, error) { + statement := Statement{ + Db: s.db, + L: l, + Obj: &issueObject, + BaseQuery: baseQuery, + Order: NewOrder(order, entity.Order{By: entity.IssueId, Direction: entity.OrderDirectionAsc}), + WithCursor: withCursor, + CheckCursorInWhere: false, + CheckCursor: true, + CheckFilter: true, + Aggregated: true, } - query := getIssueQuery(baseQuery, order, ifilter) - - // construct prepared statement and if where clause does exist add parameters - stmt, err := s.db.PreparexContext(ctx, query) - if err != nil { - msg := ERROR_MSG_PREPARED_STMT - l.WithFields( - logrus.Fields{ - "error": err, - "query": query, - "stmt": stmt, - }, - ).Error(msg) - - return nil, nil, fmt.Errorf("%s", msg) - } - - // adding parameters - filterParameters := issueObject.GetFilterParameters(ifilter, false, cursorFields) - - return stmt, filterParameters, nil + return BuildStatement(ctx, statement, filter) } -func (s *SqlDatabase) GetIssuesWithAggregations( - ctx context.Context, - filter *entity.IssueFilter, - order []entity.Order, -) ([]entity.IssueResult, error) { - filter = ensureIssueFilter(filter) +func (s *SqlDatabase) GetIssuesWithAggregations(ctx context.Context, filter *entity.IssueFilter, order []entity.Order) ([]entity.IssueResult, error) { + filter = EnsureFilter(filter) l := logrus.WithFields(logrus.Fields{ "filter": filter, "event": "database.GetIssuesWithAggregations", @@ -495,7 +400,7 @@ func (s *SqlDatabase) GetIssuesWithAggregations( JOIN Aggs A ON CIC.issue_id = A.issue_id; ` - filter = ensureIssueFilter(filter) + filter = EnsureFilter(filter) joins := issueObject.GetJoins(filter, NewOrder(order, entity.Order{})) // It seems that this join is redundant for baseAppQuery // We should improve testing and remove redundant joins from query @@ -507,7 +412,7 @@ func (s *SqlDatabase) GetIssuesWithAggregations( columns := getIssueColumns(order) ord := NewOrder(order, entity.Order{By: entity.IssueId, Direction: entity.OrderDirectionAsc}) - whereClause, _ := issueObject.GetFilterWhereClause(filter, false) + whereClause := issueObject.GetFilterWhereClause(filter, false) cursorQuery := CreateCursorQuery("", cursorFields) @@ -528,8 +433,7 @@ func (s *SqlDatabase) GetIssuesWithAggregations( "error": err, "query": query, "stmt": stmt, - }, - ).Error(msg) + }).Error(msg) return nil, fmt.Errorf("%s", msg) } @@ -537,10 +441,7 @@ func (s *SqlDatabase) GetIssuesWithAggregations( // parameters for component instance query filterParameters := issueObject.GetFilterParameters(filter, true, cursorFields) // parameters for agg query - filterParameters = append( - filterParameters, - issueObject.GetFilterParameters(filter, true, cursorFields)..., - ) + filterParameters = append(filterParameters, issueObject.GetFilterParameters(filter, true, cursorFields)...) defer func() { if err := stmt.Close(); err != nil { @@ -586,13 +487,13 @@ func (s *SqlDatabase) CountIssues(ctx context.Context, filter *entity.IssueFilte }) baseQuery := ` - SELECT COUNT(distinct I.issue_id) %s FROM Issue I + SELECT COUNT(distinct I.issue_id) FROM Issue I %s %s ORDER BY %s ` - stmt, filterParameters, err := s.buildIssueStatement(ctx, baseQuery, filter, []entity.Order{}, l) + stmt, filterParameters, err := s.buildIssueStatement(ctx, baseQuery, filter, false, []entity.Order{}, l) if err != nil { return -1, err } @@ -612,13 +513,13 @@ func (s *SqlDatabase) CountIssueTypes(ctx context.Context, filter *entity.IssueF }) baseQuery := ` - SELECT I.issue_type AS issue_value, COUNT(distinct I.issue_id) as issue_count %s FROM Issue I + SELECT I.issue_type AS issue_value, COUNT(distinct I.issue_id) as issue_count FROM Issue I %s %s GROUP BY I.issue_type ORDER BY %s ` - stmt, filterParameters, err := s.buildIssueStatement(ctx, baseQuery, filter, []entity.Order{}, l) + stmt, filterParameters, err := s.buildIssueStatement(ctx, baseQuery, filter, false, []entity.Order{}, l) if err != nil { return nil, err } @@ -674,7 +575,10 @@ func (s *SqlDatabase) GetAllIssueCursors( %s GROUP BY I.issue_id ORDER BY %s ` - stmt, filterParameters, err := s.buildIssueStatement(ctx, baseQuery, filter, order, l) + issueColumns := getIssueColumns(order) + baseQuery = fmt.Sprintf(baseQuery, issueColumns, "%s", "%s", "%s") + + stmt, filterParameters, err := s.buildIssueStatement(ctx, baseQuery, filter, false, order, l) if err != nil { return nil, err } @@ -728,9 +632,12 @@ func (s *SqlDatabase) GetIssues( GROUP BY I.issue_id %s ORDER BY %s LIMIT ? ` - filter = ensureIssueFilter(filter) + filter = EnsureFilter(filter) + + issueColumns := getIssueColumns(order) + baseQuery = fmt.Sprintf(baseQuery, issueColumns, "%s", "%s", "%s", "%s") - stmt, filterParameters, err := s.buildIssueStatementWithCursor(ctx, baseQuery, filter, order, l) + stmt, filterParameters, err := s.buildIssueStatement(ctx, baseQuery, filter, true, order, l) if err != nil { return nil, err } @@ -818,10 +725,7 @@ func (s *SqlDatabase) AddComponentVersionToIssue(issueId int64, componentVersion return nil } -func (s *SqlDatabase) RemoveComponentVersionFromIssue( - issueId int64, - componentVersionId int64, -) error { +func (s *SqlDatabase) RemoveComponentVersionFromIssue(issueId int64, componentVersionId int64) error { l := logrus.WithFields(logrus.Fields{ "issueId": issueId, "componentVersionId": componentVersionId, @@ -884,10 +788,10 @@ func (s *SqlDatabase) GetIssueNames(ctx context.Context, filter *entity.IssueFil } // Ensure the filter is initialized - filter = ensureIssueFilter(filter) + filter = EnsureFilter(filter) // Builds full statement with possible joins and filters - stmt, filterParameters, err := s.buildIssueStatement(ctx, baseQuery, filter, order, l) + stmt, filterParameters, err := s.buildIssueStatement(ctx, baseQuery, filter, false, order, l) if err != nil { l.Error("Error preparing statement: ", err) return nil, err diff --git a/internal/database/mariadb/issue_match.go b/internal/database/mariadb/issue_match.go index 0d937d86..406eb7d3 100644 --- a/internal/database/mariadb/issue_match.go +++ b/internal/database/mariadb/issue_match.go @@ -254,14 +254,6 @@ var issueMatchObject = DbObject[*entity.IssueMatch]{ }, } -func ensureIssueMatchFilter(filter *entity.IssueMatchFilter) *entity.IssueMatchFilter { - if filter == nil { - filter = &entity.IssueMatchFilter{} - } - - return EnsurePagination(filter) -} - func (s *SqlDatabase) getIssueMatchColumns(order []entity.Order) string { columns := "" @@ -285,45 +277,20 @@ func (s *SqlDatabase) buildIssueMatchStatement( order []entity.Order, l *logrus.Entry, ) (Stmt, []any, error) { - filter = ensureIssueMatchFilter(filter) - l.WithFields(logrus.Fields{"filter": filter}) - - cursorFields, err := DecodeCursor(filter.After) - if err != nil { - return nil, nil, err - } - - ord := NewOrder(order, entity.Order{By: entity.IssueMatchId, Direction: entity.OrderDirectionAsc}) - columns := s.getIssueMatchColumns(ord.Sequence()) - joins := issueMatchObject.GetJoins(filter, ord) - whereClause, hasFilter := issueMatchObject.GetFilterWhereClause(filter, withCursor) - cursorQuery := issueMatchObject.GetCursorQuery(&hasFilter, cursorFields, &withCursor, false) - - var query string - if withCursor { - query = fmt.Sprintf(baseQuery, columns, joins, whereClause, cursorQuery, ord) - } else { - query = fmt.Sprintf(baseQuery, columns, joins, whereClause, ord) - } - - // construct prepared statement and if where clause does exist add parameters - stmt, err := s.db.PreparexContext(ctx, query) - if err != nil { - msg := ERROR_MSG_PREPARED_STMT - l.WithFields( - logrus.Fields{ - "error": err, - "query": query, - "stmt": stmt, - }, - ).Error(msg) - - return nil, nil, fmt.Errorf("%s", msg) + statement := Statement{ + Db: s.db, + L: l, + Obj: &issueMatchObject, + BaseQuery: baseQuery, + Order: NewOrder(order, entity.Order{By: entity.IssueMatchId, Direction: entity.OrderDirectionAsc}), + WithCursor: withCursor, + CheckCursorInWhere: true, + CheckCursor: true, + CheckFilter: true, + Aggregated: false, } - filterParameters := issueMatchObject.GetFilterParameters(filter, withCursor, cursorFields) - - return stmt, filterParameters, nil + return BuildStatement(ctx, statement, filter) } func (s *SqlDatabase) GetAllIssueMatchIds(ctx context.Context, filter *entity.IssueMatchFilter) ([]int64, error) { @@ -333,7 +300,7 @@ func (s *SqlDatabase) GetAllIssueMatchIds(ctx context.Context, filter *entity.Is }) baseQuery := ` - SELECT IM.issuematch_id %s FROM IssueMatch IM + SELECT IM.issuematch_id FROM IssueMatch IM %s %s GROUP BY IM.issuematch_id ORDER BY %s ` @@ -364,11 +331,14 @@ func (s *SqlDatabase) GetAllIssueMatchCursors( }) baseQuery := ` - SELECT IM.* %s FROM IssueMatch IM + SELECT IM.* %s FROM IssueMatch IM %s %s GROUP BY IM.issuematch_id ORDER BY %s ` + columns := s.getIssueMatchColumns(order) + baseQuery = fmt.Sprintf(baseQuery, columns, "%s", "%s", "%s") + stmt, filterParameters, err := s.buildIssueMatchStatement(ctx, baseQuery, filter, false, order, l) if err != nil { return nil, err @@ -414,11 +384,14 @@ func (s *SqlDatabase) GetIssueMatches( }) baseQuery := ` - SELECT IM.* %s FROM IssueMatch IM + SELECT IM.* %s FROM IssueMatch IM %s %s %s GROUP BY IM.issuematch_id ORDER BY %s LIMIT ? ` + columns := s.getIssueMatchColumns(order) + baseQuery = fmt.Sprintf(baseQuery, columns, "%s", "%s", "%s", "%s") + stmt, filterParameters, err := s.buildIssueMatchStatement(ctx, baseQuery, filter, true, order, l) if err != nil { return nil, err @@ -466,7 +439,7 @@ func (s *SqlDatabase) CountIssueMatches(ctx context.Context, filter *entity.Issu }) baseQuery := ` - SELECT count(distinct IM.issuematch_id) %s FROM IssueMatch IM + SELECT count(distinct IM.issuematch_id) FROM IssueMatch IM %s %s ORDER BY %s diff --git a/internal/database/mariadb/issue_repository.go b/internal/database/mariadb/issue_repository.go index 0f348a4e..2fe483c5 100644 --- a/internal/database/mariadb/issue_repository.go +++ b/internal/database/mariadb/issue_repository.go @@ -94,16 +94,6 @@ var issueRepositoryObject = DbObject[*entity.IssueRepository]{ }, } -func ensureIssueRepositoryFilter( - filter *entity.IssueRepositoryFilter, -) *entity.IssueRepositoryFilter { - if filter == nil { - filter = &entity.IssueRepositoryFilter{} - } - - return EnsurePagination(filter) -} - func (s *SqlDatabase) buildIssueRepositoryStatement( ctx context.Context, baseQuery string, @@ -112,43 +102,20 @@ func (s *SqlDatabase) buildIssueRepositoryStatement( order []entity.Order, l *logrus.Entry, ) (Stmt, []any, error) { - filter = ensureIssueRepositoryFilter(filter) - l.WithFields(logrus.Fields{"filter": filter}) - - cursorFields, err := DecodeCursor(filter.After) - if err != nil { - return nil, nil, fmt.Errorf("failed to decode IssueRepository cursor: %w", err) + statement := Statement{ + Db: s.db, + L: l, + Obj: &issueRepositoryObject, + BaseQuery: baseQuery, + Order: NewOrder(order, entity.Order{By: entity.IssueRepositoryID, Direction: entity.OrderDirectionAsc}), + WithCursor: withCursor, + CheckCursorInWhere: true, + CheckCursor: true, + CheckFilter: true, + Aggregated: false, } - ord := NewOrder(order, entity.Order{By: entity.IssueRepositoryID, Direction: entity.OrderDirectionAsc}) - joins := issueRepositoryObject.GetJoins(filter, ord) - whereClause, hasFilter := issueRepositoryObject.GetFilterWhereClause(filter, withCursor) - cursorQuery := issueRepositoryObject.GetCursorQuery(&hasFilter, cursorFields, &withCursor, false) - - var query string - if withCursor { - query = fmt.Sprintf(baseQuery, joins, whereClause, cursorQuery, ord) - } else { - query = fmt.Sprintf(baseQuery, joins, whereClause, ord) - } - - stmt, err := s.db.PreparexContext(ctx, query) - if err != nil { - msg := ERROR_MSG_PREPARED_STMT - l.WithFields( - logrus.Fields{ - "error": err, - "query": query, - "stmt": stmt, - }, - ).Error(msg) - - return nil, nil, fmt.Errorf("failed to prepare IssueRepository statement: %w", err) - } - - filterParameters := issueRepositoryObject.GetFilterParameters(filter, withCursor, cursorFields) - - return stmt, filterParameters, nil + return BuildStatement(ctx, statement, filter) } func (s *SqlDatabase) GetAllIssueRepositoryCursors( @@ -167,7 +134,7 @@ func (s *SqlDatabase) GetAllIssueRepositoryCursors( %s GROUP BY IR.issuerepository_id ORDER BY %s ` - filter = ensureIssueRepositoryFilter(filter) + filter = EnsureFilter(filter) stmt, filterParameters, err := s.buildIssueRepositoryStatement( ctx, @@ -225,7 +192,7 @@ func (s *SqlDatabase) GetIssueRepositories( %s GROUP BY IR.issuerepository_id ORDER BY %s LIMIT ? ` - filter = ensureIssueRepositoryFilter(filter) + filter = EnsureFilter(filter) stmt, filterParameters, err := s.buildIssueRepositoryStatement( ctx, diff --git a/internal/database/mariadb/issue_variant.go b/internal/database/mariadb/issue_variant.go index 5f852f18..f9f3e0c1 100644 --- a/internal/database/mariadb/issue_variant.go +++ b/internal/database/mariadb/issue_variant.go @@ -153,14 +153,6 @@ var issueVariantObject = DbObject[*entity.IssueVariant]{ }, } -func ensureIssueVariantFilter(filter *entity.IssueVariantFilter) *entity.IssueVariantFilter { - if filter == nil { - filter = &entity.IssueVariantFilter{} - } - - return EnsurePagination(filter) -} - func (s *SqlDatabase) buildIssueVariantStatement( ctx context.Context, baseQuery string, @@ -169,43 +161,20 @@ func (s *SqlDatabase) buildIssueVariantStatement( order []entity.Order, l *logrus.Entry, ) (Stmt, []any, error) { - filter = ensureIssueVariantFilter(filter) - l.WithFields(logrus.Fields{"filter": filter}) - - cursorFields, err := DecodeCursor(filter.After) - if err != nil { - return nil, nil, fmt.Errorf("failed to decode IssueVariant cursor: %w", err) - } - - ord := NewOrder(order, entity.Order{By: entity.IssueVariantID, Direction: entity.OrderDirectionAsc}) - joins := issueVariantObject.GetJoins(filter, ord) - whereClause, hasFilter := issueVariantObject.GetFilterWhereClause(filter, withCursor) - cursorQuery := issueVariantObject.GetCursorQuery(&hasFilter, cursorFields, &withCursor, false) - - var query string - if withCursor { - query = fmt.Sprintf(baseQuery, joins, whereClause, cursorQuery, ord) - } else { - query = fmt.Sprintf(baseQuery, joins, whereClause, ord) + statement := Statement{ + Db: s.db, + L: l, + Obj: &issueVariantObject, + BaseQuery: baseQuery, + Order: NewOrder(order, entity.Order{By: entity.IssueVariantID, Direction: entity.OrderDirectionAsc}), + WithCursor: withCursor, + CheckCursorInWhere: true, + CheckCursor: true, + CheckFilter: true, + Aggregated: false, } - stmt, err := s.db.PreparexContext(ctx, query) - if err != nil { - msg := ERROR_MSG_PREPARED_STMT - l.WithFields( - logrus.Fields{ - "error": err, - "query": query, - "stmt": stmt, - }, - ).Error(msg) - - return nil, nil, fmt.Errorf("failed to prepare IssueVariant statement: %w", err) - } - - filterParameters := issueVariantObject.GetFilterParameters(filter, withCursor, cursorFields) - - return stmt, filterParameters, nil + return BuildStatement(ctx, statement, filter) } func (s *SqlDatabase) GetAllIssueVariantCursors( @@ -224,7 +193,7 @@ func (s *SqlDatabase) GetAllIssueVariantCursors( %s GROUP BY IV.issuevariant_id ORDER BY %s ` - filter = ensureIssueVariantFilter(filter) + filter = EnsureFilter(filter) stmt, filterParameters, err := s.buildIssueVariantStatement(ctx, baseQuery, filter, false, order, l) if err != nil { diff --git a/internal/database/mariadb/mv_vulnerabilities.go b/internal/database/mariadb/mv_vulnerabilities.go index d74bac50..1267ad17 100644 --- a/internal/database/mariadb/mv_vulnerabilities.go +++ b/internal/database/mariadb/mv_vulnerabilities.go @@ -53,7 +53,7 @@ func (s *SqlDatabase) CountIssueRatings( filterParameters []any ) - filter = ensureIssueFilter(filter) + filter = EnsureFilter(filter) baseQuery := ` SELECT CIR.critical_count, CIR.high_count, CIR.medium_count, CIR.low_count, CIR.none_count FROM %s AS CIR diff --git a/internal/database/mariadb/patch.go b/internal/database/mariadb/patch.go index 7a5356b5..826b1252 100644 --- a/internal/database/mariadb/patch.go +++ b/internal/database/mariadb/patch.go @@ -13,7 +13,6 @@ import ( ) var patchObject = DbObject[*entity.Patch]{ - Properties: []*Property{}, FilterProperties: []*FilterProperty{ NewFilterProperty( "P.patch_id = ?", @@ -48,14 +47,6 @@ var patchObject = DbObject[*entity.Patch]{ }, } -func ensurePatchFilter(filter *entity.PatchFilter) *entity.PatchFilter { - if filter == nil { - filter = &entity.PatchFilter{} - } - - return EnsurePagination(filter) -} - func (s *SqlDatabase) buildPatchStatement( ctx context.Context, baseQuery string, @@ -64,42 +55,20 @@ func (s *SqlDatabase) buildPatchStatement( order []entity.Order, l *logrus.Entry, ) (Stmt, []any, error) { - filter = ensurePatchFilter(filter) - l.WithFields(logrus.Fields{"filter": filter}) - - cursorFields, err := DecodeCursor(filter.After) - if err != nil { - return nil, nil, fmt.Errorf("failed to decode Patch cursor: %w", err) - } - - ord := NewOrder(order, entity.Order{By: entity.PatchId, Direction: entity.OrderDirectionAsc}) - whereClause, hasFilter := patchObject.GetFilterWhereClause(filter, withCursor) - cursorQuery := patchObject.GetCursorQuery(&hasFilter, cursorFields, &withCursor, false) - - var query string - if withCursor { - query = fmt.Sprintf(baseQuery, whereClause, cursorQuery, ord) - } else { - query = fmt.Sprintf(baseQuery, whereClause, ord) - } - - stmt, err := s.db.PreparexContext(ctx, query) - if err != nil { - msg := ERROR_MSG_PREPARED_STMT - l.WithFields( - logrus.Fields{ - "error": err, - "query": query, - "stmt": stmt, - }, - ).Error(msg) - - return nil, nil, fmt.Errorf("failed to prepare Patch statement: %w", err) + statement := Statement{ + Db: s.db, + L: l, + Obj: &patchObject, + BaseQuery: baseQuery, + Order: NewOrder(order, entity.Order{By: entity.PatchId, Direction: entity.OrderDirectionAsc}), + WithCursor: withCursor, + CheckCursorInWhere: true, + CheckCursor: true, + CheckFilter: true, + Aggregated: false, } - filterParameters := patchObject.GetFilterParameters(filter, withCursor, cursorFields) - - return stmt, filterParameters, nil + return BuildStatement(ctx, statement, filter) } func (s *SqlDatabase) GetPatches( @@ -117,6 +86,7 @@ func (s *SqlDatabase) GetPatches( SELECT P.* FROM Patch P %s %s + %s GROUP BY P.patch_id ORDER BY %s LIMIT ? ` @@ -166,6 +136,7 @@ func (s *SqlDatabase) CountPatches(ctx context.Context, filter *entity.PatchFilt baseQuery := ` SELECT count(distinct P.patch_id) FROM Patch P %s + %s ORDER BY %s ` @@ -207,10 +178,11 @@ func (s *SqlDatabase) GetAllPatchCursors( baseQuery := ` SELECT P.* FROM Patch P + %s %s GROUP BY P.patch_id ORDER BY %s ` - filter = ensurePatchFilter(filter) + filter = EnsureFilter(filter) stmt, filterParameters, err := s.buildPatchStatement(ctx, baseQuery, filter, false, order, l) if err != nil { diff --git a/internal/database/mariadb/remediation.go b/internal/database/mariadb/remediation.go index 0680e8a3..5983f49b 100644 --- a/internal/database/mariadb/remediation.go +++ b/internal/database/mariadb/remediation.go @@ -183,14 +183,6 @@ var remediationObject = DbObject[*entity.Remediation]{ }, } -func ensureRemediationFilter(filter *entity.RemediationFilter) *entity.RemediationFilter { - if filter == nil { - filter = &entity.RemediationFilter{} - } - - return EnsurePagination(filter) -} - func (s *SqlDatabase) buildRemediationStatement( ctx context.Context, baseQuery string, @@ -199,42 +191,20 @@ func (s *SqlDatabase) buildRemediationStatement( order []entity.Order, l *logrus.Entry, ) (Stmt, []any, error) { - filter = ensureRemediationFilter(filter) - l.WithFields(logrus.Fields{"filter": filter}) - - cursorFields, err := DecodeCursor(filter.After) - if err != nil { - return nil, nil, fmt.Errorf("failed to decode Remediation cursor: %w", err) - } - - ord := NewOrder(order, entity.Order{By: entity.RemediationId, Direction: entity.OrderDirectionAsc}) - whereClause, hasFilter := remediationObject.GetFilterWhereClause(filter, withCursor) - cursorQuery := remediationObject.GetCursorQuery(&hasFilter, cursorFields, &withCursor, false) - - var query string - if withCursor { - query = fmt.Sprintf(baseQuery, whereClause, cursorQuery, ord) - } else { - query = fmt.Sprintf(baseQuery, whereClause, ord) - } - - stmt, err := s.db.PreparexContext(ctx, query) - if err != nil { - msg := ERROR_MSG_PREPARED_STMT - l.WithFields( - logrus.Fields{ - "error": err, - "query": query, - "stmt": stmt, - }, - ).Error(msg) - - return nil, nil, fmt.Errorf("failed to prepare Remediation statement: %w", err) + statement := Statement{ + Db: s.db, + L: l, + Obj: &remediationObject, + BaseQuery: baseQuery, + Order: NewOrder(order, entity.Order{By: entity.RemediationId, Direction: entity.OrderDirectionAsc}), + WithCursor: withCursor, + CheckCursorInWhere: true, + CheckCursor: true, + CheckFilter: true, + Aggregated: false, } - filterParameters := remediationObject.GetFilterParameters(filter, withCursor, cursorFields) - - return stmt, filterParameters, nil + return BuildStatement(ctx, statement, filter) } func (s *SqlDatabase) GetRemediations( @@ -252,6 +222,7 @@ func (s *SqlDatabase) GetRemediations( SELECT R.* FROM Remediation R %s %s + %s GROUP BY R.remediation_id ORDER BY %s LIMIT ? ` @@ -301,6 +272,7 @@ func (s *SqlDatabase) CountRemediations(ctx context.Context, filter *entity.Reme baseQuery := ` SELECT count(distinct R.remediation_id) FROM Remediation R %s + %s ORDER BY %s ` @@ -342,10 +314,11 @@ func (s *SqlDatabase) GetAllRemediationCursors( baseQuery := ` SELECT R.* FROM Remediation R + %s %s GROUP BY R.remediation_id ORDER BY %s ` - filter = ensureRemediationFilter(filter) + filter = EnsureFilter(filter) stmt, filterParameters, err := s.buildRemediationStatement(ctx, baseQuery, filter, false, order, l) if err != nil { diff --git a/internal/database/mariadb/scanner_run.go b/internal/database/mariadb/scanner_run.go index 6120d9e3..76784764 100644 --- a/internal/database/mariadb/scanner_run.go +++ b/internal/database/mariadb/scanner_run.go @@ -111,7 +111,7 @@ func (s *SqlDatabase) ScannerRunByUUID(uuid string) (*entity.ScannerRun, error) } func (s *SqlDatabase) GetScannerRuns(filter *entity.ScannerRunFilter) ([]entity.ScannerRun, error) { - filter = ensureScannerRunFilter(filter) + filter = EnsureFilter(filter) baseQuery := ` SELECT * FROM ScannerRun @@ -197,14 +197,6 @@ func applyScannerRunFilter(baseQuery string, filter *entity.ScannerRunFilter) ([ return queryArgs, baseQuery } -func ensureScannerRunFilter(filter *entity.ScannerRunFilter) *entity.ScannerRunFilter { - if filter == nil { - filter = &entity.ScannerRunFilter{} - } - - return EnsurePagination(filter) -} - func (s *SqlDatabase) GetScannerRunTags() ([]string, error) { query := `SELECT DISTINCT scannerrun_tag diff --git a/internal/database/mariadb/service.go b/internal/database/mariadb/service.go index fc4f6049..9c10e0c5 100644 --- a/internal/database/mariadb/service.go +++ b/internal/database/mariadb/service.go @@ -187,14 +187,6 @@ var serviceObject = DbObject[*entity.Service]{ }, } -func ensureServiceFilter(filter *entity.ServiceFilter) *entity.ServiceFilter { - if filter == nil { - filter = &entity.ServiceFilter{} - } - - return EnsurePagination(filter) -} - func (s *SqlDatabase) getServiceColumns(filter *entity.ServiceFilter, order []entity.Order) string { columns := "S.*" if len(filter.IssueRepositoryId) > 0 { @@ -227,47 +219,20 @@ func (s *SqlDatabase) buildServiceStatement( order []entity.Order, l *logrus.Entry, ) (Stmt, []any, error) { - var query string - - filter = ensureServiceFilter(filter) - l.WithFields(logrus.Fields{"filter": filter}) - - cursorFields, err := DecodeCursor(filter.After) - if err != nil { - return nil, nil, err - } - - ord := NewOrder(order, entity.Order{By: entity.ServiceId, Direction: entity.OrderDirectionAsc}) - joins := serviceObject.GetJoins(filter, ord) - whereClause, hasFilter := serviceObject.GetFilterWhereClause(filter, withCursor) - cursorQuery := serviceObject.GetCursorQuery(&hasFilter, cursorFields, &withCursor, true) - - // construct final query - if withCursor { - query = fmt.Sprintf(baseQuery, joins, whereClause, cursorQuery, ord) - } else { - query = fmt.Sprintf(baseQuery, joins, whereClause, ord) + statement := Statement{ + Db: s.db, + L: l, + Obj: &serviceObject, + BaseQuery: baseQuery, + Order: NewOrder(order, entity.Order{By: entity.ServiceId, Direction: entity.OrderDirectionAsc}), + WithCursor: withCursor, + CheckCursorInWhere: true, + CheckCursor: true, + CheckFilter: true, + Aggregated: true, } - // construct prepared statement and if where clause does exist add parameters - stmt, err := s.db.PreparexContext(ctx, query) - if err != nil { - msg := ERROR_MSG_PREPARED_STMT - l.WithFields( - logrus.Fields{ - "error": err, - "query": query, - "stmt": stmt, - }, - ).Error(msg) - - return nil, nil, fmt.Errorf("%s", msg) - } - - // adding parameters - filterParameters := serviceObject.GetFilterParameters(filter, withCursor, cursorFields) - - return stmt, filterParameters, nil + return BuildStatement(ctx, statement, filter) } func (s *SqlDatabase) CountServices(ctx context.Context, filter *entity.ServiceFilter) (int64, error) { @@ -319,7 +284,7 @@ func (s *SqlDatabase) GetServices( GROUP BY S.service_id %s ORDER BY %s LIMIT ? ` - filter = ensureServiceFilter(filter) + filter = EnsureFilter(filter) columns := s.getServiceColumns(filter, order) baseQuery = fmt.Sprintf(baseQuery, columns, "%s", "%s", "%s", "%s") @@ -424,7 +389,7 @@ func (s *SqlDatabase) GetServicesWithAggregations( FROM ComponentInstanceCounts CIC JOIN IssueMatchCounts IMC ON CIC.service_id = IMC.service_id; ` - filter = ensureServiceFilter(filter) + filter = EnsureFilter(filter) ord := NewOrder(order, entity.Order{By: entity.ServiceId, Direction: entity.OrderDirectionAsc}) joins := serviceObject.GetJoins(filter, ord) columns := s.getServiceColumns(filter, ord.Sequence()) @@ -459,8 +424,7 @@ func (s *SqlDatabase) GetServicesWithAggregations( "error": err, "query": query, "stmt": stmt, - }, - ).Error(msg) + }).Error(msg) return nil, fmt.Errorf("%s", msg) } @@ -470,8 +434,7 @@ func (s *SqlDatabase) GetServicesWithAggregations( // parameters for component instance query filterParameters = append( filterParameters, - serviceObject.GetFilterParameters(filter, true, cursorFields)..., - ) + serviceObject.GetFilterParameters(filter, true, cursorFields)...) defer func() { if err := stmt.Close(); err != nil { @@ -526,7 +489,7 @@ func (s *SqlDatabase) GetAllServiceCursors( %s GROUP BY S.service_id ORDER BY %s ` - filter = ensureServiceFilter(filter) + filter = EnsureFilter(filter) columns := s.getServiceColumns(filter, order) baseQuery = fmt.Sprintf(baseQuery, columns, "%s", "%s", "%s") @@ -734,7 +697,7 @@ func (s *SqlDatabase) getServiceAttr( baseQuery = fmt.Sprintf(baseQuery, attrName, "%s", "%s", "%s") // Ensure the filter is initialized - filter = ensureServiceFilter(filter) + filter = EnsureFilter(filter) order := []entity.Order{ {By: entity.ServiceCcrn, Direction: entity.OrderDirectionAsc}, } diff --git a/internal/database/mariadb/service_issue_variant.go b/internal/database/mariadb/service_issue_variant.go index 534e672a..d8683528 100644 --- a/internal/database/mariadb/service_issue_variant.go +++ b/internal/database/mariadb/service_issue_variant.go @@ -5,7 +5,6 @@ package mariadb import ( "context" - "fmt" "github.com/cloudoperators/heureka/internal/entity" "github.com/sirupsen/logrus" @@ -35,16 +34,6 @@ var serviceIssueVariantObject = DbObject[*entity.ServiceIssueVariant]{ }, } -func ensureServiceIssueVariantFilter( - filter *entity.ServiceIssueVariantFilter, -) *entity.ServiceIssueVariantFilter { - if filter == nil { - filter = &entity.ServiceIssueVariantFilter{} - } - - return EnsurePagination(filter) -} - func (s *SqlDatabase) buildServiceIssueVariantStatement( ctx context.Context, baseQuery string, @@ -53,47 +42,20 @@ func (s *SqlDatabase) buildServiceIssueVariantStatement( order []entity.Order, l *logrus.Entry, ) (Stmt, []any, error) { - filter = ensureServiceIssueVariantFilter(filter) - l.WithFields(logrus.Fields{"filter": filter}) - - cursorFields, err := DecodeCursor(filter.After) - if err != nil { - return nil, nil, fmt.Errorf("failed to decode ServiceIssueVariant cursor: %w", err) - } - - ord := NewOrder(order, entity.Order{By: entity.ServiceIssueVariantID, Direction: entity.OrderDirectionAsc}) - whereClause, hasFilter := serviceIssueVariantObject.GetFilterWhereClause(filter, withCursor) - cursorQuery := serviceIssueVariantObject.GetCursorQuery(&hasFilter, cursorFields, &withCursor, false) - - var query string - if withCursor { - query = fmt.Sprintf(baseQuery, whereClause, cursorQuery, ord) - } else { - query = fmt.Sprintf(baseQuery, whereClause, ord) - } - - // construct prepared statement and if where clause does exist add parameters - stmt, err := s.db.PreparexContext(ctx, query) - if err != nil { - msg := ERROR_MSG_PREPARED_STMT - l.WithFields( - logrus.Fields{ - "error": err, - "query": query, - "stmt": stmt, - }, - ).Error(msg) - - return nil, nil, fmt.Errorf("%s", msg) + statement := Statement{ + Db: s.db, + L: l, + Obj: &serviceIssueVariantObject, + BaseQuery: baseQuery, + Order: NewOrder(order, entity.Order{By: entity.ServiceIssueVariantID, Direction: entity.OrderDirectionAsc}), + WithCursor: withCursor, + CheckCursorInWhere: true, + CheckCursor: true, + CheckFilter: true, + Aggregated: false, } - filterParameters := serviceIssueVariantObject.GetFilterParameters( - filter, - withCursor, - cursorFields, - ) - - return stmt, filterParameters, nil + return BuildStatement(ctx, statement, filter) } // TODO: adjust this function to fit dbObject @@ -121,6 +83,7 @@ func (s *SqlDatabase) GetServiceIssueVariants( # Join to from repo and issue to IssueVariant INNER JOIN IssueVariant IV on I.issue_id = IV.issuevariant_issue_id and IV.issuevariant_repository_id = IR.issuerepository_id %s + %s %s ORDER BY %s LIMIT ? ` diff --git a/internal/database/mariadb/support_group.go b/internal/database/mariadb/support_group.go index 13d7dacf..8c62fa07 100644 --- a/internal/database/mariadb/support_group.go +++ b/internal/database/mariadb/support_group.go @@ -6,7 +6,6 @@ package mariadb import ( "context" "errors" - "fmt" "github.com/cloudoperators/heureka/internal/database" "github.com/cloudoperators/heureka/internal/entity" @@ -110,14 +109,6 @@ var supportGroupObject = DbObject[*entity.SupportGroup]{ }, } -func ensureSupportGroupFilter(filter *entity.SupportGroupFilter) *entity.SupportGroupFilter { - if filter == nil { - filter = &entity.SupportGroupFilter{} - } - - return EnsurePagination(filter) -} - func (s *SqlDatabase) buildSupportGroupStatement( ctx context.Context, baseQuery string, @@ -126,45 +117,20 @@ func (s *SqlDatabase) buildSupportGroupStatement( order []entity.Order, l *logrus.Entry, ) (Stmt, []any, error) { - var query string - - filter = ensureSupportGroupFilter(filter) - l.WithFields(logrus.Fields{"filter": filter}) - - cursorFields, err := DecodeCursor(filter.After) - if err != nil { - return nil, nil, err + statement := Statement{ + Db: s.db, + L: l, + Obj: &supportGroupObject, + BaseQuery: baseQuery, + Order: NewOrder(order, entity.Order{By: entity.SupportGroupId, Direction: entity.OrderDirectionAsc}), + WithCursor: withCursor, + CheckCursorInWhere: true, + CheckCursor: true, + CheckFilter: true, + Aggregated: false, } - ord := NewOrder(order, entity.Order{By: entity.SupportGroupId, Direction: entity.OrderDirectionAsc}) - joins := supportGroupObject.GetJoins(filter, ord) - whereClause, hasFilter := supportGroupObject.GetFilterWhereClause(filter, withCursor) - cursorQuery := supportGroupObject.GetCursorQuery(&hasFilter, cursorFields, &withCursor, false) - - if withCursor { - query = fmt.Sprintf(baseQuery, joins, whereClause, cursorQuery, ord) - } else { - query = fmt.Sprintf(baseQuery, joins, whereClause, ord) - } - - // construct prepared statement and if where clause does exist add parameters - stmt, err := s.db.PreparexContext(ctx, query) - if err != nil { - msg := ERROR_MSG_PREPARED_STMT - l.WithFields( - logrus.Fields{ - "error": err, - "query": query, - "stmt": stmt, - }, - ).Error(msg) - - return nil, nil, fmt.Errorf("%s", msg) - } - - filterParameters := supportGroupObject.GetFilterParameters(filter, withCursor, cursorFields) - - return stmt, filterParameters, nil + return BuildStatement(ctx, statement, filter) } func (s *SqlDatabase) GetAllSupportGroupCursors( @@ -183,7 +149,7 @@ func (s *SqlDatabase) GetAllSupportGroupCursors( %s GROUP BY SG.supportgroup_id ORDER BY %s ` - filter = ensureSupportGroupFilter(filter) + filter = EnsureFilter(filter) stmt, filterParameters, err := s.buildSupportGroupStatement(ctx, baseQuery, filter, false, order, l) if err != nil { @@ -451,7 +417,7 @@ func (s *SqlDatabase) GetSupportGroupCcrns(ctx context.Context, filter *entity.S } // Ensure the filter is initialized - filter = ensureSupportGroupFilter(filter) + filter = EnsureFilter(filter) // Builds full statement with possible joins and filters stmt, filterParameters, err := s.buildSupportGroupStatement(ctx, baseQuery, filter, false, order, l) diff --git a/internal/database/mariadb/user.go b/internal/database/mariadb/user.go index ee59498a..3329915e 100644 --- a/internal/database/mariadb/user.go +++ b/internal/database/mariadb/user.go @@ -99,14 +99,6 @@ var userObject = DbObject[*entity.User]{ }, } -func ensureUserFilter(filter *entity.UserFilter) *entity.UserFilter { - if filter == nil { - return &entity.UserFilter{} - } - - return EnsurePagination(filter) -} - func (s *SqlDatabase) buildUserStatement( ctx context.Context, baseQuery string, @@ -115,43 +107,20 @@ func (s *SqlDatabase) buildUserStatement( order []entity.Order, l *logrus.Entry, ) (Stmt, []any, error) { - filter = ensureUserFilter(filter) - l.WithFields(logrus.Fields{"filter": filter}) - - cursorFields, err := DecodeCursor(filter.After) - if err != nil { - return nil, nil, fmt.Errorf("failed to decode User cursor: %w", err) + statement := Statement{ + Db: s.db, + L: l, + Obj: &userObject, + BaseQuery: baseQuery, + Order: NewOrder(order, entity.Order{By: entity.UserID, Direction: entity.OrderDirectionAsc}), + WithCursor: withCursor, + CheckCursorInWhere: true, + CheckCursor: true, + CheckFilter: true, + Aggregated: false, } - ord := NewOrder(order, entity.Order{By: entity.UserID, Direction: entity.OrderDirectionAsc}) - joins := userObject.GetJoins(filter, ord) - whereClause, hasFilter := userObject.GetFilterWhereClause(filter, withCursor) - cursorQuery := userObject.GetCursorQuery(&hasFilter, cursorFields, &withCursor, false) - - var query string - if withCursor { - query = fmt.Sprintf(baseQuery, joins, whereClause, cursorQuery, ord) - } else { - query = fmt.Sprintf(baseQuery, joins, whereClause, ord) - } - - stmt, err := s.db.PreparexContext(ctx, query) - if err != nil { - msg := ERROR_MSG_PREPARED_STMT - l.WithFields( - logrus.Fields{ - "error": err, - "query": query, - "stmt": stmt, - }, - ).Error(msg) - - return nil, nil, fmt.Errorf("failed to prepare User statement: %w", err) - } - - filterParameters := userObject.GetFilterParameters(filter, withCursor, cursorFields) - - return stmt, filterParameters, nil + return BuildStatement(ctx, statement, filter) } func (s *SqlDatabase) GetAllUserIds(ctx context.Context, filter *entity.UserFilter) ([]int64, error) { @@ -203,7 +172,7 @@ func (s *SqlDatabase) GetAllUserCursors( %s GROUP BY U.user_id ORDER BY %s ` - filter = ensureUserFilter(filter) + filter = EnsureFilter(filter) stmt, filterParameters, err := s.buildUserStatement(ctx, baseQuery, filter, false, order, l) if err != nil { @@ -250,7 +219,7 @@ func (s *SqlDatabase) GetUsers(ctx context.Context, filter *entity.UserFilter) ( %s GROUP BY U.user_id ORDER BY %s LIMIT ? ` - filter = ensureUserFilter(filter) + filter = EnsureFilter(filter) stmt, filterParameters, err := s.buildUserStatement( ctx, @@ -350,7 +319,7 @@ func (s *SqlDatabase) GetUserNames(ctx context.Context, filter *entity.UserFilte ` // Ensure the filter is initialized - filter = ensureUserFilter(filter) + filter = EnsureFilter(filter) // Builds full statement with possible joins and filters stmt, filterParameters, err := s.buildUserStatement(ctx, baseQuery, filter, false, []entity.Order{ @@ -417,7 +386,7 @@ func (s *SqlDatabase) GetUniqueUserIDs(ctx context.Context, filter *entity.UserF ` // Ensure the filter is initialized - filter = ensureUserFilter(filter) + filter = EnsureFilter(filter) // Builds full statement with possible joins and filters stmt, filterParameters, err := s.buildUserStatement(ctx, baseQuery, filter, false, []entity.Order{