Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions client/amqp/amqp.go
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this directory evolved quite a bit, we could think of renaming it from client to transport or something else (probable better in a separate PR if at all)

Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ type (
)

var (
_ orbital.Initiator = &Client{}
_ orbital.Responder = &Client{}
_ orbital.Initiator = &Client{}
_ orbital.AsyncResponder = &Client{}
)

// WithBasicAuth tells the client to use SASL PLAIN with user/password.
Expand Down
2 changes: 1 addition & 1 deletion client/grpc/grpc.go → client/rpc/rpc.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package grpc
package rpc

import (
"context"
Expand Down
44 changes: 22 additions & 22 deletions client/grpc/grpc_test.go → client/rpc/rpc_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package grpc_test
package rpc_test

import (
"context"
Expand All @@ -17,7 +17,7 @@ import (
"google.golang.org/grpc/test/bufconn"

"github.com/openkcm/orbital"
grpcclient "github.com/openkcm/orbital/client/grpc"
"github.com/openkcm/orbital/client/rpc"
"github.com/openkcm/orbital/codec"
orbitalv1 "github.com/openkcm/orbital/proto/orbital/v1"
)
Expand Down Expand Up @@ -71,42 +71,42 @@ func TestNewClient(t *testing.T) {
tests := []struct {
name string
conn *grpc.ClientConn
opts []grpcclient.ClientOption
opts []rpc.ClientOption
wantErr error
}{
{
name: "NilConn",
conn: nil,
wantErr: grpcclient.ErrNilConn,
wantErr: rpc.ErrNilConn,
},
{
name: "InvalidBufferSize",
conn: conn,
opts: []grpcclient.ClientOption{grpcclient.WithBufferSize(-1)},
wantErr: grpcclient.ErrInvalidBufferSize,
opts: []rpc.ClientOption{rpc.WithBufferSize(-1)},
wantErr: rpc.ErrInvalidBufferSize,
},
{
name: "InvalidCallTimeout/zero",
conn: conn,
opts: []grpcclient.ClientOption{grpcclient.WithCallTimeout(0)},
wantErr: grpcclient.ErrInvalidCallTimeout,
opts: []rpc.ClientOption{rpc.WithCallTimeout(0)},
wantErr: rpc.ErrInvalidCallTimeout,
},
{
name: "InvalidCallTimeout/negative",
conn: conn,
opts: []grpcclient.ClientOption{grpcclient.WithCallTimeout(-1 * time.Second)},
wantErr: grpcclient.ErrInvalidCallTimeout,
opts: []rpc.ClientOption{rpc.WithCallTimeout(-1 * time.Second)},
wantErr: rpc.ErrInvalidCallTimeout,
},
{
name: "ValidWithZeroBuffer",
conn: conn,
opts: []grpcclient.ClientOption{grpcclient.WithBufferSize(0)},
opts: []rpc.ClientOption{rpc.WithBufferSize(0)},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := grpcclient.NewClient(tt.conn, tt.opts...)
c, err := rpc.NewClient(tt.conn, tt.opts...)
if tt.wantErr != nil {
assert.Nil(t, c)
assert.ErrorIs(t, err, tt.wantErr)
Expand Down Expand Up @@ -137,7 +137,7 @@ func TestSendTaskRequest(t *testing.T) {
}, nil
})

client, err := grpcclient.NewClient(conn, grpcclient.WithCallTimeout(5*time.Second))
client, err := rpc.NewClient(conn, rpc.WithCallTimeout(5*time.Second))
require.NoError(t, err)
defer client.Close(t.Context())

Expand Down Expand Up @@ -174,7 +174,7 @@ func TestSendTaskRequest(t *testing.T) {
return nil, status.Error(codes.Internal, "something broke")
})

client, err := grpcclient.NewClient(conn, grpcclient.WithCallTimeout(5*time.Second))
client, err := rpc.NewClient(conn, rpc.WithCallTimeout(5*time.Second))
require.NoError(t, err)
defer client.Close(t.Context())

Expand All @@ -200,7 +200,7 @@ func TestSendTaskRequest(t *testing.T) {
}, nil
})

client, err := grpcclient.NewClient(conn, grpcclient.WithCallTimeout(5*time.Second))
client, err := rpc.NewClient(conn, rpc.WithCallTimeout(5*time.Second))
require.NoError(t, err)
defer client.Close(t.Context())

Expand All @@ -220,7 +220,7 @@ func TestSendTaskRequest(t *testing.T) {
t.Run("ContextCancelled", func(t *testing.T) {
conn := startServer(t, noopHandler)

client, err := grpcclient.NewClient(conn)
client, err := rpc.NewClient(conn)
require.NoError(t, err)
defer client.Close(t.Context())

Expand All @@ -241,7 +241,7 @@ func TestSendTaskRequest(t *testing.T) {
}, nil
})

client, err := grpcclient.NewClient(conn, grpcclient.WithCallTimeout(5*time.Second))
client, err := rpc.NewClient(conn, rpc.WithCallTimeout(5*time.Second))
require.NoError(t, err)
defer client.Close(t.Context())

Expand Down Expand Up @@ -274,7 +274,7 @@ func TestReceiveTaskResponse(t *testing.T) {
t.Run("ContextCancelled", func(t *testing.T) {
conn := startServer(t, noopHandler)

client, err := grpcclient.NewClient(conn)
client, err := rpc.NewClient(conn)
require.NoError(t, err)
defer client.Close(t.Context())

Expand All @@ -291,24 +291,24 @@ func TestClose(t *testing.T) {
t.Run("ThenSendAndReceive", func(t *testing.T) {
conn := startServer(t, noopHandler)

client, err := grpcclient.NewClient(conn)
client, err := rpc.NewClient(conn)
require.NoError(t, err)

err = client.Close(t.Context())
require.NoError(t, err)

err = client.SendTaskRequest(t.Context(), orbital.TaskRequest{})
assert.ErrorIs(t, err, grpcclient.ErrClientClosed)
assert.ErrorIs(t, err, rpc.ErrClientClosed)

resp, err := client.ReceiveTaskResponse(t.Context())
assert.ErrorIs(t, err, grpcclient.ErrClientClosed)
assert.ErrorIs(t, err, rpc.ErrClientClosed)
assert.Equal(t, orbital.TaskResponse{}, resp)
})

t.Run("Idempotent", func(t *testing.T) {
conn := startServer(t, noopHandler)

client, err := grpcclient.NewClient(conn)
client, err := rpc.NewClient(conn)
require.NoError(t, err)

assert.NotPanics(t, func() {
Expand Down
155 changes: 155 additions & 0 deletions client/rpc/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package rpc

import (
"context"
"errors"
"net"
"sync"
"sync/atomic"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

slogctx "github.com/veqryn/slog-context"

"github.com/openkcm/orbital"
"github.com/openkcm/orbital/codec"
orbitalv1 "github.com/openkcm/orbital/proto/orbital/v1"
)

var _ orbital.SyncResponder = (*Server)(nil)

var (
// ErrNilListener is returned by NewServer when given a nil net.Listener.
ErrNilListener = errors.New("grpc server: listener cannot be nil")
// ErrServerAlreadyRan is returned by Run if the server has already been started.
ErrServerAlreadyRan = errors.New("grpc server: Run already called")
)

type (
// ServerOption configures a Server.
ServerOption func(*serverConfig) error

serverConfig struct {
grpcServerOpts []grpc.ServerOption
}

// Server implements the orbital TaskService gRPC server and the
// SyncResponder interface. It receives task requests over gRPC,
// processes them via the TaskRequestHandler provided to Run,
// and returns responses synchronously.
Server struct {
orbitalv1.UnimplementedTaskServiceServer

config serverConfig
lis net.Listener
grpcServer *grpc.Server
handler orbital.TaskRequestHandler
ran atomic.Bool
stopOnce sync.Once
}
)

// NewServer creates a gRPC Server bound to the given listener.
func NewServer(lis net.Listener, opts ...ServerOption) (*Server, error) {
if lis == nil {
return nil, ErrNilListener
}

cfg := serverConfig{}
for _, opt := range opts {
if err := opt(&cfg); err != nil {
return nil, err
}
}

return &Server{
config: cfg,
lis: lis,
}, nil
}

// WithServerOptions appends gRPC server options (interceptors, TLS
// credentials, etc.) used when the underlying grpc.Server is created.
func WithServerOptions(opts ...grpc.ServerOption) ServerOption {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grpc.WithServerOptions(opts ...GRPC.ServerOption) grpc.ServerOption

I think the pkg naming isn't optimal because we now have 2 grpc pkgs using the same name. So using this function requires one to make aliased imports.

What if we call this pkg rpc? Its a similar case like the AMQP and even if it weren't, I think its still cleaner.

return func(cfg *serverConfig) error {
cfg.grpcServerOpts = append(cfg.grpcServerOpts, opts...)
return nil
}
}

// Run registers the TaskService, starts serving, and blocks until
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Run can be called multiple times but stop only once, so the second server cannot be gracefully stopped. This might not happen when used but then again there is Murphy's Law

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it so that Run and stop can only be called once on a grpc client.

// Close is called or ctx is cancelled. Run may only be called once.
func (s *Server) Run(ctx context.Context, handler orbital.TaskRequestHandler) error {
if !s.ran.CompareAndSwap(false, true) {
return ErrServerAlreadyRan
}

s.handler = handler
s.grpcServer = grpc.NewServer(s.config.grpcServerOpts...)
orbitalv1.RegisterTaskServiceServer(s.grpcServer, s)

go func() {
<-ctx.Done()
s.stop()
}()

slogctx.Info(ctx, "grpc server starting", "address", s.lis.Addr().String())
return s.grpcServer.Serve(s.lis)
}

// Close gracefully stops the gRPC server. If ctx expires before graceful
// shutdown completes, it force-stops.
func (s *Server) Close(ctx context.Context) error {
if s.grpcServer == nil {
return nil
}

stopped := make(chan struct{})
go func() {
s.stop()
close(stopped)
}()

select {
case <-stopped:
case <-ctx.Done():
s.grpcServer.Stop()
}
return nil
}

// SendTaskRequest implements orbitalv1.TaskServiceServer.
func (s *Server) SendTaskRequest(ctx context.Context, pReq *orbitalv1.TaskRequest) (*orbitalv1.TaskResponse, error) {
req, err := codec.FromProtoToTaskRequest(pReq)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid task request: %v", err)
}

resp, err := s.handler(ctx, req)
if err != nil {
return nil, mapProcessError(err)
}

return codec.FromTaskResponseToProto(resp), nil
}

func mapProcessError(err error) error {
switch {
case errors.Is(err, orbital.ErrSignatureInvalid):
return status.Error(codes.Unauthenticated, err.Error())
case errors.Is(err, orbital.ErrResponseSigning):
return status.Error(codes.Internal, err.Error())
default:
return status.Error(codes.Internal, err.Error())
}
}

func (s *Server) stop() {
s.stopOnce.Do(func() {
if s.grpcServer != nil {
s.grpcServer.GracefulStop()
}
})
}
Loading
Loading