From 37d10585e701af1144075f91c85362953ee42397 Mon Sep 17 00:00:00 2001 From: Alex Angelini Date: Fri, 12 Jul 2024 19:58:10 +0200 Subject: [PATCH] Cached UDS only --- Makefile | 21 ++- cmd/cached-client/main.go | 7 + cmd/cached/main.go | 4 +- internal/key/key.go | 1 + pkg/cached/cached.go | 40 +---- pkg/cachedcli/client.go | 145 ++++++++++++++++++ .../getcached.go => cachedcli/populate.go} | 6 +- pkg/cachedcli/probe.go | 28 ++++ pkg/{cli/cached.go => cachedcli/server.go} | 64 ++------ pkg/cli/client.go | 17 +- pkg/client/client.go | 53 ++++++- pkg/client/context.go | 6 +- test/cached_csi_test.go | 8 +- test/shared_test.go | 4 +- 14 files changed, 276 insertions(+), 128 deletions(-) create mode 100644 cmd/cached-client/main.go create mode 100644 pkg/cachedcli/client.go rename pkg/{cli/getcached.go => cachedcli/populate.go} (88%) create mode 100644 pkg/cachedcli/probe.go rename pkg/{cli/cached.go => cachedcli/server.go} (74%) diff --git a/Makefile b/Makefile index 17bf233f..03de2a87 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,8 @@ DB_URI := postgres://$(DB_USER):$(DB_PASS)@$(DB_HOST):5432/dl GRPC_HOST ?= localhost GRPC_PORT ?= 5051 -GRPC_CACHED_PORT ?= 5053 + +CACHED_SOCKET ?= unix:///tmp/csi.sock DEV_TOKEN_ADMIN ?= v2.public.eyJzdWIiOiJhZG1pbiJ9yt40HNkcyOUtDeFa_WPS6vi0WiE4zWngDGJLh17TuYvssTudCbOdQEkVDRD-mSNTXLgSRDXUkO-AaEr4ZLO4BQ DEV_TOKEN_PROJECT_1 ?= v2.public.eyJzdWIiOiIxIn2jV7FOdEXafKDtAnVyDgI4fmIbqU7C1iuhKiL0lDnG1Z5-j6_ObNDd75sZvLZ159-X98_mP4qvwzui0w8pjt8F @@ -125,12 +126,7 @@ server-profile: internal/pb/fs.pb.go internal/pb/fs_grpc.pb.go cached: export DL_ENV=dev cached: export DL_TOKEN=$(DEV_SHARED_READER_TOKEN) cached: internal/pb/cache.pb.go internal/pb/cache_grpc.pb.go - go run cmd/cached/main.go --upstream-host $(GRPC_HOST) --upstream-port $(GRPC_PORT) --port $(GRPC_CACHED_PORT) --staging-path tmp/cache-stage - -cached-csi: export DL_ENV=dev -cached-csi: export DL_TOKEN=$(DEV_SHARED_READER_TOKEN) -cached-csi: internal/pb/cache.pb.go internal/pb/cache_grpc.pb.go - go run cmd/cached/main.go --upstream-host $(GRPC_HOST) --upstream-port $(GRPC_PORT) --staging-path tmp/cache-stage --csi-socket unix://tmp/csi.sock + go run cmd/cached/main.go --upstream-host $(GRPC_HOST) --upstream-port $(GRPC_PORT) --csi-socket $(CACHED_SOCKET) --staging-path tmp/cache-stage client-update: export DL_TOKEN=$(DEV_TOKEN_PROJECT_1) client-update: export DL_SKIP_SSL_VERIFICATION=1 @@ -180,11 +176,6 @@ client-getcache: export DL_SKIP_SSL_VERIFICATION=1 client-getcache: go run cmd/client/main.go getcache --host $(GRPC_HOST) --path input/cache -client-getcached: export DL_TOKEN=$(DEV_TOKEN_ADMIN) -client-getcached: export DL_SKIP_SSL_VERIFICATION=1 -client-getcached: - go run cmd/client/main.go getcached --host $(GRPC_HOST) --port $(GRPC_CACHED_PORT) --path input/cache - client-gc-contents: export DL_TOKEN=$(DEV_TOKEN_ADMIN) client-gc-contents: export DL_SKIP_SSL_VERIFICATION=1 client-gc-contents: @@ -200,6 +191,12 @@ client-gc-random-projects: export DL_SKIP_SSL_VERIFICATION=1 client-gc-random-projects: go run cmd/client/main.go gc --host $(GRPC_HOST) --mode random-projects --sample 25 --keep 1 +cachedclient-probe: + go run cmd/cached-client/main.go probe --socket $(CACHED_SOCKET) + +cachedclient-populate: + go run cmd/cached-client/main.go populate --socket $(CACHED_SOCKET) + health: grpc-health-probe -addr $(GRPC_SERVER) grpc-health-probe -addr $(GRPC_SERVER) -service $(SERVICE) diff --git a/cmd/cached-client/main.go b/cmd/cached-client/main.go new file mode 100644 index 00000000..4017f2f4 --- /dev/null +++ b/cmd/cached-client/main.go @@ -0,0 +1,7 @@ +package main + +import "github.com/gadget-inc/dateilager/pkg/cachedcli" + +func main() { + cachedcli.ClientExecute() +} diff --git a/cmd/cached/main.go b/cmd/cached/main.go index 20a7b13e..2ea602d4 100644 --- a/cmd/cached/main.go +++ b/cmd/cached/main.go @@ -1,7 +1,7 @@ package main -import "github.com/gadget-inc/dateilager/pkg/cli" +import "github.com/gadget-inc/dateilager/pkg/cachedcli" func main() { - cli.CacheDaemonExecute() + cachedcli.CacheDaemonExecute() } diff --git a/internal/key/key.go b/internal/key/key.go index f0513c27..4d90e847 100644 --- a/internal/key/key.go +++ b/internal/key/key.go @@ -30,6 +30,7 @@ const ( QueryPath = StringKey("dl.query.path") SampleRate = Float32Key("dl.sample_rate") Server = StringKey("dl.server") + Socket = StringKey("dl.socket") State = StringKey("dl.state") Template = Int64pKey("dl.template") ToVersion = Int64pKey("dl.to_version") diff --git a/pkg/cached/cached.go b/pkg/cached/cached.go index fac26086..8d5e3698 100644 --- a/pkg/cached/cached.go +++ b/pkg/cached/cached.go @@ -2,8 +2,6 @@ package cached import ( "context" - "crypto/ed25519" - "crypto/tls" "fmt" "net" "net/url" @@ -12,54 +10,32 @@ import ( "path/filepath" "github.com/container-storage-interface/spec/lib/go/csi" - "github.com/gadget-inc/dateilager/internal/auth" "github.com/gadget-inc/dateilager/internal/logger" "github.com/gadget-inc/dateilager/internal/pb" "github.com/gadget-inc/dateilager/pkg/api" - "github.com/gadget-inc/dateilager/pkg/server" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/health" - healthpb "google.golang.org/grpc/health/grpc_health_v1" ) type CachedServer struct { - Grpc *grpc.Server - Health *health.Server + Grpc *grpc.Server } -func NewServer(ctx context.Context, cert *tls.Certificate, pasetoKey ed25519.PublicKey) *CachedServer { - creds := credentials.NewServerTLSFromCert(cert) - validator := auth.NewAuthValidator(pasetoKey) - +func NewServer(ctx context.Context) *CachedServer { grpcServer := grpc.NewServer( grpc.UnaryInterceptor( grpc_middleware.ChainUnaryServer( grpc_recovery.UnaryServerInterceptor(), otelgrpc.UnaryServerInterceptor(), logger.UnaryServerInterceptor(), - server.ValidateTokenUnary(validator), ), ), - grpc.ReadBufferSize(server.BUFFER_SIZE), - grpc.WriteBufferSize(server.BUFFER_SIZE), - grpc.InitialConnWindowSize(server.INITIAL_CONN_WINDOW_SIZE), - grpc.InitialWindowSize(server.INITIAL_WINDOW_SIZE), - grpc.MaxRecvMsgSize(server.MAX_MESSAGE_SIZE), - grpc.MaxSendMsgSize(server.MAX_MESSAGE_SIZE), - grpc.Creds(creds), ) - logger.Info(ctx, "register HealthServer") - healthServer := health.NewServer() - healthpb.RegisterHealthServer(grpcServer, healthServer) - server := &CachedServer{ - Grpc: grpcServer, - Health: healthServer, + Grpc: grpcServer, } return server @@ -74,14 +50,10 @@ func (s *CachedServer) RegisterCSI(cached *api.Cached) { csi.RegisterNodeServer(s.Grpc, cached) } -func (s *CachedServer) Serve(lis net.Listener) error { - return s.Grpc.Serve(lis) -} - -func (s *CachedServer) ServeCSI(listenSocketPath string) error { - u, err := url.Parse(listenSocketPath) +func (s *CachedServer) Serve(socketPath string) error { + u, err := url.Parse(socketPath) if err != nil { - return fmt.Errorf("unable to parse address: %q", err) + return fmt.Errorf("unable to parse socket address: %q", err) } addr := path.Join(u.Host, filepath.FromSlash(u.Path)) diff --git a/pkg/cachedcli/client.go b/pkg/cachedcli/client.go new file mode 100644 index 00000000..5d6a0b9f --- /dev/null +++ b/pkg/cachedcli/client.go @@ -0,0 +1,145 @@ +package cachedcli + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "strings" + "time" + + "github.com/gadget-inc/dateilager/internal/logger" + "github.com/gadget-inc/dateilager/internal/telemetry" + "github.com/gadget-inc/dateilager/pkg/client" + "github.com/gadget-inc/dateilager/pkg/version" + "github.com/spf13/cobra" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +var ( + shutdownTelemetry func() + span trace.Span +) + +func NewCachedClientCommand() *cobra.Command { + var ( + level *zapcore.Level + encoding string + tracing bool + otelContext string + socket string + timeout uint + ) + + var cancel context.CancelFunc + + cmd := &cobra.Command{ + Use: "cachedclient", + Short: "DateiLager cached client", + DisableAutoGenTag: true, + Version: version.Version, + SilenceErrors: true, + PersistentPreRunE: func(cmd *cobra.Command, _ []string) error { + cmd.SilenceUsage = true // silence usage when an error occurs after flags have been parsed + + config := zap.NewProductionConfig() + config.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder + config.Level = zap.NewAtomicLevelAt(*level) + config.Encoding = encoding + + err := logger.Init(config) + if err != nil { + return fmt.Errorf("could not initialize logger: %w", err) + } + + ctx := cmd.Context() + + if timeout != 0 { + ctx, cancel = context.WithTimeout(cmd.Context(), time.Duration(timeout)*time.Millisecond) + } + + if tracing { + shutdownTelemetry = telemetry.Init(ctx, telemetry.Client) + } + + if otelContext != "" { + var mapCarrier propagation.MapCarrier + err := json.NewDecoder(strings.NewReader(otelContext)).Decode(&mapCarrier) + if err != nil { + return fmt.Errorf("failed to decode otel-context: %w", err) + } + + ctx = otel.GetTextMapPropagator().Extract(ctx, mapCarrier) + } + + ctx, span = telemetry.Start(ctx, "cached-cmd.main") + + if socket == "" { + return fmt.Errorf("required flag(s) \"socket\" not set") + } + + cl, err := client.NewCachedUnixClient(ctx, socket) + if err != nil { + return err + } + ctx = client.CachedIntoContext(ctx, cl) + + cmd.SetContext(ctx) + + return nil + }, + PersistentPostRunE: func(cmd *cobra.Command, _ []string) error { + if cancel != nil { + cancel() + } + return nil + }, + } + + flags := cmd.PersistentFlags() + + level = zap.LevelFlag("log-level", zap.DebugLevel, "Log level") + flags.AddGoFlag(flag.CommandLine.Lookup("log-level")) + flags.StringVar(&encoding, "log-encoding", "console", "Log encoding (console | json)") + flags.BoolVar(&tracing, "tracing", false, "Whether tracing is enabled") + flags.StringVar(&otelContext, "otel-context", "", "Open Telemetry context") + + flags.StringVar(&socket, "socket", "", "Unix domain socket path") + flags.UintVar(&timeout, "timeout", 0, "GRPC client timeout (ms)") + + _ = cmd.MarkFlagRequired("socket") + + cmd.AddCommand(NewCmdPopulate()) + cmd.AddCommand(NewCmdProbe()) + + return cmd +} + +func ClientExecute() { + ctx := context.Background() + cmd := NewCachedClientCommand() + err := cmd.ExecuteContext(ctx) + + client := client.FromContext(cmd.Context()) + if client != nil { + client.Close() + } + + if span != nil { + span.End() + } + + if shutdownTelemetry != nil { + shutdownTelemetry() + } + + _ = logger.Sync() + + if err != nil { + logger.Fatal(ctx, "command failed", zap.Error(err)) + } +} diff --git a/pkg/cli/getcached.go b/pkg/cachedcli/populate.go similarity index 88% rename from pkg/cli/getcached.go rename to pkg/cachedcli/populate.go index 2e810284..271a8121 100644 --- a/pkg/cli/getcached.go +++ b/pkg/cachedcli/populate.go @@ -1,4 +1,4 @@ -package cli +package cachedcli import ( "github.com/gadget-inc/dateilager/internal/key" @@ -7,13 +7,13 @@ import ( "github.com/spf13/cobra" ) -func NewCmdGetCacheFromDaemon() *cobra.Command { +func NewCmdPopulate() *cobra.Command { var ( path string ) cmd := &cobra.Command{ - Use: "getcached", + Use: "populate", RunE: func(cmd *cobra.Command, _ []string) error { ctx := cmd.Context() c := client.CachedFromContext(ctx) diff --git a/pkg/cachedcli/probe.go b/pkg/cachedcli/probe.go new file mode 100644 index 00000000..e3e9e009 --- /dev/null +++ b/pkg/cachedcli/probe.go @@ -0,0 +1,28 @@ +package cachedcli + +import ( + "github.com/gadget-inc/dateilager/internal/logger" + "github.com/gadget-inc/dateilager/pkg/client" + "github.com/spf13/cobra" + "go.uber.org/zap" +) + +func NewCmdProbe() *cobra.Command { + cmd := &cobra.Command{ + Use: "probe", + RunE: func(cmd *cobra.Command, _ []string) error { + ctx := cmd.Context() + c := client.CachedFromContext(ctx) + + ready, err := c.Probe(ctx) + if err != nil { + return err + } + + logger.Info(ctx, "server probe", zap.Bool("ready", ready)) + return nil + }, + } + + return cmd +} diff --git a/pkg/cli/cached.go b/pkg/cachedcli/server.go similarity index 74% rename from pkg/cli/cached.go rename to pkg/cachedcli/server.go index cf950501..1530673e 100644 --- a/pkg/cli/cached.go +++ b/pkg/cachedcli/server.go @@ -1,11 +1,9 @@ -package cli +package cachedcli import ( "context" - "crypto/tls" "flag" "fmt" - "net" "os" "os/signal" "runtime/pprof" @@ -22,7 +20,6 @@ import ( "github.com/spf13/cobra" "go.uber.org/zap" "go.uber.org/zap/zapcore" - "golang.org/x/sync/errgroup" ) func NewCacheDaemonCommand() *cobra.Command { @@ -38,10 +35,6 @@ func NewCacheDaemonCommand() *cobra.Command { profilePath string upstreamHost string upstreamPort uint16 - certFile string - keyFile string - pasetoFile string - port int timeout uint headlessHost string stagingPath string @@ -97,63 +90,34 @@ func NewCacheDaemonCommand() *cobra.Command { return err } - cert, err := tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return fmt.Errorf("cannot open TLS cert and key files (%s, %s): %w", certFile, keyFile, err) - } - - pasetoKey, err := parsePublicKey(pasetoFile) - if err != nil { - return fmt.Errorf("cannot parse Paseto public key %s: %w", pasetoFile, err) - } + s := cached.NewServer(ctx) - s := cached.NewServer(ctx, &cert, pasetoKey) - - logger.Info(ctx, "register Cached") cached := &api.Cached{ Env: env, Client: cl, StagingPath: stagingPath, } + + logger.Info(ctx, "register Cached") s.RegisterCached(cached) + logger.Info(ctx, "register CSI") + s.RegisterCSI(cached) + err = cached.Prepare(ctx) if err != nil { return fmt.Errorf("failed to prepare cache daemon in %s: %w", stagingPath, err) } - group, ctx := errgroup.WithContext(ctx) - - if csiSocket != "" { - logger.Info(ctx, "register CSI") - s.RegisterCSI(cached) - - group.Go(func() error { - logger.Info(ctx, "start CSI server") - return s.ServeCSI(csiSocket) - }) - } - - group.Go(func() error { - listen, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) - if err != nil { - return fmt.Errorf("failed to listen on TCP port %d: %w", port, err) - } - - logger.Info(ctx, "start cached server", key.Port.Field(port), key.Environment.Field(env.String())) - return s.Serve(listen) - }) - osSignals := make(chan os.Signal, 1) signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM) - - group.Go(func() error { + go func() { <-osSignals s.Grpc.GracefulStop() - return nil - }) + }() - return group.Wait() + logger.Info(ctx, "start cached server", key.Socket.Field(csiSocket)) + return s.Serve(csiSocket) }, PostRunE: func(cmd *cobra.Command, _ []string) error { if shutdownTelemetry != nil { @@ -176,17 +140,15 @@ func NewCacheDaemonCommand() *cobra.Command { flags.BoolVar(&tracing, "tracing", false, "Whether tracing is enabled") flags.StringVar(&profilePath, "profile", "", "CPU profile output path (profiling enabled if set)") - flags.IntVar(&port, "port", 5053, "cache API port") flags.StringVar(&upstreamHost, "upstream-host", "localhost", "GRPC server hostname") flags.Uint16Var(&upstreamPort, "upstream-port", 5051, "GRPC server port") flags.StringVar(&headlessHost, "headless-host", "", "Alternative headless hostname to use for round robin connections") - flags.StringVar(&certFile, "cert", "development/server.crt", "TLS cert file") - flags.StringVar(&keyFile, "key", "development/server.key", "TLS key file") - flags.StringVar(&pasetoFile, "paseto", "development/paseto.pub", "Paseto public key file") flags.UintVar(&timeout, "timeout", 0, "GRPC client timeout (ms)") flags.StringVar(&csiSocket, "csi-socket", "", "path for running the Kubernetes CSI Driver interface") flags.StringVar(&stagingPath, "staging-path", "", "path for staging downloaded caches") + + _ = cmd.MarkPersistentFlagRequired("csi-socket") _ = cmd.MarkPersistentFlagRequired("staging-path") return cmd diff --git a/pkg/cli/client.go b/pkg/cli/client.go index eb1bfe4e..7f187ca6 100644 --- a/pkg/cli/client.go +++ b/pkg/cli/client.go @@ -5,7 +5,6 @@ import ( "encoding/json" "flag" "fmt" - "slices" "strings" "time" @@ -22,9 +21,8 @@ import ( ) var ( - shutdownTelemetry func() - span trace.Span - requiresCachedClient = []string{"getcached"} + shutdownTelemetry func() + span trace.Span ) func NewClientCommand() *cobra.Command { @@ -92,14 +90,6 @@ func NewClientCommand() *cobra.Command { } ctx = client.IntoContext(ctx, cl) - if slices.Contains(requiresCachedClient, cmd.CalledAs()) { - cachedClient, err := client.NewCachedClient(ctx, host, port, client.WithheadlessHost(headlessHost)) - if err != nil { - return err - } - ctx = client.CachedIntoContext(ctx, cachedClient) - } - cmd.SetContext(ctx) return nil @@ -118,8 +108,8 @@ func NewClientCommand() *cobra.Command { flags.AddGoFlag(flag.CommandLine.Lookup("log-level")) flags.StringVar(&encoding, "log-encoding", "console", "Log encoding (console | json)") flags.BoolVar(&tracing, "tracing", false, "Whether tracing is enabled") - flags.StringVar(&otelContext, "otel-context", "", "Open Telemetry context") + flags.StringVar(&host, "host", "", "GRPC server hostname") flags.Uint16Var(&port, "port", 5051, "GRPC server port") flags.StringVar(&headlessHost, "headless-host", "", "Alternative headless hostname to use for round robin connections") @@ -136,7 +126,6 @@ func NewClientCommand() *cobra.Command { cmd.AddCommand(NewCmdUpdate()) cmd.AddCommand(NewCmdGc()) cmd.AddCommand(NewCmdGetCache()) - cmd.AddCommand(NewCmdGetCacheFromDaemon()) return cmd } diff --git a/pkg/client/client.go b/pkg/client/client.go index f2587add..78bc10f5 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -7,6 +7,7 @@ import ( "encoding/hex" "fmt" "io" + "net" "os" "path/filepath" "runtime" @@ -15,6 +16,7 @@ import ( "sync/atomic" "time" + "github.com/container-storage-interface/spec/lib/go/csi" "github.com/gadget-inc/dateilager/internal/db" "github.com/gadget-inc/dateilager/internal/files" "github.com/gadget-inc/dateilager/internal/key" @@ -26,7 +28,9 @@ import ( "golang.org/x/oauth2" "golang.org/x/sync/errgroup" "google.golang.org/grpc" + "google.golang.org/grpc/backoff" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/oauth" "google.golang.org/grpc/keepalive" ) @@ -51,8 +55,9 @@ type Client struct { } type CachedClient struct { - conn *grpc.ClientConn - cached pb.CachedClient + conn *grpc.ClientConn + cached pb.CachedClient + identity csi.IdentityClient } func NewClientConn(conn *grpc.ClientConn) *Client { @@ -60,7 +65,7 @@ func NewClientConn(conn *grpc.ClientConn) *Client { } func NewCachedClientConn(conn *grpc.ClientConn) *CachedClient { - return &CachedClient{conn: conn, cached: pb.NewCachedClient(conn)} + return &CachedClient{conn: conn, cached: pb.NewCachedClient(conn), identity: csi.NewIdentityClient(conn)} } type options struct { @@ -977,6 +982,37 @@ func NewCachedClient(ctx context.Context, host string, port uint16, opts ...func return NewCachedClientConn(conn), nil } +func NewCachedUnixClient(ctx context.Context, socket string) (*CachedClient, error) { + ctx, span := telemetry.Start(ctx, "cached-unix-client.new", trace.WithAttributes( + key.Server.Attribute(socket), + )) + defer span.End() + + bc := backoff.DefaultConfig + bc.MaxDelay = time.Second + dialOptions := []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithConnectParams(grpc.ConnectParams{Backoff: bc}), + grpc.WithBlock(), + grpc.WithIdleTimeout(time.Duration(0)), + grpc.WithContextDialer(func(ctx context.Context, path string) (net.Conn, error) { + var timeout time.Duration + deadline, ok := ctx.Deadline() + if ok { + timeout = time.Until(deadline) + } + return net.DialTimeout("unix", path[len("unix://"):], timeout) + }), + } + + conn, err := grpc.DialContext(ctx, socket, dialOptions...) + if err != nil { + return nil, err + } + + return NewCachedClientConn(conn), nil +} + func (c *CachedClient) Close() { // Give a chance for the upstream socket to finish writing it's response // https://github.com/grpc/grpc-go/issues/2869#issuecomment-503310136 @@ -1002,6 +1038,17 @@ func (c *CachedClient) PopulateDiskCache(ctx context.Context, destination string return response.Version, nil } +func (c *CachedClient) Probe(ctx context.Context) (bool, error) { + request := &csi.ProbeRequest{} + + response, err := c.identity.Probe(ctx, request) + if err != nil { + return false, fmt.Errorf("failed to probe server: %w", err) + } + + return response.Ready.Value, nil +} + func parallelWorkerCount() int { envCount := os.Getenv("DL_WRITE_WORKERS") if envCount != "" { diff --git a/pkg/client/context.go b/pkg/client/context.go index 37d56f37..d9067db9 100644 --- a/pkg/client/context.go +++ b/pkg/client/context.go @@ -5,7 +5,7 @@ import ( ) type clientCtxKey struct{} -type cachedCtxKey struct{} +type cachedClientCtxKey struct{} func FromContext(ctx context.Context) *Client { client, ok := ctx.Value(clientCtxKey{}).(*Client) @@ -20,7 +20,7 @@ func IntoContext(ctx context.Context, client *Client) context.Context { } func CachedFromContext(ctx context.Context) *CachedClient { - client, ok := ctx.Value(cachedCtxKey{}).(*CachedClient) + client, ok := ctx.Value(cachedClientCtxKey{}).(*CachedClient) if !ok { return nil } @@ -28,5 +28,5 @@ func CachedFromContext(ctx context.Context) *CachedClient { } func CachedIntoContext(ctx context.Context, client *CachedClient) context.Context { - return context.WithValue(ctx, cachedCtxKey{}, client) + return context.WithValue(ctx, cachedClientCtxKey{}, client) } diff --git a/test/cached_csi_test.go b/test/cached_csi_test.go index b20d4d4e..69d885ad 100644 --- a/test/cached_csi_test.go +++ b/test/cached_csi_test.go @@ -29,7 +29,7 @@ func TestCachedCSIDriver(t *testing.T) { tmpDir := emptyTmpDir(t) defer os.RemoveAll(tmpDir) - cached, endpoint, close := createTestCachedCSIServer(tc, tmpDir) + cached, endpoint, close := createTestCachedServer(tc, tmpDir) defer close() err = cached.Prepare(tc.Context()) @@ -61,7 +61,7 @@ func TestCachedCSIDriverMountsCache(t *testing.T) { tmpDir := emptyTmpDir(t) defer os.RemoveAll(tmpDir) - cached, _, close := createTestCachedCSIServer(tc, tmpDir) + cached, _, close := createTestCachedServer(tc, tmpDir) defer close() require.NoError(t, cached.Prepare(tc.Context()), "cached.Prepare must succeed") @@ -110,7 +110,7 @@ func TestCachedCSIDriverMountsCacheAtSuffix(t *testing.T) { tmpDir := emptyTmpDir(t) defer os.RemoveAll(tmpDir) - cached, _, close := createTestCachedCSIServer(tc, tmpDir) + cached, _, close := createTestCachedServer(tc, tmpDir) defer close() err = cached.Prepare(tc.Context()) @@ -165,7 +165,7 @@ func TestCachedCSIDriverProbeFailsUntilPrepared(t *testing.T) { tmpDir := emptyTmpDir(t) defer os.RemoveAll(tmpDir) - cached, _, close := createTestCachedCSIServer(tc, tmpDir) + cached, _, close := createTestCachedServer(tc, tmpDir) defer close() response, err := cached.Probe(tc.Context(), &csi.ProbeRequest{}) diff --git a/test/shared_test.go b/test/shared_test.go index 785d8d41..97ce77e2 100644 --- a/test/shared_test.go +++ b/test/shared_test.go @@ -411,7 +411,7 @@ func createTestGRPCServer(tc util.TestCtx) (*bufconn.Listener, *grpc.Server, fun return lis, s, getConn } -func createTestCachedCSIServer(tc util.TestCtx, tmpDir string) (*api.Cached, string, func()) { +func createTestCachedServer(tc util.TestCtx, tmpDir string) (*api.Cached, string, func()) { cl, _, closeClient := createTestClient(tc) _, grpcServer, _ := createTestGRPCServer(tc) @@ -427,7 +427,7 @@ func createTestCachedCSIServer(tc util.TestCtx, tmpDir string) (*api.Cached, str endpoint := "unix://" + socket go func() { - err := s.ServeCSI(endpoint) + err := s.Serve(endpoint) require.NoError(tc.T(), err, "CSI Server exited") }()