diff --git a/Makefile b/Makefile index 3effa9c..ca40ad3 100644 --- a/Makefile +++ b/Makefile @@ -80,6 +80,7 @@ dev: postgres root proto-gen: ./scripts/proto-gen.sh "api-specs/v1/proto/agents" ./scripts/proto-gen.sh "api-specs/v1/proto/admin" + ./scripts/proto-gen.sh "api-specs/v1/proto/keys" ./scripts/proto-gen.sh "api-specs/v1/proto" $(MAKE) go-format diff --git a/api-specs/v1/proto/keys/keys.proto b/api-specs/v1/proto/keys/keys.proto new file mode 100644 index 0000000..69f6882 --- /dev/null +++ b/api-specs/v1/proto/keys/keys.proto @@ -0,0 +1,46 @@ +syntax = "proto3"; + + +package krypton.v1.keys; + +option go_package = "github.com/openkcm/krypton/pkg/api/v1/proto/keys"; + +service Service { + rpc AnnounceKey(AnnounceKeyRequest) returns (AnnounceKeyResponse); + rpc GetKey(GetKeyRequest) returns (GetKeyResponse); +} + +message Key { + string id = 1; + string name = 2; + string tenant_id = 3; + string kind = 4; + string parent_id = 5; + string managed_by = 6; + map labels = 7; + string state = 8; + int64 created_at = 9; + int64 updated_at = 10; +} + +message AnnounceKeyRequest { + string tenant_id = 1; + string kind = 2; + string name = 3; + string parent_id = 4; + string target_name = 5; + map labels = 6; +} + +message AnnounceKeyResponse { + Key key = 1; +} + +message GetKeyRequest { + string id = 1; + string tenant_id = 2; +} + +message GetKeyResponse { + Key key = 1; +} diff --git a/cmd/root/main.go b/cmd/root/main.go index e655104..cf12355 100644 --- a/cmd/root/main.go +++ b/cmd/root/main.go @@ -22,6 +22,7 @@ import ( "github.com/openkcm/krypton/internal/worker" "github.com/openkcm/krypton/pkg/api/v1/proto/admin" "github.com/openkcm/krypton/pkg/api/v1/proto/agents" + "github.com/openkcm/krypton/pkg/api/v1/proto/keys" "github.com/openkcm/krypton/pkg/store" storesql "github.com/openkcm/krypton/pkg/store/sql" ) @@ -45,16 +46,17 @@ func main() { handleErr(err, "failed to connect to database") defer db.Close() + // run migrations + err = storesql.Migrate(context.Background(), db) + handleErr(err, "failed to run migrations") + // load root configuration cfg := loadConfig() - // tenant store initialization - tenantStore, err := storesql.NewTenantStore(context.Background(), db) - handleErr(err, "failed to initialize store") - - // agent store initialization - agentStore, err := storesql.NewAgentStore(context.Background(), db) - handleErr(err, "failed to initialize store") + // store initialization + tenantStore := storesql.NewTenantStore(db) + agentStore := storesql.NewAgentStore(db) + keyStore := storesql.NewKeyStore(db) // gRPC server setup for admin API grpcServer := grpc.NewServer() @@ -63,6 +65,9 @@ func main() { // gRPC server setup for agent API agents.RegisterServiceServer(grpcServer, agents.NewAgentService(agentStore, *cfg)) + // gRPC server setup for keys API + keys.RegisterServiceServer(grpcServer, keys.NewService(keyStore)) + lis, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", ":"+srvPort) handleErr(err, "failed to listen on gRPC port") diff --git a/integration/registration_test.go b/integration/registration_test.go index 5c6cdf7..26e21e1 100644 --- a/integration/registration_test.go +++ b/integration/registration_test.go @@ -30,8 +30,8 @@ func TestRegistration(t *testing.T) { db, dbConnStr := createDatabase(t) // Create agent store - agentStore, err := sql.NewAgentStore(ctx, db) - require.NoError(t, err, "failed to create agent store") + require.NoError(t, sql.Migrate(ctx, db)) + agentStore := sql.NewAgentStore(db) // Build binaries for root server and agent rootBinary := buildBinary(t, "root", "../cmd/root") @@ -45,7 +45,7 @@ func TestRegistration(t *testing.T) { "DATABASE_URL=" + dbConnStr, "SERVER_PORT=" + rootPort, }) - err = rootCmd.Start() + err := rootCmd.Start() require.NoError(t, err, "failed to start root server process") // Wait for root server to accept connections diff --git a/integration/setup_test.go b/integration/setup_test.go index 79a8486..95942a9 100644 --- a/integration/setup_test.go +++ b/integration/setup_test.go @@ -13,6 +13,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/testcontainers/testcontainers-go/modules/postgres" "google.golang.org/grpc" @@ -126,13 +127,8 @@ func newTestStore(t *testing.T) store.Tenant { ctx := t.Context() testDB, _ := createDatabase(t) - s, err := storesql.NewTenantStore(ctx, testDB) - if err != nil { - testDB.Close() - assert.FailNowf(t, "failed to create test store", "error: %v", err) - } - - return s + require.NoError(t, storesql.Migrate(ctx, testDB)) + return storesql.NewTenantStore(testDB) } // createDatabase creates a new PostgreSQL database for testing and returns a connection to it. diff --git a/internal/reconciler/export_test.go b/internal/reconciler/export_test.go new file mode 100644 index 0000000..85fb637 --- /dev/null +++ b/internal/reconciler/export_test.go @@ -0,0 +1,43 @@ +package reconciler + +import ( + "context" + + "github.com/openkcm/orbital" + + "github.com/openkcm/krypton/internal/config" +) + +const DefaultMaxPendingReconciles = defaultMaxPendingReconciles +const NoJobHandlerRegisteredMessage = noJobHandlerRegisteredMessage + +var BuildTargets = buildTargets +var JobHandlerNotFoundError = jobHandlerNotFoundError + +func (m *Manager) OrbitalManager() *orbital.Manager { + return m.orbitalManager +} + +func (m *Manager) ConfirmJob(ctx context.Context, job orbital.Job) (orbital.JobConfirmerResult, error) { + return m.confirmJob(ctx, job) +} + +func (m *Manager) ResolveTasks(ctx context.Context, job orbital.Job, cursor orbital.TaskResolverCursor) (orbital.TaskResolverResult, error) { + return m.resolveTasks(ctx, job, cursor) +} + +func (m *Manager) JobDone(ctx context.Context, job orbital.Job) error { + return m.jobDone(ctx, job) +} + +func (m *Manager) JobFailed(ctx context.Context, job orbital.Job) error { + return m.jobFailed(ctx, job) +} + +func (m *Manager) JobCanceled(ctx context.Context, job orbital.Job) error { + return m.jobCanceled(ctx, job) +} + +var NewTargetProvider = func(fn func(context.Context, config.ReconcilerTarget) (orbital.Initiator, error)) TargetProvider { + return TargetProvider(fn) +} diff --git a/internal/reconciler/manager_test.go b/internal/reconciler/manager_test.go index 62466b8..b774837 100644 --- a/internal/reconciler/manager_test.go +++ b/internal/reconciler/manager_test.go @@ -1,11 +1,8 @@ -package reconciler +package reconciler_test import ( "context" "errors" - "go/ast" - "go/parser" - "go/token" "testing" "time" @@ -15,11 +12,12 @@ import ( "github.com/stretchr/testify/require" "github.com/openkcm/krypton/internal/config" + "github.com/openkcm/krypton/internal/reconciler" ) func TestNewManager(t *testing.T) { var createdTargets []config.ReconcilerTarget - targetProvider := TargetProvider(func(_ context.Context, target config.ReconcilerTarget) (orbital.Initiator, error) { + targetProvider := reconciler.NewTargetProvider(func(_ context.Context, target config.ReconcilerTarget) (orbital.Initiator, error) { createdTargets = append(createdTargets, target) return &fakeInitiator{}, nil }) @@ -27,45 +25,46 @@ func TestNewManager(t *testing.T) { cfg := config.ReconcilerConfig{MaxReconcileCount: 6} cfg.Targets = []config.ReconcilerTarget{validTarget("agent-aws"), validTarget("agent-gcp")} - manager, err := NewManager(t.Context(), &cfg, newNoopRepo(), targetProvider, []JobHandler{&fakeJobHandler{jobType: "job.type"}}) + manager, err := reconciler.NewManager(t.Context(), &cfg, newNoopRepo(), targetProvider, []reconciler.JobHandler{&fakeJobHandler{jobType: "job.type"}}) require.NoError(t, err) - assert.Equal(t, cfg.MaxReconcileCount, manager.orbitalManager.Config.MaxPendingReconciles) + assert.Equal(t, cfg.MaxReconcileCount, manager.OrbitalManager().Config.MaxPendingReconciles) assert.Len(t, createdTargets, 2) } func TestNewManagerUsesDefaultMaxPendingReconciles(t *testing.T) { - manager, err := NewManager( + manager, err := reconciler.NewManager( t.Context(), new(config.ReconcilerConfig), newNoopRepo(), nil, - []JobHandler{&fakeJobHandler{jobType: "job.type"}}, + []reconciler.JobHandler{&fakeJobHandler{jobType: "job.type"}}, ) require.NoError(t, err) - assert.Equal(t, defaultMaxPendingReconciles, manager.orbitalManager.Config.MaxPendingReconciles) + assert.Equal(t, reconciler.DefaultMaxPendingReconciles, manager.OrbitalManager().Config.MaxPendingReconciles) } func TestNewManagerOptions(t *testing.T) { - manager, err := NewManager( + manager, err := reconciler.NewManager( t.Context(), &config.ReconcilerConfig{}, newNoopRepo(), nil, - []JobHandler{&fakeJobHandler{jobType: "job.type"}}, - WithMaxPendingReconciles(42), - WithConfirmJobAfter(3*time.Second), - WithExecInterval(250*time.Millisecond), + []reconciler.JobHandler{&fakeJobHandler{jobType: "job.type"}}, + reconciler.WithMaxPendingReconciles(42), + reconciler.WithConfirmJobAfter(3*time.Second), + reconciler.WithExecInterval(250*time.Millisecond), ) require.NoError(t, err) - assert.Equal(t, uint64(42), manager.orbitalManager.Config.MaxPendingReconciles) - assert.Equal(t, 3*time.Second, manager.orbitalManager.Config.ConfirmJobAfter) - assert.Equal(t, 250*time.Millisecond, manager.orbitalManager.Config.ConfirmJobWorkerConfig.ExecInterval) - assert.Equal(t, 250*time.Millisecond, manager.orbitalManager.Config.CreateTasksWorkerConfig.ExecInterval) - assert.Equal(t, 250*time.Millisecond, manager.orbitalManager.Config.ReconcileWorkerConfig.ExecInterval) - assert.Equal(t, 250*time.Millisecond, manager.orbitalManager.Config.NotifyWorkerConfig.ExecInterval) + cfg := manager.OrbitalManager().Config + assert.Equal(t, uint64(42), cfg.MaxPendingReconciles) + assert.Equal(t, 3*time.Second, cfg.ConfirmJobAfter) + assert.Equal(t, 250*time.Millisecond, cfg.ConfirmJobWorkerConfig.ExecInterval) + assert.Equal(t, 250*time.Millisecond, cfg.CreateTasksWorkerConfig.ExecInterval) + assert.Equal(t, 250*time.Millisecond, cfg.ReconcileWorkerConfig.ExecInterval) + assert.Equal(t, 250*time.Millisecond, cfg.NotifyWorkerConfig.ExecInterval) } func TestNewManagerValidation(t *testing.T) { @@ -73,54 +72,54 @@ func TestNewManagerValidation(t *testing.T) { name string cfg *config.ReconcilerConfig repo *orbital.Repository - targetProvider TargetProvider - handlers []JobHandler + targetProvider reconciler.TargetProvider + handlers []reconciler.JobHandler wantErr error }{ { name: "nil config", repo: newNoopRepo(), - handlers: []JobHandler{&fakeJobHandler{jobType: "job.type"}}, + handlers: []reconciler.JobHandler{&fakeJobHandler{jobType: "job.type"}}, wantErr: config.ErrReconcilerConfigNil, }, { name: "nil repo", cfg: &config.ReconcilerConfig{}, - handlers: []JobHandler{&fakeJobHandler{jobType: "job.type"}}, - wantErr: ErrRepositoryNil, + handlers: []reconciler.JobHandler{&fakeJobHandler{jobType: "job.type"}}, + wantErr: reconciler.ErrRepositoryNil, }, { name: "target factory required", cfg: configWithTargets(), repo: newNoopRepo(), - handlers: []JobHandler{&fakeJobHandler{jobType: "job.type"}}, - wantErr: ErrTargetFactoryRequired, + handlers: []reconciler.JobHandler{&fakeJobHandler{jobType: "job.type"}}, + wantErr: reconciler.ErrTargetFactoryRequired, }, { name: "handler required", cfg: &config.ReconcilerConfig{}, repo: newNoopRepo(), - wantErr: ErrJobHandlerRequired, + wantErr: reconciler.ErrJobHandlerRequired, }, { name: "nil handler", cfg: &config.ReconcilerConfig{}, repo: newNoopRepo(), - handlers: []JobHandler{nil}, - wantErr: ErrJobHandlerNil, + handlers: []reconciler.JobHandler{nil}, + wantErr: reconciler.ErrJobHandlerNil, }, { name: "empty handler type", cfg: &config.ReconcilerConfig{}, repo: newNoopRepo(), - handlers: []JobHandler{&fakeJobHandler{}}, - wantErr: ErrJobTypeEmpty, + handlers: []reconciler.JobHandler{&fakeJobHandler{}}, + wantErr: reconciler.ErrJobTypeEmpty, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := NewManager(t.Context(), tt.cfg, tt.repo, tt.targetProvider, tt.handlers) + _, err := reconciler.NewManager(t.Context(), tt.cfg, tt.repo, tt.targetProvider, tt.handlers) assert.ErrorIs(t, err, tt.wantErr) }) } @@ -128,20 +127,20 @@ func TestNewManagerValidation(t *testing.T) { func TestManagerRoutesJobHandler(t *testing.T) { handler := &fakeJobHandler{jobType: "job.type"} - manager, err := NewManager(t.Context(), &config.ReconcilerConfig{}, newNoopRepo(), nil, []JobHandler{handler}) + manager, err := reconciler.NewManager(t.Context(), &config.ReconcilerConfig{}, newNoopRepo(), nil, []reconciler.JobHandler{handler}) require.NoError(t, err) - confirmResult, err := manager.confirmJob(t.Context(), orbital.Job{Type: "job.type"}) + confirmResult, err := manager.ConfirmJob(t.Context(), orbital.Job{Type: "job.type"}) require.NoError(t, err) assert.Equal(t, orbital.CompleteJobConfirmer().Type(), confirmResult.Type()) - resolveResult, err := manager.resolveTasks(t.Context(), orbital.Job{Type: "job.type"}, "") + resolveResult, err := manager.ResolveTasks(t.Context(), orbital.Job{Type: "job.type"}, "") require.NoError(t, err) assert.Equal(t, orbital.CompleteTaskResolver().Type(), resolveResult.Type()) - assert.NoError(t, manager.jobDone(t.Context(), orbital.Job{Type: "job.type"})) - assert.NoError(t, manager.jobFailed(t.Context(), orbital.Job{Type: "job.type"})) - assert.NoError(t, manager.jobCanceled(t.Context(), orbital.Job{Type: "job.type"})) + assert.NoError(t, manager.JobDone(t.Context(), orbital.Job{Type: "job.type"})) + assert.NoError(t, manager.JobFailed(t.Context(), orbital.Job{Type: "job.type"})) + assert.NoError(t, manager.JobCanceled(t.Context(), orbital.Job{Type: "job.type"})) assert.True(t, handler.confirmed) assert.True(t, handler.resolved) @@ -151,27 +150,27 @@ func TestManagerRoutesJobHandler(t *testing.T) { } func TestManagerUnknownJobTypeCancels(t *testing.T) { - manager, err := NewManager(t.Context(), &config.ReconcilerConfig{}, newNoopRepo(), nil, []JobHandler{&fakeJobHandler{jobType: "known"}}) + manager, err := reconciler.NewManager(t.Context(), &config.ReconcilerConfig{}, newNoopRepo(), nil, []reconciler.JobHandler{&fakeJobHandler{jobType: "known"}}) require.NoError(t, err) - confirmResult, err := manager.confirmJob(t.Context(), orbital.Job{Type: "unknown"}) + confirmResult, err := manager.ConfirmJob(t.Context(), orbital.Job{Type: "unknown"}) require.NoError(t, err) assert.Equal(t, orbital.CancelJobConfirmer("missing").Type(), confirmResult.Type()) - resolveResult, err := manager.resolveTasks(t.Context(), orbital.Job{Type: "unknown"}, "") + resolveResult, err := manager.ResolveTasks(t.Context(), orbital.Job{Type: "unknown"}, "") require.NoError(t, err) assert.Equal(t, orbital.CancelTaskResolver("missing").Type(), resolveResult.Type()) - assert.ErrorIs(t, manager.jobDone(t.Context(), orbital.Job{Type: "unknown"}), ErrJobHandlerNotFound) - assert.ErrorIs(t, manager.jobFailed(t.Context(), orbital.Job{Type: "unknown"}), ErrJobHandlerNotFound) - assert.ErrorIs(t, manager.jobCanceled(t.Context(), orbital.Job{Type: "unknown"}), ErrJobHandlerNotFound) - assert.Contains(t, jobHandlerNotFoundError("unknown").Error(), noJobHandlerRegisteredMessage) + assert.ErrorIs(t, manager.JobDone(t.Context(), orbital.Job{Type: "unknown"}), reconciler.ErrJobHandlerNotFound) + assert.ErrorIs(t, manager.JobFailed(t.Context(), orbital.Job{Type: "unknown"}), reconciler.ErrJobHandlerNotFound) + assert.ErrorIs(t, manager.JobCanceled(t.Context(), orbital.Job{Type: "unknown"}), reconciler.ErrJobHandlerNotFound) + assert.Contains(t, reconciler.JobHandlerNotFoundError("unknown").Error(), reconciler.NoJobHandlerRegisteredMessage) } func TestBuildTargetsClosesCreatedClientsOnError(t *testing.T) { first := &fakeInitiator{} providerErr := errors.New("boom") - targetProvider := TargetProvider(func(_ context.Context, target config.ReconcilerTarget) (orbital.Initiator, error) { + targetProvider := reconciler.NewTargetProvider(func(_ context.Context, target config.ReconcilerTarget) (orbital.Initiator, error) { if target.Name == "first" { return first, nil } @@ -179,7 +178,7 @@ func TestBuildTargetsClosesCreatedClientsOnError(t *testing.T) { }) targets := []config.ReconcilerTarget{validTarget("first"), validTarget("second")} - _, err := buildTargets(t.Context(), targets, targetProvider) + _, err := reconciler.BuildTargets(t.Context(), targets, targetProvider) assert.ErrorIs(t, err, providerErr) assert.True(t, first.closed) @@ -187,12 +186,12 @@ func TestBuildTargetsClosesCreatedClientsOnError(t *testing.T) { func TestStopClosesTargetsWhenOrbitalWasNotStarted(t *testing.T) { initiator := &fakeInitiator{} - targetProvider := TargetProvider(func(context.Context, config.ReconcilerTarget) (orbital.Initiator, error) { + targetProvider := reconciler.NewTargetProvider(func(context.Context, config.ReconcilerTarget) (orbital.Initiator, error) { return initiator, nil }) cfg := config.ReconcilerConfig{Targets: []config.ReconcilerTarget{validTarget("agent-aws")}} - manager, err := NewManager(t.Context(), &cfg, newNoopRepo(), targetProvider, []JobHandler{&fakeJobHandler{jobType: "job.type"}}) + manager, err := reconciler.NewManager(t.Context(), &cfg, newNoopRepo(), targetProvider, []reconciler.JobHandler{&fakeJobHandler{jobType: "job.type"}}) require.NoError(t, err) err = manager.Stop(t.Context()) @@ -202,33 +201,6 @@ func TestStopClosesTargetsWhenOrbitalWasNotStarted(t *testing.T) { assert.Equal(t, 1, initiator.closeCount) } -func TestManagerOnlyExposesStartAndStop(t *testing.T) { - fset := token.NewFileSet() - file, err := parser.ParseFile(fset, "manager.go", nil, 0) - require.NoError(t, err) - - exported := map[string]struct{}{} - for _, decl := range file.Decls { - fn, ok := decl.(*ast.FuncDecl) - if !ok || fn.Recv == nil || !fn.Name.IsExported() { - continue - } - if len(fn.Recv.List) == 0 { - continue - } - star, ok := fn.Recv.List[0].Type.(*ast.StarExpr) - if !ok { - continue - } - ident, ok := star.X.(*ast.Ident) - if ok && ident.Name == "Manager" { - exported[fn.Name.Name] = struct{}{} - } - } - - assert.Equal(t, map[string]struct{}{"Start": {}, "Stop": {}}, exported) -} - func configWithTargets() *config.ReconcilerConfig { return &config.ReconcilerConfig{ Targets: []config.ReconcilerTarget{validTarget("agent-aws")}, diff --git a/pkg/api/v1/proto/admin/admin_grpc.pb.go b/pkg/api/v1/proto/admin/admin_grpc.pb.go index 173365c..7e83031 100644 --- a/pkg/api/v1/proto/admin/admin_grpc.pb.go +++ b/pkg/api/v1/proto/admin/admin_grpc.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.1 +// - protoc-gen-go-grpc v1.5.1 // - protoc v7.34.1 // source: admin.proto @@ -90,13 +90,13 @@ type ServiceServer interface { type UnimplementedServiceServer struct{} func (UnimplementedServiceServer) CreateTenant(context.Context, *CreateTenantRequest) (*CreateTenantResponse, error) { - return nil, status.Error(codes.Unimplemented, "method CreateTenant not implemented") + return nil, status.Errorf(codes.Unimplemented, "method CreateTenant not implemented") } func (UnimplementedServiceServer) GetTenant(context.Context, *GetTenantRequest) (*GetTenantResponse, error) { - return nil, status.Error(codes.Unimplemented, "method GetTenant not implemented") + return nil, status.Errorf(codes.Unimplemented, "method GetTenant not implemented") } func (UnimplementedServiceServer) ListTenants(context.Context, *ListTenantsRequest) (*ListTenantsResponse, error) { - return nil, status.Error(codes.Unimplemented, "method ListTenants not implemented") + return nil, status.Errorf(codes.Unimplemented, "method ListTenants not implemented") } func (UnimplementedServiceServer) mustEmbedUnimplementedServiceServer() {} func (UnimplementedServiceServer) testEmbeddedByValue() {} @@ -109,7 +109,7 @@ type UnsafeServiceServer interface { } func RegisterServiceServer(s grpc.ServiceRegistrar, srv ServiceServer) { - // If the following call panics, it indicates UnimplementedServiceServer was + // If the following call pancis, it indicates UnimplementedServiceServer was // embedded by pointer and is nil. This will cause panics if an // unimplemented method is ever invoked, so we test this at initialization // time to prevent it from happening at runtime later due to I/O. diff --git a/pkg/api/v1/proto/admin/service_test.go b/pkg/api/v1/proto/admin/service_test.go index 92281c7..ad9f7e2 100644 --- a/pkg/api/v1/proto/admin/service_test.go +++ b/pkg/api/v1/proto/admin/service_test.go @@ -21,8 +21,8 @@ func TestCreateTenant(t *testing.T) { db := createDatabase(t) - tenantStore, err := storesql.NewTenantStore(ctx, db) - require.NoError(t, err) + require.NoError(t, storesql.Migrate(ctx, db)) + tenantStore := storesql.NewTenantStore(db) t.Run("creates tenant successfully", func(t *testing.T) { // given @@ -56,11 +56,11 @@ func TestCreateTenant(t *testing.T) { // given tmpDB := createDatabase(t) - tenantStore, err := storesql.NewTenantStore(ctx, tmpDB) - require.NoError(t, err) + require.NoError(t, storesql.Migrate(ctx, tmpDB)) + tenantStore := storesql.NewTenantStore(tmpDB) // Drop the tenants table to simulate a database error - _, err = tmpDB.ExecContext(ctx, "DROP TABLE tenants") + _, err := tmpDB.ExecContext(ctx, "DROP TABLE tenants CASCADE") require.NoError(t, err) cli := setupServerAndClient(t, tenantStore) @@ -86,8 +86,8 @@ func TestGetTenant(t *testing.T) { db := createDatabase(t) - tenantStore, err := storesql.NewTenantStore(ctx, db) - require.NoError(t, err) + require.NoError(t, storesql.Migrate(ctx, db)) + tenantStore := storesql.NewTenantStore(db) t.Run("should get tenant successfully", func(t *testing.T) { // given @@ -133,11 +133,11 @@ func TestGetTenant(t *testing.T) { // given tmpDB := createDatabase(t) - tenantStore, err := storesql.NewTenantStore(ctx, tmpDB) - require.NoError(t, err) + require.NoError(t, storesql.Migrate(ctx, tmpDB)) + tenantStore := storesql.NewTenantStore(tmpDB) // Drop the tenants table to simulate a database error - _, err = tmpDB.ExecContext(ctx, "DROP TABLE tenants") + _, err := tmpDB.ExecContext(ctx, "DROP TABLE tenants CASCADE") require.NoError(t, err) cli := setupServerAndClient(t, tenantStore) @@ -161,8 +161,8 @@ func TestListTenants(t *testing.T) { db := createDatabase(t) - tenantStore, err := storesql.NewTenantStore(ctx, db) - require.NoError(t, err) + require.NoError(t, storesql.Migrate(ctx, db)) + tenantStore := storesql.NewTenantStore(db) t.Run("should list tenants successfully", func(t *testing.T) { // given @@ -201,11 +201,11 @@ func TestListTenants(t *testing.T) { // given tmpDB := createDatabase(t) - tenantStore, err := storesql.NewTenantStore(ctx, tmpDB) - require.NoError(t, err) + require.NoError(t, storesql.Migrate(ctx, tmpDB)) + tenantStore := storesql.NewTenantStore(tmpDB) // Drop the tenants table to simulate a database error - _, err = tmpDB.ExecContext(ctx, "DROP TABLE tenants") + _, err := tmpDB.ExecContext(ctx, "DROP TABLE tenants CASCADE") require.NoError(t, err) cli := setupServerAndClient(t, tenantStore) diff --git a/pkg/api/v1/proto/agents/agents_grpc.pb.go b/pkg/api/v1/proto/agents/agents_grpc.pb.go index 1ac4495..bf692ab 100644 --- a/pkg/api/v1/proto/agents/agents_grpc.pb.go +++ b/pkg/api/v1/proto/agents/agents_grpc.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.1 +// - protoc-gen-go-grpc v1.5.1 // - protoc v7.34.1 // source: agents.proto @@ -90,13 +90,13 @@ type ServiceServer interface { type UnimplementedServiceServer struct{} func (UnimplementedServiceServer) Register(context.Context, *RegisterAgentRequest) (*RegisterAgentResponse, error) { - return nil, status.Error(codes.Unimplemented, "method Register not implemented") + return nil, status.Errorf(codes.Unimplemented, "method Register not implemented") } func (UnimplementedServiceServer) SendHeartbeat(context.Context, *SendHeartbeatRequest) (*SendHeartbeatResponse, error) { - return nil, status.Error(codes.Unimplemented, "method SendHeartbeat not implemented") + return nil, status.Errorf(codes.Unimplemented, "method SendHeartbeat not implemented") } func (UnimplementedServiceServer) Deregister(context.Context, *DeregisterAgentRequest) (*DeregisterAgentResponse, error) { - return nil, status.Error(codes.Unimplemented, "method Deregister not implemented") + return nil, status.Errorf(codes.Unimplemented, "method Deregister not implemented") } func (UnimplementedServiceServer) mustEmbedUnimplementedServiceServer() {} func (UnimplementedServiceServer) testEmbeddedByValue() {} @@ -109,7 +109,7 @@ type UnsafeServiceServer interface { } func RegisterServiceServer(s grpc.ServiceRegistrar, srv ServiceServer) { - // If the following call panics, it indicates UnimplementedServiceServer was + // If the following call pancis, it indicates UnimplementedServiceServer was // embedded by pointer and is nil. This will cause panics if an // unimplemented method is ever invoked, so we test this at initialization // time to prevent it from happening at runtime later due to I/O. diff --git a/pkg/api/v1/proto/agents/deregister_test.go b/pkg/api/v1/proto/agents/deregister_test.go index 532051d..9226a0c 100644 --- a/pkg/api/v1/proto/agents/deregister_test.go +++ b/pkg/api/v1/proto/agents/deregister_test.go @@ -23,8 +23,8 @@ func TestDeregister(t *testing.T) { db := createDatabase(t) - agentStore, err := storesql.NewAgentStore(ctx, db) - require.NoError(t, err) + require.NoError(t, storesql.Migrate(ctx, db)) + agentStore := storesql.NewAgentStore(db) t.Run("should update the status of the registered to deregistered", func(t *testing.T) { // given @@ -76,7 +76,7 @@ func TestDeregister(t *testing.T) { expInstanceID := uuid.NewString() cli := setupServerAndClient(t, agentStore, validRootConfig(expAgentName)) - _, err = cli.SendHeartbeat(ctx, &agents.SendHeartbeatRequest{ + _, err := cli.SendHeartbeat(ctx, &agents.SendHeartbeatRequest{ AgentName: expAgentName, InstanceId: expInstanceID, }) @@ -172,7 +172,7 @@ func TestDeregister(t *testing.T) { expInstanceID := uuid.NewString() cli := setupServerAndClient(t, agentStore, validRootConfig(expAgentName)) - _, err = agentStore.Get(ctx, store.GetAgentQuery{ + _, err := agentStore.Get(ctx, store.GetAgentQuery{ Name: expAgentName, InstanceID: expInstanceID, }) @@ -226,11 +226,11 @@ func TestDeregister(t *testing.T) { expInstanceID := uuid.NewString() tmpDB := createDatabase(t) - agentStore, err := storesql.NewAgentStore(ctx, tmpDB) - require.NoError(t, err) + require.NoError(t, storesql.Migrate(ctx, tmpDB)) + agentStore := storesql.NewAgentStore(tmpDB) // drop the table to cause an error in the agent store during deregister processing - _, err = tmpDB.ExecContext(ctx, "DROP TABLE agent_registrations") + _, err := tmpDB.ExecContext(ctx, "DROP TABLE agent_registrations") require.NoError(t, err) cli := setupServerAndClient(t, agentStore, validRootConfig(expAgentName)) diff --git a/pkg/api/v1/proto/agents/heartbeat_test.go b/pkg/api/v1/proto/agents/heartbeat_test.go index 3728e10..da326cc 100644 --- a/pkg/api/v1/proto/agents/heartbeat_test.go +++ b/pkg/api/v1/proto/agents/heartbeat_test.go @@ -23,8 +23,8 @@ func TestSendHeartbeat(t *testing.T) { db := createDatabase(t) - agentStore, err := storesql.NewAgentStore(ctx, db) - require.NoError(t, err) + require.NoError(t, storesql.Migrate(ctx, db)) + agentStore := storesql.NewAgentStore(db) t.Run("should update the status of the registered to healthy", func(t *testing.T) { // given @@ -77,7 +77,7 @@ func TestSendHeartbeat(t *testing.T) { expInstanceID := uuid.NewString() cli := setupServerAndClient(t, agentStore, validRootConfig(expAgentName)) - _, err = agentStore.Get(ctx, store.GetAgentQuery{ + _, err := agentStore.Get(ctx, store.GetAgentQuery{ Name: expAgentName, InstanceID: expInstanceID, }) @@ -156,11 +156,11 @@ func TestSendHeartbeat(t *testing.T) { expInstanceID := uuid.NewString() tmpDB := createDatabase(t) - agentStore, err := storesql.NewAgentStore(ctx, tmpDB) - require.NoError(t, err) + require.NoError(t, storesql.Migrate(ctx, tmpDB)) + agentStore := storesql.NewAgentStore(tmpDB) // drop the table to cause an error in the agent store during heartbeat processing - _, err = tmpDB.ExecContext(ctx, "DROP TABLE agent_registrations") + _, err := tmpDB.ExecContext(ctx, "DROP TABLE agent_registrations") require.NoError(t, err) cli := setupServerAndClient(t, agentStore, validRootConfig(expAgentName)) diff --git a/pkg/api/v1/proto/agents/register_test.go b/pkg/api/v1/proto/agents/register_test.go index e5145fe..d8a5165 100644 --- a/pkg/api/v1/proto/agents/register_test.go +++ b/pkg/api/v1/proto/agents/register_test.go @@ -24,8 +24,8 @@ func TestRegister(t *testing.T) { db := createDatabase(t) - agentStore, err := storesql.NewAgentStore(ctx, db) - require.NoError(t, err) + require.NoError(t, storesql.Migrate(ctx, db)) + agentStore := storesql.NewAgentStore(db) t.Run("should register agent successfully", func(t *testing.T) { // given @@ -130,11 +130,11 @@ func TestRegister(t *testing.T) { expInstanceID := uuid.NewString() tmpDB := createDatabase(t) - agentStore, err := storesql.NewAgentStore(ctx, tmpDB) - require.NoError(t, err) + require.NoError(t, storesql.Migrate(ctx, tmpDB)) + agentStore := storesql.NewAgentStore(tmpDB) // drop the table to cause an error in the agent store during register processing - _, err = tmpDB.ExecContext(ctx, "DROP TABLE agent_registrations") + _, err := tmpDB.ExecContext(ctx, "DROP TABLE agent_registrations") require.NoError(t, err) cli := setupServerAndClient(t, agentStore, validRootConfig(expAgentName)) diff --git a/pkg/api/v1/proto/keys/convert.go b/pkg/api/v1/proto/keys/convert.go new file mode 100644 index 0000000..bbf14fa --- /dev/null +++ b/pkg/api/v1/proto/keys/convert.go @@ -0,0 +1,45 @@ +package keys + +import ( + "github.com/openkcm/krypton/internal/clock" + "github.com/openkcm/krypton/pkg/model" +) + +func KeyToProto(k model.Key) *Key { + var parentID string + if k.ParentID != nil { + parentID = *k.ParentID + } + return &Key{ + Id: k.ID, + Name: k.Name, + TenantId: k.TenantID, + Kind: k.Kind, + ParentId: parentID, + ManagedBy: k.ManagedBy, + Labels: k.Labels, + State: string(k.State), + CreatedAt: int64(k.CreatedAt), + UpdatedAt: int64(k.UpdatedAt), + } +} + +func KeyFromProto(k *Key) model.Key { + var parentID *string + if k.GetParentId() != "" { + p := k.GetParentId() + parentID = &p + } + return model.Key{ + ID: k.GetId(), + Name: k.GetName(), + TenantID: k.GetTenantId(), + Kind: k.GetKind(), + ParentID: parentID, + ManagedBy: k.GetManagedBy(), + Labels: k.GetLabels(), + State: model.KeyState(k.GetState()), + CreatedAt: clock.UnixNano(k.GetCreatedAt()), + UpdatedAt: clock.UnixNano(k.GetUpdatedAt()), + } +} diff --git a/pkg/api/v1/proto/keys/keys.pb.go b/pkg/api/v1/proto/keys/keys.pb.go new file mode 100644 index 0000000..306719f --- /dev/null +++ b/pkg/api/v1/proto/keys/keys.pb.go @@ -0,0 +1,471 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v7.34.1 +// source: keys.proto + +package keys + +import ( + reflect "reflect" + sync "sync" + unsafe "unsafe" + + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Key struct { + state protoimpl.MessageState `protogen:"open.v1"` + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` + TenantId string `protobuf:"bytes,3,opt,name=tenant_id,json=tenantId,proto3" json:"tenant_id,omitempty"` + Kind string `protobuf:"bytes,4,opt,name=kind,proto3" json:"kind,omitempty"` + ParentId string `protobuf:"bytes,5,opt,name=parent_id,json=parentId,proto3" json:"parent_id,omitempty"` + ManagedBy string `protobuf:"bytes,6,opt,name=managed_by,json=managedBy,proto3" json:"managed_by,omitempty"` + Labels map[string]string `protobuf:"bytes,7,rep,name=labels,proto3" json:"labels,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + State string `protobuf:"bytes,8,opt,name=state,proto3" json:"state,omitempty"` + CreatedAt int64 `protobuf:"varint,9,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"` + UpdatedAt int64 `protobuf:"varint,10,opt,name=updated_at,json=updatedAt,proto3" json:"updated_at,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Key) Reset() { + *x = Key{} + mi := &file_keys_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Key) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Key) ProtoMessage() {} + +func (x *Key) ProtoReflect() protoreflect.Message { + mi := &file_keys_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Key.ProtoReflect.Descriptor instead. +func (*Key) Descriptor() ([]byte, []int) { + return file_keys_proto_rawDescGZIP(), []int{0} +} + +func (x *Key) GetId() string { + if x != nil { + return x.Id + } + return "" +} + +func (x *Key) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *Key) GetTenantId() string { + if x != nil { + return x.TenantId + } + return "" +} + +func (x *Key) GetKind() string { + if x != nil { + return x.Kind + } + return "" +} + +func (x *Key) GetParentId() string { + if x != nil { + return x.ParentId + } + return "" +} + +func (x *Key) GetManagedBy() string { + if x != nil { + return x.ManagedBy + } + return "" +} + +func (x *Key) GetLabels() map[string]string { + if x != nil { + return x.Labels + } + return nil +} + +func (x *Key) GetState() string { + if x != nil { + return x.State + } + return "" +} + +func (x *Key) GetCreatedAt() int64 { + if x != nil { + return x.CreatedAt + } + return 0 +} + +func (x *Key) GetUpdatedAt() int64 { + if x != nil { + return x.UpdatedAt + } + return 0 +} + +type AnnounceKeyRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + TenantId string `protobuf:"bytes,1,opt,name=tenant_id,json=tenantId,proto3" json:"tenant_id,omitempty"` + Kind string `protobuf:"bytes,2,opt,name=kind,proto3" json:"kind,omitempty"` + Name string `protobuf:"bytes,3,opt,name=name,proto3" json:"name,omitempty"` + ParentId string `protobuf:"bytes,4,opt,name=parent_id,json=parentId,proto3" json:"parent_id,omitempty"` + TargetName string `protobuf:"bytes,5,opt,name=target_name,json=targetName,proto3" json:"target_name,omitempty"` + Labels map[string]string `protobuf:"bytes,6,rep,name=labels,proto3" json:"labels,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AnnounceKeyRequest) Reset() { + *x = AnnounceKeyRequest{} + mi := &file_keys_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AnnounceKeyRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AnnounceKeyRequest) ProtoMessage() {} + +func (x *AnnounceKeyRequest) ProtoReflect() protoreflect.Message { + mi := &file_keys_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AnnounceKeyRequest.ProtoReflect.Descriptor instead. +func (*AnnounceKeyRequest) Descriptor() ([]byte, []int) { + return file_keys_proto_rawDescGZIP(), []int{1} +} + +func (x *AnnounceKeyRequest) GetTenantId() string { + if x != nil { + return x.TenantId + } + return "" +} + +func (x *AnnounceKeyRequest) GetKind() string { + if x != nil { + return x.Kind + } + return "" +} + +func (x *AnnounceKeyRequest) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *AnnounceKeyRequest) GetParentId() string { + if x != nil { + return x.ParentId + } + return "" +} + +func (x *AnnounceKeyRequest) GetTargetName() string { + if x != nil { + return x.TargetName + } + return "" +} + +func (x *AnnounceKeyRequest) GetLabels() map[string]string { + if x != nil { + return x.Labels + } + return nil +} + +type AnnounceKeyResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Key *Key `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AnnounceKeyResponse) Reset() { + *x = AnnounceKeyResponse{} + mi := &file_keys_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AnnounceKeyResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AnnounceKeyResponse) ProtoMessage() {} + +func (x *AnnounceKeyResponse) ProtoReflect() protoreflect.Message { + mi := &file_keys_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AnnounceKeyResponse.ProtoReflect.Descriptor instead. +func (*AnnounceKeyResponse) Descriptor() ([]byte, []int) { + return file_keys_proto_rawDescGZIP(), []int{2} +} + +func (x *AnnounceKeyResponse) GetKey() *Key { + if x != nil { + return x.Key + } + return nil +} + +type GetKeyRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + TenantId string `protobuf:"bytes,2,opt,name=tenant_id,json=tenantId,proto3" json:"tenant_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetKeyRequest) Reset() { + *x = GetKeyRequest{} + mi := &file_keys_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetKeyRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetKeyRequest) ProtoMessage() {} + +func (x *GetKeyRequest) ProtoReflect() protoreflect.Message { + mi := &file_keys_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetKeyRequest.ProtoReflect.Descriptor instead. +func (*GetKeyRequest) Descriptor() ([]byte, []int) { + return file_keys_proto_rawDescGZIP(), []int{3} +} + +func (x *GetKeyRequest) GetId() string { + if x != nil { + return x.Id + } + return "" +} + +func (x *GetKeyRequest) GetTenantId() string { + if x != nil { + return x.TenantId + } + return "" +} + +type GetKeyResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Key *Key `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetKeyResponse) Reset() { + *x = GetKeyResponse{} + mi := &file_keys_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetKeyResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetKeyResponse) ProtoMessage() {} + +func (x *GetKeyResponse) ProtoReflect() protoreflect.Message { + mi := &file_keys_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetKeyResponse.ProtoReflect.Descriptor instead. +func (*GetKeyResponse) Descriptor() ([]byte, []int) { + return file_keys_proto_rawDescGZIP(), []int{4} +} + +func (x *GetKeyResponse) GetKey() *Key { + if x != nil { + return x.Key + } + return nil +} + +var File_keys_proto protoreflect.FileDescriptor + +const file_keys_proto_rawDesc = "" + + "\n" + + "\n" + + "keys.proto\x12\x0fkrypton.v1.keys\"\xdf\x02\n" + + "\x03Key\x12\x0e\n" + + "\x02id\x18\x01 \x01(\tR\x02id\x12\x12\n" + + "\x04name\x18\x02 \x01(\tR\x04name\x12\x1b\n" + + "\ttenant_id\x18\x03 \x01(\tR\btenantId\x12\x12\n" + + "\x04kind\x18\x04 \x01(\tR\x04kind\x12\x1b\n" + + "\tparent_id\x18\x05 \x01(\tR\bparentId\x12\x1d\n" + + "\n" + + "managed_by\x18\x06 \x01(\tR\tmanagedBy\x128\n" + + "\x06labels\x18\a \x03(\v2 .krypton.v1.keys.Key.LabelsEntryR\x06labels\x12\x14\n" + + "\x05state\x18\b \x01(\tR\x05state\x12\x1d\n" + + "\n" + + "created_at\x18\t \x01(\x03R\tcreatedAt\x12\x1d\n" + + "\n" + + "updated_at\x18\n" + + " \x01(\x03R\tupdatedAt\x1a9\n" + + "\vLabelsEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + + "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01\"\x9b\x02\n" + + "\x12AnnounceKeyRequest\x12\x1b\n" + + "\ttenant_id\x18\x01 \x01(\tR\btenantId\x12\x12\n" + + "\x04kind\x18\x02 \x01(\tR\x04kind\x12\x12\n" + + "\x04name\x18\x03 \x01(\tR\x04name\x12\x1b\n" + + "\tparent_id\x18\x04 \x01(\tR\bparentId\x12\x1f\n" + + "\vtarget_name\x18\x05 \x01(\tR\n" + + "targetName\x12G\n" + + "\x06labels\x18\x06 \x03(\v2/.krypton.v1.keys.AnnounceKeyRequest.LabelsEntryR\x06labels\x1a9\n" + + "\vLabelsEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + + "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01\"=\n" + + "\x13AnnounceKeyResponse\x12&\n" + + "\x03key\x18\x01 \x01(\v2\x14.krypton.v1.keys.KeyR\x03key\"<\n" + + "\rGetKeyRequest\x12\x0e\n" + + "\x02id\x18\x01 \x01(\tR\x02id\x12\x1b\n" + + "\ttenant_id\x18\x02 \x01(\tR\btenantId\"8\n" + + "\x0eGetKeyResponse\x12&\n" + + "\x03key\x18\x01 \x01(\v2\x14.krypton.v1.keys.KeyR\x03key2\xae\x01\n" + + "\aService\x12X\n" + + "\vAnnounceKey\x12#.krypton.v1.keys.AnnounceKeyRequest\x1a$.krypton.v1.keys.AnnounceKeyResponse\x12I\n" + + "\x06GetKey\x12\x1e.krypton.v1.keys.GetKeyRequest\x1a\x1f.krypton.v1.keys.GetKeyResponseB2Z0github.com/openkcm/krypton/pkg/api/v1/proto/keysb\x06proto3" + +var ( + file_keys_proto_rawDescOnce sync.Once + file_keys_proto_rawDescData []byte +) + +func file_keys_proto_rawDescGZIP() []byte { + file_keys_proto_rawDescOnce.Do(func() { + file_keys_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_keys_proto_rawDesc), len(file_keys_proto_rawDesc))) + }) + return file_keys_proto_rawDescData +} + +var file_keys_proto_msgTypes = make([]protoimpl.MessageInfo, 7) +var file_keys_proto_goTypes = []any{ + (*Key)(nil), // 0: krypton.v1.keys.Key + (*AnnounceKeyRequest)(nil), // 1: krypton.v1.keys.AnnounceKeyRequest + (*AnnounceKeyResponse)(nil), // 2: krypton.v1.keys.AnnounceKeyResponse + (*GetKeyRequest)(nil), // 3: krypton.v1.keys.GetKeyRequest + (*GetKeyResponse)(nil), // 4: krypton.v1.keys.GetKeyResponse + nil, // 5: krypton.v1.keys.Key.LabelsEntry + nil, // 6: krypton.v1.keys.AnnounceKeyRequest.LabelsEntry +} +var file_keys_proto_depIdxs = []int32{ + 5, // 0: krypton.v1.keys.Key.labels:type_name -> krypton.v1.keys.Key.LabelsEntry + 6, // 1: krypton.v1.keys.AnnounceKeyRequest.labels:type_name -> krypton.v1.keys.AnnounceKeyRequest.LabelsEntry + 0, // 2: krypton.v1.keys.AnnounceKeyResponse.key:type_name -> krypton.v1.keys.Key + 0, // 3: krypton.v1.keys.GetKeyResponse.key:type_name -> krypton.v1.keys.Key + 1, // 4: krypton.v1.keys.Service.AnnounceKey:input_type -> krypton.v1.keys.AnnounceKeyRequest + 3, // 5: krypton.v1.keys.Service.GetKey:input_type -> krypton.v1.keys.GetKeyRequest + 2, // 6: krypton.v1.keys.Service.AnnounceKey:output_type -> krypton.v1.keys.AnnounceKeyResponse + 4, // 7: krypton.v1.keys.Service.GetKey:output_type -> krypton.v1.keys.GetKeyResponse + 6, // [6:8] is the sub-list for method output_type + 4, // [4:6] is the sub-list for method input_type + 4, // [4:4] is the sub-list for extension type_name + 4, // [4:4] is the sub-list for extension extendee + 0, // [0:4] is the sub-list for field type_name +} + +func init() { file_keys_proto_init() } +func file_keys_proto_init() { + if File_keys_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_keys_proto_rawDesc), len(file_keys_proto_rawDesc)), + NumEnums: 0, + NumMessages: 7, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_keys_proto_goTypes, + DependencyIndexes: file_keys_proto_depIdxs, + MessageInfos: file_keys_proto_msgTypes, + }.Build() + File_keys_proto = out.File + file_keys_proto_goTypes = nil + file_keys_proto_depIdxs = nil +} diff --git a/pkg/api/v1/proto/keys/keys_grpc.pb.go b/pkg/api/v1/proto/keys/keys_grpc.pb.go new file mode 100644 index 0000000..1b7aa32 --- /dev/null +++ b/pkg/api/v1/proto/keys/keys_grpc.pb.go @@ -0,0 +1,160 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.5.1 +// - protoc v7.34.1 +// source: keys.proto + +package keys + +import ( + context "context" + + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + Service_AnnounceKey_FullMethodName = "/krypton.v1.keys.Service/AnnounceKey" + Service_GetKey_FullMethodName = "/krypton.v1.keys.Service/GetKey" +) + +// ServiceClient is the client API for Service service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type ServiceClient interface { + AnnounceKey(ctx context.Context, in *AnnounceKeyRequest, opts ...grpc.CallOption) (*AnnounceKeyResponse, error) + GetKey(ctx context.Context, in *GetKeyRequest, opts ...grpc.CallOption) (*GetKeyResponse, error) +} + +type serviceClient struct { + cc grpc.ClientConnInterface +} + +func NewServiceClient(cc grpc.ClientConnInterface) ServiceClient { + return &serviceClient{cc} +} + +func (c *serviceClient) AnnounceKey(ctx context.Context, in *AnnounceKeyRequest, opts ...grpc.CallOption) (*AnnounceKeyResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(AnnounceKeyResponse) + err := c.cc.Invoke(ctx, Service_AnnounceKey_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *serviceClient) GetKey(ctx context.Context, in *GetKeyRequest, opts ...grpc.CallOption) (*GetKeyResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(GetKeyResponse) + err := c.cc.Invoke(ctx, Service_GetKey_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// ServiceServer is the server API for Service service. +// All implementations must embed UnimplementedServiceServer +// for forward compatibility. +type ServiceServer interface { + AnnounceKey(context.Context, *AnnounceKeyRequest) (*AnnounceKeyResponse, error) + GetKey(context.Context, *GetKeyRequest) (*GetKeyResponse, error) + mustEmbedUnimplementedServiceServer() +} + +// UnimplementedServiceServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedServiceServer struct{} + +func (UnimplementedServiceServer) AnnounceKey(context.Context, *AnnounceKeyRequest) (*AnnounceKeyResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method AnnounceKey not implemented") +} +func (UnimplementedServiceServer) GetKey(context.Context, *GetKeyRequest) (*GetKeyResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetKey not implemented") +} +func (UnimplementedServiceServer) mustEmbedUnimplementedServiceServer() {} +func (UnimplementedServiceServer) testEmbeddedByValue() {} + +// UnsafeServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to ServiceServer will +// result in compilation errors. +type UnsafeServiceServer interface { + mustEmbedUnimplementedServiceServer() +} + +func RegisterServiceServer(s grpc.ServiceRegistrar, srv ServiceServer) { + // If the following call pancis, it indicates UnimplementedServiceServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&Service_ServiceDesc, srv) +} + +func _Service_AnnounceKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(AnnounceKeyRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ServiceServer).AnnounceKey(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Service_AnnounceKey_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ServiceServer).AnnounceKey(ctx, req.(*AnnounceKeyRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Service_GetKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetKeyRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ServiceServer).GetKey(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Service_GetKey_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ServiceServer).GetKey(ctx, req.(*GetKeyRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// Service_ServiceDesc is the grpc.ServiceDesc for Service service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var Service_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "krypton.v1.keys.Service", + HandlerType: (*ServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "AnnounceKey", + Handler: _Service_AnnounceKey_Handler, + }, + { + MethodName: "GetKey", + Handler: _Service_GetKey_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "keys.proto", +} diff --git a/pkg/api/v1/proto/keys/service.go b/pkg/api/v1/proto/keys/service.go new file mode 100644 index 0000000..b761808 --- /dev/null +++ b/pkg/api/v1/proto/keys/service.go @@ -0,0 +1,67 @@ +package keys + +import ( + "context" + "errors" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/openkcm/krypton/pkg/api/v1/proto" + "github.com/openkcm/krypton/pkg/model" + "github.com/openkcm/krypton/pkg/store" +) + +type Service struct { + UnimplementedServiceServer + + keyStore store.Key +} + +func NewService(keyStore store.Key) *Service { + return &Service{keyStore: keyStore} +} + +func (s *Service) AnnounceKey(ctx context.Context, req *AnnounceKeyRequest) (*AnnounceKeyResponse, error) { + var parentID *string + if req.GetParentId() != "" { + p := req.GetParentId() + parentID = &p + } + + key := model.NewKey( + req.GetTenantId(), + req.GetName(), + req.GetKind(), + parentID, + req.GetTargetName(), + req.GetLabels(), + ) + + if err := s.keyStore.CreateKey(ctx, key); err != nil { + return nil, proto.ErrDetailsWithCode( + status.New(codes.Internal, "failed to announce key"), + proto.Code_ERROR_CODE_RETRY, + ) + } + + return &AnnounceKeyResponse{Key: KeyToProto(key)}, nil +} + +func (s *Service) GetKey(ctx context.Context, req *GetKeyRequest) (*GetKeyResponse, error) { + key, err := s.keyStore.GetKeyByID(ctx, req.GetId(), req.GetTenantId()) + if err != nil { + if errors.Is(err, store.ErrKeyNotFound) { + return nil, proto.ErrDetailsWithCode( + status.New(codes.NotFound, "key not found"), + proto.Code_ERROR_CODE_ABORT, + ) + } + return nil, proto.ErrDetailsWithCode( + status.New(codes.Internal, "failed to get key"), + proto.Code_ERROR_CODE_RETRY, + ) + } + + return &GetKeyResponse{Key: KeyToProto(*key)}, nil +} diff --git a/pkg/api/v1/proto/keys/service_test.go b/pkg/api/v1/proto/keys/service_test.go new file mode 100644 index 0000000..484384d --- /dev/null +++ b/pkg/api/v1/proto/keys/service_test.go @@ -0,0 +1,163 @@ +package keys_test + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/openkcm/krypton/pkg/api/v1/proto" + "github.com/openkcm/krypton/pkg/api/v1/proto/keys" + storesql "github.com/openkcm/krypton/pkg/store/sql" +) + +func TestAnnounceKey(t *testing.T) { + ctx := t.Context() + db := createDatabase(t) + + require.NoError(t, storesql.Migrate(ctx, db)) + keyStore := storesql.NewKeyStore(db) + + tenant := createTenant(t, db) + + t.Run("should create key successfully", func(t *testing.T) { + cli := setupServerAndClient(t, keyStore) + + res, err := cli.AnnounceKey(ctx, &keys.AnnounceKeyRequest{ + TenantId: tenant.ID, + Kind: "K0", + Name: "root-key-" + uuid.NewString(), + TargetName: "root", + Labels: map[string]string{"env": "prod"}, + }) + + assert.NoError(t, err) + assert.NotEmpty(t, res.GetKey().GetId()) + assert.Equal(t, "K0", res.GetKey().GetKind()) + assert.Equal(t, "root", res.GetKey().GetManagedBy()) + assert.Equal(t, "pending", res.GetKey().GetState()) + assert.Equal(t, tenant.ID, res.GetKey().GetTenantId()) + assert.Equal(t, "prod", res.GetKey().GetLabels()["env"]) + assert.NotZero(t, res.GetKey().GetCreatedAt()) + assert.NotZero(t, res.GetKey().GetUpdatedAt()) + }) + + t.Run("should create key with parent", func(t *testing.T) { + cli := setupServerAndClient(t, keyStore) + + parentRes, err := cli.AnnounceKey(ctx, &keys.AnnounceKeyRequest{ + TenantId: tenant.ID, + Kind: "K0", + Name: "parent-" + uuid.NewString(), + TargetName: "root", + }) + require.NoError(t, err) + + res, err := cli.AnnounceKey(ctx, &keys.AnnounceKeyRequest{ + TenantId: tenant.ID, + Kind: "K1", + Name: "child-" + uuid.NewString(), + ParentId: parentRes.GetKey().GetId(), + TargetName: "root", + }) + + assert.NoError(t, err) + assert.Equal(t, parentRes.GetKey().GetId(), res.GetKey().GetParentId()) + assert.Equal(t, "root", res.GetKey().GetManagedBy()) + }) + + t.Run("should return internal error on database failure", func(t *testing.T) { + tmpDB := createDatabase(t) + + require.NoError(t, storesql.Migrate(ctx, tmpDB)) + tmpKeyStore := storesql.NewKeyStore(tmpDB) + + _, err := tmpDB.ExecContext(ctx, "DROP TABLE keys") + require.NoError(t, err) + + cli := setupServerAndClient(t, tmpKeyStore) + + resp, err := cli.AnnounceKey(ctx, &keys.AnnounceKeyRequest{ + TenantId: uuid.NewString(), + Kind: "K0", + Name: "will-fail", + TargetName: "root", + }) + + assert.Error(t, err) + assert.Nil(t, resp) + assert.Equal(t, codes.Internal, status.Code(err)) + }) +} + +func TestGetKeyService(t *testing.T) { + ctx := t.Context() + db := createDatabase(t) + + require.NoError(t, storesql.Migrate(ctx, db)) + keyStore := storesql.NewKeyStore(db) + + tenant := createTenant(t, db) + + t.Run("should get key successfully", func(t *testing.T) { + cli := setupServerAndClient(t, keyStore) + + created, err := cli.AnnounceKey(ctx, &keys.AnnounceKeyRequest{ + TenantId: tenant.ID, + Kind: "K0", + Name: "get-me-" + uuid.NewString(), + TargetName: "root", + Labels: map[string]string{"env": "staging"}, + }) + require.NoError(t, err) + + res, err := cli.GetKey(ctx, &keys.GetKeyRequest{ + Id: created.GetKey().GetId(), + TenantId: tenant.ID, + }) + + assert.NoError(t, err) + assert.Equal(t, created.GetKey().GetId(), res.GetKey().GetId()) + assert.Equal(t, created.GetKey().GetName(), res.GetKey().GetName()) + assert.Equal(t, "staging", res.GetKey().GetLabels()["env"]) + }) + + t.Run("should return not found for nonexistent key", func(t *testing.T) { + cli := setupServerAndClient(t, keyStore) + + resp, err := cli.GetKey(ctx, &keys.GetKeyRequest{ + Id: uuid.NewString(), + TenantId: tenant.ID, + }) + + assert.Error(t, err) + assert.Nil(t, resp) + assert.Equal(t, codes.NotFound, status.Code(err)) + assertErrorDetails(t, proto.Code_ERROR_CODE_ABORT, err) + }) + + t.Run("should return internal error on database failure", func(t *testing.T) { + tmpDB := createDatabase(t) + + require.NoError(t, storesql.Migrate(ctx, tmpDB)) + tmpKeyStore := storesql.NewKeyStore(tmpDB) + + _, err := tmpDB.ExecContext(ctx, "DROP TABLE keys") + require.NoError(t, err) + + cli := setupServerAndClient(t, tmpKeyStore) + + resp, err := cli.GetKey(ctx, &keys.GetKeyRequest{ + Id: uuid.NewString(), + TenantId: uuid.NewString(), + }) + + assert.Error(t, err) + assert.Nil(t, resp) + assert.Equal(t, codes.Internal, status.Code(err)) + assertErrorDetails(t, proto.Code_ERROR_CODE_RETRY, err) + }) +} diff --git a/pkg/api/v1/proto/keys/setup_test.go b/pkg/api/v1/proto/keys/setup_test.go new file mode 100644 index 0000000..335a8ef --- /dev/null +++ b/pkg/api/v1/proto/keys/setup_test.go @@ -0,0 +1,152 @@ +package keys_test + +import ( + "context" + "database/sql" + "net" + "os" + "strings" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go/modules/postgres" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" + "google.golang.org/grpc/test/bufconn" + + _ "github.com/lib/pq" + + "github.com/openkcm/krypton/pkg/api/v1/proto" + "github.com/openkcm/krypton/pkg/api/v1/proto/keys" + "github.com/openkcm/krypton/pkg/model" + "github.com/openkcm/krypton/pkg/store" + storesql "github.com/openkcm/krypton/pkg/store/sql" +) + +var pgConnStr string + +func TestMain(m *testing.M) { + pgCleanup, err := setupPostgres() + if err != nil { + os.Exit(1) + } + + exitCode := m.Run() + pgCleanup() + os.Exit(exitCode) +} + +func setupPostgres() (func(), error) { + ctx := context.Background() + + pgContainer, err := postgres.Run(ctx, + "postgres:18-alpine", + postgres.WithDatabase("postgres"), + postgres.WithUsername("testuser"), + postgres.WithPassword("testpass"), + postgres.BasicWaitStrategies(), + ) + if err != nil { + return nil, err + } + cleanUp := func() { _ = pgContainer.Terminate(ctx) } + + pgConnStr, err = pgContainer.ConnectionString(ctx, "sslmode=disable") + + return cleanUp, err +} + +func setupServerAndClient(t *testing.T, keyStore store.Key) keys.ServiceClient { + t.Helper() + + srv := grpc.NewServer() + keys.RegisterServiceServer(srv, keys.NewService(keyStore)) + + const bufSize = 1024 * 1024 + lis := bufconn.Listen(bufSize) + go func() { + if err := srv.Serve(lis); err != nil { + assert.Fail(t, "keys service server error", err) + } + }() + dialer := func(context.Context, string) (net.Conn, error) { + return lis.Dial() + } + + t.Cleanup(func() { + srv.GracefulStop() + }) + + conn, err := grpc.NewClient( + "passthrough:///bufconn", + grpc.WithContextDialer(dialer), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + + t.Cleanup(func() { + conn.Close() + }) + + return keys.NewServiceClient(conn) +} + +func createDatabase(t *testing.T) *sql.DB { + t.Helper() + ctx := t.Context() + + db, err := sql.Open("postgres", pgConnStr) + if err != nil { + assert.FailNowf(t, "failed to connect to PostgreSQL", "error: %v", err) + } + + dbName := "test_" + strings.ReplaceAll(uuid.NewString(), "-", "") + _, err = db.ExecContext(ctx, "CREATE DATABASE "+dbName) + if err != nil { + db.Close() + assert.FailNowf(t, "failed to create test database", "error: %v", err) + } + db.Close() + + pgConStr := strings.Replace(pgConnStr, "/postgres?", "/"+dbName+"?", 1) + sqlDB, err := sql.Open("postgres", pgConStr) + if err != nil { + assert.FailNowf(t, "failed to connect to test database", "error: %v", err) + } + + t.Cleanup(func() { + sqlDB.Close() + + db, err := sql.Open("postgres", pgConnStr) + if err == nil { + _, _ = db.ExecContext(context.Background(), "DROP DATABASE "+dbName) + db.Close() + } + }) + return sqlDB +} + +func assertErrorDetails(t *testing.T, expCode proto.Code, actErr error) { + t.Helper() + + st := status.Convert(actErr) + dts := st.Details() + require.Len(t, dts, 1, "expected 1 error detail") + + dt, ok := dts[0].(*proto.ErrorDetails) + require.True(t, ok, "expected error details of type proto.ErrorDetails") + assert.Equal(t, expCode, dt.GetCode()) +} + +func createTenant(t *testing.T, db *sql.DB) model.Tenant { + t.Helper() + ctx := t.Context() + tenantStore := storesql.NewTenantStore(db) + tenant := model.NewTenant("test-tenant-"+uuid.NewString(), nil) + result, err := tenantStore.CreateTenant(ctx, store.CreateTenantQuery{Tenant: tenant}) + require.NoError(t, err) + return result.Tenant +} diff --git a/pkg/model/key.go b/pkg/model/key.go new file mode 100644 index 0000000..6710e07 --- /dev/null +++ b/pkg/model/key.go @@ -0,0 +1,44 @@ +package model + +import ( + "github.com/google/uuid" + + "github.com/openkcm/krypton/internal/clock" +) + +type KeyState string + +const ( + KeyStatePending KeyState = "pending" + KeyStateAnnounced KeyState = "announced" + KeyStateFailed KeyState = "failed" +) + +type Key struct { + ID string `json:"id"` + Name string `json:"name"` + TenantID string `json:"tenant_id"` + Kind string `json:"kind"` + ParentID *string `json:"parent_id"` + ManagedBy string `json:"managed_by"` + Labels Labels `json:"labels"` + State KeyState `json:"state"` + CreatedAt clock.UnixNano `json:"created_at"` + UpdatedAt clock.UnixNano `json:"updated_at"` +} + +func NewKey(tenantID, name string, kind string, parentID *string, managedBy string, labels Labels) Key { + now := clock.Now() + return Key{ + ID: uuid.NewString(), + Name: name, + TenantID: tenantID, + Kind: kind, + ParentID: parentID, + ManagedBy: managedBy, + Labels: labels, + State: KeyStatePending, + CreatedAt: now, + UpdatedAt: now, + } +} diff --git a/pkg/store/key.go b/pkg/store/key.go new file mode 100644 index 0000000..ac49405 --- /dev/null +++ b/pkg/store/key.go @@ -0,0 +1,15 @@ +package store + +import ( + "context" + "errors" + + "github.com/openkcm/krypton/pkg/model" +) + +var ErrKeyNotFound = errors.New("key not found") + +type Key interface { + CreateKey(ctx context.Context, key model.Key) error + GetKeyByID(ctx context.Context, id, tenantID string) (*model.Key, error) +} diff --git a/pkg/store/sql/agent.go b/pkg/store/sql/agent.go index 3b2fa57..ab3d1ea 100644 --- a/pkg/store/sql/agent.go +++ b/pkg/store/sql/agent.go @@ -17,29 +17,8 @@ type AgentStore struct { var _ store.Agent = &AgentStore{} -func NewAgentStore(ctx context.Context, db *sql.DB) (*AgentStore, error) { - reg := &AgentStore{ - db: db, - } - - stmt := ` - CREATE TABLE IF NOT EXISTS agent_registrations ( - name TEXT NOT NULL, - instance_id UUID, - status TEXT NOT NULL, - last_heartbeat BIGINT NOT NULL, - created_at BIGINT NOT NULL, - updated_at BIGINT NOT NULL, - PRIMARY KEY (name, instance_id) - ); - ` - - _, err := reg.db.ExecContext(ctx, stmt) - if err != nil { - return nil, err - } - - return reg, nil +func NewAgentStore(db *sql.DB) *AgentStore { + return &AgentStore{db: db} } func (s *AgentStore) Register(ctx context.Context, q store.RegisterAgentQuery) (store.RegisterAgentResult, error) { diff --git a/pkg/store/sql/agent_test.go b/pkg/store/sql/agent_test.go index c939016..90884cb 100644 --- a/pkg/store/sql/agent_test.go +++ b/pkg/store/sql/agent_test.go @@ -44,8 +44,8 @@ func TestRegister(t *testing.T) { db.Close() }) - subj, err := storesql.NewAgentStore(ctx, db) - require.NoError(t, err) + require.NoError(t, storesql.Migrate(ctx, db)) + subj := storesql.NewAgentStore(db) t.Run("should insert new agent registration", func(t *testing.T) { // given @@ -119,8 +119,7 @@ func TestGet(t *testing.T) { t.Run("should get existing agent registration", func(t *testing.T) { // given - subj, err := storesql.NewAgentStore(ctx, db) - require.NoError(t, err) + subj := storesql.NewAgentStore(db) registration := core.AgentRegistration{ Name: agentName(), @@ -152,8 +151,7 @@ func TestGet(t *testing.T) { t.Run("should return error if agent registration not found", func(t *testing.T) { // given - subj, err := storesql.NewAgentStore(ctx, db) - require.NoError(t, err) + subj := storesql.NewAgentStore(db) // when getResult, err := subj.Get(ctx, store.GetAgentQuery{ @@ -178,8 +176,8 @@ func TestUpdateStatus(t *testing.T) { db.Close() }) - subj, err := storesql.NewAgentStore(ctx, db) - require.NoError(t, err) + require.NoError(t, storesql.Migrate(ctx, db)) + subj := storesql.NewAgentStore(db) t.Run("should return error if the query does not have required fields", func(t *testing.T) { // given @@ -713,8 +711,8 @@ func TestDelete(t *testing.T) { t.Cleanup(func() { db.Close() }) - subj, err := storesql.NewAgentStore(ctx, db) - require.NoError(t, err) + require.NoError(t, storesql.Migrate(ctx, db)) + subj := storesql.NewAgentStore(db) t.Run("should return error if the query does not have required fields", func(t *testing.T) { // given @@ -911,8 +909,8 @@ func TestList(t *testing.T) { t.Cleanup(func() { db.Close() }) - subj, err := storesql.NewAgentStore(ctx, db) - require.NoError(t, err) + require.NoError(t, storesql.Migrate(ctx, db)) + subj := storesql.NewAgentStore(db) name1 := uuid.New().String() name2 := uuid.New().String() diff --git a/pkg/store/sql/key.go b/pkg/store/sql/key.go new file mode 100644 index 0000000..f2d168e --- /dev/null +++ b/pkg/store/sql/key.go @@ -0,0 +1,93 @@ +package sql + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + + "github.com/openkcm/krypton/pkg/model" + "github.com/openkcm/krypton/pkg/store" +) + +type KeyStore struct { + db *sql.DB +} + +var _ store.Key = &KeyStore{} + +func NewKeyStore(db *sql.DB) *KeyStore { + return &KeyStore{db: db} +} + +func (ks *KeyStore) CreateKey(ctx context.Context, key model.Key) error { + stmt := ` + INSERT INTO keys (id, tenant_id, kind, name, parent_id, managed_by, labels, state, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + ` + + labelsJSON, err := json.Marshal(key.Labels) + if err != nil { + return err + } + + _, err = ks.db.ExecContext(ctx, stmt, + key.ID, + key.TenantID, + key.Kind, + key.Name, + key.ParentID, + key.ManagedBy, + labelsJSON, + key.State, + key.CreatedAt, + key.UpdatedAt, + ) + return err +} + +func (ks *KeyStore) GetKeyByID(ctx context.Context, id, tenantID string) (*model.Key, error) { + stmt := ` + SELECT id, tenant_id, kind, name, parent_id, managed_by, labels, state, created_at, updated_at + FROM keys + WHERE id = $1 AND tenant_id = $2 + ` + row := ks.db.QueryRowContext(ctx, stmt, id, tenantID) + + return scanKey(row) +} + +func scanKey(row interface{ Scan(...any) error }) (*model.Key, error) { + var key model.Key + var kind string + var labelsData []byte + + err := row.Scan( + &key.ID, + &key.TenantID, + &kind, + &key.Name, + &key.ParentID, + &key.ManagedBy, + &labelsData, + &key.State, + &key.CreatedAt, + &key.UpdatedAt, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, store.ErrKeyNotFound + } + return nil, err + } + + key.Kind = kind + + if len(labelsData) > 0 { + if err := json.Unmarshal(labelsData, &key.Labels); err != nil { + return nil, err + } + } + + return &key, nil +} diff --git a/pkg/store/sql/key_test.go b/pkg/store/sql/key_test.go new file mode 100644 index 0000000..2b1b7e2 --- /dev/null +++ b/pkg/store/sql/key_test.go @@ -0,0 +1,158 @@ +package sql_test + +import ( + "database/sql" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + _ "github.com/lib/pq" + + "github.com/openkcm/krypton/pkg/model" + "github.com/openkcm/krypton/pkg/store" + storesql "github.com/openkcm/krypton/pkg/store/sql" +) + +func TestCreateKey(t *testing.T) { + ctx := t.Context() + db, err := sql.Open("postgres", pgConnStr) + require.NoError(t, err) + t.Cleanup(func() { db.Close() }) + + tenantStore := storesql.NewTenantStore(db) + + require.NoError(t, storesql.Migrate(ctx, db)) + keyStore := storesql.NewKeyStore(db) + + tenant := createTenant(t, tenantStore) + + t.Run("should create key without parent", func(t *testing.T) { + key := model.NewKey(tenant.ID, "root-key", "K0", nil, "root", model.Labels{"env": "prod"}) + + err := keyStore.CreateKey(ctx, key) + require.NoError(t, err) + + got, err := keyStore.GetKeyByID(ctx, key.ID, key.TenantID) + assert.NoError(t, err) + assert.Equal(t, key.ID, got.ID) + assert.Equal(t, key.Name, got.Name) + assert.Equal(t, key.TenantID, got.TenantID) + assert.Equal(t, key.Kind, got.Kind) + assert.Nil(t, got.ParentID) + assert.Equal(t, "root", got.ManagedBy) + assert.Equal(t, model.KeyStatePending, got.State) + assert.Equal(t, "prod", got.Labels["env"]) + assert.NotZero(t, got.CreatedAt) + assert.NotZero(t, got.UpdatedAt) + }) + + t.Run("should create key with parent", func(t *testing.T) { + parent := model.NewKey(tenant.ID, "parent-key", "K0", nil, "root", nil) + require.NoError(t, keyStore.CreateKey(ctx, parent)) + + key := model.NewKey(tenant.ID, "child-key", "K1", &parent.ID, "root", model.Labels{"team": "security"}) + + err := keyStore.CreateKey(ctx, key) + require.NoError(t, err) + + got, err := keyStore.GetKeyByID(ctx, key.ID, key.TenantID) + assert.NoError(t, err) + assert.Equal(t, key.ID, got.ID) + require.NotNil(t, got.ParentID) + assert.Equal(t, parent.ID, *got.ParentID) + assert.Equal(t, "security", got.Labels["team"]) + }) + + t.Run("should create key with nil labels", func(t *testing.T) { + key := model.NewKey(tenant.ID, "no-labels-key", "K0", nil, "root", nil) + + err := keyStore.CreateKey(ctx, key) + assert.NoError(t, err) + }) + + t.Run("should fail with invalid parent reference", func(t *testing.T) { + badParent := uuid.NewString() + key := model.NewKey(tenant.ID, "orphan-key", "K1", &badParent, "root", nil) + + err := keyStore.CreateKey(ctx, key) + assert.Error(t, err) + }) + + t.Run("should fail with invalid tenant reference", func(t *testing.T) { + key := model.NewKey(uuid.NewString(), "bad-tenant-key", "K0", nil, "root", nil) + + err := keyStore.CreateKey(ctx, key) + assert.Error(t, err) + }) +} + +func TestGetKey(t *testing.T) { + ctx := t.Context() + db, err := sql.Open("postgres", pgConnStr) + require.NoError(t, err) + t.Cleanup(func() { db.Close() }) + + tenantStore := storesql.NewTenantStore(db) + + require.NoError(t, storesql.Migrate(ctx, db)) + keyStore := storesql.NewKeyStore(db) + + tenant := createTenant(t, tenantStore) + + t.Run("should get existing key", func(t *testing.T) { + key := model.NewKey(tenant.ID, "find-me", "K0", nil, "root", model.Labels{"env": "staging"}) + require.NoError(t, keyStore.CreateKey(ctx, key)) + + got, err := keyStore.GetKeyByID(ctx, key.ID, tenant.ID) + + assert.NoError(t, err) + assert.Equal(t, key.ID, got.ID) + assert.Equal(t, key.Name, got.Name) + assert.Equal(t, key.TenantID, got.TenantID) + assert.Equal(t, key.Kind, got.Kind) + assert.Nil(t, got.ParentID) + assert.Equal(t, "root", got.ManagedBy) + assert.Equal(t, model.KeyStatePending, got.State) + assert.Equal(t, "staging", got.Labels["env"]) + assert.Equal(t, key.CreatedAt, got.CreatedAt) + assert.Equal(t, key.UpdatedAt, got.UpdatedAt) + }) + + t.Run("should get key with parent", func(t *testing.T) { + parent := model.NewKey(tenant.ID, "parent", "K0", nil, "root", nil) + require.NoError(t, keyStore.CreateKey(ctx, parent)) + + child := model.NewKey(tenant.ID, "child", "K1", &parent.ID, "agent-aws", nil) + require.NoError(t, keyStore.CreateKey(ctx, child)) + + got, err := keyStore.GetKeyByID(ctx, child.ID, tenant.ID) + + assert.NoError(t, err) + require.NotNil(t, got.ParentID) + assert.Equal(t, parent.ID, *got.ParentID) + assert.Equal(t, "agent-aws", got.ManagedBy) + }) + + t.Run("should return not found for nonexistent key", func(t *testing.T) { + _, err := keyStore.GetKeyByID(ctx, uuid.NewString(), tenant.ID) + assert.ErrorIs(t, err, store.ErrKeyNotFound) + }) + + t.Run("should return not found for wrong tenant", func(t *testing.T) { + key := model.NewKey(tenant.ID, "wrong-tenant-key", "K0", nil, "root", nil) + require.NoError(t, keyStore.CreateKey(ctx, key)) + + _, err := keyStore.GetKeyByID(ctx, key.ID, uuid.NewString()) + assert.ErrorIs(t, err, store.ErrKeyNotFound) + }) +} + +func createTenant(t *testing.T, s *storesql.TenantStore) model.Tenant { + t.Helper() + tenant := model.NewTenant("test-tenant-"+uuid.NewString(), nil) + result, err := s.CreateTenant(t.Context(), store.CreateTenantQuery{Tenant: tenant}) + require.NoError(t, err) + return result.Tenant +} diff --git a/pkg/store/sql/migrate.go b/pkg/store/sql/migrate.go new file mode 100644 index 0000000..d1646cb --- /dev/null +++ b/pkg/store/sql/migrate.go @@ -0,0 +1,62 @@ +package sql + +import ( + "context" + "database/sql" +) + +const createTenantsTable = ` +CREATE TABLE IF NOT EXISTS tenants ( + id UUID PRIMARY KEY, + name TEXT NOT NULL, + labels JSONB, + created_at BIGINT NOT NULL, + updated_at BIGINT NOT NULL +); +` + +const createAgentRegistrationsTable = ` +CREATE TABLE IF NOT EXISTS agent_registrations ( + name TEXT NOT NULL, + instance_id UUID, + status TEXT NOT NULL, + last_heartbeat BIGINT NOT NULL, + created_at BIGINT NOT NULL, + updated_at BIGINT NOT NULL, + PRIMARY KEY (name, instance_id) +); +` + +const createKeysTable = ` +CREATE TABLE IF NOT EXISTS keys ( + id UUID PRIMARY KEY, + tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, + kind TEXT NOT NULL, + name TEXT NOT NULL, + parent_id UUID NULL, + managed_by TEXT NOT NULL, + labels JSONB, + state TEXT NOT NULL, + created_at BIGINT NOT NULL, + updated_at BIGINT NOT NULL, + + UNIQUE (tenant_id, id), + FOREIGN KEY (tenant_id, parent_id) REFERENCES keys(tenant_id, id) +); +` + +func Migrate(ctx context.Context, db *sql.DB) error { + stmts := []string{ + createTenantsTable, + createAgentRegistrationsTable, + createKeysTable, + } + + for _, stmt := range stmts { + if _, err := db.ExecContext(ctx, stmt); err != nil { + return err + } + } + + return nil +} diff --git a/pkg/store/sql/tenant.go b/pkg/store/sql/tenant.go index 44b8cc5..fa50f99 100644 --- a/pkg/store/sql/tenant.go +++ b/pkg/store/sql/tenant.go @@ -15,27 +15,8 @@ type TenantStore struct { var _ store.Tenant = &TenantStore{} -func NewTenantStore(ctx context.Context, db *sql.DB) (*TenantStore, error) { - ps := &TenantStore{ - db: db, - } - - stmt := ` - CREATE TABLE IF NOT EXISTS tenants ( - id UUID PRIMARY KEY, - name TEXT NOT NULL, - labels JSONB, - created_at BIGINT NOT NULL, - updated_at BIGINT NOT NULL - ); - ` - - _, err := ps.db.ExecContext(ctx, stmt) - if err != nil { - return nil, err - } - - return ps, nil +func NewTenantStore(db *sql.DB) *TenantStore { + return &TenantStore{db: db} } func (ps *TenantStore) CreateTenant(ctx context.Context, query store.CreateTenantQuery) (store.CreateTenantResult, error) {