Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions apollo/egress/agent/backend/backend_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
import uuid
from typing import Dict, Any, Optional
import requests
from retry import retry
Expand All @@ -11,6 +12,8 @@

logger = logging.getLogger(__name__)

INSTANCE_ID_HEADER = "x-mcd-agent-instance-id"


class BackendClient:
"""
Expand All @@ -24,6 +27,18 @@ def __init__(
) -> None:
self._backend_service_url = backend_service_url
self._login_token_provider = login_token_provider
self._instance_id = str(uuid.uuid4())

@property
def instance_id(self) -> str:
return self._instance_id

def _headers(self, **extra: str) -> Dict[str, str]:
return {
**self._login_token_provider.get_token(),
INSTANCE_ID_HEADER: self._instance_id,
**extra,
}

def push_results(
self, operation_id: str, result: Dict[str, Any]
Expand Down Expand Up @@ -55,10 +70,7 @@ def _push_results_with_retries(
response = requests.put(
results_url,
data=result_str,
headers={
"Content-Type": "application/json",
**self._login_token_provider.get_token(),
},
headers=self._headers(**{"Content-Type": "application/json"}),
timeout=60,
)
response.raise_for_status()
Expand All @@ -81,7 +93,7 @@ def _execute_operation_with_retries(
"""
try:
url = build_url(self._backend_service_url, path)
headers = self._login_token_provider.get_token()
headers = self._headers()
if body:
headers["Content-Type"] = "application/json"
response = requests.request(
Expand Down Expand Up @@ -116,6 +128,26 @@ def download_operation(self, operation_id: str) -> Dict:
)
return operation

def send_heartbeat(self):
"""Send a liveness heartbeat to the orchestrator."""
url = build_url(self._backend_service_url, "/api/v1/agent/heartbeat")
response = requests.post(
url,
headers=self._headers(),
timeout=10,
)
response.raise_for_status()

def notify_shutdown(self):
"""Notify orchestrator that this agent is shutting down. Best-effort."""
url = build_url(self._backend_service_url, "/api/v1/agent/shutdown")
response = requests.post(
url,
headers=self._headers(),
timeout=15,
)
response.raise_for_status()

def get_next_operation(self) -> Optional[Dict[str, Any]]:
"""
Fetch next operation from orchestrator queue.
Expand All @@ -125,7 +157,7 @@ def get_next_operation(self) -> Optional[Dict[str, Any]]:
url = build_url(self._backend_service_url, "/api/v1/agent/operation")
response = requests.get(
url,
headers=self._login_token_provider.get_token(),
headers=self._headers(),
timeout=30,
)
if response.status_code == 204:
Expand Down
10 changes: 10 additions & 0 deletions apollo/egress/agent/events/events_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
_EVENT_TYPE_HEARTBEAT = "heartbeat"
_EVENT_TYPE_WELCOME = "welcome"
_EVENT_TYPE_WORK_AVAILABLE = "work_available"
_EVENT_TYPE_GOODBYE = "goodbye"

logger = logging.getLogger(__name__)

Expand All @@ -35,12 +36,15 @@ def __init__(
self._stopped = True
self._heartbeat_checker = heartbeat_checker or HeartbeatChecker(self._reconnect)
self._work_available_handler: Optional[Callable[[], None]] = None
self._goodbye_handler: Optional[Callable[[str], None]] = None

def start(
self,
work_available_handler: Callable[[], None],
goodbye_handler: Optional[Callable[[str], None]] = None,
):
self._work_available_handler = work_available_handler
self._goodbye_handler = goodbye_handler
self._stopped = False
self._receiver.start(
handler=self._event_received,
Expand All @@ -51,6 +55,7 @@ def start(
def stop(self):
self._stopped = True
self._work_available_handler = None
self._goodbye_handler = None
self._receiver.stop()

def _reconnect(self):
Expand All @@ -69,6 +74,11 @@ def _event_received(self, event: Dict):
self._work_available_handler()
else:
logger.warning("work_available received but no handler registered")
elif event_type == _EVENT_TYPE_GOODBYE:
reason = event.get("reason", "unknown")
logger.warning(f"Received goodbye from orchestrator: {reason}")
if self._goodbye_handler:
self._goodbye_handler(reason)
else:
logger.info(f"Ignoring unexpected event type: {event_type}")

Expand Down
7 changes: 6 additions & 1 deletion apollo/egress/agent/events/sse_client_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ def __init__(
self,
base_url: str,
login_token_provider: LoginTokenProvider,
extra_headers: Optional[Dict[str, str]] = None,
):
self._current_loop_id: Optional[str] = None
self._base_url = base_url
self._login_token_provider = login_token_provider
self._extra_headers = extra_headers or {}
self._sse_client: Optional[sseclient.SSEClient] = None
self._event_handler: Optional[Callable[[Dict], None]] = None
self._connected_handler: Optional[Callable[[], None]] = None
Expand All @@ -55,7 +57,9 @@ def _start_receiver_thread(self):
# enough
loop_id = str(uuid4())
self._current_loop_id = loop_id
th = Thread(target=self._run_receiver, args=(loop_id,))
# Daemon thread so it doesn't block process exit — the SSE network read
# can't be interrupted cleanly, and all other cleanup is done before exit.
th = Thread(target=self._run_receiver, args=(loop_id,), daemon=True)
th.start()

def stop(self):
Expand Down Expand Up @@ -86,6 +90,7 @@ def _connect_and_consume_events(self, loop_id: str):
headers = {
"Accept": "text/event-stream",
**mc_login_token,
**self._extra_headers,
}
self._sse_client = sseclient.SSEClient(url, headers=headers)
if self._connected_handler:
Expand Down
58 changes: 56 additions & 2 deletions apollo/egress/agent/service/base_egress_service.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import logging
import os
import signal
import sys
import threading
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Optional, Any, Callable, List

from apollo.egress.agent.backend.backend_client import BackendClient
from apollo.egress.agent.backend.backend_client import BackendClient, INSTANCE_ID_HEADER
from apollo.egress.agent.events.ack_sender import (
AckSender,
DEFAULT_ACK_INTERVAL_SECONDS,
Expand Down Expand Up @@ -177,6 +181,7 @@ def __init__(
receiver=SSEClientReceiver(
base_url=backend_service_url,
login_token_provider=self._login_token_provider,
extra_headers={INSTANCE_ID_HEADER: self._backend_client.instance_id},
),
)
self._operations_poller = operations_poller or OperationsPoller(
Expand All @@ -191,6 +196,7 @@ def __init__(
config_manager=config_manager,
push_metrics_handler=self._push_metrics,
)
self._shutdown_lock = threading.Lock()
self._operations_mapping = [
OperationMapping(
path="/api/v1/agent/execute/",
Expand Down Expand Up @@ -234,13 +240,15 @@ def start(self):
if self._sse_enabled:
self._events_client.start(
work_available_handler=self._operations_poller.notify_work_available,
goodbye_handler=self._handle_goodbye,
)
self._ack_sender.start(handler=self._send_ack)
if self._logs_sender:
self._logs_sender.start(handler=self._push_logs)

logger.info(
f"{self._service_name} Service Started: v{self._get_version()} (build #{self._get_build_number()})"
f"{self._service_name} Service Started: v{self._get_version()} (build #{self._get_build_number()}), "
f"instance_id={self._backend_client.instance_id}"
)

def stop(self):
Expand All @@ -255,6 +263,52 @@ def stop(self):
if self._logs_sender:
self._logs_sender.stop()

def _handle_goodbye(self, reason: str):
"""Handle goodbye event from orchestrator — trigger graceful shutdown."""
logger.warning(f"Orchestrator requested shutdown: {reason}")
self._trigger_graceful_shutdown()

def _trigger_graceful_shutdown(self):
"""Notify orchestrator, stop all threads, and terminate the process.

Safe to call from any thread. Cleanup runs exactly once (guarded by
_shutdown_lock). In-flight operations are abandoned — the orchestrator
requeues them via the shutdown notification.
"""
if not self._shutdown_lock.acquire(blocking=False):
return
try:
self._backend_client.notify_shutdown()
logger.info("Notified orchestrator of shutdown")
except Exception:
logger.exception("Failed to notify orchestrator of shutdown")
self.stop()
logger.info("Shutdown complete, signaling exit")
# When running under gunicorn, signal the master process so all workers
# shut down — not just the one that received the goodbye event. When
# running standalone (local dev, single process), signal just ourselves.
if "gunicorn" in sys.modules:
os.kill(os.getppid(), signal.SIGTERM)
else:
os.kill(os.getpid(), signal.SIGTERM)

def register_signal_handlers(self):
"""Register SIGTERM and SIGINT handlers for graceful shutdown.

Should be called from the main thread after start(). Signal handlers
can only be registered from the main thread.
"""

def _signal_handler(signum: int, frame: Any):
sig_name = signal.Signals(signum).name
logger.info(f"Received {sig_name}, shutting down")
self._trigger_graceful_shutdown()
sys.exit(0)

signal.signal(signal.SIGTERM, _signal_handler)
signal.signal(signal.SIGINT, _signal_handler)
logger.info("Registered SIGTERM and SIGINT signal handlers")

def health_information(self, trace_id: Optional[str] = None) -> Dict[str, Any]:
health_info = utils.health_information(
self._platform,
Expand Down
12 changes: 10 additions & 2 deletions apollo/egress/agent/service/operations_poller.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def _run_loop(self):
# Check backpressure - wait if ops_runner queue is full
if self._can_accept_work and not self._can_accept_work():
logger.debug("Backpressure: waiting for ops_runner capacity")
self._send_heartbeat()
self._wait_for_work()
continue

Expand Down Expand Up @@ -140,6 +141,13 @@ def _submit_operation(self, operation: Dict):
)
self._operation_handler(path, operation_id, operation)

def _send_heartbeat(self):
"""Send a heartbeat to the orchestrator to signal liveness during backpressure."""
try:
self._backend_client.send_heartbeat()
except Exception:
logger.exception("Failed to send heartbeat")

def _fetch_operation(self) -> Optional[Dict]:
"""Fetch next operation from orchestrator. Returns None if queue empty or on error."""
try:
Expand All @@ -162,12 +170,12 @@ def _wait_for_work(self):
notified = self._condition.wait(timeout=self._poll_interval)
self._waiting = False
wait_seconds = round(time.monotonic() - wait_start, 3)
if notified:
if notified and self._running:
logger.info(
"Woken by work_available notification",
extra={"wait_seconds": wait_seconds},
)
else:
elif not notified:
logger.info(
"Poll interval reached, checking for work",
extra={
Expand Down
4 changes: 3 additions & 1 deletion apollo/egress/agent/utils/queue_async_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def __init__(self, name: str, handler: Callable[[T], None], thread_count: int):
def start(self):
self._running = True
for thread_number in range(self._thread_count):
th = Thread(target=self._run, args=(thread_number,))
# Daemon threads so they don't block process exit during shutdown.
# In-flight operations are abandoned — the orchestrator requeues them.
th = Thread(target=self._run, args=(thread_number,), daemon=True)
th.start()

def stop(self):
Expand Down
73 changes: 72 additions & 1 deletion tests/test_backend_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,78 @@
from unittest import TestCase
from unittest.mock import Mock, patch

from apollo.egress.agent.backend.backend_client import BackendClient
from apollo.egress.agent.backend.backend_client import BackendClient, INSTANCE_ID_HEADER


class BackendClientTests(TestCase):
def setUp(self):
self._login_token_provider = Mock()
self._login_token_provider.get_token.return_value = {
"x-mcd-id": "test-id",
"x-mcd-token": "test-token",
}
self._client = BackendClient(
backend_service_url="https://orchestrator.test",
login_token_provider=self._login_token_provider,
)

def test_instance_id_is_generated(self):
"""Instance ID should be a non-empty string, unique per client."""
self.assertIsInstance(self._client.instance_id, str)
self.assertTrue(len(self._client.instance_id) > 0)

other = BackendClient(
backend_service_url="https://orchestrator.test",
login_token_provider=self._login_token_provider,
)
self.assertNotEqual(self._client.instance_id, other.instance_id)

def test_headers_include_instance_id(self):
"""All requests should include the instance ID header."""
headers = self._client._headers()
self.assertEqual(headers[INSTANCE_ID_HEADER], self._client.instance_id)
self.assertEqual(headers["x-mcd-id"], "test-id")
self.assertEqual(headers["x-mcd-token"], "test-token")

def test_headers_include_extra(self):
"""Extra headers should be merged."""
headers = self._client._headers(**{"Content-Type": "application/json"})
self.assertEqual(headers["Content-Type"], "application/json")
self.assertIn(INSTANCE_ID_HEADER, headers)

@patch("requests.post")
def test_send_heartbeat(self, mock_post):
"""send_heartbeat should POST to /api/v1/agent/heartbeat with correct headers."""
mock_post.return_value.status_code = 200

self._client.send_heartbeat()

mock_post.assert_called_once()
args, kwargs = mock_post.call_args
self.assertEqual(args[0], "https://orchestrator.test/api/v1/agent/heartbeat")
self.assertIn(INSTANCE_ID_HEADER, kwargs["headers"])
self.assertEqual(kwargs["timeout"], 10)

@patch("requests.post")
def test_notify_shutdown(self, mock_post):
"""notify_shutdown should POST to /api/v1/agent/shutdown with correct headers."""
mock_post.return_value.status_code = 200

self._client.notify_shutdown()

mock_post.assert_called_once()
args, kwargs = mock_post.call_args
self.assertEqual(args[0], "https://orchestrator.test/api/v1/agent/shutdown")
self.assertIn(INSTANCE_ID_HEADER, kwargs["headers"])
self.assertEqual(kwargs["timeout"], 15)

@patch("requests.post")
def test_notify_shutdown_raises_on_failure(self, mock_post):
"""notify_shutdown should raise on HTTP errors."""
mock_post.return_value.raise_for_status.side_effect = Exception("500")

with self.assertRaises(Exception):
self._client.notify_shutdown()


class BackendClientURLTests(TestCase):
Expand Down
Loading