Skip to content

Commit

Permalink
feat: implement backoff for snowpipe streaming authorization errors
Browse files Browse the repository at this point in the history
  • Loading branch information
shekhar-rudder committed Dec 27, 2024
1 parent 8a186e5 commit 476cdb4
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package snowpipestreaming
import (
"context"
"fmt"
"time"

"github.com/rudderlabs/rudder-go-kit/logger"

Expand All @@ -12,6 +13,14 @@ import (
whutils "github.com/rudderlabs/rudder-server/warehouse/utils"
)

type snowpipeAuthzError struct {
err error
}

func (sae *snowpipeAuthzError) Error() string {
return sae.err.Error()
}

// initializeChannelWithSchema creates a new channel for the given table if it doesn't exist.
// If the channel already exists, it checks for new columns and adds them to the table.
// It returns the channel response after creating or recreating the channel.
Expand Down Expand Up @@ -66,7 +75,8 @@ func (m *Manager) addColumns(ctx context.Context, namespace, tableName string, c
snowflakeManager.Cleanup(ctx)
}()
if err = snowflakeManager.AddColumns(ctx, tableName, columns); err != nil {
return fmt.Errorf("adding column: %w", err)
m.setAuthzErrorTime()
return &snowpipeAuthzError{fmt.Errorf("adding column: %w", err)}
}
return nil
}
Expand Down Expand Up @@ -157,10 +167,12 @@ func (m *Manager) handleSchemaError(
snowflakeManager.Cleanup(ctx)
}()
if err := snowflakeManager.CreateSchema(ctx); err != nil {
return nil, fmt.Errorf("creating schema: %w", err)
m.setAuthzErrorTime()
return nil, &snowpipeAuthzError{fmt.Errorf("creating schema: %w", err)}
}
if err := snowflakeManager.CreateTable(ctx, channelReq.TableConfig.Table, eventSchema); err != nil {
return nil, fmt.Errorf("creating table: %w", err)
m.setAuthzErrorTime()
return nil, &snowpipeAuthzError{fmt.Errorf("creating table: %w", err)}
}
return m.api.CreateChannel(ctx, channelReq)
}
Expand All @@ -185,7 +197,8 @@ func (m *Manager) handleTableError(
snowflakeManager.Cleanup(ctx)
}()
if err := snowflakeManager.CreateTable(ctx, channelReq.TableConfig.Table, eventSchema); err != nil {
return nil, fmt.Errorf("creating table: %w", err)
m.setAuthzErrorTime()
return nil, &snowpipeAuthzError{fmt.Errorf("creating table: %w", err)}
}
return m.api.CreateChannel(ctx, channelReq)
}
Expand Down Expand Up @@ -225,6 +238,10 @@ func (m *Manager) deleteChannel(ctx context.Context, tableName, channelID string
}

func (m *Manager) createSnowflakeManager(ctx context.Context, namespace string) (manager.Manager, error) {
nextBackoffTime := m.config.lastestAuthzErrorTime.Add(m.config.backoffDuration)
if m.now().Before(nextBackoffTime) {
return nil, &snowpipeAuthzError{fmt.Errorf("skipping snowflake manager creation due to backoff")}
}
modelWarehouse := whutils.ModelWarehouse{
WorkspaceID: m.destination.WorkspaceID,
Destination: *m.destination,
Expand All @@ -234,13 +251,19 @@ func (m *Manager) createSnowflakeManager(ctx context.Context, namespace string)
}
modelWarehouse.Destination.Config["useKeyPairAuth"] = true // Since we are currently only supporting key pair auth

sf, err := manager.New(whutils.SnowpipeStreaming, m.appConfig, m.logger, m.statsFactory)
if err != nil {
return nil, fmt.Errorf("creating snowflake manager: %w", err)
}
err = sf.Setup(ctx, modelWarehouse, whutils.NewNoOpUploader())
if err != nil {
return nil, fmt.Errorf("setting up snowflake manager: %w", err)
return m.managerCreator(ctx, modelWarehouse, m.appConfig, m.logger, m.statsFactory)
}

func (m *Manager) setAuthzErrorTime() {
m.config.lastestAuthzErrorTime = m.now()
if m.config.backoffDuration == 0 {
m.config.backoffDuration = m.config.initialBackoffDuration
} else {
m.config.backoffDuration = m.config.backoffDuration * 2
}
return sf, nil
}

func (m *Manager) resetAuthzErrorTime() {
m.config.lastestAuthzErrorTime = time.Time{}
m.config.backoffDuration = 0
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"context"
stdjson "encoding/json"
"errors"
"fmt"
"net/http"
"os"
Expand Down Expand Up @@ -31,20 +32,21 @@ import (
"github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/snowpipestreaming/internal/model"
"github.com/rudderlabs/rudder-server/utils/misc"
"github.com/rudderlabs/rudder-server/utils/timeutil"
"github.com/rudderlabs/rudder-server/warehouse/integrations/manager"
whutils "github.com/rudderlabs/rudder-server/warehouse/utils"
)

var json = jsoniter.ConfigCompatibleWithStandardLibrary

func New(
conf *config.Config,
logger logger.Logger,
mLogger logger.Logger,
statsFactory stats.Stats,
destination *backendconfig.DestinationT,
) *Manager {
m := &Manager{
appConfig: conf,
logger: logger.Child("snowpipestreaming").Withn(
logger: mLogger.Child("snowpipestreaming").Withn(
obskit.WorkspaceID(destination.WorkspaceID),
obskit.DestinationID(destination.ID),
obskit.DestinationType(destination.DestinationDefinition.Name),
Expand All @@ -67,6 +69,7 @@ func New(
m.config.client.retryMax = conf.GetInt("SnowpipeStreaming.Client.retryMax", 5)
m.config.instanceID = conf.GetString("INSTANCE_ID", "1")
m.config.maxBufferCapacity = conf.GetReloadableInt64Var(512*bytesize.KB, bytesize.B, "SnowpipeStreaming.maxBufferCapacity")
m.config.initialBackoffDuration = conf.GetDuration("SnowpipeStreaming.backoffDuration", 1, time.Second)

tags := stats.Tags{
"module": "batch_router",
Expand Down Expand Up @@ -100,6 +103,17 @@ func New(
snowpipeapi.New(m.appConfig, m.statsFactory, m.config.client.url, m.requestDoer),
destination,
)
m.managerCreator = func(mCtx context.Context, modelWarehouse whutils.ModelWarehouse, conf *config.Config, logger logger.Logger, stats stats.Stats) (manager.Manager, error) {
sf, err := manager.New(whutils.SnowpipeStreaming, conf, logger, stats)
if err != nil {
return nil, fmt.Errorf("creating snowflake manager: %w", err)
}
err = sf.Setup(mCtx, modelWarehouse, whutils.NewNoOpUploader())
if err != nil {
return nil, fmt.Errorf("setting up snowflake manager: %w", err)
}
return sf, nil
}
return m
}

Expand Down Expand Up @@ -149,10 +163,23 @@ func (m *Manager) Upload(asyncDest *common.AsyncDestinationStruct) common.AsyncU

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// backoff should be reset if authz error is not encountered for any of the tables
shouldResetBackoff := true

discardsChannel, err := m.initializeChannelWithSchema(ctx, asyncDest.Destination.ID, &destConf, discardsTable(), discardsSchema())
if err != nil {
return m.abortJobs(asyncDest, fmt.Errorf("failed to prepare discards channel: %w", err).Error())
var authzErr *snowpipeAuthzError
if errors.As(err, &authzErr) {
// Ignoring this error so that the jobs are marked as failed and not aborted since
// we want these jobs to be retried the next time.
m.logger.Warnn("Failed to initialize channel with schema",
logger.NewStringField("table", discardsTable()),
obskit.Error(err),
)
shouldResetBackoff = false
} else {
return m.abortJobs(asyncDest, fmt.Errorf("failed to prepare discards channel: %w", err).Error())
}
}
m.logger.Infon("Prepared discards channel")

Expand Down Expand Up @@ -187,6 +214,12 @@ func (m *Manager) Upload(asyncDest *common.AsyncDestinationStruct) common.AsyncU
for _, info := range uploadInfos {
imInfo, discardImInfo, err := m.sendEventsToSnowpipe(ctx, asyncDest.Destination.ID, &destConf, info)
if err != nil {
var authzErr *snowpipeAuthzError
if errors.As(err, &authzErr) {
if shouldResetBackoff {
shouldResetBackoff = false
}
}
m.logger.Warnn("Failed to send events to Snowpipe",
logger.NewStringField("table", info.tableName),
obskit.Error(err),
Expand All @@ -206,6 +239,9 @@ func (m *Manager) Upload(asyncDest *common.AsyncDestinationStruct) common.AsyncU
discardImportInfo.Offset = discardImInfo.Offset
}
}
if shouldResetBackoff {
m.resetAuthzErrorTime()
}
if discardImportInfo != nil {
importInfos = append(importInfos, discardImportInfo)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package snowpipestreaming

import (
"context"
"fmt"
"net/http"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -16,7 +18,10 @@ import (
backendconfig "github.com/rudderlabs/rudder-server/backend-config"
"github.com/rudderlabs/rudder-server/jobsdb"
"github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/common"
internalapi "github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/snowpipestreaming/internal/api"
"github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/snowpipestreaming/internal/model"
"github.com/rudderlabs/rudder-server/warehouse/integrations/manager"
"github.com/rudderlabs/rudder-server/warehouse/integrations/snowflake"
whutils "github.com/rudderlabs/rudder-server/warehouse/utils"
)

Expand All @@ -43,6 +48,29 @@ func (m *mockAPI) GetStatus(_ context.Context, channelID string) (*model.StatusR
return m.getStatusOutputMap[channelID]()
}

type mockManager struct {
manager.Manager
throwSchemaErr bool
}

func newMockManager(m manager.Manager, throwSchemaErr bool) *mockManager {
return &mockManager{
Manager: m,
throwSchemaErr: throwSchemaErr,
}
}

func (m *mockManager) CreateSchema(ctx context.Context) (err error) {
if m.throwSchemaErr {
return fmt.Errorf("failed to create schema")
}
return nil
}

func (m *mockManager) CreateTable(ctx context.Context, tableName string, columnMap manager.ModelTableSchema) (err error) {
return nil
}

var (
usersChannelResponse = &model.ChannelResponse{
ChannelID: "test-users-channel",
Expand Down Expand Up @@ -77,6 +105,7 @@ func TestSnowpipeStreaming(t *testing.T) {
DestinationDefinition: backendconfig.DestinationDefinitionT{
Name: "SNOWPIPE_STREAMING",
},
Config: make(map[string]interface{}),
}

t.Run("Upload with invalid file path", func(t *testing.T) {
Expand All @@ -100,6 +129,7 @@ func TestSnowpipeStreaming(t *testing.T) {
"status": "aborted",
}).LastValue())
})

t.Run("Upload with invalid record in file", func(t *testing.T) {
statsStore, err := memstats.New()
require.NoError(t, err)
Expand Down Expand Up @@ -310,6 +340,96 @@ func TestSnowpipeStreaming(t *testing.T) {
"status": "failed",
}).LastValue())
})

t.Run("Upload with unauthorized schema error should add backoff", func(t *testing.T) {
statsStore, err := memstats.New()
require.NoError(t, err)

sm := New(config.New(), logger.NOP, statsStore, destination)
sm.channelCache.Store("RUDDER_DISCARDS", rudderDiscardsChannelResponse)
sm.api = &mockAPI{
createChannelOutputMap: map[string]func() (*model.ChannelResponse, error){
"USERS": func() (*model.ChannelResponse, error) {
return &model.ChannelResponse{Code: internalapi.ErrSchemaDoesNotExistOrNotAuthorized}, nil
},
},
}
managerCreatorCallCount := 0
sm.managerCreator = func(_ context.Context, _ whutils.ModelWarehouse, _ *config.Config, _ logger.Logger, _ stats.Stats) (manager.Manager, error) {
sm := snowflake.New(config.New(), logger.NOP, stats.NOP)
managerCreatorCallCount++
return newMockManager(sm, true), nil
}
sm.config.initialBackoffDuration = time.Second * 10
asyncDestStruct := &common.AsyncDestinationStruct{
Destination: destination,
FileName: "testdata/successful_user_records.txt",
}
output1 := sm.Upload(asyncDestStruct)
require.Equal(t, 2, output1.FailedCount)
require.Equal(t, 0, output1.AbortCount)
require.Equal(t, 1, managerCreatorCallCount)
require.Equal(t, time.Second*10, sm.config.backoffDuration)
require.Equal(t, false, sm.config.lastestAuthzErrorTime.IsZero())

sm.Upload(asyncDestStruct)
// client is not created again due to backoff error
require.Equal(t, 1, managerCreatorCallCount)
require.Equal(t, time.Second*10, sm.config.backoffDuration)
require.Equal(t, false, sm.config.lastestAuthzErrorTime.IsZero())

sm.now = func() time.Time {
return time.Now().UTC().Add(time.Second * 100)
}

sm.Upload(asyncDestStruct)
// client created again since backoff duration has been exceeded
require.Equal(t, 2, managerCreatorCallCount)
require.Equal(t, time.Second*20, sm.config.backoffDuration)
require.Equal(t, false, sm.config.lastestAuthzErrorTime.IsZero())

sm.managerCreator = func(_ context.Context, _ whutils.ModelWarehouse, _ *config.Config, _ logger.Logger, _ stats.Stats) (manager.Manager, error) {
sm := snowflake.New(config.New(), logger.NOP, stats.NOP)
managerCreatorCallCount++
return newMockManager(sm, false), nil
}
sm.now = func() time.Time {
return time.Now().UTC().Add(time.Second * 200)
}
sm.Upload(asyncDestStruct)
require.Equal(t, 3, managerCreatorCallCount)
// no error should reset the backoff config
require.Equal(t, time.Duration(0), sm.config.backoffDuration)
require.Equal(t, true, sm.config.lastestAuthzErrorTime.IsZero())
})

t.Run("Upload with discards table authorization error should not abort the job", func(t *testing.T) {
statsStore, err := memstats.New()
require.NoError(t, err)

sm := New(config.New(), logger.NOP, statsStore, destination)
sm.api = &mockAPI{
createChannelOutputMap: map[string]func() (*model.ChannelResponse, error){
"RUDDER_DISCARDS": func() (*model.ChannelResponse, error) {
return &model.ChannelResponse{Code: internalapi.ErrSchemaDoesNotExistOrNotAuthorized}, nil
},
"USERS": func() (*model.ChannelResponse, error) {
return &model.ChannelResponse{Code: internalapi.ErrSchemaDoesNotExistOrNotAuthorized}, nil
},
},
}
sm.managerCreator = func(_ context.Context, _ whutils.ModelWarehouse, _ *config.Config, _ logger.Logger, _ stats.Stats) (manager.Manager, error) {
sm := snowflake.New(config.New(), logger.NOP, stats.NOP)
return newMockManager(sm, true), nil
}
output := sm.Upload(&common.AsyncDestinationStruct{
Destination: destination,
FileName: "testdata/successful_user_records.txt",
})
require.Equal(t, 2, output.FailedCount)
require.Equal(t, 0, output.AbortCount)
})

t.Run("Upload insert error for all events", func(t *testing.T) {
statsStore, err := memstats.New()
require.NoError(t, err)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{"message":{"metadata":{"table":"USERS","columns":{"ID":"int","NAME":"string","AGE":"int","RECEIVED_AT":"datetime"}},"data":{"ID":1,"NAME":"Alice","AGE":30,"RECEIVED_AT":"2023-05-12T04:36:50.199Z"}},"metadata":{"job_id":1001}}
{"message":{"metadata":{"table":"USERS","columns":{"ID":"int","NAME":"string","AGE":"int","RECEIVED_AT":"datetime"}},"data":{"ID":1,"NAME":"Alice","AGE":30,"RECEIVED_AT":"2023-05-12T04:36:50.199Z"}},"metadata":{"job_id":1003}}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/snowpipestreaming/internal/model"

backendconfig "github.com/rudderlabs/rudder-server/backend-config"
"github.com/rudderlabs/rudder-server/warehouse/integrations/manager"
whutils "github.com/rudderlabs/rudder-server/warehouse/utils"
)

Expand All @@ -25,6 +26,7 @@ type (
statsFactory stats.Stats
destination *backendconfig.DestinationT
requestDoer requestDoer
managerCreator func(mCtx context.Context, modelWarehouse whutils.ModelWarehouse, conf *config.Config, mLogger logger.Logger, stats stats.Stats) (manager.Manager, error)
now func() time.Time
api api
channelCache sync.Map
Expand All @@ -44,6 +46,12 @@ type (
}
instanceID string
maxBufferCapacity config.ValueLoader[int64]
// time at which the an attempt was made to create a resource but it failed likely due to permission issues.
lastestAuthzErrorTime time.Time
// If lastAttemptedTime is not zero, then the next attempt to create a SF connection will be made after backoffDuration.
// This approach prevents repeatedly activating the warehouse even though the permission issue remains unresolved.
backoffDuration time.Duration
initialBackoffDuration time.Duration
}

stats struct {
Expand Down
Loading

0 comments on commit 476cdb4

Please sign in to comment.