diff --git a/events-processor/config/kafka/consumer.go b/events-processor/config/kafka/consumer.go index 2c440e994..aa055f1f0 100644 --- a/events-processor/config/kafka/consumer.go +++ b/events-processor/config/kafka/consumer.go @@ -17,7 +17,7 @@ import ( type ConsumerGroupConfig struct { Topic string ConsumerGroup string - ProcessRecords func([]*kgo.Record) []*kgo.Record + ProcessRecords func(context.Context, []*kgo.Record) []*kgo.Record } type TopicPartition struct { @@ -34,17 +34,17 @@ type PartitionConsumer struct { quit chan struct{} done chan struct{} records chan []*kgo.Record - processRecords func([]*kgo.Record) []*kgo.Record + processRecords func(context.Context, []*kgo.Record) []*kgo.Record } type ConsumerGroup struct { consumers map[TopicPartition]*PartitionConsumer client *kgo.Client - processRecords func([]*kgo.Record) []*kgo.Record + processRecords func(context.Context, []*kgo.Record) []*kgo.Record logger *slog.Logger } -func (pc *PartitionConsumer) consume() { +func (pc *PartitionConsumer) consume(ctx context.Context) { defer close(pc.done) pc.logger.Info(fmt.Sprintf("Starting consume for topic %s partition %d\n", pc.topic, pc.partition)) @@ -56,25 +56,32 @@ func (pc *PartitionConsumer) consume() { pc.logger.Info("partition consumer quit") return + case <-ctx.Done(): + pc.logger.Info("partition consumer context canceled") + return + case records := <-pc.records: - ctx := context.Background() span := tracer.GetTracerSpan(ctx, "post_process", "Consumer.Consume") recordsAttr := attribute.Int("records.length", len(records)) span.SetAttributes(recordsAttr) defer span.End() - processedRecords := pc.processRecords(records) + processedRecords := pc.processRecords(ctx, records) commitableRecords := records if len(processedRecords) != len(records) { // Ensure we are not committing records that were not processed and can be re-consumed record := findMaxCommitableRecord(processedRecords, records) commitableRecords = []*kgo.Record{record} - return } err := pc.client.CommitRecords(ctx, commitableRecords...) if err != nil { + if ctx.Err() != nil { + pc.logger.Info("Commit canceled due to shutdown") + return + } + pc.logger.Error(fmt.Sprintf("Error when committing offets to kafka. Error: %v topic: %s partition: %d offset: %d\n", err, pc.topic, pc.partition, records[len(records)-1].Offset+1)) utils.CaptureError(err) } @@ -82,7 +89,7 @@ func (pc *PartitionConsumer) consume() { } } -func (cg *ConsumerGroup) assigned(_ context.Context, cl *kgo.Client, assigned map[string][]int32) { +func (cg *ConsumerGroup) assigned(ctx context.Context, cl *kgo.Client, assigned map[string][]int32) { for topic, partitions := range assigned { for _, partition := range partitions { pc := &PartitionConsumer{ @@ -97,7 +104,7 @@ func (cg *ConsumerGroup) assigned(_ context.Context, cl *kgo.Client, assigned ma processRecords: cg.processRecords, } cg.consumers[TopicPartition{topic: topic, partition: partition}] = pc - go pc.consume() + go pc.consume(ctx) } } } @@ -120,25 +127,73 @@ func (cg *ConsumerGroup) lost(_ context.Context, _ *kgo.Client, lost map[string] } } -func (cg *ConsumerGroup) poll() { +func (cg *ConsumerGroup) poll(ctx context.Context, done chan<- error) { + defer func() { + if r := recover(); r != nil { + cg.logger.Error("Consumer group poll panic", slog.Any("panic", r)) + done <- fmt.Errorf("consumer group poll panic: %v", r) + } + }() + for { - fetches := cg.client.PollRecords(context.Background(), 10000) - if fetches.IsClientClosed() { - cg.logger.Info("client closed") + select { + case <-ctx.Done(): + cg.logger.Info("Consumer group stopped") return + + default: + fetches := cg.client.PollRecords(ctx, 10000) + if fetches.IsClientClosed() { + cg.logger.Info("client closed") + return + } + + if ctx.Err() != nil { + return + } + + fetches.EachError(func(_ string, _ int32, err error) { + cg.logger.Error("Fetch error", slog.String("error", err.Error())) + done <- err + }) + + fetches.EachPartition(func(p kgo.FetchTopicPartition) { + tp := TopicPartition{p.Topic, p.Partition} + if consumer, exists := cg.consumers[tp]; exists { + select { + case consumer.records <- p.Records: + case <-ctx.Done(): + return + } + } + }) + + cg.client.AllowRebalance() } + } +} + +func (cg *ConsumerGroup) gracefulShutdown() { + var wg sync.WaitGroup - fetches.EachError(func(_ string, _ int32, err error) { - panic(err) - }) + for tp, pc := range cg.consumers { + wg.Add(1) - fetches.EachPartition(func(p kgo.FetchTopicPartition) { - tp := TopicPartition{p.Topic, p.Partition} - cg.consumers[tp].records <- p.Records - }) + go func(tp TopicPartition, pc *PartitionConsumer) { + defer wg.Done() - cg.client.AllowRebalance() + cg.logger.Info("Shuting down partion consumer", + slog.String("topic", tp.topic), + slog.Int("partition", int(tp.partition)), + ) + + close(pc.quit) + <-pc.done + }(tp, pc) } + + wg.Wait() + cg.client.Close() } func NewConsumerGroup(serverConfig ServerConfig, cfg *ConsumerGroupConfig) (*ConsumerGroup, error) { @@ -175,8 +230,32 @@ func NewConsumerGroup(serverConfig ServerConfig, cfg *ConsumerGroupConfig) (*Con return cg, nil } -func (cg *ConsumerGroup) Start() { - cg.poll() +func (cg *ConsumerGroup) Start(ctx context.Context) error { + pollCtx, cancel := context.WithCancel(ctx) + defer cancel() + + done := make(chan error, 1) + go func() { + defer close(done) + cg.poll(pollCtx, done) + }() + + select { + case <-ctx.Done(): + cg.logger.Info("Gracefully shutting down consumer group") + cancel() + + cg.gracefulShutdown() + + cg.logger.Info("Consumer group shutdown is complete") + return ctx.Err() + + case err := <-done: + if err != nil { + cg.logger.Error("Consumer group stopped with error", slog.String("error", err.Error())) + } + return err + } } func findMaxCommitableRecord(processedRecords []*kgo.Record, records []*kgo.Record) *kgo.Record { diff --git a/events-processor/processors/events_processor/processor.go b/events-processor/processors/events_processor/processor.go index 7d094ed0c..90101f170 100644 --- a/events-processor/processors/events_processor/processor.go +++ b/events-processor/processors/events_processor/processor.go @@ -33,8 +33,15 @@ func NewEventProcessor(logger *slog.Logger, enrichmentService *EventEnrichmentSe } } -func (processor *EventProcessor) ProcessEvents(records []*kgo.Record) []*kgo.Record { - ctx := context.Background() +func (processor *EventProcessor) ProcessEvents(ctx context.Context, records []*kgo.Record) []*kgo.Record { + // Handle graceful shutdown + select { + case <-ctx.Done(): + processor.logger.Info("Ongoing shutdown. Stop processing new events") + return nil + default: + } + span := tracer.GetTracerSpan(ctx, "post_process", "PostProcess.ProcessEvents") recordsAttr := attribute.Int("records.length", len(records)) span.SetAttributes(recordsAttr) @@ -43,6 +50,8 @@ func (processor *EventProcessor) ProcessEvents(records []*kgo.Record) []*kgo.Rec wg := sync.WaitGroup{} wg.Add(len(records)) + producersWg := sync.WaitGroup{} + var mu sync.Mutex processedRecords := make([]*kgo.Record, 0) @@ -50,6 +59,14 @@ func (processor *EventProcessor) ProcessEvents(records []*kgo.Record) []*kgo.Rec go func(record *kgo.Record) { defer wg.Done() + // Check if a shutdown process is ongoing + select { + case <-ctx.Done(): + processor.logger.Info("Ongoing shutdown. Stop processing event") + return + default: + } + sp := tracer.GetTracerSpan(ctx, "post_process", "PostProcess.ProcessOneEvent") defer sp.End() @@ -66,7 +83,7 @@ func (processor *EventProcessor) ProcessEvents(records []*kgo.Record) []*kgo.Rec return } - result := processor.processEvent(ctx, &event) + result := processor.processEvent(ctx, &event, &producersWg) if result.Failure() { processor.logger.Error( result.ErrorMessage(), @@ -86,7 +103,11 @@ func (processor *EventProcessor) ProcessEvents(records []*kgo.Record) []*kgo.Rec } // Push failed records to the dead letter queue - go processor.ProducerService.ProduceToDeadLetterQueue(ctx, event, result) + producersWg.Add(1) + go func() { + defer producersWg.Done() + processor.ProducerService.ProduceToDeadLetterQueue(ctx, event, result) + }() } // Track processed records @@ -98,10 +119,13 @@ func (processor *EventProcessor) ProcessEvents(records []*kgo.Record) []*kgo.Rec wg.Wait() + // Wait for all producers routines to complete. + processor.waitForProducers(ctx, &producersWg) + return processedRecords } -func (processor *EventProcessor) processEvent(ctx context.Context, event *models.Event) utils.Result[*models.EnrichedEvent] { +func (processor *EventProcessor) processEvent(ctx context.Context, event *models.Event, producersWg *sync.WaitGroup) utils.Result[*models.EnrichedEvent] { enrichedEventResult := processor.EnrichmentService.EnrichEvent(event) if enrichedEventResult.Failure() { return failedResult(enrichedEventResult, enrichedEventResult.ErrorCode(), enrichedEventResult.ErrorMessage()) @@ -110,11 +134,23 @@ func (processor *EventProcessor) processEvent(ctx context.Context, event *models enrichedEvents := enrichedEventResult.Value() enrichedEvent := enrichedEvents[0] - go processor.ProducerService.ProduceEnrichedEvent(ctx, enrichedEvent) + processor.trackProducer( + ctx, + producersWg, + func() { + processor.ProducerService.ProduceEnrichedEvent(ctx, enrichedEvent) + }, + ) // TODO(pre-aggregation): Uncomment to enable the feature // for _, ev := range enrichedEvents { - // go processor.ProducerService.ProduceEnrichedExpendedEvent(ctx, ev) + // processor.trackProducer( + // ctx, + // producersWg, + // func() { + // processor.ProducerService.ProduceEnrichedExpendedEvent(ctx, ev) + // }, + // ) // } if enrichedEvent.Subscription != nil && event.NotAPIPostProcessed() { @@ -127,7 +163,13 @@ func (processor *EventProcessor) processEvent(ctx context.Context, event *models } if payInAdvance { - go processor.ProducerService.ProduceChargedInAdvanceEvent(ctx, enrichedEvent) + processor.trackProducer( + ctx, + producersWg, + func() { + processor.ProducerService.ProduceChargedInAdvanceEvent(ctx, enrichedEvent) + }, + ) } flagResult := processor.RefreshService.FlagSubscriptionRefresh(enrichedEvent) @@ -142,6 +184,38 @@ func (processor *EventProcessor) processEvent(ctx context.Context, event *models return utils.SuccessResult(enrichedEvent) } +func (processor *EventProcessor) trackProducer(ctx context.Context, producerWg *sync.WaitGroup, routine func()) { + producerWg.Add(1) + go func() { + defer producerWg.Done() + + select { + case <-ctx.Done(): + processor.logger.Debug("Shutdown signal received, skipping producer") + return + default: + routine() + } + }() +} + +func (processor *EventProcessor) waitForProducers(ctx context.Context, producersWg *sync.WaitGroup) { + done := make(chan struct{}) + go func() { + producersWg.Wait() + close(done) + }() + + select { + case <-done: + processor.logger.Debug("All producer goroutines completed successfully") + case <-time.After(30 * time.Second): // Configurable timeout + processor.logger.Warn("Timeout waiting for producer goroutines to complete") + case <-ctx.Done(): + processor.logger.Info("Shutdown signal received while waiting for producer goroutines") + } +} + func failedResult(r utils.AnyResult, code string, message string) utils.Result[*models.EnrichedEvent] { result := utils.FailedResult[*models.EnrichedEvent](r.Error()).AddErrorDetails(code, message) result.Retryable = r.IsRetryable() diff --git a/events-processor/processors/events_processor/processor_test.go b/events-processor/processors/events_processor/processor_test.go index b1934420f..12dcd19a5 100644 --- a/events-processor/processors/events_processor/processor_test.go +++ b/events-processor/processors/events_processor/processor_test.go @@ -4,6 +4,7 @@ import ( "context" "log/slog" "os" + "sync" "testing" "time" @@ -135,6 +136,8 @@ func TestProcessEvent(t *testing.T) { processor, mockedStore, _, _, delete := setupProcessorTestEnv(t) defer delete() + wg := &sync.WaitGroup{} + event := models.Event{ OrganizationID: "1a901a90-1a90-1a90-1a90-1a901a901a90", ExternalSubscriptionID: "sub_id", @@ -144,7 +147,7 @@ func TestProcessEvent(t *testing.T) { mockedStore.SQLMock.ExpectQuery(".*").WillReturnError(gorm.ErrRecordNotFound) - result := processor.processEvent(context.Background(), &event) + result := processor.processEvent(context.Background(), &event, wg) assert.False(t, result.Success()) assert.Equal(t, "record not found", result.ErrorMsg()) assert.Equal(t, "fetch_billable_metric", result.ErrorCode()) @@ -155,6 +158,8 @@ func TestProcessEvent(t *testing.T) { processor, mockedStore, testProducers, _, delete := setupProcessorTestEnv(t) defer delete() + wg := &sync.WaitGroup{} + properties := map[string]any{ "api_requests": "12.0", } @@ -186,7 +191,7 @@ func TestProcessEvent(t *testing.T) { sub := models.Subscription{ID: "sub123", PlanID: "plan123"} mockSubscriptionLookup(mockedStore, &sub) - result := processor.processEvent(context.Background(), &event) + result := processor.processEvent(context.Background(), &event, wg) assert.True(t, result.Success()) assert.Equal(t, "12.0", *result.Value().Value) @@ -205,6 +210,8 @@ func TestProcessEvent(t *testing.T) { processor, mockedStore, _, _, delete := setupProcessorTestEnv(t) defer delete() + wg := &sync.WaitGroup{} + event := models.Event{ OrganizationID: "1a901a90-1a90-1a90-1a90-1a901a901a90", ExternalSubscriptionID: "sub_id", @@ -225,7 +232,7 @@ func TestProcessEvent(t *testing.T) { } mockBmLookup(mockedStore, &bm) - result := processor.processEvent(context.Background(), &event) + result := processor.processEvent(context.Background(), &event, wg) assert.False(t, result.Success()) assert.Equal(t, "strconv.ParseFloat: parsing \"2025-03-06T12:00:00Z\": invalid syntax", result.ErrorMsg()) assert.Equal(t, "build_enriched_event", result.ErrorCode()) @@ -236,6 +243,8 @@ func TestProcessEvent(t *testing.T) { processor, mockedStore, _, _, delete := setupProcessorTestEnv(t) defer delete() + wg := &sync.WaitGroup{} + event := models.Event{ OrganizationID: "1a901a90-1a90-1a90-1a90-1a901a901a90", ExternalSubscriptionID: "sub_id", @@ -258,7 +267,7 @@ func TestProcessEvent(t *testing.T) { mockedStore.SQLMock.ExpectQuery(".* FROM \"subscriptions\"").WillReturnError(gorm.ErrRecordNotFound) - result := processor.processEvent(context.Background(), &event) + result := processor.processEvent(context.Background(), &event, wg) assert.True(t, result.Success()) }) @@ -266,6 +275,8 @@ func TestProcessEvent(t *testing.T) { processor, mockedStore, _, _, delete := setupProcessorTestEnv(t) defer delete() + wg := &sync.WaitGroup{} + event := models.Event{ OrganizationID: "1a901a90-1a90-1a90-1a90-1a901a901a90", ExternalSubscriptionID: "sub_id", @@ -288,7 +299,7 @@ func TestProcessEvent(t *testing.T) { mockedStore.SQLMock.ExpectQuery(".* FROM \"subscriptions\"").WillReturnError(gorm.ErrNotImplemented) - result := processor.processEvent(context.Background(), &event) + result := processor.processEvent(context.Background(), &event, wg) assert.False(t, result.Success()) assert.NotNil(t, result.ErrorMsg()) assert.Equal(t, "fetch_subscription", result.ErrorCode()) @@ -299,6 +310,8 @@ func TestProcessEvent(t *testing.T) { processor, mockedStore, _, _, delete := setupProcessorTestEnv(t) defer delete() + wg := &sync.WaitGroup{} + // properties := map[string]any{ // "value": "12.12", // } @@ -327,7 +340,7 @@ func TestProcessEvent(t *testing.T) { sub := models.Subscription{ID: "sub123"} mockSubscriptionLookup(mockedStore, &sub) - result := processor.processEvent(context.Background(), &event) + result := processor.processEvent(context.Background(), &event, wg) assert.False(t, result.Success()) assert.Contains(t, result.ErrorMsg(), "Failed to evaluate expr: round(event.properties.value)") assert.Equal(t, "evaluate_expression", result.ErrorCode()) @@ -338,6 +351,8 @@ func TestProcessEvent(t *testing.T) { processor, mockedStore, testProducers, flagger, delete := setupProcessorTestEnv(t) defer delete() + wg := &sync.WaitGroup{} + properties := map[string]any{ "value": "12.12", } @@ -379,7 +394,7 @@ func TestProcessEvent(t *testing.T) { }, }) - result := processor.processEvent(context.Background(), &event) + result := processor.processEvent(context.Background(), &event, wg) assert.True(t, result.Success()) assert.Equal(t, "12", *result.Value().Value) @@ -397,6 +412,8 @@ func TestProcessEvent(t *testing.T) { processor, mockedStore, testProducers, _, delete := setupProcessorTestEnv(t) defer delete() + wg := &sync.WaitGroup{} + properties := map[string]any{ "api_requests": "12.0", } @@ -450,7 +467,7 @@ func TestProcessEvent(t *testing.T) { } mockFlatFiltersLookup(mockedStore, []*models.FlatFilter{flatFilter1, flatFilter2}) - result := processor.processEvent(context.Background(), &event) + result := processor.processEvent(context.Background(), &event, wg) assert.True(t, result.Success()) assert.Equal(t, "12.0", *result.Value().Value) @@ -469,6 +486,8 @@ func TestProcessEvent(t *testing.T) { processor, mockedStore, testProducers, _, delete := setupProcessorTestEnv(t) defer delete() + wg := &sync.WaitGroup{} + properties := map[string]any{ "api_requests": "12.0", } @@ -498,7 +517,7 @@ func TestProcessEvent(t *testing.T) { mockSubscriptionLookup(mockedStore, &sub) mockFlatFiltersLookup(mockedStore, []*models.FlatFilter{}) - result := processor.processEvent(context.Background(), &event) + result := processor.processEvent(context.Background(), &event, wg) assert.True(t, result.Success()) assert.Equal(t, "12.0", *result.Value().Value) diff --git a/events-processor/processors/main_processor.go b/events-processor/processors/main_processor.go index 6cb2bd848..45e8e4110 100644 --- a/events-processor/processors/main_processor.go +++ b/events-processor/processors/main_processor.go @@ -5,6 +5,8 @@ import ( "fmt" "log/slog" "os" + "os/signal" + "syscall" "github.com/twmb/franz-go/pkg/kgo" @@ -18,7 +20,6 @@ import ( ) var ( - ctx context.Context logger *slog.Logger processor *events_processor.EventProcessor apiStore *models.ApiStore @@ -51,7 +52,7 @@ const ( envOtelServiceName = "OTEL_SERVICE_NAME" ) -func initProducer(context context.Context, topicEnv string) utils.Result[*kafka.Producer] { +func initProducer(ctx context.Context, topicEnv string) utils.Result[*kafka.Producer] { if os.Getenv(topicEnv) == "" { return utils.FailedResult[*kafka.Producer](fmt.Errorf("%s variable is required", topicEnv)) } @@ -67,7 +68,7 @@ func initProducer(context context.Context, topicEnv string) utils.Result[*kafka. return utils.FailedResult[*kafka.Producer](err) } - err = producer.Ping(context) + err = producer.Ping(ctx) if err != nil { return utils.FailedResult[*kafka.Producer](err) } @@ -75,7 +76,7 @@ func initProducer(context context.Context, topicEnv string) utils.Result[*kafka. return utils.SuccessResult(producer) } -func initFlagStore(name string) (*models.FlagStore, error) { +func initFlagStore(ctx context.Context, name string) (*models.FlagStore, error) { redisDb, err := utils.GetEnvAsInt(envLagoRedisStoreDB, 0) if err != nil { return nil, err @@ -96,7 +97,7 @@ func initFlagStore(name string) (*models.FlagStore, error) { return models.NewFlagStore(ctx, db, name), nil } -func initChargeCacheStore() (*models.ChargeCache, error) { +func initChargeCacheStore(ctx context.Context) (*models.ChargeCache, error) { redisDb, err := utils.GetEnvAsInt(envLagoRedisCacheDB, 0) if err != nil { return nil, err @@ -122,7 +123,10 @@ func initChargeCacheStore() (*models.ChargeCache, error) { } func StartProcessingEvents() { - ctx = context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + setupGracefulShutdown(cancel) logger = slog.New(slog.NewJSONHandler(os.Stdout, nil)). With("service", "post_process") @@ -196,7 +200,7 @@ func StartProcessingEvents() { apiStore = models.NewApiStore(db) defer db.Close() - flagger, err := initFlagStore("subscription_refreshed") + flagger, err := initFlagStore(ctx, "subscription_refreshed") if err != nil { logger.Error("Error connecting to the flag store", slog.String("error", err.Error())) utils.CaptureError(err) @@ -204,7 +208,7 @@ func StartProcessingEvents() { } defer flagger.Close() - cacher, err := initChargeCacheStore() + cacher, err := initChargeCacheStore(ctx) if err != nil { logger.Error("Error connecting to the charge cache store", slog.String("error", err.Error())) utils.CaptureError(err) @@ -232,8 +236,8 @@ func StartProcessingEvents() { &kafka.ConsumerGroupConfig{ Topic: os.Getenv(envLagoKafkaRawEventsTopic), ConsumerGroup: os.Getenv(envLagoKafkaConsumerGroup), - ProcessRecords: func(records []*kgo.Record) []*kgo.Record { - return processor.ProcessEvents(records) + ProcessRecords: func(ctx context.Context, records []*kgo.Record) []*kgo.Record { + return processor.ProcessEvents(ctx, records) }, }) if err != nil { @@ -242,5 +246,22 @@ func StartProcessingEvents() { panic(err.Error()) } - cg.Start() + logger.Info("Starting event consumer") + if err := cg.Start(ctx); err != nil && err != context.Canceled { + logger.Error("Consumer stopped with error", slog.String("error", err.Error())) + utils.CaptureError(err) + } + + logger.Info("Event processor stopped") +} + +func setupGracefulShutdown(cancel context.CancelFunc) { + signChan := make(chan os.Signal, 1) + signal.Notify(signChan, syscall.SIGINT, syscall.SIGTERM) + + go func() { + sig := <-signChan + logger.Info("Received shutdown signal", slog.String("signal", sig.String())) + cancel() + }() }