diff --git a/router/batchrouter/asyncdestinationmanager/snowpipestreaming/channel.go b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/channel.go index 5e01150e44..9664eba1f3 100644 --- a/router/batchrouter/asyncdestinationmanager/snowpipestreaming/channel.go +++ b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/channel.go @@ -3,6 +3,7 @@ package snowpipestreaming import ( "context" "fmt" + "time" "github.com/rudderlabs/rudder-go-kit/logger" @@ -12,6 +13,14 @@ import ( whutils "github.com/rudderlabs/rudder-server/warehouse/utils" ) +type snowpipeAuthzError struct { + err error +} + +func (sae *snowpipeAuthzError) Error() string { + return sae.err.Error() +} + // initializeChannelWithSchema creates a new channel for the given table if it doesn't exist. // If the channel already exists, it checks for new columns and adds them to the table. // It returns the channel response after creating or recreating the channel. @@ -66,7 +75,8 @@ func (m *Manager) addColumns(ctx context.Context, namespace, tableName string, c snowflakeManager.Cleanup(ctx) }() if err = snowflakeManager.AddColumns(ctx, tableName, columns); err != nil { - return fmt.Errorf("adding column: %w", err) + m.setAuthzErrorTime() + return &snowpipeAuthzError{fmt.Errorf("adding column: %w", err)} } return nil } @@ -157,10 +167,12 @@ func (m *Manager) handleSchemaError( snowflakeManager.Cleanup(ctx) }() if err := snowflakeManager.CreateSchema(ctx); err != nil { - return nil, fmt.Errorf("creating schema: %w", err) + m.setAuthzErrorTime() + return nil, &snowpipeAuthzError{fmt.Errorf("creating schema: %w", err)} } if err := snowflakeManager.CreateTable(ctx, channelReq.TableConfig.Table, eventSchema); err != nil { - return nil, fmt.Errorf("creating table: %w", err) + m.setAuthzErrorTime() + return nil, &snowpipeAuthzError{fmt.Errorf("creating table: %w", err)} } return m.api.CreateChannel(ctx, channelReq) } @@ -185,7 +197,8 @@ func (m *Manager) handleTableError( snowflakeManager.Cleanup(ctx) }() if err := snowflakeManager.CreateTable(ctx, channelReq.TableConfig.Table, eventSchema); err != nil { - return nil, fmt.Errorf("creating table: %w", err) + m.setAuthzErrorTime() + return nil, &snowpipeAuthzError{fmt.Errorf("creating table: %w", err)} } return m.api.CreateChannel(ctx, channelReq) } @@ -225,6 +238,10 @@ func (m *Manager) deleteChannel(ctx context.Context, tableName, channelID string } func (m *Manager) createSnowflakeManager(ctx context.Context, namespace string) (manager.Manager, error) { + nextBackoffTime := m.config.lastestAuthzErrorTime.Add(m.config.backoffDuration) + if m.now().Before(nextBackoffTime) { + return nil, &snowpipeAuthzError{fmt.Errorf("skipping snowflake manager creation due to backoff")} + } modelWarehouse := whutils.ModelWarehouse{ WorkspaceID: m.destination.WorkspaceID, Destination: *m.destination, @@ -234,13 +251,19 @@ func (m *Manager) createSnowflakeManager(ctx context.Context, namespace string) } modelWarehouse.Destination.Config["useKeyPairAuth"] = true // Since we are currently only supporting key pair auth - sf, err := manager.New(whutils.SnowpipeStreaming, m.appConfig, m.logger, m.statsFactory) - if err != nil { - return nil, fmt.Errorf("creating snowflake manager: %w", err) - } - err = sf.Setup(ctx, modelWarehouse, whutils.NewNoOpUploader()) - if err != nil { - return nil, fmt.Errorf("setting up snowflake manager: %w", err) + return m.managerCreator(ctx, modelWarehouse, m.appConfig, m.logger, m.statsFactory) +} + +func (m *Manager) setAuthzErrorTime() { + m.config.lastestAuthzErrorTime = m.now() + if m.config.backoffDuration == 0 { + m.config.backoffDuration = m.config.initialBackoffDuration + } else { + m.config.backoffDuration = m.config.backoffDuration * 2 } - return sf, nil +} + +func (m *Manager) resetAuthzErrorTime() { + m.config.lastestAuthzErrorTime = time.Time{} + m.config.backoffDuration = 0 } diff --git a/router/batchrouter/asyncdestinationmanager/snowpipestreaming/snowpipestreaming.go b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/snowpipestreaming.go index fef146dfdb..11d2c3359b 100644 --- a/router/batchrouter/asyncdestinationmanager/snowpipestreaming/snowpipestreaming.go +++ b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/snowpipestreaming.go @@ -4,6 +4,7 @@ import ( "bufio" "context" stdjson "encoding/json" + "errors" "fmt" "net/http" "os" @@ -31,6 +32,7 @@ import ( "github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/snowpipestreaming/internal/model" "github.com/rudderlabs/rudder-server/utils/misc" "github.com/rudderlabs/rudder-server/utils/timeutil" + "github.com/rudderlabs/rudder-server/warehouse/integrations/manager" whutils "github.com/rudderlabs/rudder-server/warehouse/utils" ) @@ -38,13 +40,13 @@ var json = jsoniter.ConfigCompatibleWithStandardLibrary func New( conf *config.Config, - logger logger.Logger, + mLogger logger.Logger, statsFactory stats.Stats, destination *backendconfig.DestinationT, ) *Manager { m := &Manager{ appConfig: conf, - logger: logger.Child("snowpipestreaming").Withn( + logger: mLogger.Child("snowpipestreaming").Withn( obskit.WorkspaceID(destination.WorkspaceID), obskit.DestinationID(destination.ID), obskit.DestinationType(destination.DestinationDefinition.Name), @@ -67,6 +69,7 @@ func New( m.config.client.retryMax = conf.GetInt("SnowpipeStreaming.Client.retryMax", 5) m.config.instanceID = conf.GetString("INSTANCE_ID", "1") m.config.maxBufferCapacity = conf.GetReloadableInt64Var(512*bytesize.KB, bytesize.B, "SnowpipeStreaming.maxBufferCapacity") + m.config.initialBackoffDuration = conf.GetDuration("SnowpipeStreaming.backoffDuration", 1, time.Second) tags := stats.Tags{ "module": "batch_router", @@ -100,6 +103,17 @@ func New( snowpipeapi.New(m.appConfig, m.statsFactory, m.config.client.url, m.requestDoer), destination, ) + m.managerCreator = func(mCtx context.Context, modelWarehouse whutils.ModelWarehouse, conf *config.Config, logger logger.Logger, stats stats.Stats) (manager.Manager, error) { + sf, err := manager.New(whutils.SnowpipeStreaming, conf, logger, stats) + if err != nil { + return nil, fmt.Errorf("creating snowflake manager: %w", err) + } + err = sf.Setup(mCtx, modelWarehouse, whutils.NewNoOpUploader()) + if err != nil { + return nil, fmt.Errorf("setting up snowflake manager: %w", err) + } + return sf, nil + } return m } @@ -149,10 +163,23 @@ func (m *Manager) Upload(asyncDest *common.AsyncDestinationStruct) common.AsyncU ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // backoff should be reset if authz error is not encountered for any of the tables + shouldResetBackoff := true discardsChannel, err := m.initializeChannelWithSchema(ctx, asyncDest.Destination.ID, &destConf, discardsTable(), discardsSchema()) if err != nil { - return m.abortJobs(asyncDest, fmt.Errorf("failed to prepare discards channel: %w", err).Error()) + var authzErr *snowpipeAuthzError + if errors.As(err, &authzErr) { + // Ignoring this error so that the jobs are marked as failed and not aborted since + // we want these jobs to be retried the next time. + m.logger.Warnn("Failed to initialize channel with schema", + logger.NewStringField("table", discardsTable()), + obskit.Error(err), + ) + shouldResetBackoff = false + } else { + return m.abortJobs(asyncDest, fmt.Errorf("failed to prepare discards channel: %w", err).Error()) + } } m.logger.Infon("Prepared discards channel") @@ -187,6 +214,12 @@ func (m *Manager) Upload(asyncDest *common.AsyncDestinationStruct) common.AsyncU for _, info := range uploadInfos { imInfo, discardImInfo, err := m.sendEventsToSnowpipe(ctx, asyncDest.Destination.ID, &destConf, info) if err != nil { + var authzErr *snowpipeAuthzError + if errors.As(err, &authzErr) { + if shouldResetBackoff { + shouldResetBackoff = false + } + } m.logger.Warnn("Failed to send events to Snowpipe", logger.NewStringField("table", info.tableName), obskit.Error(err), @@ -206,6 +239,9 @@ func (m *Manager) Upload(asyncDest *common.AsyncDestinationStruct) common.AsyncU discardImportInfo.Offset = discardImInfo.Offset } } + if shouldResetBackoff { + m.resetAuthzErrorTime() + } if discardImportInfo != nil { importInfos = append(importInfos, discardImportInfo) } diff --git a/router/batchrouter/asyncdestinationmanager/snowpipestreaming/snowpipestreaming_test.go b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/snowpipestreaming_test.go index 912a546cb1..d11d35949b 100644 --- a/router/batchrouter/asyncdestinationmanager/snowpipestreaming/snowpipestreaming_test.go +++ b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/snowpipestreaming_test.go @@ -2,8 +2,10 @@ package snowpipestreaming import ( "context" + "fmt" "net/http" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -16,7 +18,10 @@ import ( backendconfig "github.com/rudderlabs/rudder-server/backend-config" "github.com/rudderlabs/rudder-server/jobsdb" "github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/common" + internalapi "github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/snowpipestreaming/internal/api" "github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/snowpipestreaming/internal/model" + "github.com/rudderlabs/rudder-server/warehouse/integrations/manager" + "github.com/rudderlabs/rudder-server/warehouse/integrations/snowflake" whutils "github.com/rudderlabs/rudder-server/warehouse/utils" ) @@ -43,6 +48,29 @@ func (m *mockAPI) GetStatus(_ context.Context, channelID string) (*model.StatusR return m.getStatusOutputMap[channelID]() } +type mockManager struct { + manager.Manager + throwSchemaErr bool +} + +func newMockManager(m manager.Manager, throwSchemaErr bool) *mockManager { + return &mockManager{ + Manager: m, + throwSchemaErr: throwSchemaErr, + } +} + +func (m *mockManager) CreateSchema(ctx context.Context) (err error) { + if m.throwSchemaErr { + return fmt.Errorf("failed to create schema") + } + return nil +} + +func (m *mockManager) CreateTable(ctx context.Context, tableName string, columnMap manager.ModelTableSchema) (err error) { + return nil +} + var ( usersChannelResponse = &model.ChannelResponse{ ChannelID: "test-users-channel", @@ -77,6 +105,7 @@ func TestSnowpipeStreaming(t *testing.T) { DestinationDefinition: backendconfig.DestinationDefinitionT{ Name: "SNOWPIPE_STREAMING", }, + Config: make(map[string]interface{}), } t.Run("Upload with invalid file path", func(t *testing.T) { @@ -100,6 +129,7 @@ func TestSnowpipeStreaming(t *testing.T) { "status": "aborted", }).LastValue()) }) + t.Run("Upload with invalid record in file", func(t *testing.T) { statsStore, err := memstats.New() require.NoError(t, err) @@ -310,6 +340,96 @@ func TestSnowpipeStreaming(t *testing.T) { "status": "failed", }).LastValue()) }) + + t.Run("Upload with unauthorized schema error should add backoff", func(t *testing.T) { + statsStore, err := memstats.New() + require.NoError(t, err) + + sm := New(config.New(), logger.NOP, statsStore, destination) + sm.channelCache.Store("RUDDER_DISCARDS", rudderDiscardsChannelResponse) + sm.api = &mockAPI{ + createChannelOutputMap: map[string]func() (*model.ChannelResponse, error){ + "USERS": func() (*model.ChannelResponse, error) { + return &model.ChannelResponse{Code: internalapi.ErrSchemaDoesNotExistOrNotAuthorized}, nil + }, + }, + } + managerCreatorCallCount := 0 + sm.managerCreator = func(_ context.Context, _ whutils.ModelWarehouse, _ *config.Config, _ logger.Logger, _ stats.Stats) (manager.Manager, error) { + sm := snowflake.New(config.New(), logger.NOP, stats.NOP) + managerCreatorCallCount++ + return newMockManager(sm, true), nil + } + sm.config.initialBackoffDuration = time.Second * 10 + asyncDestStruct := &common.AsyncDestinationStruct{ + Destination: destination, + FileName: "testdata/successful_user_records.txt", + } + output1 := sm.Upload(asyncDestStruct) + require.Equal(t, 2, output1.FailedCount) + require.Equal(t, 0, output1.AbortCount) + require.Equal(t, 1, managerCreatorCallCount) + require.Equal(t, time.Second*10, sm.config.backoffDuration) + require.Equal(t, false, sm.config.lastestAuthzErrorTime.IsZero()) + + sm.Upload(asyncDestStruct) + // client is not created again due to backoff error + require.Equal(t, 1, managerCreatorCallCount) + require.Equal(t, time.Second*10, sm.config.backoffDuration) + require.Equal(t, false, sm.config.lastestAuthzErrorTime.IsZero()) + + sm.now = func() time.Time { + return time.Now().UTC().Add(time.Second * 100) + } + + sm.Upload(asyncDestStruct) + // client created again since backoff duration has been exceeded + require.Equal(t, 2, managerCreatorCallCount) + require.Equal(t, time.Second*20, sm.config.backoffDuration) + require.Equal(t, false, sm.config.lastestAuthzErrorTime.IsZero()) + + sm.managerCreator = func(_ context.Context, _ whutils.ModelWarehouse, _ *config.Config, _ logger.Logger, _ stats.Stats) (manager.Manager, error) { + sm := snowflake.New(config.New(), logger.NOP, stats.NOP) + managerCreatorCallCount++ + return newMockManager(sm, false), nil + } + sm.now = func() time.Time { + return time.Now().UTC().Add(time.Second * 200) + } + sm.Upload(asyncDestStruct) + require.Equal(t, 3, managerCreatorCallCount) + // no error should reset the backoff config + require.Equal(t, time.Duration(0), sm.config.backoffDuration) + require.Equal(t, true, sm.config.lastestAuthzErrorTime.IsZero()) + }) + + t.Run("Upload with discards table authorization error should not abort the job", func(t *testing.T) { + statsStore, err := memstats.New() + require.NoError(t, err) + + sm := New(config.New(), logger.NOP, statsStore, destination) + sm.api = &mockAPI{ + createChannelOutputMap: map[string]func() (*model.ChannelResponse, error){ + "RUDDER_DISCARDS": func() (*model.ChannelResponse, error) { + return &model.ChannelResponse{Code: internalapi.ErrSchemaDoesNotExistOrNotAuthorized}, nil + }, + "USERS": func() (*model.ChannelResponse, error) { + return &model.ChannelResponse{Code: internalapi.ErrSchemaDoesNotExistOrNotAuthorized}, nil + }, + }, + } + sm.managerCreator = func(_ context.Context, _ whutils.ModelWarehouse, _ *config.Config, _ logger.Logger, _ stats.Stats) (manager.Manager, error) { + sm := snowflake.New(config.New(), logger.NOP, stats.NOP) + return newMockManager(sm, true), nil + } + output := sm.Upload(&common.AsyncDestinationStruct{ + Destination: destination, + FileName: "testdata/successful_user_records.txt", + }) + require.Equal(t, 2, output.FailedCount) + require.Equal(t, 0, output.AbortCount) + }) + t.Run("Upload insert error for all events", func(t *testing.T) { statsStore, err := memstats.New() require.NoError(t, err) diff --git a/router/batchrouter/asyncdestinationmanager/snowpipestreaming/testdata/successful_user_records.txt b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/testdata/successful_user_records.txt new file mode 100644 index 0000000000..007ae8bafe --- /dev/null +++ b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/testdata/successful_user_records.txt @@ -0,0 +1,2 @@ +{"message":{"metadata":{"table":"USERS","columns":{"ID":"int","NAME":"string","AGE":"int","RECEIVED_AT":"datetime"}},"data":{"ID":1,"NAME":"Alice","AGE":30,"RECEIVED_AT":"2023-05-12T04:36:50.199Z"}},"metadata":{"job_id":1001}} +{"message":{"metadata":{"table":"USERS","columns":{"ID":"int","NAME":"string","AGE":"int","RECEIVED_AT":"datetime"}},"data":{"ID":1,"NAME":"Alice","AGE":30,"RECEIVED_AT":"2023-05-12T04:36:50.199Z"}},"metadata":{"job_id":1003}} diff --git a/router/batchrouter/asyncdestinationmanager/snowpipestreaming/types.go b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/types.go index 3376f6104a..09ef7f15b5 100644 --- a/router/batchrouter/asyncdestinationmanager/snowpipestreaming/types.go +++ b/router/batchrouter/asyncdestinationmanager/snowpipestreaming/types.go @@ -15,6 +15,7 @@ import ( "github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/snowpipestreaming/internal/model" backendconfig "github.com/rudderlabs/rudder-server/backend-config" + "github.com/rudderlabs/rudder-server/warehouse/integrations/manager" whutils "github.com/rudderlabs/rudder-server/warehouse/utils" ) @@ -25,6 +26,7 @@ type ( statsFactory stats.Stats destination *backendconfig.DestinationT requestDoer requestDoer + managerCreator func(mCtx context.Context, modelWarehouse whutils.ModelWarehouse, conf *config.Config, mLogger logger.Logger, stats stats.Stats) (manager.Manager, error) now func() time.Time api api channelCache sync.Map @@ -44,6 +46,12 @@ type ( } instanceID string maxBufferCapacity config.ValueLoader[int64] + // time at which the an attempt was made to create a resource but it failed likely due to permission issues. + lastestAuthzErrorTime time.Time + // If lastAttemptedTime is not zero, then the next attempt to create a SF connection will be made after backoffDuration. + // This approach prevents repeatedly activating the warehouse even though the permission issue remains unresolved. + backoffDuration time.Duration + initialBackoffDuration time.Duration } stats struct { diff --git a/warehouse/integrations/manager/manager.go b/warehouse/integrations/manager/manager.go index 59aeeaa34d..b91006aae2 100644 --- a/warehouse/integrations/manager/manager.go +++ b/warehouse/integrations/manager/manager.go @@ -26,6 +26,9 @@ import ( warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" ) +// Creating an alias since "model.TableSchema" is defined in an internal module +type ModelTableSchema = model.TableSchema + type Manager interface { Setup(ctx context.Context, warehouse model.Warehouse, uploader warehouseutils.Uploader) error FetchSchema(ctx context.Context) (model.Schema, error)