diff --git a/ada_verona/database/epsilon_status.py b/ada_verona/database/epsilon_status.py index 9a65577..0f6321f 100644 --- a/ada_verona/database/epsilon_status.py +++ b/ada_verona/database/epsilon_status.py @@ -15,6 +15,8 @@ from dataclasses import dataclass +import numpy as np + from ada_verona.database.verification_result import CompleteVerificationData, VerificationResult @@ -40,15 +42,23 @@ def set_values(self, complete_verification_data: CompleteVerificationData): """ self.result = complete_verification_data.result self.time = complete_verification_data.took - self.obtained_labels = complete_verification_data.obtained_labels + self.obtained_labels = getattr(complete_verification_data, "obtained_labels", None) def to_dict(self) -> dict: - """ - Convert the EpsilonStatus to a dictionary. + """Convert the EpsilonStatus to a dictionary.""" + obtained_labels_value = None + if self.obtained_labels is not None: + if isinstance(self.obtained_labels, np.ndarray): + obtained_labels_value = self.obtained_labels.flatten().tolist() + elif isinstance(self.obtained_labels, list): + obtained_labels_value = self.obtained_labels + else: + obtained_labels_value = [self.obtained_labels] - Returns: - dict: The dictionary representation of the EpsilonStatus. - """ - return dict(epsilon_value=self.value, result=self.result, time=self.time, verifier=self.verifier, - obtained_labels=self.obtained_labels.flatten().tolist() if self.obtained_labels is not None - else None) + return dict( + epsilon_value=self.value, + result=self.result, + time=self.time, + verifier=self.verifier, + obtained_labels=obtained_labels_value, + ) diff --git a/ada_verona/verification_module/auto_verify_module.py b/ada_verona/verification_module/auto_verify_module.py index d8764a0..49547a7 100644 --- a/ada_verona/verification_module/auto_verify_module.py +++ b/ada_verona/verification_module/auto_verify_module.py @@ -28,7 +28,6 @@ logger = logging.getLogger(__name__) - class AutoVerifyModule(VerificationModule): """ A module for automatically verifying the robustness of a model using a specified verifier. @@ -46,7 +45,7 @@ def __init__(self, verifier: Verifier, timeout: float, config: Path = None) -> N self.verifier = verifier self.timeout = timeout self.config = config - self.name = f"AutoVerifyModule ({verifier.name})" + self.name = f"AutoVerifyModule ({verifier.name})" def verify(self, verification_context: VerificationContext, epsilon: float) -> str | CompleteVerificationData: """ @@ -78,11 +77,20 @@ def verify(self, verification_context: VerificationContext, epsilon: float) -> s if isinstance(result, Ok): outcome = result.unwrap() + if outcome.result == "SAT" and outcome.counter_example: + try: + predicted_label = parse_counter_example_label(result) + outcome.obtained_labels = [str(predicted_label)] + except Exception as e: + logger.warning(f"Failed to parse counter example label: {e}") + if not hasattr(outcome, "obtained_labels"): + outcome.obtained_labels = None return outcome elif isinstance(result, Err): logger.info(f"Error during verification: {result.unwrap_err()}") return result.unwrap_err() + def parse_counter_example(result: Ok, verification_context: VerificationContext) -> np.ndarray: """ Parse the counter example from the verification result. @@ -95,11 +103,11 @@ def parse_counter_example(result: Ok, verification_context: VerificationContext) """ string_list_without_sat = [x for x in result.unwrap().counter_example.split("\n") if "sat" not in x] numbers = [x.replace("(", "").replace(")", "") for x in string_list_without_sat if "Y" not in x] - counter_example_array = np.array([float(re.sub(r'X_\d*', '', x).strip()) for x in numbers if x.strip()]) - + counter_example_array = np.array([float(re.sub(r"X_\d*", "", x).strip()) for x in numbers if x.strip()]) return counter_example_array.reshape(verification_context.data_point.data.shape) + def parse_counter_example_label(result: Ok) -> int: """ Parse the counter example label from the verification result. @@ -112,6 +120,6 @@ def parse_counter_example_label(result: Ok) -> int: """ string_list_without_sat = [x for x in result.unwrap().counter_example.split("\n") if "sat" not in x] numbers = [x.replace("(", "").replace(")", "") for x in string_list_without_sat if "X" not in x] - counter_example_array = np.array([float(re.sub(r'Y_\d*', '', x).strip()) for x in numbers if x.strip()]) + counter_example_array = np.array([float(re.sub(r"Y_\d*", "", x).strip()) for x in numbers if x.strip()]) - return int(np.argmax(counter_example_array)) \ No newline at end of file + return int(np.argmax(counter_example_array)) diff --git a/tests/test_database/test_epsilon_status.py b/tests/test_database/test_epsilon_status.py index dc7427a..0e87d13 100644 --- a/tests/test_database/test_epsilon_status.py +++ b/tests/test_database/test_epsilon_status.py @@ -13,14 +13,15 @@ # limitations under the License. # ============================================================================== +import numpy as np + from ada_verona.database.epsilon_status import EpsilonStatus from ada_verona.database.verification_result import VerificationResult def test_epsilon_status_initialization(): - epsilon_value = 0.5 - result = VerificationResult.SAT + result = VerificationResult.SAT time_taken = 1.23 epsilon_status = EpsilonStatus(value=epsilon_value, result=result, time=time_taken) @@ -32,14 +33,12 @@ def test_epsilon_status_initialization(): def test_epsilon_status_to_dict(): epsilon_value = 0.5 - result = VerificationResult.UNSAT + result = VerificationResult.UNSAT time_taken = 2.34 epsilon_status = EpsilonStatus(value=epsilon_value, result=result, time=time_taken) - result_dict = epsilon_status.to_dict() - assert result_dict == { "epsilon_value": epsilon_value, "result": result, @@ -48,9 +47,10 @@ def test_epsilon_status_to_dict(): "obtained_labels": None, } + def test_set_values(complete_verification_data): epsilon_value = 0.5 - result = VerificationResult.UNSAT + result = VerificationResult.UNSAT time_taken = 2.34 epsilon_status = EpsilonStatus(value=epsilon_value, result=result, time=time_taken) @@ -60,3 +60,53 @@ def test_set_values(complete_verification_data): assert epsilon_status.result == complete_verification_data.result assert epsilon_status.time == complete_verification_data.took + +def test_epsilon_status_to_dict_with_numpy_array(): + """Test to_dict() when obtained_labels is a numpy array.""" + epsilon_value = 0.5 + result = VerificationResult.SAT + time_taken = 1.25 + obtained_labels = np.array([1, 2, 3]) + + epsilon_status = EpsilonStatus(value=epsilon_value, result=result, time=time_taken, obtained_labels=obtained_labels) + + result_dict = epsilon_status.to_dict() + + assert result_dict["obtained_labels"] == [1, 2, 3] + assert result_dict["epsilon_value"] == epsilon_value + assert result_dict["result"] == result + assert result_dict["time"] == time_taken + + +def test_epsilon_status_to_dict_with_list(): + """Test to_dict() when obtained_labels is a list.""" + epsilon_value = 0.5 + result = VerificationResult.SAT + time_taken = 1.23 + obtained_labels = [1, 2, 3] + + epsilon_status = EpsilonStatus(value=epsilon_value, result=result, time=time_taken, obtained_labels=obtained_labels) + + result_dict = epsilon_status.to_dict() + + assert result_dict["obtained_labels"] == [1, 2, 3] + assert result_dict["epsilon_value"] == epsilon_value + assert result_dict["result"] == result + assert result_dict["time"] == time_taken + + +def test_epsilon_status_to_dict_with_string(): + """Test to_dict() when obtained_labels is a string.""" + epsilon_value = 0.5 + result = VerificationResult.SAT + time_taken = 1.23 + obtained_labels = "1" + + epsilon_status = EpsilonStatus(value=epsilon_value, result=result, time=time_taken, obtained_labels=obtained_labels) + + result_dict = epsilon_status.to_dict() + + assert result_dict["obtained_labels"] == ["1"] + assert result_dict["epsilon_value"] == epsilon_value + assert result_dict["result"] == result + assert result_dict["time"] == time_taken diff --git a/tests/test_verification_module/test_auto_verify_module.py b/tests/test_verification_module/test_auto_verify_module.py index 7d56ea5..aaebb01 100644 --- a/tests/test_verification_module/test_auto_verify_module.py +++ b/tests/test_verification_module/test_auto_verify_module.py @@ -14,9 +14,11 @@ # ============================================================================== from pathlib import Path +from unittest.mock import MagicMock import numpy as np import pytest +from result import Err, Ok from ada_verona.database.verification_context import VerificationContext from ada_verona.database.verification_result import CompleteVerificationData @@ -36,14 +38,17 @@ def property_generator(request): return request.param + @pytest.fixture def tmp_path(): return Path("/tmp") + @pytest.fixture def verification_context(network, datapoint, tmp_path, property_generator): return VerificationContext(network, datapoint, tmp_path, property_generator) + @pytest.fixture def auto_verify_module_fixture(request, auto_verify_module, auto_verify_module_config): if request.param == "auto_verify_module": @@ -51,6 +56,7 @@ def auto_verify_module_fixture(request, auto_verify_module, auto_verify_module_c elif request.param == "auto_verify_module_config": return auto_verify_module_config + def test_auto_verify_module_initialization(auto_verify_module, verifier): assert auto_verify_module.verifier == verifier assert auto_verify_module.timeout == 60 @@ -58,30 +64,91 @@ def test_auto_verify_module_initialization(auto_verify_module, verifier): @pytest.mark.parametrize( - "auto_verify_module_fixture", - ["auto_verify_module", "auto_verify_module_config"], - indirect=True + "auto_verify_module_fixture", ["auto_verify_module", "auto_verify_module_config"], indirect=True ) def test_auto_verify_module_verify(auto_verify_module_fixture, verification_context): result = auto_verify_module_fixture.verify(verification_context, 0.6) assert isinstance(result, CompleteVerificationData) assert result.result == "SAT" - + result = auto_verify_module_fixture.verify(verification_context, 0.01) assert isinstance(result, CompleteVerificationData) assert result.result == "UNSAT" - + def test_parse_counter_example(result, verification_context): counter_example = parse_counter_example(result, verification_context) - assert isinstance(counter_example, np.ndarray) + assert isinstance(counter_example, np.ndarray) assert counter_example.shape == verification_context.data_point.data.shape def test_parse_counter_example_label(result): label = parse_counter_example_label(result) + assert isinstance(label, int) - assert label == 0 \ No newline at end of file + assert label == 0 + + +def test_auto_verify_module_verify_sat_with_counter_example(auto_verify_module, verification_context, datapoint): + """Test that SAT results with counter_example parse the label correctly.""" + + formatted_strings = [f"(X_{i} {datapoint.data.flatten()[i]:.4f})" for i in range(28 * 28)] + counter_example = "\n".join(formatted_strings) + counter_example += "\n(Y_0 0.1)\n(Y_1 0.9)" + + mock_result = Ok(CompleteVerificationData(result="SAT", counter_example=counter_example, took=10.0)) + auto_verify_module.verifier.verify_property = MagicMock(return_value=mock_result) + + result = auto_verify_module.verify(verification_context, 0.6) + + assert isinstance(result, CompleteVerificationData) + assert result.result == "SAT" + assert result.obtained_labels == ["1"] + + +def test_auto_verify_module_verify_sat_with_counter_example_parse_error(auto_verify_module, verification_context): + """Test that exception during label parsing is handled gracefully.""" + + counter_example = "invalid format that cannot be parsed" + + mock_result = Ok(CompleteVerificationData(result="SAT", counter_example=counter_example, took=10.0)) + auto_verify_module.verifier.verify_property = MagicMock(return_value=mock_result) + + result = auto_verify_module.verify(verification_context, 0.6) + + assert isinstance(result, CompleteVerificationData) + assert result.result == "SAT" + assert result.obtained_labels is None + + +def test_auto_verify_module_verify_error_result(auto_verify_module, verification_context): + """Test that Err results are handled correctly.""" + error_message = "Verification failed with error" + mock_result = Err(error_message) + auto_verify_module.verifier.verify_property = MagicMock(return_value=mock_result) + + result = auto_verify_module.verify(verification_context, 0.6) + + assert result == error_message + + +def test_auto_verify_module_verify_unsat_sets_obtained_labels_none(auto_verify_module, verification_context): + """Test that UNSAT results without obtained_labels attribute get it set to None.""" + class MockOutcome: + def __init__(self): + self.result = "UNSAT" + self.took = 10.0 + self.counter_example = None + + outcome = MockOutcome() + mock_result = Ok(outcome) + auto_verify_module.verifier.verify_property = MagicMock(return_value=mock_result) + + result = auto_verify_module.verify(verification_context, 0.01) + + assert hasattr(result, "obtained_labels") + assert result.obtained_labels is None + assert result.result == "UNSAT"