|
1 | 1 | from unittest import TestCase |
2 | 2 | from unittest.mock import Mock, patch |
3 | 3 |
|
4 | | -from apollo.egress.agent.backend.backend_client import BackendClient |
| 4 | +from apollo.egress.agent.backend.backend_client import BackendClient, INSTANCE_ID_HEADER |
| 5 | + |
| 6 | + |
| 7 | +class BackendClientTests(TestCase): |
| 8 | + def setUp(self): |
| 9 | + self._login_token_provider = Mock() |
| 10 | + self._login_token_provider.get_token.return_value = { |
| 11 | + "x-mcd-id": "test-id", |
| 12 | + "x-mcd-token": "test-token", |
| 13 | + } |
| 14 | + self._client = BackendClient( |
| 15 | + backend_service_url="https://orchestrator.test", |
| 16 | + login_token_provider=self._login_token_provider, |
| 17 | + ) |
| 18 | + |
| 19 | + def test_instance_id_is_generated(self): |
| 20 | + """Instance ID should be a non-empty string, unique per client.""" |
| 21 | + self.assertIsInstance(self._client.instance_id, str) |
| 22 | + self.assertTrue(len(self._client.instance_id) > 0) |
| 23 | + |
| 24 | + other = BackendClient( |
| 25 | + backend_service_url="https://orchestrator.test", |
| 26 | + login_token_provider=self._login_token_provider, |
| 27 | + ) |
| 28 | + self.assertNotEqual(self._client.instance_id, other.instance_id) |
| 29 | + |
| 30 | + def test_headers_include_instance_id(self): |
| 31 | + """All requests should include the instance ID header.""" |
| 32 | + headers = self._client._headers() |
| 33 | + self.assertEqual(headers[INSTANCE_ID_HEADER], self._client.instance_id) |
| 34 | + self.assertEqual(headers["x-mcd-id"], "test-id") |
| 35 | + self.assertEqual(headers["x-mcd-token"], "test-token") |
| 36 | + |
| 37 | + def test_headers_include_extra(self): |
| 38 | + """Extra headers should be merged.""" |
| 39 | + headers = self._client._headers(**{"Content-Type": "application/json"}) |
| 40 | + self.assertEqual(headers["Content-Type"], "application/json") |
| 41 | + self.assertIn(INSTANCE_ID_HEADER, headers) |
| 42 | + |
| 43 | + @patch("requests.post") |
| 44 | + def test_send_heartbeat(self, mock_post): |
| 45 | + """send_heartbeat should POST to /api/v1/agent/heartbeat with correct headers.""" |
| 46 | + mock_post.return_value.status_code = 200 |
| 47 | + |
| 48 | + self._client.send_heartbeat() |
| 49 | + |
| 50 | + mock_post.assert_called_once() |
| 51 | + args, kwargs = mock_post.call_args |
| 52 | + self.assertEqual(args[0], "https://orchestrator.test/api/v1/agent/heartbeat") |
| 53 | + self.assertIn(INSTANCE_ID_HEADER, kwargs["headers"]) |
| 54 | + self.assertEqual(kwargs["timeout"], 10) |
| 55 | + |
| 56 | + @patch("requests.post") |
| 57 | + def test_notify_shutdown(self, mock_post): |
| 58 | + """notify_shutdown should POST to /api/v1/agent/shutdown with correct headers.""" |
| 59 | + mock_post.return_value.status_code = 200 |
| 60 | + |
| 61 | + self._client.notify_shutdown() |
| 62 | + |
| 63 | + mock_post.assert_called_once() |
| 64 | + args, kwargs = mock_post.call_args |
| 65 | + self.assertEqual(args[0], "https://orchestrator.test/api/v1/agent/shutdown") |
| 66 | + self.assertIn(INSTANCE_ID_HEADER, kwargs["headers"]) |
| 67 | + self.assertEqual(kwargs["timeout"], 15) |
| 68 | + |
| 69 | + @patch("requests.post") |
| 70 | + def test_notify_shutdown_raises_on_failure(self, mock_post): |
| 71 | + """notify_shutdown should raise on HTTP errors.""" |
| 72 | + mock_post.return_value.raise_for_status.side_effect = Exception("500") |
| 73 | + |
| 74 | + with self.assertRaises(Exception): |
| 75 | + self._client.notify_shutdown() |
5 | 76 |
|
6 | 77 |
|
7 | 78 | class BackendClientURLTests(TestCase): |
|
0 commit comments