diff --git a/cmd/registry/main.go b/cmd/registry/main.go index 4dda03d..212c502 100644 --- a/cmd/registry/main.go +++ b/cmd/registry/main.go @@ -50,7 +50,7 @@ func main() { initOTLP(ctx, cfg) // Status server initialization - go startStatusServer(cfg, ctx) + go startStatusServer(ctx, cfg) db := initDB(ctx, cfg) @@ -185,7 +185,7 @@ func loadConfig() *config.Config { return cfg } -func startStatusServer(cfg *config.Config, ctx context.Context) { +func startStatusServer(ctx context.Context, cfg *config.Config) { liveness := status.WithLiveness( health.NewHandler( health.NewChecker(health.WithDisabledAutostart()), diff --git a/integration/operatortest/operatortest.go b/integration/operatortest/operatortest.go index d338c64..617b391 100644 --- a/integration/operatortest/operatortest.go +++ b/integration/operatortest/operatortest.go @@ -69,7 +69,7 @@ func New(ctx context.Context) (*orbital.Operator, error) { } client, err := amqp.NewClient(ctx, codec.Proto{}, amqp.ConnectionInfo{ - URL: target.Connection.AMQP.Url, + URL: target.Connection.AMQP.URL, Target: target.Connection.AMQP.Source, Source: target.Connection.AMQP.Target, }, option) diff --git a/internal/config/config.go b/internal/config/config.go index 23dab68..5d494af 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -151,9 +151,9 @@ func (o *Orbital) Validate() error { } func (o *Orbital) GetWorker(workerName string) *Worker { - for _, worker := range o.Workers { - if worker.Name == workerName { - return &worker + for i := range o.Workers { + if o.Workers[i].Name == workerName { + return &o.Workers[i] } } @@ -237,13 +237,13 @@ func (c *Connection) validate() error { } type AMQP struct { - Url string `yaml:"url" json:"url"` + URL string `yaml:"url" json:"url"` Source string `yaml:"source" json:"source"` Target string `yaml:"target" json:"target"` } func (a *AMQP) validate() error { - if a.Url == "" { + if a.URL == "" { return ErrEmptyURL } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 7fbb358..4b7673b 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -15,7 +15,7 @@ func TestValidateTarget(t *testing.T) { Connection: &config.Connection{ Type: config.ConnectionTypeAMQP, AMQP: &config.AMQP{ - Url: "amqp://localhost:5672", + URL: "amqp://localhost:5672", Source: "source", Target: "target", }, @@ -90,7 +90,7 @@ func TestValidateTarget(t *testing.T) { name: "missing AMQP URL", patchTarget: func(t config.Target) config.Target { t = deepCopyTarget(t) - t.Connection.AMQP.Url = "" + t.Connection.AMQP.URL = "" return t }, expErr: config.ErrEmptyURL, @@ -376,7 +376,7 @@ func deepCopyTarget(t config.Target) config.Target { Connection: &config.Connection{ Type: t.Connection.Type, AMQP: &config.AMQP{ - Url: t.Connection.AMQP.Url, + URL: t.Connection.AMQP.URL, Source: t.Connection.AMQP.Source, Target: t.Connection.AMQP.Target, }, diff --git a/internal/interceptor/metrics.go b/internal/interceptor/metrics.go index d43f39a..be4733a 100644 --- a/internal/interceptor/metrics.go +++ b/internal/interceptor/metrics.go @@ -89,8 +89,8 @@ func (m *Meters) StreamInterceptor(srv any, stream grpc.ServerStream, info *grpc attribute.String("status", statusCode), )..., ) - m.requestDurations.Record(context.Background(), elapsedTime, attrs) - m.requestCounts.Add(context.Background(), 1, attrs) + m.requestDurations.Record(stream.Context(), elapsedTime, attrs) + m.requestCounts.Add(stream.Context(), 1, attrs) return err } diff --git a/internal/interceptor/metrics_test.go b/internal/interceptor/metrics_test.go index 0f6b5a2..dd98209 100644 --- a/internal/interceptor/metrics_test.go +++ b/internal/interceptor/metrics_test.go @@ -11,6 +11,7 @@ import ( "go.opentelemetry.io/otel/sdk/metric/metricdata" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" sdkmetric "go.opentelemetry.io/otel/sdk/metric" @@ -18,6 +19,18 @@ import ( "github.com/openkcm/registry/internal/interceptor" ) +// mockServerStream is a minimal grpc.ServerStream for testing. +type mockServerStream struct { + ctxFunc func() context.Context +} + +func (m *mockServerStream) Context() context.Context { return m.ctxFunc() } +func (m *mockServerStream) SetHeader(metadata.MD) error { return nil } +func (m *mockServerStream) SendHeader(metadata.MD) error { return nil } +func (m *mockServerStream) SetTrailer(metadata.MD) {} +func (m *mockServerStream) SendMsg(any) error { return nil } +func (m *mockServerStream) RecvMsg(any) error { return nil } + func TestMetricsUnaryInterceptor(t *testing.T) { ctx := t.Context() app := &commoncfg.Application{} @@ -103,7 +116,7 @@ func TestMetricsStreamInterceptor(t *testing.T) { err = met.StreamInterceptor( nil, - nil, + &mockServerStream{ctxFunc: t.Context}, &grpc.StreamServerInfo{FullMethod: "/test.method"}, handler, ) diff --git a/internal/interceptor/recover_test.go b/internal/interceptor/recover_test.go index d0838cb..daba8a3 100644 --- a/internal/interceptor/recover_test.go +++ b/internal/interceptor/recover_test.go @@ -54,17 +54,11 @@ func TestServerPanic(t *testing.T) { ) // registering server servicetest.RegisterTestServiceServer(srv, serviceTest) + t.Cleanup(srv.Stop) - go func(t *testing.T, srv *grpc.Server, ls *bufconn.Listener) { - t.Helper() - - defer srv.Stop() - - err := srv.Serve(ls) - if err != nil { - assert.NoError(t, err, "server could not be started") - } - }(t, srv, ls) + go func() { + _ = srv.Serve(ls) + }() // creating client connection conn, err := grpc.NewClient("passthrough://bufnet", @@ -126,17 +120,11 @@ func TestServerPanic(t *testing.T) { // registering server servicetest.RegisterTestServiceServer(srv, serviceTest) + t.Cleanup(srv.Stop) - go func(t *testing.T, srv *grpc.Server, ls *bufconn.Listener) { - t.Helper() - - defer srv.Stop() - - err := srv.Serve(ls) - if err != nil { - assert.NoError(t, err, "server could not be started") - } - }(t, srv, ls) + go func() { + _ = srv.Serve(ls) + }() // creating client connection conn, err := grpc.NewClient("passthrough://bufnet", diff --git a/internal/model/auth.go b/internal/model/auth.go index 60316a3..72103ac 100644 --- a/internal/model/auth.go +++ b/internal/model/auth.go @@ -96,6 +96,15 @@ type AuthStatusConstraint struct{} var validAuthStatuses map[string]struct{} +func init() { + validAuthStatuses = make(map[string]struct{}, len(pb.AuthStatus_name)-1) + for _, v := range pb.AuthStatus_name { + if v != pb.AuthStatus_AUTH_STATUS_UNSPECIFIED.String() { + validAuthStatuses[v] = struct{}{} + } + } +} + // Validate checks if the provided value is a valid Auth status. // Auth status must be one of the defined enum values in pb.AuthStatus. func (c AuthStatusConstraint) Validate(value any) error { @@ -103,15 +112,6 @@ func (c AuthStatusConstraint) Validate(value any) error { if !ok { return fmt.Errorf("%w: %T", validation.ErrWrongType, value) } - // lazy initialization of validAuthStatuses - if validAuthStatuses == nil { - validAuthStatuses = make(map[string]struct{}, len(pb.AuthStatus_name)-1) - for _, v := range pb.AuthStatus_name { - if v != pb.AuthStatus_AUTH_STATUS_UNSPECIFIED.String() { - validAuthStatuses[v] = struct{}{} - } - } - } if _, ok := validAuthStatuses[statusValue]; !ok { return validation.ErrValueNotAllowed diff --git a/internal/model/regional_system.go b/internal/model/regional_system.go index 0d3a58e..ded3671 100644 --- a/internal/model/regional_system.go +++ b/internal/model/regional_system.go @@ -129,6 +129,15 @@ type RegionalSystemStatusConstraint struct{} var validSystemStatuses map[string]struct{} +func init() { + validSystemStatuses = make(map[string]struct{}, len(typespb.Status_name)-1) + for _, v := range typespb.Status_name { + if v != typespb.Status_STATUS_UNSPECIFIED.String() { + validSystemStatuses[v] = struct{}{} + } + } +} + // Validate checks if the provided system status is valid. func (c RegionalSystemStatusConstraint) Validate(value any) error { status, ok := value.(string) @@ -136,16 +145,6 @@ func (c RegionalSystemStatusConstraint) Validate(value any) error { return validation.ErrWrongType } - // lazy initialization of valid system statuses - if validSystemStatuses == nil { - validSystemStatuses = make(map[string]struct{}) - for _, v := range typespb.Status_name { - if v != typespb.Status_STATUS_UNSPECIFIED.String() { - validSystemStatuses[v] = struct{}{} - } - } - } - if _, exists := validSystemStatuses[status]; !exists { return validation.ErrValueNotAllowed } diff --git a/internal/model/tenant.go b/internal/model/tenant.go index 4444554..fc5d5fa 100644 --- a/internal/model/tenant.go +++ b/internal/model/tenant.go @@ -95,6 +95,15 @@ type TenantRoleConstraint struct{} var validTenantRoles map[string]struct{} +func init() { + validTenantRoles = make(map[string]struct{}, len(tenantgrpc.Role_name)-1) + for _, v := range tenantgrpc.Role_name { + if v != tenantgrpc.Role_ROLE_UNSPECIFIED.String() { + validTenantRoles[v] = struct{}{} + } + } +} + // Validate checks if the provided value is a valid Tenant role. // Tenant role must be one of the defined enum values in tenant proto Role. func (t TenantRoleConstraint) Validate(value any) error { @@ -102,14 +111,6 @@ func (t TenantRoleConstraint) Validate(value any) error { if !ok { return fmt.Errorf("%w: %T", validation.ErrWrongType, value) } - if validTenantRoles == nil { - validTenantRoles = make(map[string]struct{}, len(tenantgrpc.Role_name)-1) - for _, v := range tenantgrpc.Role_name { - if v != tenantgrpc.Role_ROLE_UNSPECIFIED.String() { - validTenantRoles[v] = struct{}{} - } - } - } if _, ok := validTenantRoles[roleValue]; !ok { return validation.ErrValueNotAllowed } diff --git a/internal/repository/sql/resource_repository_test.go b/internal/repository/sql/resource_repository_test.go new file mode 100644 index 0000000..7400433 --- /dev/null +++ b/internal/repository/sql/resource_repository_test.go @@ -0,0 +1,129 @@ +package sql_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "gorm.io/gorm/callbacks" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + + "github.com/openkcm/registry/internal/repository" + sqlrepo "github.com/openkcm/registry/internal/repository/sql" +) + +// noopDialector is a minimal gorm.Dialector for unit testing without a real database. +type noopDialector struct{} + +func (noopDialector) Name() string { return "noop" } +func (d noopDialector) Initialize(db *gorm.DB) error { + callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{}) + return nil +} +func (noopDialector) Migrator(*gorm.DB) gorm.Migrator { return nil } +func (noopDialector) DataTypeOf(*schema.Field) string { return "text" } +func (noopDialector) DefaultValueOf(*schema.Field) clause.Expression { return clause.Expr{SQL: "NULL"} } +func (noopDialector) BindVarTo(w clause.Writer, _ *gorm.Statement, _ any) { _ = w.WriteByte('?') } +func (noopDialector) QuoteTo(w clause.Writer, s string) { _, _ = w.WriteString(s) } +func (noopDialector) Explain(s string, _ ...any) string { return s } + +type testRecord struct{ ID string } + +func (testRecord) TableName() string { return "records" } + +func newTestDB(t *testing.T) *gorm.DB { + t.Helper() + db, err := gorm.Open(noopDialector{}, &gorm.Config{}) + require.NoError(t, err) + return db +} + +func TestHandleQueryField(t *testing.T) { + t.Run("slice generates IN clause", func(t *testing.T) { + // given + db := newTestDB(t) + + // when + result := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx, err := sqlrepo.HandleQueryField(tx, "status", []string{"active", "pending"}) + require.NoError(t, err) + return tx.Find(&[]testRecord{}) + }) + + // then + assert.Contains(t, result, "status IN") + }) + + t.Run("scalar generates equality clause", func(t *testing.T) { + // given + db := newTestDB(t) + + // when + result := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx, err := sqlrepo.HandleQueryField(tx, "id", "abc-123") + require.NoError(t, err) + return tx.Find(&[]testRecord{}) + }) + + // then + assert.Contains(t, result, "id = ") + }) + + t.Run("NotEmpty generates IS NOT NULL clause", func(t *testing.T) { + // given + db := newTestDB(t) + + // when + result := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx, err := sqlrepo.HandleQueryField(tx, "name", repository.NotEmpty) + require.NoError(t, err) + return tx.Find(&[]testRecord{}) + }) + + // then + assert.Contains(t, result, "name IS NOT NULL") + }) + + t.Run("Empty generates IS NULL clause", func(t *testing.T) { + // given + db := newTestDB(t) + + // when + result := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx, err := sqlrepo.HandleQueryField(tx, "name", repository.Empty) + require.NoError(t, err) + return tx.Find(&[]testRecord{}) + }) + + // then + assert.Contains(t, result, "name IS NULL") + }) + + t.Run("map generates JSONB operator clause", func(t *testing.T) { + // given + db := newTestDB(t) + + // when + result := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + tx, err := sqlrepo.HandleQueryField(tx, "labels", map[string]any{"env": "prod"}) + require.NoError(t, err) + return tx.Find(&[]testRecord{}) + }) + + // then + assert.Contains(t, result, "labels ->>") + }) + + t.Run("invalid map type returns error", func(t *testing.T) { + // given + db := newTestDB(t) + + // when + _, err := sqlrepo.HandleQueryField(db, "labels", map[string]string{"key": "val"}) + + // then + assert.ErrorIs(t, err, sqlrepo.ErrUnknownTypeForJSONBField) + }) +} diff --git a/internal/service/orbital.go b/internal/service/orbital.go index 56ed355..47eff35 100644 --- a/internal/service/orbital.go +++ b/internal/service/orbital.go @@ -149,7 +149,7 @@ func createAMQPClient(ctx context.Context, cfgTarget config.Target) (*amqp.Clien } connInfo := amqp.ConnectionInfo{ - URL: cfgTarget.Connection.AMQP.Url, + URL: cfgTarget.Connection.AMQP.URL, Target: cfgTarget.Connection.AMQP.Target, Source: cfgTarget.Connection.AMQP.Source, }