diff --git a/internal/flavors/assetinventory/strategy.go b/internal/flavors/assetinventory/strategy.go index 3878b20c0a..ef688967ba 100644 --- a/internal/flavors/assetinventory/strategy.go +++ b/internal/flavors/assetinventory/strategy.go @@ -73,7 +73,7 @@ func (s *strategy) NewAssetInventory(ctx context.Context, client beat.Client) (i s.logger.Infof("Creating %s AssetInventory", strings.ToUpper(s.cfg.AssetInventoryProvider)) now := func() time.Time { return time.Now() } //nolint:gocritic - return inventory.NewAssetInventory(s.logger, fetchers, client, now), nil + return inventory.NewAssetInventory(s.logger, fetchers, client, now, s.cfg.Period), nil } func (s *strategy) initAzureFetchers(_ context.Context) ([]inventory.AssetFetcher, error) { diff --git a/internal/inventory/inventory.go b/internal/inventory/inventory.go index 3cc7438127..3cb6d88c12 100644 --- a/internal/inventory/inventory.go +++ b/internal/inventory/inventory.go @@ -29,13 +29,17 @@ import ( "github.com/samber/lo" ) -const indexTemplate = "logs-cloud_asset_inventory.asset_inventory-%s_%s_%s_%s-default" +const ( + indexTemplate = "logs-cloud_asset_inventory.asset_inventory-%s_%s_%s_%s-default" + minimalPeriod = 30 * time.Second +) type AssetInventory struct { fetchers []AssetFetcher publisher AssetPublisher bufferFlushInterval time.Duration bufferMaxSize int + period time.Duration logger *logp.Logger assetCh chan AssetEvent now func() time.Time @@ -49,8 +53,11 @@ type AssetPublisher interface { PublishAll([]beat.Event) } -func NewAssetInventory(logger *logp.Logger, fetchers []AssetFetcher, publisher AssetPublisher, now func() time.Time) AssetInventory { - logger.Info("Initializing Asset Inventory POC") +func NewAssetInventory(logger *logp.Logger, fetchers []AssetFetcher, publisher AssetPublisher, now func() time.Time, period time.Duration) AssetInventory { + if period < minimalPeriod { + period = minimalPeriod + } + logger.Infof("Initializing Asset Inventory POC with period of %s", period) return AssetInventory{ logger: logger, fetchers: fetchers, @@ -58,20 +65,18 @@ func NewAssetInventory(logger *logp.Logger, fetchers []AssetFetcher, publisher A // move to a configuration parameter bufferFlushInterval: 10 * time.Second, bufferMaxSize: 1600, + period: period, assetCh: make(chan AssetEvent), now: now, } } func (a *AssetInventory) Run(ctx context.Context) { - for _, fetcher := range a.fetchers { - go func(fetcher AssetFetcher) { - fetcher.Fetch(ctx, a.assetCh) - }(fetcher) - } + a.runAllFetchersOnce(ctx) assetsBuffer := make([]AssetEvent, 0, a.bufferMaxSize) flushTicker := time.NewTicker(a.bufferFlushInterval) + fetcherPeriod := time.NewTicker(a.period) for { select { case <-ctx.Done(): @@ -79,6 +84,9 @@ func (a *AssetInventory) Run(ctx context.Context) { a.publish(assetsBuffer) return + case <-fetcherPeriod.C: + a.runAllFetchersOnce(ctx) + case <-flushTicker.C: if len(assetsBuffer) == 0 { a.logger.Debugf("Interval reached without events") @@ -101,6 +109,17 @@ func (a *AssetInventory) Run(ctx context.Context) { } } +// runAllFetchersOnce runs every fetcher to collect assets to assetCh ONCE. It +// should be called every cycle, once every `a.period`. +func (a *AssetInventory) runAllFetchersOnce(ctx context.Context) { + a.logger.Debug("Running all fetchers once") + for _, fetcher := range a.fetchers { + go func(fetcher AssetFetcher) { + fetcher.Fetch(ctx, a.assetCh) + }(fetcher) + } +} + func (a *AssetInventory) publish(assets []AssetEvent) { events := lo.Map(assets, func(e AssetEvent, _ int) beat.Event { var relatedEntity []string diff --git a/internal/inventory/inventory_test.go b/internal/inventory/inventory_test.go index 685d951634..30e4bb3deb 100644 --- a/internal/inventory/inventory_test.go +++ b/internal/inventory/inventory_test.go @@ -19,6 +19,7 @@ package inventory import ( "context" + "sync/atomic" "testing" "time" @@ -31,6 +32,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/elastic/cloudbeat/internal/resources/utils/pointers" + "github.com/elastic/cloudbeat/internal/resources/utils/testhelper" ) func TestAssetInventory_Run(t *testing.T) { @@ -150,6 +152,7 @@ func TestAssetInventory_Run(t *testing.T) { publisher: publisher, bufferFlushInterval: 10 * time.Millisecond, bufferMaxSize: 1, + period: 24 * time.Hour, assetCh: make(chan AssetEvent), now: now, } @@ -169,3 +172,84 @@ func TestAssetInventory_Run(t *testing.T) { assert.ElementsMatch(t, received, expected) } } + +func TestAssetInventory_Period(t *testing.T) { + testhelper.SkipLong(t) + now := func() time.Time { return time.Date(2024, 1, 1, 1, 1, 1, 0, time.Local) } + + var cycleCounter int64 + + publisher := NewMockAssetPublisher(t) + publisher.EXPECT().PublishAll(mock.Anything).Maybe() + + fetcher := NewMockAssetFetcher(t) + fetcher.EXPECT().Fetch(mock.Anything, mock.Anything).Run(func(_ context.Context, _ chan<- AssetEvent) { + atomic.AddInt64(&cycleCounter, 1) + }) + + logger := logp.NewLogger("test_run") + inventory := AssetInventory{ + logger: logger, + fetchers: []AssetFetcher{fetcher}, + publisher: publisher, + bufferFlushInterval: 10 * time.Millisecond, + bufferMaxSize: 1, + period: 500 * time.Millisecond, + assetCh: make(chan AssetEvent), + now: now, + } + + // Run it enough for 2 cycles to finish; one starts immediately, the other after 500 milliseconds + ctx, cancel := context.WithTimeout(context.Background(), 600*time.Millisecond) + defer cancel() + + go func() { + inventory.Run(ctx) + }() + + <-ctx.Done() + val := atomic.LoadInt64(&cycleCounter) + assert.Equal(t, int64(2), val, "Expected to run 2 cycles, got %d", val) +} + +func TestAssetInventory_RunAllFetchersOnce(t *testing.T) { + now := func() time.Time { return time.Date(2024, 1, 1, 1, 1, 1, 0, time.Local) } + publisher := NewMockAssetPublisher(t) + publisher.EXPECT().PublishAll(mock.Anything).Maybe() + + fetchers := []AssetFetcher{} + fetcherCounters := [](*int64){} + for i := 0; i < 5; i++ { + fetcher := NewMockAssetFetcher(t) + counter := int64(0) + fetcher.EXPECT().Fetch(mock.Anything, mock.Anything).Run(func(_ context.Context, _ chan<- AssetEvent) { + atomic.AddInt64(&counter, 1) + }) + fetchers = append(fetchers, fetcher) + fetcherCounters = append(fetcherCounters, &counter) + } + + logger := logp.NewLogger("test_run") + inventory := AssetInventory{ + logger: logger, + fetchers: fetchers, + publisher: publisher, + bufferFlushInterval: 10 * time.Millisecond, + bufferMaxSize: 1, + period: 24 * time.Hour, + assetCh: make(chan AssetEvent), + now: now, + } + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + inventory.runAllFetchersOnce(ctx) + <-ctx.Done() + + // Check that EVERY fetcher has been called EXACTLY ONCE + for _, counter := range fetcherCounters { + val := atomic.LoadInt64(counter) + assert.Equal(t, int64(1), val, "Expected to run once, got %d", val) + } +}