Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmd/registry/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func main() {
initOTLP(ctx, cfg)

// Status server initialization
go startStatusServer(cfg, ctx)
go startStatusServer(ctx, cfg)

db := initDB(ctx, cfg)

Expand Down Expand Up @@ -185,7 +185,7 @@ func loadConfig() *config.Config {
return cfg
}

func startStatusServer(cfg *config.Config, ctx context.Context) {
func startStatusServer(ctx context.Context, cfg *config.Config) {
liveness := status.WithLiveness(
health.NewHandler(
health.NewChecker(health.WithDisabledAutostart()),
Expand Down
2 changes: 1 addition & 1 deletion integration/operatortest/operatortest.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func New(ctx context.Context) (*orbital.Operator, error) {
}

client, err := amqp.NewClient(ctx, codec.Proto{}, amqp.ConnectionInfo{
URL: target.Connection.AMQP.Url,
URL: target.Connection.AMQP.URL,
Target: target.Connection.AMQP.Source,
Source: target.Connection.AMQP.Target,
}, option)
Expand Down
10 changes: 5 additions & 5 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ func (o *Orbital) Validate() error {
}

func (o *Orbital) GetWorker(workerName string) *Worker {
for _, worker := range o.Workers {
if worker.Name == workerName {
return &worker
for i := range o.Workers {
if o.Workers[i].Name == workerName {
return &o.Workers[i]
}
}

Expand Down Expand Up @@ -237,13 +237,13 @@ func (c *Connection) validate() error {
}

type AMQP struct {
Url string `yaml:"url" json:"url"`
URL string `yaml:"url" json:"url"`
Source string `yaml:"source" json:"source"`
Target string `yaml:"target" json:"target"`
}

func (a *AMQP) validate() error {
if a.Url == "" {
if a.URL == "" {
return ErrEmptyURL
}

Expand Down
6 changes: 3 additions & 3 deletions internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func TestValidateTarget(t *testing.T) {
Connection: &config.Connection{
Type: config.ConnectionTypeAMQP,
AMQP: &config.AMQP{
Url: "amqp://localhost:5672",
URL: "amqp://localhost:5672",
Source: "source",
Target: "target",
},
Expand Down Expand Up @@ -90,7 +90,7 @@ func TestValidateTarget(t *testing.T) {
name: "missing AMQP URL",
patchTarget: func(t config.Target) config.Target {
t = deepCopyTarget(t)
t.Connection.AMQP.Url = ""
t.Connection.AMQP.URL = ""
return t
},
expErr: config.ErrEmptyURL,
Expand Down Expand Up @@ -376,7 +376,7 @@ func deepCopyTarget(t config.Target) config.Target {
Connection: &config.Connection{
Type: t.Connection.Type,
AMQP: &config.AMQP{
Url: t.Connection.AMQP.Url,
URL: t.Connection.AMQP.URL,
Source: t.Connection.AMQP.Source,
Target: t.Connection.AMQP.Target,
},
Expand Down
4 changes: 2 additions & 2 deletions internal/interceptor/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ func (m *Meters) StreamInterceptor(srv any, stream grpc.ServerStream, info *grpc
attribute.String("status", statusCode),
)...,
)
m.requestDurations.Record(context.Background(), elapsedTime, attrs)
m.requestCounts.Add(context.Background(), 1, attrs)
m.requestDurations.Record(stream.Context(), elapsedTime, attrs)
m.requestCounts.Add(stream.Context(), 1, attrs)

return err
}
15 changes: 14 additions & 1 deletion internal/interceptor/metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,26 @@ import (
"go.opentelemetry.io/otel/sdk/metric/metricdata"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

sdkmetric "go.opentelemetry.io/otel/sdk/metric"

"github.com/openkcm/registry/internal/interceptor"
)

// mockServerStream is a minimal grpc.ServerStream for testing.
type mockServerStream struct {
ctxFunc func() context.Context
}

func (m *mockServerStream) Context() context.Context { return m.ctxFunc() }
func (m *mockServerStream) SetHeader(metadata.MD) error { return nil }
func (m *mockServerStream) SendHeader(metadata.MD) error { return nil }
func (m *mockServerStream) SetTrailer(metadata.MD) {}
func (m *mockServerStream) SendMsg(any) error { return nil }
func (m *mockServerStream) RecvMsg(any) error { return nil }

func TestMetricsUnaryInterceptor(t *testing.T) {
ctx := t.Context()
app := &commoncfg.Application{}
Expand Down Expand Up @@ -103,7 +116,7 @@ func TestMetricsStreamInterceptor(t *testing.T) {

err = met.StreamInterceptor(
nil,
nil,
&mockServerStream{ctxFunc: t.Context},
&grpc.StreamServerInfo{FullMethod: "/test.method"},
handler,
)
Expand Down
28 changes: 8 additions & 20 deletions internal/interceptor/recover_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,11 @@ func TestServerPanic(t *testing.T) {
)
// registering server
servicetest.RegisterTestServiceServer(srv, serviceTest)
t.Cleanup(srv.Stop)

go func(t *testing.T, srv *grpc.Server, ls *bufconn.Listener) {
t.Helper()

defer srv.Stop()

err := srv.Serve(ls)
if err != nil {
assert.NoError(t, err, "server could not be started")
}
}(t, srv, ls)
go func() {
_ = srv.Serve(ls)
}()

// creating client connection
conn, err := grpc.NewClient("passthrough://bufnet",
Expand Down Expand Up @@ -126,17 +120,11 @@ func TestServerPanic(t *testing.T) {

// registering server
servicetest.RegisterTestServiceServer(srv, serviceTest)
t.Cleanup(srv.Stop)

go func(t *testing.T, srv *grpc.Server, ls *bufconn.Listener) {
t.Helper()

defer srv.Stop()

err := srv.Serve(ls)
if err != nil {
assert.NoError(t, err, "server could not be started")
}
}(t, srv, ls)
go func() {
_ = srv.Serve(ls)
}()

// creating client connection
conn, err := grpc.NewClient("passthrough://bufnet",
Expand Down
18 changes: 9 additions & 9 deletions internal/model/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,22 +96,22 @@ type AuthStatusConstraint struct{}

var validAuthStatuses map[string]struct{}

func init() {
validAuthStatuses = make(map[string]struct{}, len(pb.AuthStatus_name)-1)
for _, v := range pb.AuthStatus_name {
if v != pb.AuthStatus_AUTH_STATUS_UNSPECIFIED.String() {
validAuthStatuses[v] = struct{}{}
}
}
}

// Validate checks if the provided value is a valid Auth status.
// Auth status must be one of the defined enum values in pb.AuthStatus.
func (c AuthStatusConstraint) Validate(value any) error {
statusValue, ok := value.(string)
if !ok {
return fmt.Errorf("%w: %T", validation.ErrWrongType, value)
}
// lazy initialization of validAuthStatuses
if validAuthStatuses == nil {
validAuthStatuses = make(map[string]struct{}, len(pb.AuthStatus_name)-1)
for _, v := range pb.AuthStatus_name {
if v != pb.AuthStatus_AUTH_STATUS_UNSPECIFIED.String() {
validAuthStatuses[v] = struct{}{}
}
}
}

if _, ok := validAuthStatuses[statusValue]; !ok {
return validation.ErrValueNotAllowed
Expand Down
19 changes: 9 additions & 10 deletions internal/model/regional_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,23 +129,22 @@ type RegionalSystemStatusConstraint struct{}

var validSystemStatuses map[string]struct{}

func init() {
validSystemStatuses = make(map[string]struct{}, len(typespb.Status_name)-1)
for _, v := range typespb.Status_name {
if v != typespb.Status_STATUS_UNSPECIFIED.String() {
validSystemStatuses[v] = struct{}{}
}
}
}

// Validate checks if the provided system status is valid.
func (c RegionalSystemStatusConstraint) Validate(value any) error {
status, ok := value.(string)
if !ok {
return validation.ErrWrongType
}

// lazy initialization of valid system statuses
if validSystemStatuses == nil {
validSystemStatuses = make(map[string]struct{})
for _, v := range typespb.Status_name {
if v != typespb.Status_STATUS_UNSPECIFIED.String() {
validSystemStatuses[v] = struct{}{}
}
}
}

if _, exists := validSystemStatuses[status]; !exists {
return validation.ErrValueNotAllowed
}
Expand Down
17 changes: 9 additions & 8 deletions internal/model/tenant.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,22 @@ type TenantRoleConstraint struct{}

var validTenantRoles map[string]struct{}

func init() {
validTenantRoles = make(map[string]struct{}, len(tenantgrpc.Role_name)-1)
for _, v := range tenantgrpc.Role_name {
if v != tenantgrpc.Role_ROLE_UNSPECIFIED.String() {
validTenantRoles[v] = struct{}{}
}
}
}

// Validate checks if the provided value is a valid Tenant role.
// Tenant role must be one of the defined enum values in tenant proto Role.
func (t TenantRoleConstraint) Validate(value any) error {
roleValue, ok := value.(string)
if !ok {
return fmt.Errorf("%w: %T", validation.ErrWrongType, value)
}
if validTenantRoles == nil {
validTenantRoles = make(map[string]struct{}, len(tenantgrpc.Role_name)-1)
for _, v := range tenantgrpc.Role_name {
if v != tenantgrpc.Role_ROLE_UNSPECIFIED.String() {
validTenantRoles[v] = struct{}{}
}
}
}
if _, ok := validTenantRoles[roleValue]; !ok {
return validation.ErrValueNotAllowed
}
Expand Down
129 changes: 129 additions & 0 deletions internal/repository/sql/resource_repository_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package sql_test

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"

"github.com/openkcm/registry/internal/repository"
sqlrepo "github.com/openkcm/registry/internal/repository/sql"
)

// noopDialector is a minimal gorm.Dialector for unit testing without a real database.
type noopDialector struct{}

func (noopDialector) Name() string { return "noop" }
func (d noopDialector) Initialize(db *gorm.DB) error {
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{})
return nil
}
func (noopDialector) Migrator(*gorm.DB) gorm.Migrator { return nil }
func (noopDialector) DataTypeOf(*schema.Field) string { return "text" }
func (noopDialector) DefaultValueOf(*schema.Field) clause.Expression { return clause.Expr{SQL: "NULL"} }
func (noopDialector) BindVarTo(w clause.Writer, _ *gorm.Statement, _ any) { _ = w.WriteByte('?') }
func (noopDialector) QuoteTo(w clause.Writer, s string) { _, _ = w.WriteString(s) }
func (noopDialector) Explain(s string, _ ...any) string { return s }

type testRecord struct{ ID string }

func (testRecord) TableName() string { return "records" }

func newTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(noopDialector{}, &gorm.Config{})
require.NoError(t, err)
return db
}

func TestHandleQueryField(t *testing.T) {
t.Run("slice generates IN clause", func(t *testing.T) {
// given
db := newTestDB(t)

// when
result := db.ToSQL(func(tx *gorm.DB) *gorm.DB {
tx, err := sqlrepo.HandleQueryField(tx, "status", []string{"active", "pending"})
require.NoError(t, err)
return tx.Find(&[]testRecord{})
})

// then
assert.Contains(t, result, "status IN")
})

t.Run("scalar generates equality clause", func(t *testing.T) {
// given
db := newTestDB(t)

// when
result := db.ToSQL(func(tx *gorm.DB) *gorm.DB {
tx, err := sqlrepo.HandleQueryField(tx, "id", "abc-123")
require.NoError(t, err)
return tx.Find(&[]testRecord{})
})

// then
assert.Contains(t, result, "id = ")
})

t.Run("NotEmpty generates IS NOT NULL clause", func(t *testing.T) {
// given
db := newTestDB(t)

// when
result := db.ToSQL(func(tx *gorm.DB) *gorm.DB {
tx, err := sqlrepo.HandleQueryField(tx, "name", repository.NotEmpty)
require.NoError(t, err)
return tx.Find(&[]testRecord{})
})

// then
assert.Contains(t, result, "name IS NOT NULL")
})

t.Run("Empty generates IS NULL clause", func(t *testing.T) {
// given
db := newTestDB(t)

// when
result := db.ToSQL(func(tx *gorm.DB) *gorm.DB {
tx, err := sqlrepo.HandleQueryField(tx, "name", repository.Empty)
require.NoError(t, err)
return tx.Find(&[]testRecord{})
})

// then
assert.Contains(t, result, "name IS NULL")
})

t.Run("map generates JSONB operator clause", func(t *testing.T) {
// given
db := newTestDB(t)

// when
result := db.ToSQL(func(tx *gorm.DB) *gorm.DB {
tx, err := sqlrepo.HandleQueryField(tx, "labels", map[string]any{"env": "prod"})
require.NoError(t, err)
return tx.Find(&[]testRecord{})
})

// then
assert.Contains(t, result, "labels ->>")
})

t.Run("invalid map type returns error", func(t *testing.T) {
// given
db := newTestDB(t)

// when
_, err := sqlrepo.HandleQueryField(db, "labels", map[string]string{"key": "val"})

// then
assert.ErrorIs(t, err, sqlrepo.ErrUnknownTypeForJSONBField)
})
}
Loading
Loading