diff --git a/SkyhighSecurity/gateway_cloud_services/trigger_skyhigh_security_swg.py b/SkyhighSecurity/gateway_cloud_services/trigger_skyhigh_security_swg.py index 15b83e4d4..bc17192b7 100644 --- a/SkyhighSecurity/gateway_cloud_services/trigger_skyhigh_security_swg.py +++ b/SkyhighSecurity/gateway_cloud_services/trigger_skyhigh_security_swg.py @@ -1,6 +1,7 @@ import csv import os import queue +import uuid from collections.abc import Generator from datetime import datetime, timedelta, timezone from functools import cached_property @@ -32,10 +33,18 @@ class SkyhighSWGConfig(DefaultConnectorConfiguration): class EventCollector(Thread): - def __init__(self, connector: "SkyhighSecuritySWGTrigger", events_queue: queue.Queue): + def __init__( + self, + connector: "SkyhighSecuritySWGTrigger", + events_queue: queue.Queue, + batch_status_queue: queue.Queue, + ): super().__init__() self.connector = connector self.events_queue = events_queue + self.batch_status_queue = ( + batch_status_queue # Queue to receive batch push confirmation + ) self.trigger_activation: datetime = datetime.now(timezone.utc) self.headers = {"Accept": "text/csv", "x-mwg-api-version": "8"} self.endpoint: str = "/mwg/api/reporting/forensic/" @@ -44,6 +53,7 @@ def __init__(self, connector: "SkyhighSecuritySWGTrigger", events_queue: queue.Q self.end_date: datetime self.start_date: datetime self.url: str + self.pending_batches: dict = {} # Track batch_id -> end_date def log(self, *args, **kwargs): self.connector.log(*args, **kwargs) @@ -169,24 +179,57 @@ def next_batch(self): Fetch the next batch of events and put them in the queue 1. Query the API - 2. If we have a response, put it in the queue - 3. Update the time range - 4. Sleep until the next batch + 2. If we have a response, tag it with batch ID and put it in the queue + 3. Wait for confirmation that the batch was pushed successfully + 4. Update the time range + 5. Sleep until the next batch """ try: # 1. Query the API response = self.query_api() if response: - # 2. If we have a response, put it in the queue - self.events_queue.put(response) + # 2. Tag with batch ID and queue it + batch_id = str(uuid.uuid4()) + self.pending_batches[batch_id] = self.end_date + self.events_queue.put((batch_id, response)) + + # 3. Wait for confirmation that this batch was pushed + self.log( + message=f"Waiting for batch {batch_id} to be pushed...", + level="debug", + ) + try: + confirmed_batch_id = self.batch_status_queue.get( + block=True, timeout=60 + ) # 60 second timeout + if confirmed_batch_id == batch_id: + self.log( + message=f"Batch {batch_id} confirmed pushed", level="debug" + ) + # Remove from pending + self.pending_batches.pop(batch_id, None) + else: + self.log( + message=f"Received confirmation for {confirmed_batch_id} but waiting for {batch_id}", + level="warning", + ) + # Put it back for next iteration + self.batch_status_queue.put(confirmed_batch_id) + except queue.Empty: + self.log( + message=f"Timeout waiting for batch {batch_id} confirmation. Batch may still be processing.", + level="warning", + ) + # Note: We don't remove from pending, checkpoint won't be saved + return else: self.log(message="No messages to forward", level="info") - # 3. Update the time range + # 4. Update the time range (safe now, events are pushed) self._update_time_range() - # 4. Sleep until the next batch + # 5. Sleep until the next batch self._sleep_until_next_batch() except Exception as ex: self.log_exception(ex, message="Failed to fetch events") @@ -236,7 +279,8 @@ def run(self): try: while self.is_running or self.queue.qsize() > 0: try: - response = self.queue.get(block=True, timeout=0.5) + # Get batch_id and response + batch_id, response = self.queue.get(block=True, timeout=0.5) # The transformation is done in batches to avoid filling the memory if we have a lot of events for messages in batched(self._transform(response), self.max_batch_size): @@ -244,7 +288,8 @@ def run(self): nb_events = len(messages) INCOMING_EVENTS.labels(intake_key=self.configuration.intake_key).inc(nb_events) logger.info("Transformed events", nb_events=nb_events) - self.output_queue.put(list(messages)) + # Pass batch_id along with messages + self.output_queue.put((batch_id, list(messages))) except queue.Empty: pass @@ -257,21 +302,40 @@ def run(self): class EventsForwarder(Worker): KIND = "forwarder" - def __init__(self, connector: "SkyhighSecuritySWGTrigger", queue: queue.Queue, max_batch_size: int = 20000): + def __init__( + self, + connector: "SkyhighSecuritySWGTrigger", + queue: queue.Queue, + batch_status_queue: queue.Queue, + max_batch_size: int = 20000, + ): super().__init__() self.connector = connector self.configuration = connector.configuration self.queue = queue + self.batch_status_queue = ( + batch_status_queue # Queue to send batch completion confirmation + ) self.max_batch_size = max_batch_size + self.processed_batches: set = ( + set() + ) # Track which batch_ids we've already confirmed - def next_batch(self, max_batch_size: int) -> list: + def next_batch(self, max_batch_size: int) -> tuple[set, list]: + """ + Returns tuple of (batch_ids, events) + batch_ids: set of batch IDs processed in this batch + events: list of events + """ events = [] + batch_ids = set() while self.is_running: try: - messages = self.queue.get(block=True, timeout=0.5) + batch_id, messages = self.queue.get(block=True, timeout=0.5) if len(messages) > 0: events.extend(messages) + batch_ids.add(batch_id) if len(events) >= max_batch_size: break @@ -279,22 +343,36 @@ def next_batch(self, max_batch_size: int) -> list: except queue.Empty: break - return events + return batch_ids, events def run(self): logger.info("Starting Events Forwarder worker thread.") try: while self.is_running or self.queue.qsize() > 0: - events = self.next_batch(self.max_batch_size) - OUTCOMING_EVENTS.labels(intake_key=self.configuration.intake_key).inc(len(events)) + batch_ids, events = self.next_batch(self.max_batch_size) if len(events) > 0: + OUTCOMING_EVENTS.labels(intake_key=self.configuration.intake_key).inc(len(events)) self.connector.log( message=f"Forward {len(events)} events to the intake", level="info", ) self.connector.push_events_to_intakes(events=events) + + # Confirm batches after successful push + for batch_id in batch_ids: + if batch_id not in self.processed_batches: + try: + self.batch_status_queue.put(batch_id, block=False) + self.processed_batches.add(batch_id) + logger.debug( + f"Confirmed batch {batch_id} pushed successfully" + ) + except queue.Full: + logger.warning( + f"Failed to confirm batch {batch_id}, status queue full" + ) except Exception as ex: self.connector.log_exception(ex, message="Failed to forward events") @@ -318,6 +396,8 @@ def run(self): # pragma: no cover collect_queue: queue.Queue = queue.Queue(maxsize=collect_queue_size) forwarding_queue_size = int(os.environ.get("FORWARDING_QUEUE_SIZE", 10000)) forwarding_queue: queue.Queue = queue.Queue(maxsize=forwarding_queue_size) + # Queue for batch status confirmation (small size, only needs batch IDs) + batch_status_queue: queue.Queue = queue.Queue(maxsize=100) # start the event forwarder batch_size = int(os.environ.get("BATCH_SIZE", 10000)) @@ -333,7 +413,7 @@ def run(self): # pragma: no cover transformers.start() # start the event collector - collector = EventCollector(self, collect_queue) + collector = EventCollector(self, collect_queue, batch_status_queue) collector.start() try: @@ -349,7 +429,7 @@ def run(self): # pragma: no cover # if the collector is down, restart it if not collector.is_alive(): self.log(message="Event collector failed", level="error") - collector = EventCollector(self, collect_queue) + collector = EventCollector(self, collect_queue, batch_status_queue) collector.start() finally: diff --git a/SkyhighSecurity/tests/test_gateway_cloud_services_trigger.py b/SkyhighSecurity/tests/test_gateway_cloud_services_trigger.py index 64c06554b..5b59d4f8f 100644 --- a/SkyhighSecurity/tests/test_gateway_cloud_services_trigger.py +++ b/SkyhighSecurity/tests/test_gateway_cloud_services_trigger.py @@ -44,15 +44,19 @@ def trigger(symphony_storage): yield trigger +@pytest.fixture +def batch_status_queue(): + return queue.Queue() + @pytest.fixture -def event_collector(trigger, events_queue): - return EventCollector(trigger, events_queue) +def event_collector(trigger, events_queue, batch_status_queue): + return EventCollector(trigger, events_queue, batch_status_queue) @pytest.fixture -def forwarder(trigger, events_queue): - return EventsForwarder(trigger, events_queue, 500) +def forwarder(trigger, events_queue, batch_status_queue): + return EventsForwarder(trigger, events_queue, batch_status_queue, 500) def test_query_api_wrong_creds(trigger, event_collector, requests_mock): @@ -166,25 +170,31 @@ def test_next_batch_error_should_wait(event_collector, requests_mock): def test_tranformer_with_event(trigger, events_queue): input_queue = queue.Queue() - transformer = Transformer(trigger, input_queue, events_queue) + batch_status_queue = queue.Queue() + transformer = Transformer(trigger, input_queue, events_queue, batch_status_queue) - input_queue.put('"user_id","username"\r\n"-1","foo"') + input_queue.put(("batch-1", '"user_id","username"\r\n"-1","foo"')) transformer.start() time.sleep(0.5) transformer.stop() - events = events_queue.get(block=False) + batch_ids, events = events_queue.get(block=False) assert events == ["user_id=-1 username=foo"] + assert "batch-1" in batch_ids -def test_forwarder(trigger, forwarder, events_queue): - events_queue.put("message") +def test_forwarder(trigger, forwarder, events_queue, batch_status_queue): + batch_id = "batch-test" + events = ["user_id=-1 username=foo"] + events_queue.put((batch_id, events)) + forwarder.start() time.sleep(1) forwarder.stop() assert trigger.push_events_to_intakes.called - + confirmed_batch = batch_status_queue.get(block=False) + assert confirmed_batch == batch_id def test_sleep_until_next_batch(event_collector): end_date = datetime(2023, 3, 22, 11, 50, 46, tzinfo=timezone.utc)