diff --git a/internal/controlplane/handlers_ruletype.go b/internal/controlplane/handlers_ruletype.go index f3ab082425..21c879682a 100644 --- a/internal/controlplane/handlers_ruletype.go +++ b/internal/controlplane/handlers_ruletype.go @@ -21,6 +21,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + datasourcesvc "github.com/mindersec/minder/internal/datasources/service" "github.com/mindersec/minder/internal/db" "github.com/mindersec/minder/internal/engine/engcontext" "github.com/mindersec/minder/internal/flags" @@ -175,12 +176,13 @@ func (s *Server) CreateRuleType( return nil, util.UserVisibleError(codes.InvalidArgument, "%s", err) } - ds := crt.GetRuleType().GetDef().GetEval().GetDataSources() - if len(ds) > 0 && !flags.Bool(ctx, s.featureFlags, flags.DataSources) { - return nil, status.Errorf(codes.Unavailable, "DataSources feature is disabled") - } - newRuleType, err := db.WithTransaction(s.store, func(qtx db.ExtendQuerier) (*minderv1.RuleType, error) { + ruleDS := crt.GetRuleType().GetDef().GetEval().GetDataSources() + if err := s.validateDataSources(ctx, projectID, ruleDS, qtx); err != nil { + // We expect the error to be a user visible error + return nil, err + } + return s.ruleTypes.CreateRuleType(ctx, projectID, uuid.Nil, crt.GetRuleType(), qtx) }) if err != nil { @@ -220,12 +222,13 @@ func (s *Server) UpdateRuleType( return nil, util.UserVisibleError(codes.InvalidArgument, "%s", err) } - ds := urt.GetRuleType().GetDef().GetEval().GetDataSources() - if len(ds) > 0 && !flags.Bool(ctx, s.featureFlags, flags.DataSources) { - return nil, status.Errorf(codes.Unavailable, "DataSources feature is disabled") - } - updatedRuleType, err := db.WithTransaction(s.store, func(qtx db.ExtendQuerier) (*minderv1.RuleType, error) { + ruleDS := urt.GetRuleType().GetDef().GetEval().GetDataSources() + if err := s.validateDataSources(ctx, projectID, ruleDS, qtx); err != nil { + // We expect the error to be a user visible error + return nil, err + } + return s.ruleTypes.UpdateRuleType(ctx, projectID, uuid.Nil, urt.GetRuleType(), qtx) }) if err != nil { @@ -373,3 +376,34 @@ func validateMarkdown(md string) error { return nil } + +func (s *Server) validateDataSources( + ctx context.Context, + projectID uuid.UUID, + ruleDS []*minderv1.DataSourceReference, + qtx db.ExtendQuerier, +) error { + // Short circuiting to avoid accessing the database. + if len(ruleDS) == 0 { + return nil + } + + if len(ruleDS) > 0 && !flags.Bool(ctx, s.featureFlags, flags.DataSources) { + return status.Errorf(codes.Unavailable, "DataSources feature is disabled") + } + + opts := datasourcesvc.ReadBuilder().WithTransaction(qtx) + for _, requested := range ruleDS { + _, err := s.dataSourcesService.GetByName( + ctx, + requested.Name, + projectID, + opts, + ) + if err != nil { + return util.UserVisibleError(codes.Internal, "failed retrieving data sources") + } + } + + return nil +} diff --git a/internal/controlplane/handlers_ruletype_test.go b/internal/controlplane/handlers_ruletype_test.go index f36facbab5..ee00a5f60b 100644 --- a/internal/controlplane/handlers_ruletype_test.go +++ b/internal/controlplane/handlers_ruletype_test.go @@ -14,7 +14,10 @@ import ( mockdb "github.com/mindersec/minder/database/mock" df "github.com/mindersec/minder/database/mock/fixtures" + dsf "github.com/mindersec/minder/internal/datasources/service/mock/fixtures" db "github.com/mindersec/minder/internal/db" + "github.com/mindersec/minder/internal/engine/engcontext" + "github.com/mindersec/minder/internal/flags" minderv1 "github.com/mindersec/minder/pkg/api/protobuf/go/minder/v1" sf "github.com/mindersec/minder/pkg/ruletypes/mock/fixtures" ) @@ -22,18 +25,21 @@ import ( func TestCreateRuleType(t *testing.T) { t.Parallel() + projectID := uuid.New() tests := []struct { - name string - mockStoreFunc df.MockStoreBuilder - ruleTypeServiceFunc sf.RuleTypeSvcMockBuilder - request *minderv1.CreateRuleTypeRequest - error bool + name string + mockStoreFunc df.MockStoreBuilder + ruleTypeServiceFunc sf.RuleTypeSvcMockBuilder + dataSourcesServiceFunc dsf.DataSourcesSvcMockBuilder + features map[string]any + request *minderv1.CreateRuleTypeRequest + error bool }{ { name: "happy path", mockStoreFunc: df.NewMockStore( df.WithTransaction(), - WithSuccessfulGetProjectByID(uuid.Nil), + WithSuccessfulGetProjectByID(projectID), ), ruleTypeServiceFunc: sf.NewRuleTypeServiceMock( sf.WithSuccessfulCreateRuleType, @@ -45,7 +51,7 @@ func TestCreateRuleType(t *testing.T) { { name: "guidance sanitize error", mockStoreFunc: df.NewMockStore( - WithSuccessfulGetProjectByID(uuid.Nil), + WithSuccessfulGetProjectByID(projectID), ), ruleTypeServiceFunc: sf.NewRuleTypeServiceMock(), request: &minderv1.CreateRuleTypeRequest{ @@ -58,7 +64,7 @@ func TestCreateRuleType(t *testing.T) { { name: "guidance not utf-8", mockStoreFunc: df.NewMockStore( - WithSuccessfulGetProjectByID(uuid.Nil), + WithSuccessfulGetProjectByID(projectID), ), ruleTypeServiceFunc: sf.NewRuleTypeServiceMock(), request: &minderv1.CreateRuleTypeRequest{ @@ -71,7 +77,7 @@ func TestCreateRuleType(t *testing.T) { { name: "guidance too long", mockStoreFunc: df.NewMockStore( - WithSuccessfulGetProjectByID(uuid.Nil), + WithSuccessfulGetProjectByID(projectID), ), ruleTypeServiceFunc: sf.NewRuleTypeServiceMock(), request: &minderv1.CreateRuleTypeRequest{ @@ -81,6 +87,122 @@ func TestCreateRuleType(t *testing.T) { }, error: true, }, + + // data sources validation + { + name: "available data sources", + mockStoreFunc: df.NewMockStore( + df.WithTransaction(), + WithSuccessfulGetProjectByID(projectID), + ), + ruleTypeServiceFunc: sf.NewRuleTypeServiceMock( + sf.WithSuccessfulCreateRuleType, + ), + dataSourcesServiceFunc: dsf.NewDataSourcesServiceMock( + dsf.WithSuccessfulGetByName( + projectID, + &minderv1.DataSource{ + Name: "foo", + }, + ), + ), + features: map[string]any{ + "data_sources": true, + }, + request: &minderv1.CreateRuleTypeRequest{ + RuleType: &minderv1.RuleType{ + Def: &minderv1.RuleType_Definition{ + Eval: &minderv1.RuleType_Definition_Eval{ + DataSources: []*minderv1.DataSourceReference{ + { + Name: "foo", + }, + }, + }, + }, + }, + }, + }, + { + name: "no data sources", + mockStoreFunc: df.NewMockStore( + df.WithRollbackTransaction(), + WithSuccessfulGetProjectByID(projectID), + ), + ruleTypeServiceFunc: sf.NewRuleTypeServiceMock(), + dataSourcesServiceFunc: dsf.NewDataSourcesServiceMock( + dsf.WithNotFoundGetByName(projectID), + ), + features: map[string]any{ + "data_sources": true, + }, + request: &minderv1.CreateRuleTypeRequest{ + RuleType: &minderv1.RuleType{ + Def: &minderv1.RuleType_Definition{ + Eval: &minderv1.RuleType_Definition_Eval{ + DataSources: []*minderv1.DataSourceReference{ + { + Name: "foo", + }, + }, + }, + }, + }, + }, + error: true, + }, + { + name: "failed data sources", + mockStoreFunc: df.NewMockStore( + df.WithRollbackTransaction(), + WithSuccessfulGetProjectByID(projectID), + ), + ruleTypeServiceFunc: sf.NewRuleTypeServiceMock(), + dataSourcesServiceFunc: dsf.NewDataSourcesServiceMock( + dsf.WithFailedGetByName(), + ), + features: map[string]any{ + "data_sources": true, + }, + request: &minderv1.CreateRuleTypeRequest{ + RuleType: &minderv1.RuleType{ + Def: &minderv1.RuleType_Definition{ + Eval: &minderv1.RuleType_Definition_Eval{ + DataSources: []*minderv1.DataSourceReference{ + { + Name: "foo", + }, + }, + }, + }, + }, + }, + error: true, + }, + { + name: "disabled data sources", + mockStoreFunc: df.NewMockStore( + df.WithRollbackTransaction(), + WithSuccessfulGetProjectByID(projectID), + ), + ruleTypeServiceFunc: sf.NewRuleTypeServiceMock(), + dataSourcesServiceFunc: dsf.NewDataSourcesServiceMock(), + features: map[string]any{}, + request: &minderv1.CreateRuleTypeRequest{ + RuleType: &minderv1.RuleType{ + Def: &minderv1.RuleType_Definition{ + Eval: &minderv1.RuleType_Definition_Eval{ + DataSources: []*minderv1.DataSourceReference{ + { + Name: "foo", + }, + }, + }, + }, + }, + }, + error: true, + }, } for _, tt := range tests { @@ -103,10 +225,27 @@ func TestCreateRuleType(t *testing.T) { mockSvc = tt.ruleTypeServiceFunc(ctrl) } + var mockDsSvc dsf.DataSourcesSvcMock + if tt.dataSourcesServiceFunc != nil { + mockDsSvc = tt.dataSourcesServiceFunc(ctrl) + } + + featureClient := &flags.FakeClient{} + if tt.features != nil { + featureClient.Data = tt.features + } + srv := newDefaultServer(t, mockStore, nil, nil, nil) srv.ruleTypes = mockSvc + srv.dataSourcesService = mockDsSvc + srv.featureFlags = featureClient - resp, err := srv.CreateRuleType(context.Background(), tt.request) + ctx := context.Background() + ctx = engcontext.WithEntityContext(ctx, &engcontext.EntityContext{ + Project: engcontext.Project{ID: projectID}, + Provider: engcontext.Provider{Name: "testing"}, + }) + resp, err := srv.CreateRuleType(ctx, tt.request) if tt.error { require.Error(t, err) require.Nil(t, resp) @@ -122,18 +261,21 @@ func TestCreateRuleType(t *testing.T) { func TestUpdateRuleType(t *testing.T) { t.Parallel() + projectID := uuid.New() tests := []struct { - name string - mockStoreFunc df.MockStoreBuilder - ruleTypeServiceFunc sf.RuleTypeSvcMockBuilder - request *minderv1.UpdateRuleTypeRequest - error bool + name string + mockStoreFunc df.MockStoreBuilder + ruleTypeServiceFunc sf.RuleTypeSvcMockBuilder + dataSourcesServiceFunc dsf.DataSourcesSvcMockBuilder + features map[string]any + request *minderv1.UpdateRuleTypeRequest + error bool }{ { name: "happy path", mockStoreFunc: df.NewMockStore( df.WithTransaction(), - WithSuccessfulGetProjectByID(uuid.Nil), + WithSuccessfulGetProjectByID(projectID), ), ruleTypeServiceFunc: sf.NewRuleTypeServiceMock( sf.WithSuccessfulUpdateRuleType, @@ -145,7 +287,7 @@ func TestUpdateRuleType(t *testing.T) { { name: "guidance sanitize error", mockStoreFunc: df.NewMockStore( - WithSuccessfulGetProjectByID(uuid.Nil), + WithSuccessfulGetProjectByID(projectID), ), ruleTypeServiceFunc: sf.NewRuleTypeServiceMock(), request: &minderv1.UpdateRuleTypeRequest{ @@ -158,7 +300,7 @@ func TestUpdateRuleType(t *testing.T) { { name: "guidance not utf-8", mockStoreFunc: df.NewMockStore( - WithSuccessfulGetProjectByID(uuid.Nil), + WithSuccessfulGetProjectByID(projectID), ), ruleTypeServiceFunc: sf.NewRuleTypeServiceMock(), request: &minderv1.UpdateRuleTypeRequest{ @@ -171,7 +313,7 @@ func TestUpdateRuleType(t *testing.T) { { name: "guidance too long", mockStoreFunc: df.NewMockStore( - WithSuccessfulGetProjectByID(uuid.Nil), + WithSuccessfulGetProjectByID(projectID), ), ruleTypeServiceFunc: sf.NewRuleTypeServiceMock(), request: &minderv1.UpdateRuleTypeRequest{ @@ -181,6 +323,122 @@ func TestUpdateRuleType(t *testing.T) { }, error: true, }, + + // data sources validation + { + name: "available data sources", + mockStoreFunc: df.NewMockStore( + df.WithTransaction(), + WithSuccessfulGetProjectByID(projectID), + ), + ruleTypeServiceFunc: sf.NewRuleTypeServiceMock( + sf.WithSuccessfulUpdateRuleType, + ), + dataSourcesServiceFunc: dsf.NewDataSourcesServiceMock( + dsf.WithSuccessfulGetByName( + projectID, + &minderv1.DataSource{ + Name: "foo", + }, + ), + ), + features: map[string]any{ + "data_sources": true, + }, + request: &minderv1.UpdateRuleTypeRequest{ + RuleType: &minderv1.RuleType{ + Def: &minderv1.RuleType_Definition{ + Eval: &minderv1.RuleType_Definition_Eval{ + DataSources: []*minderv1.DataSourceReference{ + { + Name: "foo", + }, + }, + }, + }, + }, + }, + }, + { + name: "no data sources", + mockStoreFunc: df.NewMockStore( + df.WithRollbackTransaction(), + WithSuccessfulGetProjectByID(projectID), + ), + ruleTypeServiceFunc: sf.NewRuleTypeServiceMock(), + dataSourcesServiceFunc: dsf.NewDataSourcesServiceMock( + dsf.WithNotFoundGetByName(projectID), + ), + features: map[string]any{ + "data_sources": true, + }, + request: &minderv1.UpdateRuleTypeRequest{ + RuleType: &minderv1.RuleType{ + Def: &minderv1.RuleType_Definition{ + Eval: &minderv1.RuleType_Definition_Eval{ + DataSources: []*minderv1.DataSourceReference{ + { + Name: "foo", + }, + }, + }, + }, + }, + }, + error: true, + }, + { + name: "failed data sources", + mockStoreFunc: df.NewMockStore( + df.WithRollbackTransaction(), + WithSuccessfulGetProjectByID(projectID), + ), + ruleTypeServiceFunc: sf.NewRuleTypeServiceMock(), + dataSourcesServiceFunc: dsf.NewDataSourcesServiceMock( + dsf.WithFailedGetByName(), + ), + features: map[string]any{ + "data_sources": true, + }, + request: &minderv1.UpdateRuleTypeRequest{ + RuleType: &minderv1.RuleType{ + Def: &minderv1.RuleType_Definition{ + Eval: &minderv1.RuleType_Definition_Eval{ + DataSources: []*minderv1.DataSourceReference{ + { + Name: "foo", + }, + }, + }, + }, + }, + }, + error: true, + }, + { + name: "disabled data sources", + mockStoreFunc: df.NewMockStore( + df.WithRollbackTransaction(), + WithSuccessfulGetProjectByID(projectID), + ), + ruleTypeServiceFunc: sf.NewRuleTypeServiceMock(), + dataSourcesServiceFunc: dsf.NewDataSourcesServiceMock(), + features: map[string]any{}, + request: &minderv1.UpdateRuleTypeRequest{ + RuleType: &minderv1.RuleType{ + Def: &minderv1.RuleType_Definition{ + Eval: &minderv1.RuleType_Definition_Eval{ + DataSources: []*minderv1.DataSourceReference{ + { + Name: "foo", + }, + }, + }, + }, + }, + }, + error: true, + }, } for _, tt := range tests { @@ -203,10 +461,27 @@ func TestUpdateRuleType(t *testing.T) { mockSvc = tt.ruleTypeServiceFunc(ctrl) } + var mockDsSvc dsf.DataSourcesSvcMock + if tt.dataSourcesServiceFunc != nil { + mockDsSvc = tt.dataSourcesServiceFunc(ctrl) + } + + featureClient := &flags.FakeClient{} + if tt.features != nil { + featureClient.Data = tt.features + } + srv := newDefaultServer(t, mockStore, nil, nil, nil) srv.ruleTypes = mockSvc + srv.dataSourcesService = mockDsSvc + srv.featureFlags = featureClient - resp, err := srv.UpdateRuleType(context.Background(), tt.request) + ctx := context.Background() + ctx = engcontext.WithEntityContext(ctx, &engcontext.EntityContext{ + Project: engcontext.Project{ID: projectID}, + Provider: engcontext.Provider{Name: "testing"}, + }) + resp, err := srv.UpdateRuleType(ctx, tt.request) if tt.error { require.Error(t, err) require.Nil(t, resp) diff --git a/internal/datasources/service/mock/fixtures/service.go b/internal/datasources/service/mock/fixtures/service.go new file mode 100644 index 0000000000..8bb3f00174 --- /dev/null +++ b/internal/datasources/service/mock/fixtures/service.go @@ -0,0 +1,77 @@ +// SPDX-FileCopyrightText: Copyright 2024 The Minder Authors +// SPDX-License-Identifier: Apache-2.0 + +// Package fixtures contains code for creating DataSourceService +// fixtures and is used in various parts of the code. For testing use +// only. +// +//nolint:all +package fixtures + +import ( + "errors" + + "github.com/google/uuid" + mockdssvc "github.com/mindersec/minder/internal/datasources/service/mock" + minderv1 "github.com/mindersec/minder/pkg/api/protobuf/go/minder/v1" + "go.uber.org/mock/gomock" +) + +type ( + DataSourcesSvcMock = *mockdssvc.MockDataSourcesService + DataSourcesSvcMockBuilder = func(*gomock.Controller) DataSourcesSvcMock +) + +func NewDataSourcesServiceMock(opts ...func(mock DataSourcesSvcMock)) DataSourcesSvcMockBuilder { + return func(ctrl *gomock.Controller) DataSourcesSvcMock { + mock := mockdssvc.NewMockDataSourcesService(ctrl) + for _, opt := range opts { + opt(mock) + } + return mock + } +} + +var ( + errDefault = errors.New("error during data sources service operation") +) + +func WithSuccessfulListDataSources(datasources ...*minderv1.DataSource) func(DataSourcesSvcMock) { + return func(mock DataSourcesSvcMock) { + mock.EXPECT(). + List(gomock.Any(), gomock.Any(), gomock.Any()). + Return(datasources, nil) + } +} + +func WithFailedListDataSources() func(DataSourcesSvcMock) { + return func(mock DataSourcesSvcMock) { + mock.EXPECT(). + List(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, errDefault) + } +} + +func WithSuccessfulGetByName(projectID uuid.UUID, datasource *minderv1.DataSource) func(DataSourcesSvcMock) { + return func(mock DataSourcesSvcMock) { + mock.EXPECT(). + GetByName(gomock.Any(), datasource.Name, projectID, gomock.Any()). + Return(datasource, nil) + } +} + +func WithNotFoundGetByName(projectID uuid.UUID) func(DataSourcesSvcMock) { + return func(mock DataSourcesSvcMock) { + mock.EXPECT(). + GetByName(gomock.Any(), gomock.Any(), projectID, gomock.Any()). + Return(&minderv1.DataSource{}, errDefault) + } +} + +func WithFailedGetByName() func(DataSourcesSvcMock) { + return func(mock DataSourcesSvcMock) { + mock.EXPECT(). + GetByName(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, errDefault) + } +}