Skip to content

Commit d22e5cb

Browse files
committed
address code review findings
1 parent e62d664 commit d22e5cb

File tree

3 files changed

+97
-8
lines changed

3 files changed

+97
-8
lines changed

apollo/egress/agent/service/base_egress_service.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import signal
44
import sys
5+
import threading
56
import uuid
67
from abc import ABC, abstractmethod
78
from dataclasses import dataclass
@@ -195,7 +196,7 @@ def __init__(
195196
config_manager=config_manager,
196197
push_metrics_handler=self._push_metrics,
197198
)
198-
self._shutting_down = False
199+
self._shutdown_lock = threading.Lock()
199200
self._operations_mapping = [
200201
OperationMapping(
201202
path="/api/v1/agent/execute/",
@@ -271,12 +272,11 @@ def _trigger_graceful_shutdown(self):
271272
"""Notify orchestrator, stop all threads, and terminate the process.
272273
273274
Safe to call from any thread. Cleanup runs exactly once (guarded by
274-
_shutting_down flag). In-flight operations are abandoned — the
275-
orchestrator requeues them via the shutdown notification.
275+
_shutdown_lock). In-flight operations are abandoned — the orchestrator
276+
requeues them via the shutdown notification.
276277
"""
277-
if self._shutting_down:
278+
if not self._shutdown_lock.acquire(blocking=False):
278279
return
279-
self._shutting_down = True
280280
try:
281281
self._backend_client.notify_shutdown()
282282
logger.info("Notified orchestrator of shutdown")
@@ -302,8 +302,7 @@ def register_signal_handlers(self):
302302
def _signal_handler(signum: int, frame: Any):
303303
sig_name = signal.Signals(signum).name
304304
logger.info(f"Received {sig_name}, shutting down")
305-
if not self._shutting_down:
306-
self._trigger_graceful_shutdown()
305+
self._trigger_graceful_shutdown()
307306
sys.exit(0)
308307

309308
signal.signal(signal.SIGTERM, _signal_handler)

tests/test_backend_client.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,78 @@
11
from unittest import TestCase
22
from unittest.mock import Mock, patch
33

4-
from apollo.egress.agent.backend.backend_client import BackendClient
4+
from apollo.egress.agent.backend.backend_client import BackendClient, INSTANCE_ID_HEADER
5+
6+
7+
class BackendClientTests(TestCase):
8+
def setUp(self):
9+
self._login_token_provider = Mock()
10+
self._login_token_provider.get_token.return_value = {
11+
"x-mcd-id": "test-id",
12+
"x-mcd-token": "test-token",
13+
}
14+
self._client = BackendClient(
15+
backend_service_url="https://orchestrator.test",
16+
login_token_provider=self._login_token_provider,
17+
)
18+
19+
def test_instance_id_is_generated(self):
20+
"""Instance ID should be a non-empty string, unique per client."""
21+
self.assertIsInstance(self._client.instance_id, str)
22+
self.assertTrue(len(self._client.instance_id) > 0)
23+
24+
other = BackendClient(
25+
backend_service_url="https://orchestrator.test",
26+
login_token_provider=self._login_token_provider,
27+
)
28+
self.assertNotEqual(self._client.instance_id, other.instance_id)
29+
30+
def test_headers_include_instance_id(self):
31+
"""All requests should include the instance ID header."""
32+
headers = self._client._headers()
33+
self.assertEqual(headers[INSTANCE_ID_HEADER], self._client.instance_id)
34+
self.assertEqual(headers["x-mcd-id"], "test-id")
35+
self.assertEqual(headers["x-mcd-token"], "test-token")
36+
37+
def test_headers_include_extra(self):
38+
"""Extra headers should be merged."""
39+
headers = self._client._headers(**{"Content-Type": "application/json"})
40+
self.assertEqual(headers["Content-Type"], "application/json")
41+
self.assertIn(INSTANCE_ID_HEADER, headers)
42+
43+
@patch("requests.post")
44+
def test_send_heartbeat(self, mock_post):
45+
"""send_heartbeat should POST to /api/v1/agent/heartbeat with correct headers."""
46+
mock_post.return_value.status_code = 200
47+
48+
self._client.send_heartbeat()
49+
50+
mock_post.assert_called_once()
51+
args, kwargs = mock_post.call_args
52+
self.assertEqual(args[0], "https://orchestrator.test/api/v1/agent/heartbeat")
53+
self.assertIn(INSTANCE_ID_HEADER, kwargs["headers"])
54+
self.assertEqual(kwargs["timeout"], 10)
55+
56+
@patch("requests.post")
57+
def test_notify_shutdown(self, mock_post):
58+
"""notify_shutdown should POST to /api/v1/agent/shutdown with correct headers."""
59+
mock_post.return_value.status_code = 200
60+
61+
self._client.notify_shutdown()
62+
63+
mock_post.assert_called_once()
64+
args, kwargs = mock_post.call_args
65+
self.assertEqual(args[0], "https://orchestrator.test/api/v1/agent/shutdown")
66+
self.assertIn(INSTANCE_ID_HEADER, kwargs["headers"])
67+
self.assertEqual(kwargs["timeout"], 15)
68+
69+
@patch("requests.post")
70+
def test_notify_shutdown_raises_on_failure(self, mock_post):
71+
"""notify_shutdown should raise on HTTP errors."""
72+
mock_post.return_value.raise_for_status.side_effect = Exception("500")
73+
74+
with self.assertRaises(Exception):
75+
self._client.notify_shutdown()
576

677

778
class BackendClientURLTests(TestCase):

tests/test_operations_poller.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,25 @@ def can_accept_work():
220220
self.assertGreater(call_count[0], 1)
221221
poller.stop()
222222

223+
def test_backpressure_sends_heartbeat(self):
224+
"""Test that poller sends heartbeat to orchestrator during backpressure."""
225+
self._backend_client.get_next_operation.return_value = None
226+
227+
self._config_manager.get_int_value.return_value = 0.1
228+
229+
poller = OperationsPoller(
230+
backend_client=self._backend_client,
231+
config_manager=self._config_manager,
232+
operation_handler=self._operation_handler,
233+
can_accept_work=lambda: False, # always backpressured
234+
)
235+
236+
poller.start()
237+
time.sleep(0.3)
238+
239+
self._backend_client.send_heartbeat.assert_called()
240+
poller.stop()
241+
223242
def test_no_backpressure_when_can_accept_work_is_none(self):
224243
"""Test that poller fetches normally when can_accept_work is not provided."""
225244
op1 = {"operation_id": "op-1", "path": "/test1", "operation": {}}

0 commit comments

Comments
 (0)