diff --git a/apollo/egress/agent/backend/backend_client.py b/apollo/egress/agent/backend/backend_client.py index 27c1827..62c45ea 100644 --- a/apollo/egress/agent/backend/backend_client.py +++ b/apollo/egress/agent/backend/backend_client.py @@ -1,5 +1,6 @@ import json import logging +import uuid from typing import Dict, Any, Optional import requests from retry import retry @@ -11,6 +12,8 @@ logger = logging.getLogger(__name__) +INSTANCE_ID_HEADER = "x-mcd-agent-instance-id" + class BackendClient: """ @@ -24,6 +27,18 @@ def __init__( ) -> None: self._backend_service_url = backend_service_url self._login_token_provider = login_token_provider + self._instance_id = str(uuid.uuid4()) + + @property + def instance_id(self) -> str: + return self._instance_id + + def _headers(self, **extra: str) -> Dict[str, str]: + return { + **self._login_token_provider.get_token(), + INSTANCE_ID_HEADER: self._instance_id, + **extra, + } def push_results( self, operation_id: str, result: Dict[str, Any] @@ -55,10 +70,7 @@ def _push_results_with_retries( response = requests.put( results_url, data=result_str, - headers={ - "Content-Type": "application/json", - **self._login_token_provider.get_token(), - }, + headers=self._headers(**{"Content-Type": "application/json"}), timeout=60, ) response.raise_for_status() @@ -81,7 +93,7 @@ def _execute_operation_with_retries( """ try: url = build_url(self._backend_service_url, path) - headers = self._login_token_provider.get_token() + headers = self._headers() if body: headers["Content-Type"] = "application/json" response = requests.request( @@ -116,6 +128,26 @@ def download_operation(self, operation_id: str) -> Dict: ) return operation + def send_heartbeat(self): + """Send a liveness heartbeat to the orchestrator.""" + url = build_url(self._backend_service_url, "/api/v1/agent/heartbeat") + response = requests.post( + url, + headers=self._headers(), + timeout=10, + ) + response.raise_for_status() + + def notify_shutdown(self): + """Notify orchestrator that this agent is shutting down. Best-effort.""" + url = build_url(self._backend_service_url, "/api/v1/agent/shutdown") + response = requests.post( + url, + headers=self._headers(), + timeout=15, + ) + response.raise_for_status() + def get_next_operation(self) -> Optional[Dict[str, Any]]: """ Fetch next operation from orchestrator queue. @@ -125,7 +157,7 @@ def get_next_operation(self) -> Optional[Dict[str, Any]]: url = build_url(self._backend_service_url, "/api/v1/agent/operation") response = requests.get( url, - headers=self._login_token_provider.get_token(), + headers=self._headers(), timeout=30, ) if response.status_code == 204: diff --git a/apollo/egress/agent/events/events_client.py b/apollo/egress/agent/events/events_client.py index e746b60..0e4bef0 100644 --- a/apollo/egress/agent/events/events_client.py +++ b/apollo/egress/agent/events/events_client.py @@ -10,6 +10,7 @@ _EVENT_TYPE_HEARTBEAT = "heartbeat" _EVENT_TYPE_WELCOME = "welcome" _EVENT_TYPE_WORK_AVAILABLE = "work_available" +_EVENT_TYPE_GOODBYE = "goodbye" logger = logging.getLogger(__name__) @@ -35,12 +36,15 @@ def __init__( self._stopped = True self._heartbeat_checker = heartbeat_checker or HeartbeatChecker(self._reconnect) self._work_available_handler: Optional[Callable[[], None]] = None + self._goodbye_handler: Optional[Callable[[str], None]] = None def start( self, work_available_handler: Callable[[], None], + goodbye_handler: Optional[Callable[[str], None]] = None, ): self._work_available_handler = work_available_handler + self._goodbye_handler = goodbye_handler self._stopped = False self._receiver.start( handler=self._event_received, @@ -51,6 +55,7 @@ def start( def stop(self): self._stopped = True self._work_available_handler = None + self._goodbye_handler = None self._receiver.stop() def _reconnect(self): @@ -69,6 +74,11 @@ def _event_received(self, event: Dict): self._work_available_handler() else: logger.warning("work_available received but no handler registered") + elif event_type == _EVENT_TYPE_GOODBYE: + reason = event.get("reason", "unknown") + logger.warning(f"Received goodbye from orchestrator: {reason}") + if self._goodbye_handler: + self._goodbye_handler(reason) else: logger.info(f"Ignoring unexpected event type: {event_type}") diff --git a/apollo/egress/agent/events/sse_client_receiver.py b/apollo/egress/agent/events/sse_client_receiver.py index ce6d1f0..6f2ff75 100644 --- a/apollo/egress/agent/events/sse_client_receiver.py +++ b/apollo/egress/agent/events/sse_client_receiver.py @@ -28,10 +28,12 @@ def __init__( self, base_url: str, login_token_provider: LoginTokenProvider, + extra_headers: Optional[Dict[str, str]] = None, ): self._current_loop_id: Optional[str] = None self._base_url = base_url self._login_token_provider = login_token_provider + self._extra_headers = extra_headers or {} self._sse_client: Optional[sseclient.SSEClient] = None self._event_handler: Optional[Callable[[Dict], None]] = None self._connected_handler: Optional[Callable[[], None]] = None @@ -55,7 +57,9 @@ def _start_receiver_thread(self): # enough loop_id = str(uuid4()) self._current_loop_id = loop_id - th = Thread(target=self._run_receiver, args=(loop_id,)) + # Daemon thread so it doesn't block process exit — the SSE network read + # can't be interrupted cleanly, and all other cleanup is done before exit. + th = Thread(target=self._run_receiver, args=(loop_id,), daemon=True) th.start() def stop(self): @@ -86,6 +90,7 @@ def _connect_and_consume_events(self, loop_id: str): headers = { "Accept": "text/event-stream", **mc_login_token, + **self._extra_headers, } self._sse_client = sseclient.SSEClient(url, headers=headers) if self._connected_handler: diff --git a/apollo/egress/agent/service/base_egress_service.py b/apollo/egress/agent/service/base_egress_service.py index ea0ee50..d709fdb 100644 --- a/apollo/egress/agent/service/base_egress_service.py +++ b/apollo/egress/agent/service/base_egress_service.py @@ -1,11 +1,15 @@ import logging +import os +import signal +import sys +import threading import uuid from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum from typing import Dict, Optional, Any, Callable, List -from apollo.egress.agent.backend.backend_client import BackendClient +from apollo.egress.agent.backend.backend_client import BackendClient, INSTANCE_ID_HEADER from apollo.egress.agent.events.ack_sender import ( AckSender, DEFAULT_ACK_INTERVAL_SECONDS, @@ -177,6 +181,7 @@ def __init__( receiver=SSEClientReceiver( base_url=backend_service_url, login_token_provider=self._login_token_provider, + extra_headers={INSTANCE_ID_HEADER: self._backend_client.instance_id}, ), ) self._operations_poller = operations_poller or OperationsPoller( @@ -191,6 +196,7 @@ def __init__( config_manager=config_manager, push_metrics_handler=self._push_metrics, ) + self._shutdown_lock = threading.Lock() self._operations_mapping = [ OperationMapping( path="/api/v1/agent/execute/", @@ -234,13 +240,15 @@ def start(self): if self._sse_enabled: self._events_client.start( work_available_handler=self._operations_poller.notify_work_available, + goodbye_handler=self._handle_goodbye, ) self._ack_sender.start(handler=self._send_ack) if self._logs_sender: self._logs_sender.start(handler=self._push_logs) logger.info( - f"{self._service_name} Service Started: v{self._get_version()} (build #{self._get_build_number()})" + f"{self._service_name} Service Started: v{self._get_version()} (build #{self._get_build_number()}), " + f"instance_id={self._backend_client.instance_id}" ) def stop(self): @@ -255,6 +263,52 @@ def stop(self): if self._logs_sender: self._logs_sender.stop() + def _handle_goodbye(self, reason: str): + """Handle goodbye event from orchestrator — trigger graceful shutdown.""" + logger.warning(f"Orchestrator requested shutdown: {reason}") + self._trigger_graceful_shutdown() + + def _trigger_graceful_shutdown(self): + """Notify orchestrator, stop all threads, and terminate the process. + + Safe to call from any thread. Cleanup runs exactly once (guarded by + _shutdown_lock). In-flight operations are abandoned — the orchestrator + requeues them via the shutdown notification. + """ + if not self._shutdown_lock.acquire(blocking=False): + return + try: + self._backend_client.notify_shutdown() + logger.info("Notified orchestrator of shutdown") + except Exception: + logger.exception("Failed to notify orchestrator of shutdown") + self.stop() + logger.info("Shutdown complete, signaling exit") + # When running under gunicorn, signal the master process so all workers + # shut down — not just the one that received the goodbye event. When + # running standalone (local dev, single process), signal just ourselves. + if "gunicorn" in sys.modules: + os.kill(os.getppid(), signal.SIGTERM) + else: + os.kill(os.getpid(), signal.SIGTERM) + + def register_signal_handlers(self): + """Register SIGTERM and SIGINT handlers for graceful shutdown. + + Should be called from the main thread after start(). Signal handlers + can only be registered from the main thread. + """ + + def _signal_handler(signum: int, frame: Any): + sig_name = signal.Signals(signum).name + logger.info(f"Received {sig_name}, shutting down") + self._trigger_graceful_shutdown() + sys.exit(0) + + signal.signal(signal.SIGTERM, _signal_handler) + signal.signal(signal.SIGINT, _signal_handler) + logger.info("Registered SIGTERM and SIGINT signal handlers") + def health_information(self, trace_id: Optional[str] = None) -> Dict[str, Any]: health_info = utils.health_information( self._platform, diff --git a/apollo/egress/agent/service/operations_poller.py b/apollo/egress/agent/service/operations_poller.py index 7df55da..c562fd5 100644 --- a/apollo/egress/agent/service/operations_poller.py +++ b/apollo/egress/agent/service/operations_poller.py @@ -107,6 +107,7 @@ def _run_loop(self): # Check backpressure - wait if ops_runner queue is full if self._can_accept_work and not self._can_accept_work(): logger.debug("Backpressure: waiting for ops_runner capacity") + self._send_heartbeat() self._wait_for_work() continue @@ -140,6 +141,13 @@ def _submit_operation(self, operation: Dict): ) self._operation_handler(path, operation_id, operation) + def _send_heartbeat(self): + """Send a heartbeat to the orchestrator to signal liveness during backpressure.""" + try: + self._backend_client.send_heartbeat() + except Exception: + logger.exception("Failed to send heartbeat") + def _fetch_operation(self) -> Optional[Dict]: """Fetch next operation from orchestrator. Returns None if queue empty or on error.""" try: @@ -162,12 +170,12 @@ def _wait_for_work(self): notified = self._condition.wait(timeout=self._poll_interval) self._waiting = False wait_seconds = round(time.monotonic() - wait_start, 3) - if notified: + if notified and self._running: logger.info( "Woken by work_available notification", extra={"wait_seconds": wait_seconds}, ) - else: + elif not notified: logger.info( "Poll interval reached, checking for work", extra={ diff --git a/apollo/egress/agent/utils/queue_async_processor.py b/apollo/egress/agent/utils/queue_async_processor.py index 8f9e928..19d8eb3 100644 --- a/apollo/egress/agent/utils/queue_async_processor.py +++ b/apollo/egress/agent/utils/queue_async_processor.py @@ -26,7 +26,9 @@ def __init__(self, name: str, handler: Callable[[T], None], thread_count: int): def start(self): self._running = True for thread_number in range(self._thread_count): - th = Thread(target=self._run, args=(thread_number,)) + # Daemon threads so they don't block process exit during shutdown. + # In-flight operations are abandoned — the orchestrator requeues them. + th = Thread(target=self._run, args=(thread_number,), daemon=True) th.start() def stop(self): diff --git a/tests/test_backend_client.py b/tests/test_backend_client.py index ea859e8..9bfd48b 100644 --- a/tests/test_backend_client.py +++ b/tests/test_backend_client.py @@ -1,7 +1,78 @@ from unittest import TestCase from unittest.mock import Mock, patch -from apollo.egress.agent.backend.backend_client import BackendClient +from apollo.egress.agent.backend.backend_client import BackendClient, INSTANCE_ID_HEADER + + +class BackendClientTests(TestCase): + def setUp(self): + self._login_token_provider = Mock() + self._login_token_provider.get_token.return_value = { + "x-mcd-id": "test-id", + "x-mcd-token": "test-token", + } + self._client = BackendClient( + backend_service_url="https://orchestrator.test", + login_token_provider=self._login_token_provider, + ) + + def test_instance_id_is_generated(self): + """Instance ID should be a non-empty string, unique per client.""" + self.assertIsInstance(self._client.instance_id, str) + self.assertTrue(len(self._client.instance_id) > 0) + + other = BackendClient( + backend_service_url="https://orchestrator.test", + login_token_provider=self._login_token_provider, + ) + self.assertNotEqual(self._client.instance_id, other.instance_id) + + def test_headers_include_instance_id(self): + """All requests should include the instance ID header.""" + headers = self._client._headers() + self.assertEqual(headers[INSTANCE_ID_HEADER], self._client.instance_id) + self.assertEqual(headers["x-mcd-id"], "test-id") + self.assertEqual(headers["x-mcd-token"], "test-token") + + def test_headers_include_extra(self): + """Extra headers should be merged.""" + headers = self._client._headers(**{"Content-Type": "application/json"}) + self.assertEqual(headers["Content-Type"], "application/json") + self.assertIn(INSTANCE_ID_HEADER, headers) + + @patch("requests.post") + def test_send_heartbeat(self, mock_post): + """send_heartbeat should POST to /api/v1/agent/heartbeat with correct headers.""" + mock_post.return_value.status_code = 200 + + self._client.send_heartbeat() + + mock_post.assert_called_once() + args, kwargs = mock_post.call_args + self.assertEqual(args[0], "https://orchestrator.test/api/v1/agent/heartbeat") + self.assertIn(INSTANCE_ID_HEADER, kwargs["headers"]) + self.assertEqual(kwargs["timeout"], 10) + + @patch("requests.post") + def test_notify_shutdown(self, mock_post): + """notify_shutdown should POST to /api/v1/agent/shutdown with correct headers.""" + mock_post.return_value.status_code = 200 + + self._client.notify_shutdown() + + mock_post.assert_called_once() + args, kwargs = mock_post.call_args + self.assertEqual(args[0], "https://orchestrator.test/api/v1/agent/shutdown") + self.assertIn(INSTANCE_ID_HEADER, kwargs["headers"]) + self.assertEqual(kwargs["timeout"], 15) + + @patch("requests.post") + def test_notify_shutdown_raises_on_failure(self, mock_post): + """notify_shutdown should raise on HTTP errors.""" + mock_post.return_value.raise_for_status.side_effect = Exception("500") + + with self.assertRaises(Exception): + self._client.notify_shutdown() class BackendClientURLTests(TestCase): diff --git a/tests/test_base_egress_service.py b/tests/test_base_egress_service.py index 925d9fd..75beb63 100644 --- a/tests/test_base_egress_service.py +++ b/tests/test_base_egress_service.py @@ -1,3 +1,5 @@ +import signal + from unittest import TestCase from unittest.mock import Mock, patch, MagicMock @@ -264,3 +266,32 @@ def test_can_accept_work_returns_false_when_queue_over_capacity(self): result = self._service._can_accept_work() self.assertFalse(result) + + @patch("apollo.egress.agent.service.base_egress_service.os") + def test_handle_goodbye_triggers_graceful_shutdown(self, mock_os): + """Test that _handle_goodbye notifies orchestrator, stops threads, and signals exit.""" + self._service._handle_goodbye("activity_timeout") + + self._backend_client.notify_shutdown.assert_called_once() + self._operations_poller.stop.assert_called_once() + mock_os.kill.assert_called_once_with(mock_os.getpid(), signal.SIGTERM) + + @patch("apollo.egress.agent.service.base_egress_service.os") + def test_trigger_graceful_shutdown_only_runs_once(self, mock_os): + """Test that _trigger_graceful_shutdown is guarded against double execution.""" + self._service._trigger_graceful_shutdown() + self._service._trigger_graceful_shutdown() + + # notify and stop called only once + self._backend_client.notify_shutdown.assert_called_once() + self._operations_poller.stop.assert_called_once() + + def test_start_passes_goodbye_handler_to_events_client(self): + """Test that start() passes goodbye_handler to EventsClient.""" + self._service._sse_enabled = True + + self._service.start() + + call_kwargs = self._events_client.start.call_args.kwargs + self.assertIn("goodbye_handler", call_kwargs) + self.assertEqual(call_kwargs["goodbye_handler"], self._service._handle_goodbye) diff --git a/tests/test_events_client.py b/tests/test_events_client.py index e22dc48..32cda94 100644 --- a/tests/test_events_client.py +++ b/tests/test_events_client.py @@ -65,13 +65,49 @@ def test_operation_event_is_ignored(self): # Operation events should be ignored - pull model handles operations self._work_available_handler.assert_not_called() + def test_goodbye_event_calls_handler(self): + """Test that goodbye event calls the goodbye handler with the reason.""" + goodbye_handler = Mock() + self._client.start( + work_available_handler=self._work_available_handler, + goodbye_handler=goodbye_handler, + ) + + self._client._event_received({"type": "goodbye", "reason": "activity_timeout"}) + + goodbye_handler.assert_called_once_with("activity_timeout") + self._work_available_handler.assert_not_called() + + def test_goodbye_event_defaults_reason_to_unknown(self): + """Test that goodbye event without reason defaults to 'unknown'.""" + goodbye_handler = Mock() + self._client.start( + work_available_handler=self._work_available_handler, + goodbye_handler=goodbye_handler, + ) + + self._client._event_received({"type": "goodbye"}) + + goodbye_handler.assert_called_once_with("unknown") + + def test_goodbye_event_without_handler(self): + """Test that goodbye event without handler doesn't raise.""" + self._client.start( + work_available_handler=self._work_available_handler, + ) + + # Should not raise + self._client._event_received({"type": "goodbye", "reason": "activity_timeout"}) + def test_stop_clears_handlers(self): """Test that stop() clears the handlers.""" self._client.start( work_available_handler=self._work_available_handler, + goodbye_handler=Mock(), ) self._client.stop() self.assertIsNone(self._client._work_available_handler) + self.assertIsNone(self._client._goodbye_handler) self._receiver.stop.assert_called_once() diff --git a/tests/test_operations_poller.py b/tests/test_operations_poller.py index 6669b3b..3e5c186 100644 --- a/tests/test_operations_poller.py +++ b/tests/test_operations_poller.py @@ -220,6 +220,25 @@ def can_accept_work(): self.assertGreater(call_count[0], 1) poller.stop() + def test_backpressure_sends_heartbeat(self): + """Test that poller sends heartbeat to orchestrator during backpressure.""" + self._backend_client.get_next_operation.return_value = None + + self._config_manager.get_int_value.return_value = 0.1 + + poller = OperationsPoller( + backend_client=self._backend_client, + config_manager=self._config_manager, + operation_handler=self._operation_handler, + can_accept_work=lambda: False, # always backpressured + ) + + poller.start() + time.sleep(0.3) + + self._backend_client.send_heartbeat.assert_called() + poller.stop() + def test_no_backpressure_when_can_accept_work_is_none(self): """Test that poller fetches normally when can_accept_work is not provided.""" op1 = {"operation_id": "op-1", "path": "/test1", "operation": {}}