From f90d4fe7bff838cfed3001920965f33c57105f3d Mon Sep 17 00:00:00 2001 From: Sidhant Kohli Date: Fri, 25 Aug 2023 16:31:22 -0700 Subject: [PATCH] feat: Add side input sdkclient and grpc (#953) Signed-off-by: Sidhant Kohli --- go.mod | 2 +- go.sum | 4 +- pkg/sdkclient/sideinput/client.go | 95 +++++++++++++++++++ pkg/sdkclient/sideinput/client_test.go | 86 +++++++++++++++++ pkg/sdkclient/sideinput/interface.go | 15 +++ pkg/sdkclient/sideinput/options.go | 49 ++++++++++ pkg/sideinputs/initializer/initializer.go | 17 ++-- .../initializer/initializer_test.go | 4 +- pkg/sideinputs/manager/manager.go | 59 +++++++++--- pkg/sideinputs/synchronizer/synchronizer.go | 19 ++-- .../synchronizer/synchronizer_test.go | 4 +- pkg/sideinputs/utils/utils.go | 22 ++++- pkg/sideinputs/utils/utils_test.go | 53 +++++++++-- 13 files changed, 381 insertions(+), 48 deletions(-) create mode 100644 pkg/sdkclient/sideinput/client.go create mode 100644 pkg/sdkclient/sideinput/client_test.go create mode 100644 pkg/sdkclient/sideinput/interface.go create mode 100644 pkg/sdkclient/sideinput/options.go diff --git a/go.mod b/go.mod index 976531e412..0c4c32f7db 100644 --- a/go.mod +++ b/go.mod @@ -27,7 +27,7 @@ require ( github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe github.com/nats-io/nats-server/v2 v2.9.19 github.com/nats-io/nats.go v1.27.1 - github.com/numaproj/numaflow-go v0.4.6-0.20230822054239-88190e94e727 + github.com/numaproj/numaflow-go v0.4.6-0.20230824220200-630a5eba1f54 github.com/prometheus/client_golang v1.14.0 github.com/prometheus/common v0.37.0 github.com/redis/go-redis/v9 v9.0.3 diff --git a/go.sum b/go.sum index 6bff37c334..8ef5168ab1 100644 --- a/go.sum +++ b/go.sum @@ -670,8 +670,8 @@ github.com/nats-io/nkeys v0.4.4/go.mod h1:XUkxdLPTufzlihbamfzQ7mw/VGx6ObUs+0bN5s github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= -github.com/numaproj/numaflow-go v0.4.6-0.20230822054239-88190e94e727 h1:m+2sl0pbBvhiiLEXyyslBv0GeWXm/1wpR4PUg0C2xY8= -github.com/numaproj/numaflow-go v0.4.6-0.20230822054239-88190e94e727/go.mod h1:5zwvvREIbqaCPCKsNE1MVjVToD0kvkCh2Z90Izlhw5U= +github.com/numaproj/numaflow-go v0.4.6-0.20230824220200-630a5eba1f54 h1:nx77VKeseDKPFHhY4AMecvzhJw8oSEVeisAROufT5dU= +github.com/numaproj/numaflow-go v0.4.6-0.20230824220200-630a5eba1f54/go.mod h1:5zwvvREIbqaCPCKsNE1MVjVToD0kvkCh2Z90Izlhw5U= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= diff --git a/pkg/sdkclient/sideinput/client.go b/pkg/sdkclient/sideinput/client.go new file mode 100644 index 0000000000..466ccef79e --- /dev/null +++ b/pkg/sdkclient/sideinput/client.go @@ -0,0 +1,95 @@ +package sideinput + +import ( + "context" + "fmt" + "time" + + sideinputpb "github.com/numaproj/numaflow-go/pkg/apis/proto/sideinput/v1" + "github.com/numaproj/numaflow-go/pkg/shared" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/protobuf/types/known/emptypb" +) + +// client contains the grpc connection and the grpc client. +type client struct { + conn *grpc.ClientConn + grpcClt sideinputpb.SideInputClient +} + +var _ Client = (*client)(nil) + +// New creates a new client object. +func New(inputOptions ...Option) (*client, error) { + var opts = &options{ + sockAddr: shared.SideInputAddr, + maxMessageSize: 1024 * 1024 * 64, // 64 MB + } + for _, inputOption := range inputOptions { + inputOption(opts) + } + _, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + c := new(client) + sockAddr := fmt.Sprintf("%s:%s", shared.UDS, opts.sockAddr) + conn, err := grpc.Dial(sockAddr, grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(opts.maxMessageSize), grpc.MaxCallSendMsgSize(opts.maxMessageSize))) + if err != nil { + return nil, fmt.Errorf("failed to execute grpc.Dial(%q): %w", sockAddr, err) + } + c.conn = conn + c.grpcClt = sideinputpb.NewSideInputClient(conn) + return c, nil +} + +// NewFromClient creates a new client object from a grpc client. This is used for testing. +func NewFromClient(c sideinputpb.SideInputClient) (Client, error) { + return &client{ + grpcClt: c, + }, nil +} + +// CloseConn closes the grpc connection. +func (c client) CloseConn(ctx context.Context) error { + return c.conn.Close() +} + +// IsReady checks if the grpc connection is ready to use. +func (c client) IsReady(ctx context.Context, in *emptypb.Empty) (bool, error) { + resp, err := c.grpcClt.IsReady(ctx, in) + if err != nil { + return false, err + } + return resp.GetReady(), nil +} + +// RetrieveSideInput retrieves the side input value and returns the updated payload. +func (c client) RetrieveSideInput(ctx context.Context, in *emptypb.Empty) (*sideinputpb.SideInputResponse, error) { + retrieveResponse, err := c.grpcClt.RetrieveSideInput(ctx, in) + // TODO check which error to use + if err != nil { + return nil, fmt.Errorf("failed to execute c.grpcClt.RetrieveSideInput(): %w", err) + } + return retrieveResponse, nil +} + +// IsHealthy checks if the client is healthy. +func (c client) IsHealthy(ctx context.Context) error { + return c.WaitUntilReady(ctx) +} + +// WaitUntilReady waits until the client is connected. +func (c client) WaitUntilReady(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return fmt.Errorf("failed on readiness check: %w", ctx.Err()) + default: + if _, err := c.IsReady(ctx, &emptypb.Empty{}); err == nil { + return nil + } + time.Sleep(1 * time.Second) + } + } +} diff --git a/pkg/sdkclient/sideinput/client_test.go b/pkg/sdkclient/sideinput/client_test.go new file mode 100644 index 0000000000..7b7152736b --- /dev/null +++ b/pkg/sdkclient/sideinput/client_test.go @@ -0,0 +1,86 @@ +package sideinput + +import ( + "bytes" + "context" + "fmt" + "reflect" + "testing" + + "github.com/golang/mock/gomock" + sideinputpb "github.com/numaproj/numaflow-go/pkg/apis/proto/sideinput/v1" + "github.com/numaproj/numaflow-go/pkg/apis/proto/sideinput/v1/sideinputmock" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/emptypb" +) + +type rpcMsg struct { + msg proto.Message +} + +func (r *rpcMsg) Matches(msg interface{}) bool { + m, ok := msg.(proto.Message) + if !ok { + return false + } + return proto.Equal(m, r.msg) +} + +func (r *rpcMsg) String() string { + return fmt.Sprintf("is %s", r.msg) +} + +func TestIsReady(t *testing.T) { + var ctx = context.Background() + LintCleanCall() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := sideinputmock.NewMockSideInputClient(ctrl) + mockClient.EXPECT().IsReady(gomock.Any(), gomock.Any()).Return(&sideinputpb.ReadyResponse{Ready: true}, nil) + mockClient.EXPECT().IsReady(gomock.Any(), gomock.Any()).Return(&sideinputpb.ReadyResponse{Ready: false}, fmt.Errorf("mock connection refused")) + + testClient, err := NewFromClient(mockClient) + assert.NoError(t, err) + reflect.DeepEqual(testClient, &client{ + grpcClt: mockClient, + }) + + ready, err := testClient.IsReady(ctx, &emptypb.Empty{}) + assert.True(t, ready) + assert.NoError(t, err) + + ready, err = testClient.IsReady(ctx, &emptypb.Empty{}) + assert.False(t, ready) + assert.EqualError(t, err, "mock connection refused") +} + +func TestRetrieveFn(t *testing.T) { + var ctx = context.Background() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockSideInputClient := sideinputmock.NewMockSideInputClient(ctrl) + response := sideinputpb.SideInputResponse{Value: []byte("mock side input message")} + mockSideInputClient.EXPECT().RetrieveSideInput(gomock.Any(), gomock.Any()).Return(&sideinputpb.SideInputResponse{Value: []byte("mock side input message")}, nil) + + testClient, err := NewFromClient(mockSideInputClient) + assert.NoError(t, err) + reflect.DeepEqual(testClient, &client{ + grpcClt: mockSideInputClient, + }) + + got, err := testClient.RetrieveSideInput(ctx, &emptypb.Empty{}) + assert.True(t, bytes.Equal(got.Value, response.Value)) + assert.NoError(t, err) +} + +// Check if there is a better way to resolve +func LintCleanCall() { + var m = rpcMsg{} + fmt.Println(m.Matches(m)) + fmt.Println(m) +} diff --git a/pkg/sdkclient/sideinput/interface.go b/pkg/sdkclient/sideinput/interface.go new file mode 100644 index 0000000000..cba396dad0 --- /dev/null +++ b/pkg/sdkclient/sideinput/interface.go @@ -0,0 +1,15 @@ +package sideinput + +import ( + "context" + + sideinputpb "github.com/numaproj/numaflow-go/pkg/apis/proto/sideinput/v1" + "google.golang.org/protobuf/types/known/emptypb" +) + +// Client contains methods to call a gRPC client for side input. +type Client interface { + CloseConn(ctx context.Context) error + IsReady(ctx context.Context, in *emptypb.Empty) (bool, error) + RetrieveSideInput(ctx context.Context, in *emptypb.Empty) (*sideinputpb.SideInputResponse, error) +} diff --git a/pkg/sdkclient/sideinput/options.go b/pkg/sdkclient/sideinput/options.go new file mode 100644 index 0000000000..425f7dbe68 --- /dev/null +++ b/pkg/sdkclient/sideinput/options.go @@ -0,0 +1,49 @@ +/* +Copyright 2022 The Numaproj Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sideinput + +import "time" + +type options struct { + sockAddr string + maxMessageSize int + sideInputTimeout time.Duration +} + +// Option is the interface to apply options. +type Option func(*options) + +// WithSockAddr start the client with the given sock addr. This is mainly used for testing purpose. +func WithSockAddr(addr string) Option { + return func(opts *options) { + opts.sockAddr = addr + } +} + +// WithMaxMessageSize sets the max message size to the given size. +func WithMaxMessageSize(size int) Option { + return func(o *options) { + o.maxMessageSize = size + } +} + +// WithSideInputTimeout sets the side input timeout to the given timeout. +func WithSideInputTimeout(t time.Duration) Option { + return func(o *options) { + o.sideInputTimeout = t + } +} diff --git a/pkg/sideinputs/initializer/initializer.go b/pkg/sideinputs/initializer/initializer.go index 1b3e5dc440..07e193f8fc 100644 --- a/pkg/sideinputs/initializer/initializer.go +++ b/pkg/sideinputs/initializer/initializer.go @@ -56,8 +56,9 @@ func NewSideInputsInitializer(isbSvcType dfv1.ISBSvcType, pipelineName, sideInpu // and update the values on the disk. This would exit once all the side inputs are initialized. func (sii *sideInputsInitializer) Run(ctx context.Context) error { var ( - natsClient *jsclient.NATSClient - err error + natsClient *jsclient.NATSClient + err error + sideInputWatcher kvs.KVWatcher ) log := logging.FromContext(ctx) @@ -75,15 +76,15 @@ func (sii *sideInputsInitializer) Run(ctx context.Context) error { return err } defer natsClient.Close() + // Load the required KV bucket and create a sideInputWatcher for it + kvName := isbsvc.JetStreamSideInputsStoreKVName(sii.sideInputsStore) + sideInputWatcher, err = jetstream.NewKVJetStreamKVWatch(ctx, kvName, natsClient) + if err != nil { + return fmt.Errorf("failed to create a sideInputWatcher, %w", err) + } default: return fmt.Errorf("unrecognized isbsvc type %q", sii.isbSvcType) } - // Load the required KV bucket and create a sideInputWatcher for it - kvName := isbsvc.JetStreamSideInputsStoreKVName(sii.sideInputsStore) - sideInputWatcher, err := jetstream.NewKVJetStreamKVWatch(ctx, kvName, natsClient) - if err != nil { - return fmt.Errorf("failed to create a sideInputWatcher, %w", err) - } return startSideInputInitializer(ctx, sideInputWatcher, dfv1.PathSideInputsMount, sii.sideInputs) } diff --git a/pkg/sideinputs/initializer/initializer_test.go b/pkg/sideinputs/initializer/initializer_test.go index 1b1cae926d..2707aa03a4 100644 --- a/pkg/sideinputs/initializer/initializer_test.go +++ b/pkg/sideinputs/initializer/initializer_test.go @@ -93,14 +93,14 @@ func TestSideInputsInitializer_Success(t *testing.T) { for x, sideInput := range sideInputs { p := path.Join(mountPath, sideInput) - fileData, err := utils.FetchSideInputFile(p) + fileData, err := utils.FetchSideInputFileValue(p) for err != nil { select { case <-ctx.Done(): t.Fatalf("Context timeout") default: time.Sleep(10 * time.Millisecond) - fileData, err = utils.FetchSideInputFile(p) + fileData, err = utils.FetchSideInputFileValue(p) } } assert.Equal(t, dataTest[x], string(fileData)) diff --git a/pkg/sideinputs/manager/manager.go b/pkg/sideinputs/manager/manager.go index aeda138911..12f4cf5f89 100644 --- a/pkg/sideinputs/manager/manager.go +++ b/pkg/sideinputs/manager/manager.go @@ -23,10 +23,14 @@ import ( cronlib "github.com/robfig/cron/v3" "go.uber.org/zap" + "google.golang.org/protobuf/types/known/emptypb" dfv1 "github.com/numaproj/numaflow/pkg/apis/numaflow/v1alpha1" "github.com/numaproj/numaflow/pkg/isbsvc" + "github.com/numaproj/numaflow/pkg/sdkclient/sideinput" jsclient "github.com/numaproj/numaflow/pkg/shared/clients/nats" + "github.com/numaproj/numaflow/pkg/shared/kvs" + "github.com/numaproj/numaflow/pkg/shared/kvs/jetstream" "github.com/numaproj/numaflow/pkg/shared/logging" ) @@ -52,31 +56,51 @@ func (sim *sideInputsManager) Start(ctx context.Context) error { ctx, cancel := context.WithCancel(ctx) defer cancel() - var isbSvcClient isbsvc.ISBService + var natsClient *jsclient.NATSClient var err error + var siStore kvs.KVStorer switch sim.isbSvcType { case dfv1.ISBSvcTypeRedis: return fmt.Errorf("unsupported isbsvc type %q", sim.isbSvcType) case dfv1.ISBSvcTypeJetStream: - natsClient, err := jsclient.NewNATSClient(ctx) + natsClient, err = jsclient.NewNATSClient(ctx) if err != nil { log.Errorw("Failed to get a NATS client.", zap.Error(err)) return err } - isbSvcClient, err = isbsvc.NewISBJetStreamSvc(sim.pipelineName, isbsvc.WithJetStreamClient(natsClient)) + defer natsClient.Close() + // Load the required KV bucket and create a sideInputWatcher for it + sideInputBucketName := isbsvc.JetStreamSideInputsStoreKVName(sim.sideInputsStore) + siStore, err = jetstream.NewKVJetStreamKVStore(ctx, sideInputBucketName, natsClient) if err != nil { - log.Errorw("Failed to get an ISB Service client.", zap.Error(err)) - return err + return fmt.Errorf("failed to create a new KVStore: %w", err) } + default: return fmt.Errorf("unrecognized isbsvc type %q", sim.isbSvcType) } - // TODO(SI): remove it. - fmt.Printf("ISB Svc Client nil: %v\n", isbSvcClient == nil) + // Create a new gRPC client for Side Input + sideInputClient, err := sideinput.New() + if err != nil { + return fmt.Errorf("failed to create a new gRPC client: %w", err) + } + + // close the connection when the function exits + defer func() { + err = sideInputClient.CloseConn(ctx) + if err != nil { + log.Warnw("Failed to close gRPC client conn", zap.Error(err)) + } + }() + + // Readiness check + if err = sideInputClient.WaitUntilReady(ctx); err != nil { + return fmt.Errorf("failed on SideInput readiness check, %w", err) + } f := func() { - if err := sim.execute(ctx); err != nil { + if err := sim.execute(ctx, sideInputClient, siStore); err != nil { log.Errorw("Failed to execute the call to fetch Side Inputs.", zap.Error(err)) } } @@ -96,10 +120,23 @@ func (sim *sideInputsManager) Start(ctx context.Context) error { return nil } -func (sim *sideInputsManager) execute(ctx context.Context) error { +func (sim *sideInputsManager) execute(ctx context.Context, sideInputClient sideinput.Client, siStore kvs.KVStorer) error { log := logging.FromContext(ctx) - // TODO(SI): call ud container to fetch data and write to store. - log.Info("Executing ...") + log.Info("Executing Side Inputs manager cron ...") + resp, err := sideInputClient.RetrieveSideInput(ctx, &emptypb.Empty{}) + if err != nil { + return fmt.Errorf("failed to retrieve side input: %w", err) + } + // If the NoBroadcast flag is True, skip writing to the store. + if resp.NoBroadcast { + log.Info("Side input is not broadcasted, skipping ...") + return nil + } + // Write the side input value to the store. + err = siStore.PutKV(ctx, sim.sideInput.Name, resp.Value) + if err != nil { + return fmt.Errorf("failed to write side input %q to store: %w", sim.sideInput.Name, err) + } return nil } diff --git a/pkg/sideinputs/synchronizer/synchronizer.go b/pkg/sideinputs/synchronizer/synchronizer.go index cd1b27ccf1..494d4af5de 100644 --- a/pkg/sideinputs/synchronizer/synchronizer.go +++ b/pkg/sideinputs/synchronizer/synchronizer.go @@ -56,8 +56,9 @@ func NewSideInputsSynchronizer(isbSvcType dfv1.ISBSvcType, pipelineName, sideInp // and keeps on watching for updates for all the side inputs while writing the new values to the disk. func (sis *sideInputsSynchronizer) Start(ctx context.Context) error { var ( - natsClient *jsclient.NATSClient - err error + natsClient *jsclient.NATSClient + err error + sideInputWatcher kvs.KVWatcher ) log := logging.FromContext(ctx) @@ -75,15 +76,15 @@ func (sis *sideInputsSynchronizer) Start(ctx context.Context) error { log.Errorw("Failed to get a NATS client.", zap.Error(err)) return err } + // Create a new watcher for the side input KV store + kvName := isbsvc.JetStreamSideInputsStoreKVName(sis.sideInputsStore) + sideInputWatcher, err = jetstream.NewKVJetStreamKVWatch(ctx, kvName, natsClient) + if err != nil { + return fmt.Errorf("failed to create a sideInputWatcher, %w", err) + } default: return fmt.Errorf("unrecognized isbsvc type %q", sis.isbSvcType) } - // Create a new watcher for the side input KV store - kvName := isbsvc.JetStreamSideInputsStoreKVName(sis.sideInputsStore) - sideInputWatcher, err := jetstream.NewKVJetStreamKVWatch(ctx, kvName, natsClient) - if err != nil { - return fmt.Errorf("failed to create a sideInputWatcher, %w", err) - } go startSideInputSynchronizer(ctx, sideInputWatcher, dfv1.PathSideInputsMount) <-ctx.Done() return nil @@ -104,7 +105,7 @@ func startSideInputSynchronizer(ctx context.Context, watch kvs.KVWatcher, mountP log.Warnw("nil value received from Side Input watcher") continue } - log.Debug("Side Input value received ", + log.Infow("Side Input value received ", zap.String("key", value.Key()), zap.String("value", string(value.Value()))) p := path.Join(mountPath, value.Key()) // Write changes to disk diff --git a/pkg/sideinputs/synchronizer/synchronizer_test.go b/pkg/sideinputs/synchronizer/synchronizer_test.go index 884a7a34c4..08bb5976ee 100644 --- a/pkg/sideinputs/synchronizer/synchronizer_test.go +++ b/pkg/sideinputs/synchronizer/synchronizer_test.go @@ -103,14 +103,14 @@ func TestSideInputsValueUpdates(t *testing.T) { for x, sideInput := range sideInputs { p := path.Join(mountPath, sideInput) - fileData, err := utils.FetchSideInputFile(p) + fileData, err := utils.FetchSideInputFileValue(p) for err != nil { select { case <-ctx.Done(): t.Fatalf("Context timeout") default: time.Sleep(10 * time.Millisecond) - fileData, err = utils.FetchSideInputFile(p) + fileData, err = utils.FetchSideInputFileValue(p) } } assert.Equal(t, dataTest[x], string(fileData)) diff --git a/pkg/sideinputs/utils/utils.go b/pkg/sideinputs/utils/utils.go index a9a3d205a6..807dd4bb8c 100644 --- a/pkg/sideinputs/utils/utils.go +++ b/pkg/sideinputs/utils/utils.go @@ -17,11 +17,14 @@ limitations under the License. package utils import ( + "bytes" "context" "fmt" "os" "time" + "go.uber.org/zap" + "github.com/numaproj/numaflow/pkg/shared/logging" ) @@ -40,8 +43,20 @@ func UpdateSideInputFile(ctx context.Context, fileSymLink string, value []byte) timestamp := time.Now().UnixNano() newFileName := fmt.Sprintf("%s_%d", fileSymLink, timestamp) + // Fetch the current side input value from the file + currentValue, err := FetchSideInputFileValue(fileSymLink) + + // Check if the current value is same as the new value + // If true then don't update file again and return + if err == nil && bytes.Equal(currentValue, value) { + log.Debugw("Side Input value is same as current value, "+ + "skipping update", zap.String("side_input", fileSymLink)) + return nil + } + // Write the side input value to the new file - err := os.WriteFile(newFileName, value, 0666) + // A New file is created with the given name if it doesn't exist + err = os.WriteFile(newFileName, value, 0666) if err != nil { return fmt.Errorf("failed to write Side Input file %s : %w", newFileName, err) } @@ -74,9 +89,8 @@ func UpdateSideInputFile(ctx context.Context, fileSymLink string, value []byte) return nil } -// FetchSideInputFile reads a given file and returns the value in bytes -// Used as utility for unit tests -func FetchSideInputFile(filePath string) ([]byte, error) { +// FetchSideInputFileValue reads a given file and returns the value in bytes +func FetchSideInputFileValue(filePath string) ([]byte, error) { b, err := os.ReadFile(filePath) if err != nil { return nil, fmt.Errorf("failed to read Side Input %s file: %w", filePath, err) diff --git a/pkg/sideinputs/utils/utils_test.go b/pkg/sideinputs/utils/utils_test.go index 18d420d756..c6100f9d14 100644 --- a/pkg/sideinputs/utils/utils_test.go +++ b/pkg/sideinputs/utils/utils_test.go @@ -17,6 +17,7 @@ limitations under the License. package utils import ( + "bytes" "context" "os" "testing" @@ -35,10 +36,6 @@ func cleanup(mountPath string) { // TestSymLinkUpdate tests that the symlink is updated with a new file // whenever data is written to the symlink. func TestSymLinkUpdate(t *testing.T) { - var ( - size = int64(10 * 1024 * 1024) // 100 MB - byteArray = make([]byte, size) - ) mountPath, err := os.MkdirTemp("", "side-input") assert.NoError(t, err) // Clean up @@ -54,13 +51,15 @@ func TestSymLinkUpdate(t *testing.T) { err := os.Mkdir(mountPath, 0777) assert.NoError(t, err) } + byteArray := []byte("test") // Write data to the link err = UpdateSideInputFile(ctx, filePath, byteArray) assert.NoError(t, err) // Get the target file from the symlink file1, err := os.Readlink(filePath) assert.NoError(t, err) - // Write data to the link again + byteArray = []byte("test-new") + // Write data to the link again with new value err = UpdateSideInputFile(ctx, filePath, byteArray) assert.NoError(t, err) // Get the new target file from the symlink @@ -75,10 +74,6 @@ func TestSymLinkUpdate(t *testing.T) { func TestSymLinkFileDelete(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) defer cancel() - var ( - size = int64(10 * 1024 * 1024) // 100 MB - byteArray = make([]byte, size) - ) mountPath, err := os.MkdirTemp("", "side-input") assert.NoError(t, err) // Clean up @@ -89,14 +84,54 @@ func TestSymLinkFileDelete(t *testing.T) { fileName := filePath.Name() // Write data to the link + byteArray := []byte("test") err = UpdateSideInputFile(ctx, fileName, byteArray) assert.NoError(t, err) // Get the target file from the symlink file1, err := os.Readlink(fileName) assert.NoError(t, err) // Write data to the link again + byteArray = []byte("test-new") err = UpdateSideInputFile(ctx, fileName, byteArray) assert.NoError(t, err) // The older file should have been deleted assert.False(t, CheckFileExists(file1)) } + +// TestUpdateSideInputFileNoUpdate tests if the new value is same as the current +// value then new file isn't created and file is not updated. +func TestUpdateSideInputFileNoUpdate(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + mountPath, err := os.MkdirTemp("", "side-input") + assert.NoError(t, err) + // Clean up + defer cleanup(mountPath) + + filePath, err := os.CreateTemp(mountPath, "unit-test") + assert.NoError(t, err) + fileName := filePath.Name() + + byteArray := []byte("test") + // Write data to the link + err = UpdateSideInputFile(ctx, fileName, byteArray) + assert.NoError(t, err) + // Get the target file from the symlink + file1, err := os.Readlink(fileName) + assert.NoError(t, err) + data1, err := FetchSideInputFileValue(fileName) + assert.NoError(t, err) + // Write data to the link again with same value + err = UpdateSideInputFile(ctx, fileName, byteArray) + assert.NoError(t, err) + // Get the new target file from the symlink + file2, err := os.Readlink(fileName) + assert.NoError(t, err) + data2, err := FetchSideInputFileValue(fileName) + assert.NoError(t, err) + // We expect the target to be same file + assert.Equal(t, file1, file2) + // We expect the target to have the same data + assert.True(t, bytes.Equal(data1, data2)) + +}