diff --git a/setup.py b/setup.py index b9f67754..f4bad6cb 100644 --- a/setup.py +++ b/setup.py @@ -94,7 +94,7 @@ def get_cryptography_version(): "asyncstdlib", "async-timeout", "typing_extensions", - "python-auditor==0.5.0", + "python-auditor==0.10.0", "tzlocal", "pyotp", *REST_REQUIRES, diff --git a/tardis/plugins/auditor.py b/tardis/plugins/auditor.py index b843f0b6..29c9f14a 100644 --- a/tardis/plugins/auditor.py +++ b/tardis/plugins/auditor.py @@ -52,13 +52,33 @@ def __init__(self): self._group = getattr(config_auditor, "group", "tardis") auditor_timeout = getattr(config_auditor, "timeout", 30) self._local_timezone = get_localzone() - self._client = ( + + tls_config = getattr(config_auditor, "tls_config", {}) + use_tls = bool(getattr(tls_config, "use_tls", False)) + + client_builder = ( pyauditor.AuditorClientBuilder() .address(config_auditor.host, config_auditor.port) .timeout(auditor_timeout) - .build() ) + if use_tls: + required_tls_paths = ["client_cert_path", "client_key_path", "ca_cert_path"] + + for path_key in required_tls_paths: + if not hasattr(tls_config, path_key): + raise ValueError(f"Missing TLS configuration: {path_key}") + + self._client = client_builder.with_tls( + tls_config.client_cert_path, + tls_config.client_key_path, + tls_config.ca_cert_path, + ).build() + self.logger.debug("TLS configuration completed successfully") + + else: + self._client = client_builder.build() + async def notify(self, state: State, resource_attributes: AttributeDict) -> None: """ Pushes a record to an Auditor instance when the drone is in state diff --git a/tests/plugins_t/test_auditor.py b/tests/plugins_t/test_auditor.py index f3dc9dcc..63d5c358 100644 --- a/tests/plugins_t/test_auditor.py +++ b/tests/plugins_t/test_auditor.py @@ -21,6 +21,8 @@ from tests.utilities.utilities import async_return from tests.utilities.utilities import run_async +from unittest.mock import AsyncMock + class TestAuditor(TestCase): @classmethod @@ -49,6 +51,10 @@ def setUp(self): self.memory = 100 self.drone_uuid = "test-drone" self.machine_type = "test_machine_type" + self.use_tls = True + self.ca_cert_path = "/path/to/ca.crt" + self.client_cert_path = "/path/to/client.crt" + self.client_key_path = "/path/to/client.key" config = self.mock_config.return_value config.Plugins.Auditor = AttributeDict( host=self.address, @@ -62,6 +68,12 @@ def setUp(self): Memory=AttributeDict(BLUBB=1.4), ) ), + tls_config=AttributeDict( + use_tls=self.use_tls, + ca_cert_path=self.ca_cert_path, + client_cert_path=self.client_cert_path, + client_key_path=self.client_key_path, + ), ) config.Sites = [AttributeDict(name=self.site)] config.testsite.MachineTypes = [self.machine_type] @@ -85,8 +97,12 @@ def setUp(self): self.client.add.return_value = async_return() self.client.update.return_value = async_return() + self.client = AsyncMock() + self.mock_auditorclientbuilder.return_value.build.return_value = self.client + self.config = config self.plugin = Auditor() + self.plugin._client = self.client def test_default_fields(self): # Potential future race condition ahead. @@ -100,7 +116,29 @@ def test_default_fields(self): self.assertEqual(plugin._user, "tardis") self.assertEqual(plugin._group, "tardis") - def test_notify(self): + def test_tls_configuration(self): + tls_config = self.config.Plugins.Auditor.tls_config + + self.assertTrue(tls_config.use_tls, "TLS should be enabled") + + self.assertIsNotNone(tls_config, "Client should be created") + + def test_tls_disabled(self): + self.config.Plugins.Auditor.tls_config.use_tls = False + + builder_mock = self.mock_auditorclientbuilder.return_value + timeout_mock = builder_mock.address.return_value.timeout.return_value + + with patch.object(timeout_mock, "with_tls") as mock_with_tls: + mock_with_tls.return_value = timeout_mock + + mock_with_tls.assert_not_called() + + @patch( + "tardis.plugins.auditor.pyauditor.AuditorClientBuilder.build", + new_callable=AsyncMock, + ) + def test_notify(self, mock_build): self.mock_auditorclientbuilder.return_value.address.assert_called_with( self.address, self.port, @@ -108,21 +146,29 @@ def test_notify(self): self.mock_auditorclientbuilder.return_value.address.return_value.timeout.assert_called_with( self.timeout, ) + self.mock_auditorclientbuilder.return_value.address.return_value.timeout.return_value.with_tls.assert_called_with( + self.client_cert_path, self.client_key_path, self.ca_cert_path + ) + mock_client = AsyncMock() + mock_build.return_value = mock_client + self.plugin._client = mock_client + run_async( self.plugin.notify, state=AvailableState(), resource_attributes=self.test_param, ) - self.assertEqual(self.client.add.call_count, 1) - self.assertEqual(self.client.update.call_count, 0) + mock_client.add.assert_called_once() + mock_client.update.assert_not_called() run_async( self.plugin.notify, state=DownState(), resource_attributes=self.test_param, ) - self.assertEqual(self.client.add.call_count, 1) - self.assertEqual(self.client.update.call_count, 1) + + self.assertEqual(mock_client.add.call_count, 1) + self.assertEqual(mock_client.update.call_count, 1) # test for no-op for state in [ @@ -140,11 +186,11 @@ def test_notify(self): state=state, resource_attributes=self.test_param, ) - self.assertEqual(self.client.add.call_count, 1) - self.assertEqual(self.client.update.call_count, 1) + self.assertEqual(mock_client.add.call_count, 1) + self.assertEqual(mock_client.update.call_count, 1) # test exception handling - self.client.update.side_effect = RuntimeError( + mock_client.update.side_effect = RuntimeError( "Reqwest Error: HTTP status client error (404 Not Found) " "for url (http://127.0.0.1:8000/update)" ) @@ -154,7 +200,7 @@ def test_notify(self): resource_attributes=self.test_param, ) - self.client.update.side_effect = RuntimeError( + mock_client.update.side_effect = RuntimeError( "Reqwest Error: HTTP status client error (403 Forbidden) " "for url (http://127.0.0.1:8000/update)" ) @@ -165,7 +211,7 @@ def test_notify(self): resource_attributes=self.test_param, ) - self.client.update.side_effect = RuntimeError("Does not match RegEx") + mock_client.update.side_effect = RuntimeError("Does not match RegEx") with self.assertRaises(RuntimeError): run_async( self.plugin.notify, @@ -173,7 +219,7 @@ def test_notify(self): resource_attributes=self.test_param, ) - self.client.update.side_effect = ValueError("Other exception") + mock_client.update.side_effect = ValueError("Other exception") with self.assertRaises(ValueError): run_async( self.plugin.notify,