diff --git a/internal/api/graphql/graph/resolver/mutation.go b/internal/api/graphql/graph/resolver/mutation.go index aca755e5..6ee76628 100644 --- a/internal/api/graphql/graph/resolver/mutation.go +++ b/internal/api/graphql/graph/resolver/mutation.go @@ -1032,7 +1032,7 @@ func (r *mutationResolver) CreateRemediation(ctx context.Context, input model.Re &entity.ServiceFilter{CCRN: []*string{input.Service}}, nil, ) - if err != nil || len(serviceResult.Elements) == 0 || len(serviceResult.Elements) > 1 { + if err != nil || len(serviceResult.Elements) != 1 { return nil, baseResolver.NewResolverError( "CreateRemediationMutationResolver", "Internal Error - when creating remediation - service id not found", @@ -1041,35 +1041,62 @@ func (r *mutationResolver) CreateRemediation(ctx context.Context, input model.Re remediation.ServiceId = serviceResult.Elements[0].Id - // fetch component id for given component name - componentResult, err := r.App.ListComponents( + // fetch issue id for given issue name + issueResult, err := r.App.ListIssues( ctx, - &entity.ComponentFilter{Repository: []*string{input.Image}}, + &entity.IssueFilter{ + PrimaryName: []*string{input.Vulnerability}, + }, nil, ) - if err != nil || len(componentResult.Elements) == 0 || len(componentResult.Elements) > 1 { + if err != nil || len(issueResult.Elements) != 1 { return nil, baseResolver.NewResolverError( "CreateRemediationMutationResolver", - "Internal Error - when creating remediation - component id not found", + "Internal Error - when creating remediation - issue id not found", ) } - remediation.ComponentId = componentResult.Elements[0].Id + remediation.IssueId = issueResult.Elements[0].Issue.Id - // fetch issue id for given issue name - issueResult, err := r.App.ListIssues( + // fetch component version id + componentVersionResult, err := r.App.ListComponentVersions(ctx, &entity.ComponentVersionFilter{ + IssueId: []*int64{ + &issueResult.Elements[0].Issue.Id, + }, + ServiceCCRN: []*string{ + input.Service, + }, + Repository: []*string{ + input.Image, + }, + }, &entity.ListOptions{}) + if err != nil || len(componentVersionResult.Elements) != 1 { + return nil, baseResolver.NewResolverError( + "CreateRemediationMutationResolver", + "Internal Error - when creating remediation - component version not found", + ) + } + + componentVersionID := componentVersionResult.Elements[0].Id + + // fetch component id for given component name + componentResult, err := r.App.ListComponents( ctx, - &entity.IssueFilter{PrimaryName: []*string{input.Vulnerability}}, + &entity.ComponentFilter{ + Repository: []*string{input.Image}, + ComponentVersionId: []*int64{&componentVersionID}, + ServiceCCRN: []*string{input.Service}, + }, nil, ) - if err != nil || len(issueResult.Elements) == 0 || len(issueResult.Elements) > 1 { + if err != nil || len(componentResult.Elements) != 1 { return nil, baseResolver.NewResolverError( "CreateRemediationMutationResolver", - "Internal Error - when creating remediation - issue id not found", + "Internal Error - when creating remediation - component id not found", ) } - remediation.IssueId = issueResult.Elements[0].Issue.Id + remediation.ComponentId = componentResult.Elements[0].Id if input.RemediatedBy != nil { userUniqueUserIDs, err := r.App.ListUniqueUserIDs(ctx, &entity.UserFilter{ @@ -1132,7 +1159,7 @@ func (r *mutationResolver) UpdateRemediation(ctx context.Context, id string, inp &entity.ServiceFilter{CCRN: []*string{input.Service}}, nil, ) - if err != nil || len(serviceResult.Elements) == 0 || len(serviceResult.Elements) > 1 { + if err != nil || len(serviceResult.Elements) != 1 { return nil, baseResolver.NewResolverError( "UpdateRemediationMutationResolver", "Internal Error - when updating remediation - service id not found", @@ -1150,7 +1177,7 @@ func (r *mutationResolver) UpdateRemediation(ctx context.Context, id string, inp &entity.ComponentFilter{Repository: []*string{input.Image}}, nil, ) - if err != nil || len(componentResult.Elements) == 0 || len(componentResult.Elements) > 1 { + if err != nil || len(componentResult.Elements) != 1 { return nil, baseResolver.NewResolverError( "UpdateRemediationMutationResolver", "Internal Error - when updating remediation - component id not found", @@ -1167,7 +1194,7 @@ func (r *mutationResolver) UpdateRemediation(ctx context.Context, id string, inp &entity.IssueFilter{PrimaryName: []*string{input.Vulnerability}}, nil, ) - if err != nil || len(issueResult.Elements) == 0 || len(issueResult.Elements) > 1 { + if err != nil || len(issueResult.Elements) != 1 { return nil, baseResolver.NewResolverError( "UpdateRemediationMutationResolver", "Internal Error - when updating remediation - issue id not found", diff --git a/internal/database/mariadb/test/fixture.go b/internal/database/mariadb/test/fixture.go index ce980554..df0d08bf 100644 --- a/internal/database/mariadb/test/fixture.go +++ b/internal/database/mariadb/test/fixture.go @@ -301,6 +301,41 @@ func (s *SeedCollection) FindMatchingComponentVersionAndIssueVariant() ( return mariadb.ComponentVersionRow{}, mariadb.IssueVariantRow{}, false } +func (s *SeedCollection) FindLinkedRemediationData() (mariadb.BaseServiceRow, mariadb.ComponentRow, mariadb.IssueRow, bool) { + for _, cvi := range s.ComponentVersionIssueRows { + ci, ok := lo.Find(s.ComponentInstanceRows, func(ci mariadb.ComponentInstanceRow) bool { + return ci.ComponentVersionId.Int64 == cvi.ComponentVersionId.Int64 + }) + if !ok { + continue + } + + service, ok := s.GetServiceById(ci.ServiceId.Int64) + if !ok { + continue + } + + cv, ok := s.GetComponentVersionById(cvi.ComponentVersionId.Int64) + if !ok { + continue + } + + component, ok := s.GetComponentById(cv.ComponentId.Int64) + if !ok { + continue + } + + issue, ok := s.GetIssueById(cvi.IssueId.Int64) + if !ok { + continue + } + + return service, component, issue, true + } + + return mariadb.BaseServiceRow{}, mariadb.ComponentRow{}, mariadb.IssueRow{}, false +} + func (s *SeedCollection) FindMatchingSupportGroupAndIssueMatch() (mariadb.SupportGroupRow, mariadb.IssueMatchRow, bool) { for _, imRow := range s.IssueMatchRows { ciRow, ok := s.GetComponentInstanceById(imRow.ComponentInstanceId.Int64) @@ -893,6 +928,7 @@ func (s *DatabaseSeeder) SeedComponentVersions( component := components[randomIndex] componentVersion.ComponentId = component.Id componentVersion.EndOfLife = sql.NullBool{Bool: []bool{true, false}[i%2], Valid: true} + componentVersion.Repository = component.Repository componentVersionId, err := s.InsertFakeComponentVersion(componentVersion) if err != nil { @@ -914,10 +950,12 @@ func (s *DatabaseSeeder) SeedComponentInstances( ) []mariadb.ComponentInstanceRow { var componentInstances []mariadb.ComponentInstanceRow - for range num { + limit := min(len(services), min(len(componentVersions), num)) + + for i := range limit { componentInstance := NewFakeComponentInstance() - componentInstance.ComponentVersionId = PickOne(componentVersions).Id - componentInstance.ServiceId = PickOne(services).Id + componentInstance.ComponentVersionId = componentVersions[i].Id + componentInstance.ServiceId = services[i].Id componentInstanceId, err := s.InsertFakeComponentInstance(componentInstance) if err != nil { @@ -982,8 +1020,8 @@ func (s *DatabaseSeeder) SeedComponentVersionIssues( cviList := make([]mariadb.ComponentVersionIssueRow, num) for i := range num { cvi := NewFakeComponentVersionIssue() - cvi.IssueId = PickOne(issues).Id - cvi.ComponentVersionId = PickOne(componentVersions).Id + cvi.IssueId = issues[i].Id + cvi.ComponentVersionId = componentVersions[i].Id _, err := s.InsertFakeComponentVersionIssue(cvi) if err != nil { diff --git a/internal/e2e/remediation_query_test.go b/internal/e2e/remediation_query_test.go index 8c47c1dc..20f49411 100644 --- a/internal/e2e/remediation_query_test.go +++ b/internal/e2e/remediation_query_test.go @@ -169,9 +169,11 @@ var _ = Describe("Creating Remediation via API", Label("e2e", "Remediations"), f BeforeEach(func() { seedCollection = seeder.SeedDbWithNFakeData(10) remediation = testentity.NewFakeRemediationEntity() - remediation.Service = seedCollection.ServiceRows[0].CCRN.String - remediation.Component = seedCollection.ComponentRows[0].Repository.String - remediation.Issue = seedCollection.IssueRows[0].PrimaryName.String + service, component, issue, ok := seedCollection.FindLinkedRemediationData() + Expect(ok).To(BeTrue(), "linked service/component/issue data must exist in seed") + remediation.Service = service.CCRN.String + remediation.Component = component.Repository.String + remediation.Issue = issue.PrimaryName.String }) Context("and a mutation query is performed", Label("create.graphql"), func() {