Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_cryptography_version():
"asyncstdlib",
"async-timeout",
"typing_extensions",
"python-auditor==0.5.0",
"python-auditor==0.10.0",
"tzlocal",
"pyotp",
*REST_REQUIRES,
Expand Down
24 changes: 22 additions & 2 deletions tardis/plugins/auditor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,33 @@ def __init__(self):
self._group = getattr(config_auditor, "group", "tardis")
auditor_timeout = getattr(config_auditor, "timeout", 30)
self._local_timezone = get_localzone()
self._client = (

tls_config = getattr(config_auditor, "tls_config", {})
use_tls = bool(getattr(tls_config, "use_tls", False))

client_builder = (
pyauditor.AuditorClientBuilder()
.address(config_auditor.host, config_auditor.port)
.timeout(auditor_timeout)
.build()
)

if use_tls:
required_tls_paths = ["client_cert_path", "client_key_path", "ca_cert_path"]

for path_key in required_tls_paths:
if not hasattr(tls_config, path_key):
raise ValueError(f"Missing TLS configuration: {path_key}")

self._client = client_builder.with_tls(
tls_config.client_cert_path,
tls_config.client_key_path,
tls_config.ca_cert_path,
).build()
self.logger.debug("TLS configuration completed successfully")

else:
self._client = client_builder.build()

async def notify(self, state: State, resource_attributes: AttributeDict) -> None:
"""
Pushes a record to an Auditor instance when the drone is in state
Expand Down
68 changes: 57 additions & 11 deletions tests/plugins_t/test_auditor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from tests.utilities.utilities import async_return
from tests.utilities.utilities import run_async

from unittest.mock import AsyncMock


class TestAuditor(TestCase):
@classmethod
Expand Down Expand Up @@ -49,6 +51,10 @@ def setUp(self):
self.memory = 100
self.drone_uuid = "test-drone"
self.machine_type = "test_machine_type"
self.use_tls = True
self.ca_cert_path = "/path/to/ca.crt"
self.client_cert_path = "/path/to/client.crt"
self.client_key_path = "/path/to/client.key"
config = self.mock_config.return_value
config.Plugins.Auditor = AttributeDict(
host=self.address,
Expand All @@ -62,6 +68,12 @@ def setUp(self):
Memory=AttributeDict(BLUBB=1.4),
)
),
tls_config=AttributeDict(
use_tls=self.use_tls,
ca_cert_path=self.ca_cert_path,
client_cert_path=self.client_cert_path,
client_key_path=self.client_key_path,
),
)
config.Sites = [AttributeDict(name=self.site)]
config.testsite.MachineTypes = [self.machine_type]
Expand All @@ -85,8 +97,12 @@ def setUp(self):
self.client.add.return_value = async_return()
self.client.update.return_value = async_return()

self.client = AsyncMock()
self.mock_auditorclientbuilder.return_value.build.return_value = self.client

self.config = config
self.plugin = Auditor()
self.plugin._client = self.client

def test_default_fields(self):
# Potential future race condition ahead.
Expand All @@ -100,29 +116,59 @@ def test_default_fields(self):
self.assertEqual(plugin._user, "tardis")
self.assertEqual(plugin._group, "tardis")

def test_notify(self):
def test_tls_configuration(self):
tls_config = self.config.Plugins.Auditor.tls_config

self.assertTrue(tls_config.use_tls, "TLS should be enabled")

self.assertIsNotNone(tls_config, "Client should be created")

def test_tls_disabled(self):
self.config.Plugins.Auditor.tls_config.use_tls = False

builder_mock = self.mock_auditorclientbuilder.return_value
timeout_mock = builder_mock.address.return_value.timeout.return_value

with patch.object(timeout_mock, "with_tls") as mock_with_tls:
mock_with_tls.return_value = timeout_mock

mock_with_tls.assert_not_called()

@patch(
"tardis.plugins.auditor.pyauditor.AuditorClientBuilder.build",
new_callable=AsyncMock,
)
def test_notify(self, mock_build):
self.mock_auditorclientbuilder.return_value.address.assert_called_with(
self.address,
self.port,
)
self.mock_auditorclientbuilder.return_value.address.return_value.timeout.assert_called_with(
self.timeout,
)
self.mock_auditorclientbuilder.return_value.address.return_value.timeout.return_value.with_tls.assert_called_with(
self.client_cert_path, self.client_key_path, self.ca_cert_path
)
mock_client = AsyncMock()
mock_build.return_value = mock_client
self.plugin._client = mock_client

run_async(
self.plugin.notify,
state=AvailableState(),
resource_attributes=self.test_param,
)
self.assertEqual(self.client.add.call_count, 1)
self.assertEqual(self.client.update.call_count, 0)
mock_client.add.assert_called_once()
mock_client.update.assert_not_called()

run_async(
self.plugin.notify,
state=DownState(),
resource_attributes=self.test_param,
)
self.assertEqual(self.client.add.call_count, 1)
self.assertEqual(self.client.update.call_count, 1)

self.assertEqual(mock_client.add.call_count, 1)
self.assertEqual(mock_client.update.call_count, 1)

# test for no-op
for state in [
Expand All @@ -140,11 +186,11 @@ def test_notify(self):
state=state,
resource_attributes=self.test_param,
)
self.assertEqual(self.client.add.call_count, 1)
self.assertEqual(self.client.update.call_count, 1)
self.assertEqual(mock_client.add.call_count, 1)
self.assertEqual(mock_client.update.call_count, 1)

# test exception handling
self.client.update.side_effect = RuntimeError(
mock_client.update.side_effect = RuntimeError(
"Reqwest Error: HTTP status client error (404 Not Found) "
"for url (http://127.0.0.1:8000/update)"
)
Expand All @@ -154,7 +200,7 @@ def test_notify(self):
resource_attributes=self.test_param,
)

self.client.update.side_effect = RuntimeError(
mock_client.update.side_effect = RuntimeError(
"Reqwest Error: HTTP status client error (403 Forbidden) "
"for url (http://127.0.0.1:8000/update)"
)
Expand All @@ -165,15 +211,15 @@ def test_notify(self):
resource_attributes=self.test_param,
)

self.client.update.side_effect = RuntimeError("Does not match RegEx")
mock_client.update.side_effect = RuntimeError("Does not match RegEx")
with self.assertRaises(RuntimeError):
run_async(
self.plugin.notify,
state=DownState(),
resource_attributes=self.test_param,
)

self.client.update.side_effect = ValueError("Other exception")
mock_client.update.side_effect = ValueError("Other exception")
with self.assertRaises(ValueError):
run_async(
self.plugin.notify,
Expand Down