diff --git a/docs/source/implementation_details.rst b/docs/source/implementation_details.rst index 6306cac..afc9729 100644 --- a/docs/source/implementation_details.rst +++ b/docs/source/implementation_details.rst @@ -95,4 +95,5 @@ checks which, on their hand, might check on yet-to-be-defined instance attribute def some_func(self) -> int: return 1984 + .. _functools.update_wrapper: https://docs.python.org/3/library/functools.html#functools.update_wrapper \ No newline at end of file diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 393874b..d014e4f 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -153,6 +153,9 @@ To save you some typing, we introduced the shortcut, :attr:`InvariantCheckEvent. The property getters and setters are considered "normal" methods. If you want to check the invariants at property getters and/or setters, make sure to include :attr:`InvariantCheckEvent.CALL` in ``check_on``. +In addition, we treat ``__setstate__`` as a constructor. +That is, the invariants are checked *after* the call, but not before and during the call, as unpickling results in intermediate object states which might be invalid. + The following examples show various cases when an invariant is breached. After the initialization: diff --git a/icontract/_checkers.py b/icontract/_checkers.py index 24fa3d4..1f4c994 100644 --- a/icontract/_checkers.py +++ b/icontract/_checkers.py @@ -1203,6 +1203,12 @@ def add_invariant_checks(cls: ClassT) -> None: """Decorate each of the class functions with invariant checks if not already decorated.""" # Candidates for the decoration as list of (name, dir() value) init_func = None # type: Optional[Callable[..., None]] + + # NOTE (mristin): + # We also have to disable the invariant checks *before* and *during* the call to + # __setstate__ function as the invariants can not hold while unpickling the object. + setstate_func = None # type: Optional[Callable[..., None]] + names_funcs = [] # type: List[Tuple[str, Callable[..., None]]] names_properties = [] # type: List[Tuple[str, property]] @@ -1246,6 +1252,14 @@ def add_invariant_checks(cls: ClassT) -> None: init_func = value continue + if name == "__setstate__": + assert inspect.isfunction( + value + ), "Expected __setstate__ to be a function, but got: {}".format(type(value)) + + setstate_func = value + continue + if ( name != "__setattr__" and InvariantCheckEvent.CALL not in last_invariant.check_on @@ -1365,6 +1379,16 @@ def __init__(self: Any, *args: Any, **kwargs: Any) -> None: wrapper = _decorate_with_invariants(func=init_func, cls=cls, is_init=True) setattr(cls, init_func.__name__, wrapper) + if setstate_func is not None: + assert setstate_func.__name__ == "__setstate__" + + # NOTE (mristin): + # We make the decoration of __setstate__ the same as for the init function since + # we want to disable the invariant checks before and during the call, but we + # need to check the invariants *after* the call. + wrapper = _decorate_with_invariants(func=setstate_func, cls=cls, is_init=True) + setattr(cls, setstate_func.__name__, wrapper) + for name, func in names_funcs: wrapper = _decorate_with_invariants(func=func, cls=cls, is_init=False) setattr(cls, name, wrapper) diff --git a/tests/error.py b/tests/error.py index 159c22a..d0f847d 100644 --- a/tests/error.py +++ b/tests/error.py @@ -3,11 +3,12 @@ import re _LOCATION_RE = re.compile( - r"\AFile [^\n]+, line [0-9]+ in [a-zA-Z_0-9]+:\n(.*)\Z", + r"\AFile [^\n]+, line [0-9]+ in ([a-zA-Z_0-9]+|):\n(.*)\Z", flags=re.MULTILINE | re.DOTALL, ) +# pylint: disable=line-too-long def wo_mandatory_location(text: str) -> str: r""" Strip the location of the contract from the text of the error. @@ -19,17 +20,20 @@ def wo_mandatory_location(text: str) -> str: >>> wo_mandatory_location(text='File /some/file.py, line 233 in some_module:\nsome\ntext') 'some\ntext' + >>> wo_mandatory_location(text='File /some/file.py, line 233 in :\nsome\ntext') + 'some\ntext' + >>> wo_mandatory_location(text='a text') Traceback (most recent call last): ... - AssertionError: Expected the text to match \AFile [^\n]+, line [0-9]+ in [a-zA-Z_0-9]+:\n(.*)\Z, but got: 'a text' + AssertionError: Expected the text to match \AFile [^\n]+, line [0-9]+ in ([a-zA-Z_0-9]+|):\n(.*)\Z, but got: 'a text' """ - mtch = _LOCATION_RE.match(text) - if not mtch: + match = _LOCATION_RE.match(text) + if not match: raise AssertionError( "Expected the text to match {}, but got: {!r}".format( _LOCATION_RE.pattern, text ) ) - return mtch.group(1) + return match.group(2) diff --git a/tests/test_invariant.py b/tests/test_invariant.py index e40ecd9..517fb38 100644 --- a/tests/test_invariant.py +++ b/tests/test_invariant.py @@ -1,6 +1,8 @@ # pylint: disable=missing-docstring # pylint: disable=invalid-name # pylint: disable=unused-argument +import io +import pickle import textwrap import time import unittest @@ -18,6 +20,84 @@ import tests.mock +# NOTE (mristin): +# We need to introduce a global class so that we can perform tests with pickling. Pickle module does +# not support local classes. +@icontract.invariant(lambda self: self.x > 0) +class AMadeForPickling: + def __init__(self, x: int) -> None: + self.x = x + + +@icontract.invariant(lambda self: self.x > 0) +@icontract.invariant(lambda self: self.y > 0) +class AMadeForPicklingWithSetState: + def __init__(self, x: int) -> None: + self.x = x + + # NOTE (mristin): + # The attribute ``y`` is computed here, and we will not pickle it intentionally. + self._compute_internal_state() + + def _compute_internal_state(self) -> None: + self.y = self.x + 10 + + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + + # NOTE (mristin): + # We intentionally do not want to pickle ``y``. + state.pop("y", None) + + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + # NOTE (mristin): + # The invariants should not be checked in __setstate__ as the object + # will not be in the correct state. + + self.__dict__.update(state) + + # NOTE (mristin): + # We have to re-compute the internal state. + self._compute_internal_state() + + +@icontract.invariant(lambda self: self.x > 0) +@icontract.invariant(lambda self: self.y > 0) +class AMadeForPicklingWithInvalidSetState: + def __init__(self, x: int) -> None: + self.x = x + + # NOTE (mristin): + # The creation does not violate the invariant. + self.y = 1000 + + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + + # NOTE (mristin): + # We intentionally do not want to pickle ``y``. + state.pop("y", None) + + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + # NOTE (mristin): + # The invariants should not be checked before and during the __setstate__ + # as the object will not be in the correct state, but they should be + # checked after the call. + + self.__dict__.update(state) + + # NOTE (mristin): + # We wrongly re-compute the internal state; this violates the invariant. + self.y = -1000 + + def __repr__(self) -> str: + return "an instance of {}".format(self.__class__.__name__) + + class TestOK(unittest.TestCase): def test_init(self) -> None: @icontract.invariant(lambda self: self.x > 0) @@ -275,6 +355,34 @@ def __str__(self) -> str: # 4 checks after the methods. self.assertEqual(9, counter) + def test_pickle(self) -> None: + a = AMadeForPickling(x=2) + + buffer = io.BytesIO() + # noinspection PyTypeChecker + pickle.dump(a, buffer) + + buffer.seek(0) + _ = pickle.load(buffer) + + # NOTE (mristin): + # No invariant violation expected even though __setstate__ will result in + # a temporarily invalid object state. + + def test_pickle_with_setstate(self) -> None: + a = AMadeForPicklingWithSetState(x=2) + + buffer = io.BytesIO() + # noinspection PyTypeChecker + pickle.dump(a, buffer) + + buffer.seek(0) + _ = pickle.load(buffer) + + # NOTE (mristin): + # No invariant violation expected even though __setstate__ will result in + # a temporarily invalid object state. + class TestViolation(unittest.TestCase): def test_init(self) -> None: @@ -650,6 +758,33 @@ def __repr__(self) -> str: tests.error.wo_mandatory_location(str(violation_error)), ) + def test_pickle_with_invalid_set_state(self) -> None: + a = AMadeForPicklingWithInvalidSetState(x=2) + + buffer = io.BytesIO() + # noinspection PyTypeChecker + pickle.dump(a, buffer) + + buffer.seek(0) + + violation_error = None # type: Optional[icontract.ViolationError] + try: + # NOTE (mristin): + # The __setstate__ is expected to violate the invariant. + _ = pickle.load(buffer) + + except icontract.ViolationError as err: + violation_error = err + + self.assertIsNotNone(violation_error) + self.assertEqual( + """\ +self.y > 0: +self was an instance of AMadeForPicklingWithInvalidSetState +self.y was -1000""", + tests.error.wo_mandatory_location(str(violation_error)), + ) + class TestProperty(unittest.TestCase): def test_property_getter(self) -> None: