diff --git a/mocket/__init__.py b/mocket/__init__.py index 53064434..8c841ec1 100644 --- a/mocket/__init__.py +++ b/mocket/__init__.py @@ -1,20 +1,48 @@ -from mocket.async_mocket import async_mocketize -from mocket.entry import MocketEntry -from mocket.mocket import Mocket -from mocket.mocketizer import Mocketizer, mocketize -from mocket.ssl.context import MocketSSLContext - -# NOTE this is here for backwards-compat to keep old import-paths working -from mocket.ssl.context import MocketSSLContext as FakeSSLContext +from mocket.bytes import ( + MocketBytesEntry, + MocketBytesRequest, + MocketBytesResponse, +) +from mocket.compat import FakeSSLContext, MocketEntry +from mocket.core.async_mocket import async_mocketize +from mocket.core.mocket import Mocket +from mocket.core.mocketizer import Mocketizer, mocketize +from mocket.core.socket import MocketSocket +from mocket.core.ssl.context import MocketSSLContext +from mocket.core.ssl.socket import MocketSSLSocket +from mocket.http import ( + MocketHttpEntry, + MocketHttpMethod, + MocketHttpRequest, + MocketHttpResponse, +) +from mocket.redis import ( + MocketRedisEntry, + MocketRedisRequest, + MocketRedisResponse, +) -__all__ = ( - "async_mocketize", - "mocketize", +__all__ = [ "Mocket", - "MocketEntry", - "Mocketizer", + "MocketBytesEntry", + "MocketBytesRequest", + "MocketBytesResponse", + "MocketHttpEntry", + "MocketHttpMethod", + "MocketHttpRequest", + "MocketHttpResponse", + "MocketRedisEntry", + "MocketRedisRequest", + "MocketRedisResponse", "MocketSSLContext", + "MocketSSLSocket", + "MocketSocket", + "Mocketizer", + "async_mocketize", + "mocketize", + # NOTE this is here for backwards-compat to keep old import-paths working "FakeSSLContext", -) + "MocketEntry", +] __version__ = "3.13.2" diff --git a/mocket/async_mocket.py b/mocket/async_mocket.py index 709d225f..3d996df6 100644 --- a/mocket/async_mocket.py +++ b/mocket/async_mocket.py @@ -1,22 +1,6 @@ -from mocket.mocketizer import Mocketizer -from mocket.utils import get_mocketize +from mocket.core.async_mocket import async_mocketize - -async def wrapper( - test, - truesocket_recording_dir=None, - strict_mode=False, - strict_mode_allowed=None, - *args, - **kwargs, -): - async with Mocketizer.factory( - test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args - ): - return await test(*args, **kwargs) - - -async_mocketize = get_mocketize(wrapper_=wrapper) - - -__all__ = ("async_mocketize",) +# NOTE this is here for backwards-compat to keep old import-paths working +__all__ = [ + "async_mocketize", +] diff --git a/mocket/bytes.py b/mocket/bytes.py new file mode 100644 index 00000000..4ecdb3ca --- /dev/null +++ b/mocket/bytes.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from typing import Sequence + +from typing_extensions import Self + +from mocket.core.entry import MocketBaseEntry, MocketBaseRequest, MocketBaseResponse +from mocket.core.mocket import Mocket +from mocket.core.types import Address + + +class MocketBytesRequest(MocketBaseRequest): + def __init__(self) -> None: + self._data = b"" + + @property + def data(self) -> bytes: + return self._data + + @classmethod + def from_data(cls: type[Self], data: bytes) -> Self: + request = cls() + request._data = data + return request + + +class MocketBytesResponse(MocketBaseResponse): + def __init__(self, data: bytes | str | bool) -> None: + if isinstance(data, str): + data = data.encode() + elif isinstance(data, bool): + data = bytes(data) + self._data = data + + @property + def data(self) -> bytes: + return self._data + + +class MocketBytesEntry(MocketBaseEntry): + request_cls = MocketBytesRequest + response_cls = MocketBytesResponse + + def __init__( + self, + address: Address, + responses: Sequence[MocketBytesResponse | Exception], + ) -> None: + if not len(responses): + responses = [MocketBytesResponse(data=b"")] + + super().__init__( + address=address, + responses=responses, + ) + + @classmethod + def register_response( + cls, + address: Address, + response: MocketBytesResponse | Exception, + ) -> None: + entry = cls( + address=address, + responses=[response], + ) + Mocket.register(entry) + + @classmethod + def register_responses( + cls, + address: Address, + responses: Sequence[MocketBytesResponse | Exception], + ) -> None: + entry = cls( + address=address, + responses=responses, + ) + Mocket.register(entry) diff --git a/mocket/compat/__init__.py b/mocket/compat/__init__.py new file mode 100644 index 00000000..01bdd533 --- /dev/null +++ b/mocket/compat/__init__.py @@ -0,0 +1,8 @@ +from mocket.compat.entry import MocketEntry, Response +from mocket.core.ssl.context import MocketSSLContext as FakeSSLContext + +__all__ = [ + "FakeSSLContext", + "MocketEntry", + "Response", +] diff --git a/mocket/compat/entry.py b/mocket/compat/entry.py new file mode 100644 index 00000000..1ea7cece --- /dev/null +++ b/mocket/compat/entry.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from mocket.bytes import MocketBytesEntry, MocketBytesResponse +from mocket.core.types import Address + + +class Response(MocketBytesResponse): + def __init__(self, data: bytes | str | bool) -> None: + if isinstance(data, str): + data = data.encode() + elif isinstance(data, bool): + data = bytes(data) + self._data = data + + +class MocketEntry(MocketBytesEntry): + def __init__( + self, + location: Address, + responses: list[MocketBytesResponse | Exception | bytes | str | bool] + | MocketBytesResponse + | Exception + | bytes + | str + | bool, + ) -> None: + if not isinstance(responses, list): + responses = [responses] + + _responses = [] + for response in responses: + if not isinstance(response, (MocketBytesResponse, Exception)): + response = MocketBytesResponse(response) + _responses.append(response) + + super().__init__(address=location, responses=_responses) diff --git a/mocket/compat/mockhttp.py b/mocket/compat/mockhttp.py new file mode 100644 index 00000000..36b36d82 --- /dev/null +++ b/mocket/compat/mockhttp.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from io import BufferedReader +from typing import Any + +from mocket.http import ( + MocketHttpEntry, + MocketHttpMethod, + MocketHttpRequest, + MocketHttpResponse, +) +from mocket.mocket import Mocket + + +class Response(MocketHttpResponse): + def __init__( + self, + body: str | bytes | BufferedReader = b"", + status: int = 200, + headers: dict[str, str] | None = None, + ) -> None: + super().__init__( + status_code=status, + headers=headers, + body=body, + ) + + @property + def status(self) -> int: + return self.status_code + + +class Request(MocketHttpRequest): + @property + def body(self) -> str | None: # type: ignore + body = super().body + if body is None: + return None + return body.decode() + + +class Entry(MocketHttpEntry): + request_cls = Request + response_cls = Response # type: ignore[assignment] + + CONNECT = MocketHttpMethod.CONNECT + DELETE = MocketHttpMethod.DELETE + GET = MocketHttpMethod.GET + HEAD = MocketHttpMethod.HEAD + OPTIONS = MocketHttpMethod.OPTIONS + PATCH = MocketHttpMethod.PATCH + POST = MocketHttpMethod.POST + PUT = MocketHttpMethod.PUT + TRACE = MocketHttpMethod.TRACE + + METHODS = list(MocketHttpMethod) + + def __init__( + self, + uri: str, + method: MocketHttpMethod, + responses: list[Response | Exception], + match_querystring: bool = True, + add_trailing_slash: bool = True, + ) -> None: + super().__init__( + method=method, + uri=uri, + responses=responses, + match_querystring=match_querystring, + add_trailing_slash=add_trailing_slash, + ) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"method='{self.method.name}', " + f"schema='{self.schema}', " + f"location={self.address}, " + f"path='{self.path}', " + f"query='{self.query}'" + ")" + ) + + @property + def schema(self) -> str: + return self.scheme + + @classmethod + def register( + cls, + method: MocketHttpMethod, + uri: str, + *responses: Response | Exception, + **config: Any, + ) -> None: + if "body" in config or "status" in config: + raise AttributeError("Did you mean `Entry.single_register(...)`?") + + if isinstance(config, dict): + match_querystring = config.get("match_querystring", True) + add_trailing_slash = config.get("add_trailing_slash", True) + + entry = cls( + method=method, + uri=uri, + responses=list(responses), + match_querystring=match_querystring, + add_trailing_slash=add_trailing_slash, + ) + Mocket.register(entry) + + @classmethod + def single_register( + cls, + method: MocketHttpMethod, + uri: str, + body: str | bytes | BufferedReader = b"", + status: int = 200, + headers: dict[str, str] | None = None, + match_querystring: bool = True, + exception: Exception | None = None, + ) -> None: + response: Response | Exception + if exception is not None: + response = exception + else: + response = Response( + body=body, + status=status, + headers=headers, + ) + + cls.register( + method, + uri, + response, + match_querystring=match_querystring, + ) diff --git a/mocket/compat/mockredis.py b/mocket/compat/mockredis.py new file mode 100644 index 00000000..d7dd364a --- /dev/null +++ b/mocket/compat/mockredis.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from typing import Sequence + +from mocket.core.mocket import Mocket +from mocket.core.types import Address +from mocket.redis import MocketRedisEntry, MocketRedisResponse + +DEFAULT_ADDRESS = ("localhost", 6379) + + +class Entry(MocketRedisEntry): + def __init__( + self, + addr: Address | None, + command: str | bytes, + responses: Sequence[MocketRedisResponse | Exception], + ) -> None: + super().__init__( + address=addr or DEFAULT_ADDRESS, + command=command, + responses=responses, + ) + + @property + def command(self) -> list[bytes]: # type: ignore[override] + return self._command_tokens + + @staticmethod + def _convert_response( + response: str + | bytes + | int + | list[str] + | list[bytes] + | dict[str, str] + | dict[bytes, bytes] + | Exception + | MocketRedisResponse, + ) -> MocketRedisResponse | Exception: + if isinstance(response, (MocketRedisResponse, Exception)): + return response + + return MocketRedisResponse(data=response) + + @classmethod + def register( + cls, + addr: Address | None, + command: str | bytes, + *responses: str + | bytes + | int + | list[str] + | list[bytes] + | dict[str, str] + | dict[bytes, bytes] + | Exception + | MocketRedisResponse, + ) -> None: + cls.register_responses( + command=command, + responses=responses, + addr=addr, + ) + + @classmethod + def register_response( # type: ignore[override] + cls, + command: str | bytes, + response: str + | bytes + | int + | list[str] + | list[bytes] + | dict[str, str] + | dict[bytes, bytes] + | Exception + | MocketRedisResponse, + addr: Address | None = None, + ) -> None: + response = Entry._convert_response(response) + entry = cls( + addr=addr or DEFAULT_ADDRESS, + command=command, + responses=[response], + ) + Mocket.register(entry) + + @classmethod + def register_responses( # type: ignore[override] + cls, + command: str | bytes, + responses: Sequence[ + str + | bytes + | int + | list[str] + | list[bytes] + | dict[str, str] + | dict[bytes, bytes] + | Exception + | MocketRedisResponse + ], + addr: Address | None = None, + ) -> None: + _responses = [] + for response in responses: + response = Entry._convert_response(response) + _responses.append(response) + + entry = cls( + addr=addr or DEFAULT_ADDRESS, + command=command, + responses=_responses, + ) + Mocket.register(entry) diff --git a/mocket/ssl/__init__.py b/mocket/core/__init__.py similarity index 100% rename from mocket/ssl/__init__.py rename to mocket/core/__init__.py diff --git a/mocket/core/async_mocket.py b/mocket/core/async_mocket.py new file mode 100644 index 00000000..8117ad51 --- /dev/null +++ b/mocket/core/async_mocket.py @@ -0,0 +1,19 @@ +from mocket.core.mocketizer import Mocketizer +from mocket.core.utils import get_mocketize + + +async def wrapper( + test, + truesocket_recording_dir=None, + strict_mode=False, + strict_mode_allowed=None, + *args, + **kwargs, +): + async with Mocketizer.factory( + test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args + ): + return await test(*args, **kwargs) + + +async_mocketize = get_mocketize(wrapper_=wrapper) diff --git a/mocket/compat.py b/mocket/core/compat.py similarity index 94% rename from mocket/compat.py rename to mocket/core/compat.py index 1ac2fc89..cb9295a1 100644 --- a/mocket/compat.py +++ b/mocket/core/compat.py @@ -27,7 +27,7 @@ def shsplit(s: str | bytes) -> list[str]: return shlex.split(s) -def do_the_magic(body): +def do_the_magic(body: str | bytes) -> str: try: magic = puremagic.magic_string(body) except puremagic.PureError: diff --git a/mocket/core/entry.py b/mocket/core/entry.py new file mode 100644 index 00000000..cc2c0ff2 --- /dev/null +++ b/mocket/core/entry.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, ClassVar, Sequence + +from typing_extensions import Self + +from mocket.core.mocket import Mocket +from mocket.core.types import Address + + +class MocketBaseRequest(ABC): + def __repr__(self) -> str: + return f"{self.__class__.__name__}(data='{self.data!r}')" + + def __eq__(self, other: Any) -> bool: + if isinstance(other, MocketBaseRequest): + return self.data == other.data + + if isinstance(other, bytes): + return self.data == other + + return False + + @property + @abstractmethod + def data(self) -> bytes: + raise NotImplementedError() + + @classmethod + @abstractmethod + def from_data(cls: type[Self], data: bytes) -> Self: + raise NotImplementedError() + + +class MocketBaseResponse(ABC): + def __repr__(self) -> str: + return f"{self.__class__.__name__}(data='{self.data!r}')" + + @property + @abstractmethod + def data(self) -> bytes: + raise NotImplementedError() + + +class MocketBaseEntry(ABC): + request_cls: ClassVar[type[MocketBaseRequest]] + response_cls: ClassVar[type[MocketBaseResponse]] + + def __init__( + self, + address: Address, + responses: Sequence[MocketBaseResponse | Exception], + ) -> None: + self._address = address + self._responses = responses + self._served_response = False + self._current_response_index = 0 + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(address={self.address})" + + @property + def address(self) -> Address: + return self._address + + @property + def responses(self) -> Sequence[MocketBaseResponse | Exception]: + return self._responses + + @property + def served_response(self) -> bool: + return self._served_response + + def can_handle(self, data: bytes) -> bool: + return True + + def collect(self, data: bytes) -> bool: + request = self.request_cls.from_data(data) + Mocket.collect(request) + return True + + def get_response(self) -> bytes: + response = self._responses[self._current_response_index] + + self._served_response = True + + self._current_response_index = min( + self._current_response_index + 1, + len(self._responses) - 1, + ) + + if isinstance(response, BaseException): + raise response + + return response.data diff --git a/mocket/core/exceptions.py b/mocket/core/exceptions.py new file mode 100644 index 00000000..f5537568 --- /dev/null +++ b/mocket/core/exceptions.py @@ -0,0 +1,6 @@ +class MocketException(Exception): + pass + + +class StrictMocketException(MocketException): + pass diff --git a/mocket/inject.py b/mocket/core/inject.py similarity index 92% rename from mocket/inject.py rename to mocket/core/inject.py index 866ee563..0a9edbbb 100644 --- a/mocket/inject.py +++ b/mocket/core/inject.py @@ -23,7 +23,7 @@ def _restore(module: ModuleType, name: str) -> None: def enable() -> None: - from mocket.socket import ( + from mocket.core.socket import ( MocketSocket, mock_create_connection, mock_getaddrinfo, @@ -32,11 +32,11 @@ def enable() -> None: mock_inet_pton, mock_socketpair, ) - from mocket.ssl.context import MocketSSLContext, mock_wrap_socket - from mocket.urllib3 import ( + from mocket.core.ssl.context import MocketSSLContext, mock_wrap_socket + from mocket.core.urllib3 import ( mock_match_hostname as mock_urllib3_match_hostname, ) - from mocket.urllib3 import ( + from mocket.core.urllib3 import ( mock_ssl_wrap_socket as mock_urllib3_ssl_wrap_socket, ) diff --git a/mocket/core/io.py b/mocket/core/io.py new file mode 100644 index 00000000..f9be52b4 --- /dev/null +++ b/mocket/core/io.py @@ -0,0 +1,22 @@ +import io +import os + +from typing_extensions import Buffer + +from mocket.core.mocket import Mocket +from mocket.core.types import Address + + +class MocketSocketIO(io.BytesIO): + def __init__(self, address: Address) -> None: + self._address = address + super().__init__() + + def write(self, content: Buffer) -> int: + bytes_written = super().write(content) + + _, w_fd = Mocket.get_pair(self._address) + if w_fd: + return os.write(w_fd, content) + + return bytes_written diff --git a/mocket/core/mocket.py b/mocket/core/mocket.py new file mode 100644 index 00000000..7352837b --- /dev/null +++ b/mocket/core/mocket.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +import collections +import itertools +import os +from pathlib import Path +from typing import TYPE_CHECKING, ClassVar + +import mocket.core.inject +from mocket.core.recording import MocketRecordStorage + +if TYPE_CHECKING: + from mocket.core.entry import MocketBaseEntry, MocketBaseRequest + from mocket.core.types import Address + + +class Mocket: + _socket_pairs: ClassVar[dict[Address, tuple[int, int]]] = {} + _address: ClassVar[Address] = (None, None) + _entries: ClassVar[dict[Address, list[MocketBaseEntry]]] = collections.defaultdict( + list + ) + _requests: ClassVar[list[MocketBaseRequest]] = [] + _last_entry: ClassVar[MocketBaseEntry | None] = None + _record_storage: ClassVar[MocketRecordStorage | None] = None + + @classmethod + def enable( + cls, + namespace: str | None = None, + truesocket_recording_dir: str | None = None, + ) -> None: + if namespace is None: + namespace = str(id(cls._entries)) + + if truesocket_recording_dir is not None: + recording_dir = Path(truesocket_recording_dir) + + if not recording_dir.is_dir(): + # JSON dumps will be saved here + raise AssertionError + + cls._record_storage = MocketRecordStorage( + directory=recording_dir, + namespace=namespace, + ) + + mocket.core.inject.enable() + + @classmethod + def disable(cls) -> None: + cls.reset() + + mocket.core.inject.disable() + + @classmethod + def get_pair(cls, address: Address) -> tuple[int, int] | tuple[None, None]: + """ + Given the id() of the caller, return a pair of file descriptors + as a tuple of two integers: (, ) + """ + return cls._socket_pairs.get(address, (None, None)) + + @classmethod + def set_pair(cls, address: Address, pair: tuple[int, int]) -> None: + """ + Store a pair of file descriptors under the key `id_` + as a tuple of two integers: (, ) + """ + cls._socket_pairs[address] = pair + + @classmethod + def register(cls, *entries: MocketBaseEntry) -> None: + for entry in entries: + cls._entries[entry.address].append(entry) + + @classmethod + def get_entry(cls, host: str, port: int, data) -> MocketBaseEntry | None: + host = host or cls._address[0] + port = port or cls._address[1] + entries = cls._entries.get((host, port), []) + for entry in entries: + if entry.can_handle(data): + return entry + return None + + @classmethod + def collect(cls, data) -> None: + cls._requests.append(data) + + @classmethod + def reset(cls) -> None: + for r_fd, w_fd in cls._socket_pairs.values(): + os.close(r_fd) + os.close(w_fd) + cls._socket_pairs = {} + cls._entries = collections.defaultdict(list) + cls._requests = [] + cls._record_storage = None + + @classmethod + def last_request(cls) -> MocketBaseRequest | None: + if cls.has_requests(): + return cls._requests[-1] + return None + + @classmethod + def request_list(cls) -> list[MocketBaseRequest]: + return cls._requests + + @classmethod + def remove_last_request(cls) -> None: + if cls.has_requests(): + del cls._requests[-1] + + @classmethod + def has_requests(cls) -> bool: + return bool(cls.request_list()) + + @classmethod + def get_namespace(cls) -> str | None: + if not cls._record_storage: + return None + return cls._record_storage.namespace + + @classmethod + def get_truesocket_recording_dir(cls) -> str | None: + if not cls._record_storage: + return None + return str(cls._record_storage.directory) + + @classmethod + def assert_fail_if_entries_not_served(cls) -> None: + """Mocket checks that all entries have been served at least once.""" + if not all( + entry.served_response for entry in itertools.chain(*cls._entries.values()) + ): + raise AssertionError("Some Mocket entries have not been served") diff --git a/mocket/core/mocketizer.py b/mocket/core/mocketizer.py new file mode 100644 index 00000000..a70a1d89 --- /dev/null +++ b/mocket/core/mocketizer.py @@ -0,0 +1,95 @@ +from mocket.core.mocket import Mocket +from mocket.core.mode import MocketMode +from mocket.core.utils import get_mocketize + + +class Mocketizer: + def __init__( + self, + instance=None, + namespace=None, + truesocket_recording_dir=None, + strict_mode=False, + strict_mode_allowed=None, + ): + self.instance = instance + self.truesocket_recording_dir = truesocket_recording_dir + self.namespace = namespace or str(id(self)) + MocketMode().STRICT = strict_mode + if strict_mode: + MocketMode().STRICT_ALLOWED = strict_mode_allowed or [] + elif strict_mode_allowed: + raise ValueError( + "Allowed locations are only accepted when STRICT mode is active." + ) + + def enter(self): + Mocket.enable( + namespace=self.namespace, + truesocket_recording_dir=self.truesocket_recording_dir, + ) + if self.instance: + self.check_and_call("mocketize_setup") + + def __enter__(self): + self.enter() + return self + + def exit(self): + if self.instance: + self.check_and_call("mocketize_teardown") + + Mocket.disable() + + def __exit__(self, type, value, tb): + self.exit() + + async def __aenter__(self, *args, **kwargs): + self.enter() + return self + + async def __aexit__(self, *args, **kwargs): + self.exit() + + def check_and_call(self, method_name): + method = getattr(self.instance, method_name, None) + if callable(method): + method() + + @staticmethod + def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args): + instance = args[0] if args else None + namespace = None + if truesocket_recording_dir: + namespace = ".".join( + ( + instance.__class__.__module__, + instance.__class__.__name__, + test.__name__, + ) + ) + + return Mocketizer( + instance, + namespace=namespace, + truesocket_recording_dir=truesocket_recording_dir, + strict_mode=strict_mode, + strict_mode_allowed=strict_mode_allowed, + ) + + +def wrapper( + test, + truesocket_recording_dir=None, + strict_mode=False, + strict_mode_allowed=None, + *args, + **kwargs, +): + with Mocketizer.factory( + test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args + ): + return test(*args, **kwargs) + + +mocketize = get_mocketize(wrapper_=wrapper) diff --git a/mocket/core/mode.py b/mocket/core/mode.py new file mode 100644 index 00000000..a4d6df69 --- /dev/null +++ b/mocket/core/mode.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, ClassVar + +from mocket.core.exceptions import StrictMocketException +from mocket.core.mocket import Mocket + +if TYPE_CHECKING: # pragma: no cover + from typing import NoReturn + + +class MocketMode: + __shared_state: ClassVar[dict[str, Any]] = {} + STRICT: ClassVar = None + STRICT_ALLOWED: ClassVar = None + + def __init__(self) -> None: + self.__dict__ = self.__shared_state + + def is_allowed(self, location: str | tuple[str, int]) -> bool: + """ + Checks if (`host`, `port`) or at least `host` + are allowed locations to perform real `socket` calls + """ + if not self.STRICT: + return True + + host_allowed = False + if isinstance(location, tuple): + host_allowed = location[0] in self.STRICT_ALLOWED + return host_allowed or location in self.STRICT_ALLOWED + + @staticmethod + def raise_not_allowed() -> NoReturn: + current_entries = [ + (location, "\n ".join(map(str, entries))) + for location, entries in Mocket._entries.items() + ] + formatted_entries = "\n".join( + [f" {location}:\n {entries}" for location, entries in current_entries] + ) + raise StrictMocketException( + "Mocket tried to use the real `socket` module while STRICT mode was active.\n" + f"Registered entries:\n{formatted_entries}" + ) diff --git a/mocket/recording.py b/mocket/core/recording.py similarity index 96% rename from mocket/recording.py rename to mocket/core/recording.py index 97d2adbe..6d1b7289 100644 --- a/mocket/recording.py +++ b/mocket/core/recording.py @@ -7,9 +7,9 @@ from dataclasses import dataclass from pathlib import Path -from mocket.compat import decode_from_bytes, encode_to_bytes -from mocket.types import Address -from mocket.utils import hexdump, hexload +from mocket.core.compat import decode_from_bytes, encode_to_bytes +from mocket.core.types import Address +from mocket.core.utils import hexdump, hexload hash_function = hashlib.md5 diff --git a/mocket/socket.py b/mocket/core/socket.py similarity index 95% rename from mocket/socket.py rename to mocket/core/socket.py index 3b1862e2..ee03e573 100644 --- a/mocket/socket.py +++ b/mocket/core/socket.py @@ -10,11 +10,11 @@ from typing_extensions import Self -from mocket.entry import MocketEntry -from mocket.io import MocketSocketIO -from mocket.mocket import Mocket -from mocket.mode import MocketMode -from mocket.types import ( +from mocket.core.entry import MocketBaseEntry +from mocket.core.io import MocketSocketIO +from mocket.core.mocket import Mocket +from mocket.core.mode import MocketMode +from mocket.core.types import ( Address, ReadableBuffer, WriteableBuffer, @@ -167,7 +167,7 @@ def connect(self, address: Address) -> None: def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketIO: return self.io - def get_entry(self, data: bytes) -> MocketEntry | None: + def get_entry(self, data: bytes) -> MocketBaseEntry | None: return Mocket.get_entry(self._host, self._port, data) def sendall(self, data, entry=None, *args, **kwargs): @@ -271,8 +271,8 @@ def send( self.sendall(data, *args, **kwargs) else: req = Mocket.last_request() - if hasattr(req, "add_data"): - req.add_data(data) + if hasattr(req, "_add_data"): + req._add_data(data) self._entry = entry return len(data) diff --git a/mocket/core/ssl/__init__.py b/mocket/core/ssl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mocket/ssl/context.py b/mocket/core/ssl/context.py similarity index 95% rename from mocket/ssl/context.py rename to mocket/core/ssl/context.py index 6d5e7307..161fc574 100644 --- a/mocket/ssl/context.py +++ b/mocket/core/ssl/context.py @@ -2,8 +2,8 @@ from typing import Any -from mocket.socket import MocketSocket -from mocket.ssl.socket import MocketSSLSocket +from mocket.core.socket import MocketSocket +from mocket.core.ssl.socket import MocketSSLSocket class _MocketSSLContext: diff --git a/mocket/ssl/socket.py b/mocket/core/ssl/socket.py similarity index 93% rename from mocket/ssl/socket.py rename to mocket/core/ssl/socket.py index 6dcd7817..aee2be65 100644 --- a/mocket/ssl/socket.py +++ b/mocket/core/ssl/socket.py @@ -5,10 +5,10 @@ from ssl import Options from typing import Any -from mocket.compat import encode_to_bytes -from mocket.mocket import Mocket -from mocket.socket import MocketSocket -from mocket.types import _PeerCertRetDictType +from mocket.core.compat import encode_to_bytes +from mocket.core.mocket import Mocket +from mocket.core.socket import MocketSocket +from mocket.core.types import _PeerCertRetDictType class MocketSSLSocket(MocketSocket): diff --git a/mocket/types.py b/mocket/core/types.py similarity index 100% rename from mocket/types.py rename to mocket/core/types.py diff --git a/mocket/urllib3.py b/mocket/core/urllib3.py similarity index 68% rename from mocket/urllib3.py rename to mocket/core/urllib3.py index e89bc7b5..eebc982e 100644 --- a/mocket/urllib3.py +++ b/mocket/core/urllib3.py @@ -2,9 +2,9 @@ from typing import Any -from mocket.socket import MocketSocket -from mocket.ssl.context import MocketSSLContext -from mocket.ssl.socket import MocketSSLSocket +from mocket.core.socket import MocketSocket +from mocket.core.ssl.context import MocketSSLContext +from mocket.core.ssl.socket import MocketSSLSocket def mock_match_hostname(*args: Any) -> None: diff --git a/mocket/utils.py b/mocket/core/utils.py similarity index 94% rename from mocket/utils.py rename to mocket/core/utils.py index ab293776..1d6d61bd 100644 --- a/mocket/utils.py +++ b/mocket/core/utils.py @@ -3,7 +3,7 @@ import binascii from typing import Callable -from mocket.compat import decode_from_bytes, encode_to_bytes +from mocket.core.compat import decode_from_bytes, encode_to_bytes def hexdump(binary_string: bytes) -> str: diff --git a/mocket/entry.py b/mocket/entry.py deleted file mode 100644 index 9dbbf442..00000000 --- a/mocket/entry.py +++ /dev/null @@ -1,58 +0,0 @@ -import collections.abc - -from mocket.compat import encode_to_bytes -from mocket.mocket import Mocket - - -class MocketEntry: - class Response(bytes): - @property - def data(self): - return self - - response_index = 0 - request_cls = bytes - response_cls = Response - responses = None - _served = None - - def __init__(self, location, responses): - self._served = False - self.location = location - - if not isinstance(responses, collections.abc.Iterable): - responses = [responses] - - if not responses: - self.responses = [self.response_cls(encode_to_bytes(""))] - else: - self.responses = [] - for r in responses: - if not isinstance(r, BaseException) and not getattr(r, "data", False): - if isinstance(r, str): - r = encode_to_bytes(r) - r = self.response_cls(r) - self.responses.append(r) - - def __repr__(self): - return f"{self.__class__.__name__}(location={self.location})" - - @staticmethod - def can_handle(data): - return True - - def collect(self, data): - req = self.request_cls(data) - Mocket.collect(req) - - def get_response(self): - response = self.responses[self.response_index] - if self.response_index < len(self.responses) - 1: - self.response_index += 1 - - self._served = True - - if isinstance(response, BaseException): - raise response - - return response.data diff --git a/mocket/exceptions.py b/mocket/exceptions.py index f5537568..589339ea 100644 --- a/mocket/exceptions.py +++ b/mocket/exceptions.py @@ -1,6 +1,7 @@ -class MocketException(Exception): - pass +from mocket.core.exceptions import MocketException, StrictMocketException - -class StrictMocketException(MocketException): - pass +# NOTE this is here for backwards-compat to keep old import-paths working +__all__ = [ + "MocketException", + "StrictMocketException", +] diff --git a/mocket/http.py b/mocket/http.py new file mode 100644 index 00000000..d823de33 --- /dev/null +++ b/mocket/http.py @@ -0,0 +1,395 @@ +from __future__ import annotations + +import contextlib +import time +from enum import Enum +from http.server import BaseHTTPRequestHandler +from io import BufferedReader +from typing import ClassVar, Sequence +from urllib.parse import parse_qs, unquote, urlsplit + +import h11 +from typing_extensions import Self + +from mocket.core.compat import ENCODING, do_the_magic +from mocket.core.entry import MocketBaseEntry, MocketBaseRequest, MocketBaseResponse +from mocket.core.mocket import Mocket + +STATUS = {k: v[0] for k, v in BaseHTTPRequestHandler.responses.items()} +CRLF = "\r\n" +ASCII = "ascii" + + +class MocketHttpMethod(str, Enum): + CONNECT = "CONNECT" + DELETE = "DELETE" + GET = "GET" + HEAD = "HEAD" + OPTIONS = "OPTIONS" + PATCH = "PATCH" + POST = "POST" + PUT = "PUT" + TRACE = "TRACE" + + +class MocketHttpRequest(MocketBaseRequest): + def __init__(self) -> None: + super().__init__() + + self._parser = h11.Connection(h11.SERVER) + + self._method: MocketHttpMethod | None = None + self._path: str | None = None + self._querystring: dict[str, list[str]] | None = None + self._headers: dict[str, str] | None = None + self._body: bytes | None = None + + self._has_start_line: bool = False + self._has_body: bool = False + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"method='{self.method.name if self.method else None}', " + f"path='{self.path}', " + f"headers={self.headers}" + ")" + ) + + @property + def data(self) -> bytes: + return bytes(self._parser._receive_buffer) + + @property + def has_start_line(self) -> bool: + return self._has_start_line + + @property + def has_body(self) -> bool: + return self._has_body + + @property + def method(self) -> MocketHttpMethod | None: + return self._method + + @property + def path(self) -> str | None: + return self._path + + @property + def querystring(self) -> dict[str, list[str]] | None: + return self._querystring + + @property + def headers(self) -> dict[str, str] | None: + return self._headers + + @property + def body(self) -> bytes | None: + return self._body + + def _add_data(self, data: bytes) -> None: + self._parser.receive_data(data) + while True: + event = self._parser.next_event() + if isinstance(event, h11.Request): + self._set_h11_request(event) + elif isinstance(event, h11.Data): + self._set_h11_data(event) + else: + return + + def _set_h11_request(self, request: h11.Request) -> None: + self._has_start_line = True + self._method = MocketHttpMethod(request.method.decode(ASCII)) + self._path = request.target.decode(ASCII) + self._querystring = self._parse_querystring(self._path) + self._headers = {k.decode(ASCII): v.decode(ASCII) for k, v in request.headers} + + def _set_h11_data(self, data: h11.Data) -> None: + self._has_body = True + self._body = data.data + + @staticmethod + def _parse_querystring(path: str) -> dict[str, list[str]]: + parts = path.split("?", 1) + return ( + parse_qs(unquote(parts[1]), keep_blank_values=True) + if len(parts) == 2 + else {} + ) + + @classmethod + def from_data(cls: type[Self], data: bytes) -> Self: + request = cls() + request._add_data(data) + return request + + +class MocketHttpResponse(MocketBaseResponse): + server: ClassVar[str] = "Python/Mocket" + protocol: ClassVar[str] = "HTTP/1.1" + + def __init__( + self, + status_code: int = 200, + headers: dict[str, str] | None = None, + body: str | bytes | BufferedReader = b"", + ): + body_from_file = False + if isinstance(body, str): + body = body.encode() + elif isinstance(body, BufferedReader): + # File Objects + body = body.read() + body_from_file = True + + self._status_code = status_code + self._body = body + self._headers: dict[str, str] = {} + + base_headers = self._get_base_headers( + status_code=status_code, + body=body, + body_from_file=body_from_file, + ) + + self.set_headers(base_headers) + self.add_headers(headers or {}) + + super().__init__() + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"status_code={self.status_code}, " + f"headers={self.headers}, " + f"body={self.body!r}" + ")" + ) + + @property + def data(self) -> bytes: + return self._get_http_message( + status_code=self._status_code, + headers=self._headers, + body=self._body, + ) + + @property + def status_code(self) -> int: + return self._status_code + + @property + def headers(self) -> dict[str, str]: + return self._headers + + @property + def body(self) -> bytes: + return self._body + + def set_headers(self, headers: dict[str, str]) -> None: + self._headers = {} + self.add_headers(headers) + + def add_headers(self, headers: dict[str, str]) -> None: + for k, v in headers.items(): + formatted_key = self._format_header_key(k) + self._headers[formatted_key] = v + + def set_extra_headers(self, headers: dict[str, str]) -> None: + r""" + >>> from mocket.core.utils import encode_to_bytes + >>> r = MocketHttpResponse(body="") + >>> len(r.headers.keys()) + 6 + >>> r.set_extra_headers({"foo-bar": "Foobar"}) + >>> len(r.headers.keys()) + 7 + >>> encode_to_bytes(r.headers.get("Foo-Bar")) == encode_to_bytes("Foobar") + True + """ + self.add_headers(headers) + + @classmethod + def _get_base_headers( + cls, + status_code: int, + body: bytes, + body_from_file: bool, + ) -> dict[str, str]: + if body_from_file: + content_type = do_the_magic(body) + else: + content_type = f"text/plain; charset={ENCODING}" + + return { + "Status": str(status_code), + "Date": time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime()), + "Server": cls.server, + "Connection": "close", + "Content-Length": str(len(body)), + "Content-Type": content_type, + } + + @classmethod + def _format_header_key(cls, key: str) -> str: + return "-".join(token.capitalize() for token in key.split("-")) + + @classmethod + def _get_http_message( + cls, + status_code: int, + headers: dict[str, str], + body: bytes, + ) -> bytes: + protocol = cls.protocol + status_text = STATUS[status_code] + status_line = f"{protocol} {status_code} {status_text}" + header_lines = [f"{k}: {v}" for k, v in headers.items()] + head_lines = [status_line] + header_lines + [CRLF] + head = CRLF.join(head_lines).encode(ENCODING) + return head + body + + +class MocketHttpEntry(MocketBaseEntry): + request_cls = MocketHttpRequest + response_cls = MocketHttpResponse + + def __init__( + self, + method: MocketHttpMethod, + uri: str, + responses: Sequence[MocketHttpResponse | Exception], + match_querystring: bool = True, + add_trailing_slash: bool = True, + ) -> None: + uri_split = urlsplit(uri) + + host = uri_split.hostname or "" + port = uri_split.port or (443 if uri_split.scheme == "https" else 80) + + responses = responses or [self.response_cls()] + + self._method = method + self._scheme = uri_split.scheme + self._path = uri_split.path or ("/" if add_trailing_slash else "") + # TODO should this be query-string and be parsed as in request? + self._query = uri_split.query + self._match_querystring = match_querystring + self._sent_data = b"" + + super().__init__(address=(host, port), responses=responses) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"address={self.address}, " + f"method='{self.method}', " + f"scheme='{self.scheme}', " + f"path='{self.path}', " + f"query='{self.query}'" + ")" + ) + + @property + def method(self) -> MocketHttpMethod: + return self._method + + @property + def scheme(self) -> str: + return self._scheme + + @property + def path(self) -> str: + return self._path + + @property + def query(self) -> str: + return self._query + + def can_handle(self, data: bytes) -> bool: + request = None + with contextlib.suppress(h11.RemoteProtocolError): + # add a CRLF so that this _could_ be considered a complete http-head + request = self.request_cls.from_data(data + CRLF.encode()) + + if request is None or not request.has_start_line: + return self is getattr(Mocket, "_last_entry", None) + + uri = urlsplit(request.path) + path_match = uri.path == self._path + method_match = request.method == self._method + query_match = True + + if self._match_querystring: + self_querystring = parse_qs(self._query, keep_blank_values=True) + query_match = request.querystring == self_querystring + + can_handle = path_match and method_match and query_match + if can_handle: + Mocket._last_entry = self + return can_handle + + # TODO dunno if i like this method here + def collect(self, data: bytes) -> bool: + consume_response = True + + methods = tuple([n.value.encode() for n in MocketHttpMethod]) + if data.startswith(methods): + self._sent_data = data + else: + Mocket.remove_last_request() + self._sent_data += data + consume_response = False + + request = self.request_cls.from_data(self._sent_data) + Mocket.collect(request) + + return consume_response + + @classmethod + def register_response( + cls, + method: MocketHttpMethod, + uri: str, + body: str | bytes | BufferedReader = b"", + status_code: int = 200, + headers: dict[str, str] | None = None, + match_querystring: bool = True, + exception: Exception | None = None, + ) -> None: + response: MocketHttpResponse | Exception + if exception is not None: + response = exception + else: + response = MocketHttpResponse( + body=body, + status_code=status_code, + headers=headers, + ) + + cls.register_responses( + method=method, + uri=uri, + responses=[response], + match_querystring=match_querystring, + ) + + @classmethod + def register_responses( + cls, + method: MocketHttpMethod, + uri: str, + responses: Sequence[MocketHttpResponse | Exception], + match_querystring: bool = True, + add_trailing_slash: bool = True, + ) -> None: + entry = cls( + method=method, + uri=uri, + responses=responses, + match_querystring=match_querystring, + add_trailing_slash=add_trailing_slash, + ) + Mocket.register(entry) diff --git a/mocket/io.py b/mocket/io.py deleted file mode 100644 index 0334410b..00000000 --- a/mocket/io.py +++ /dev/null @@ -1,17 +0,0 @@ -import io -import os - -from mocket.mocket import Mocket - - -class MocketSocketIO(io.BytesIO): - def __init__(self, address) -> None: - self._address = address - super().__init__() - - def write(self, content): - super().write(content) - - _, w_fd = Mocket.get_pair(self._address) - if w_fd: - os.write(w_fd, content) diff --git a/mocket/mocket.py b/mocket/mocket.py index a01a7b46..8b72b52b 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -1,135 +1,6 @@ -from __future__ import annotations - -import collections -import itertools -import os -from pathlib import Path -from typing import TYPE_CHECKING, ClassVar - -import mocket.inject -from mocket.recording import MocketRecordStorage +from mocket.core.mocket import Mocket # NOTE this is here for backwards-compat to keep old import-paths working -# from mocket.socket import MocketSocket as MocketSocket - -if TYPE_CHECKING: - from mocket.entry import MocketEntry - from mocket.types import Address - - -class Mocket: - _socket_pairs: ClassVar[dict[Address, tuple[int, int]]] = {} - _address: ClassVar[Address] = (None, None) - _entries: ClassVar[dict[Address, list[MocketEntry]]] = collections.defaultdict(list) - _requests: ClassVar[list] = [] - _record_storage: ClassVar[MocketRecordStorage | None] = None - - @classmethod - def enable( - cls, - namespace: str | None = None, - truesocket_recording_dir: str | None = None, - ) -> None: - if namespace is None: - namespace = str(id(cls._entries)) - - if truesocket_recording_dir is not None: - recording_dir = Path(truesocket_recording_dir) - - if not recording_dir.is_dir(): - # JSON dumps will be saved here - raise AssertionError - - cls._record_storage = MocketRecordStorage( - directory=recording_dir, - namespace=namespace, - ) - - mocket.inject.enable() - - @classmethod - def disable(cls) -> None: - cls.reset() - - mocket.inject.disable() - - @classmethod - def get_pair(cls, address: Address) -> tuple[int, int] | tuple[None, None]: - """ - Given the id() of the caller, return a pair of file descriptors - as a tuple of two integers: (, ) - """ - return cls._socket_pairs.get(address, (None, None)) - - @classmethod - def set_pair(cls, address: Address, pair: tuple[int, int]) -> None: - """ - Store a pair of file descriptors under the key `id_` - as a tuple of two integers: (, ) - """ - cls._socket_pairs[address] = pair - - @classmethod - def register(cls, *entries: MocketEntry) -> None: - for entry in entries: - cls._entries[entry.location].append(entry) - - @classmethod - def get_entry(cls, host: str, port: int, data) -> MocketEntry | None: - host = host or cls._address[0] - port = port or cls._address[1] - entries = cls._entries.get((host, port), []) - for entry in entries: - if entry.can_handle(data): - return entry - return None - - @classmethod - def collect(cls, data) -> None: - cls._requests.append(data) - - @classmethod - def reset(cls) -> None: - for r_fd, w_fd in cls._socket_pairs.values(): - os.close(r_fd) - os.close(w_fd) - cls._socket_pairs = {} - cls._entries = collections.defaultdict(list) - cls._requests = [] - cls._record_storage = None - - @classmethod - def last_request(cls): - if cls.has_requests(): - return cls._requests[-1] - - @classmethod - def request_list(cls): - return cls._requests - - @classmethod - def remove_last_request(cls) -> None: - if cls.has_requests(): - del cls._requests[-1] - - @classmethod - def has_requests(cls) -> bool: - return bool(cls.request_list()) - - @classmethod - def get_namespace(cls) -> str | None: - if not cls._record_storage: - return None - return cls._record_storage.namespace - - @classmethod - def get_truesocket_recording_dir(cls) -> str | None: - if not cls._record_storage: - return None - return str(cls._record_storage.directory) - - @classmethod - def assert_fail_if_entries_not_served(cls) -> None: - """Mocket checks that all entries have been served at least once.""" - if not all(entry._served for entry in itertools.chain(*cls._entries.values())): - raise AssertionError("Some Mocket entries have not been served") +__all__ = [ + "Mocket", +] diff --git a/mocket/mocketizer.py b/mocket/mocketizer.py index 2bf2b9cd..3c1fbf5e 100644 --- a/mocket/mocketizer.py +++ b/mocket/mocketizer.py @@ -1,95 +1,7 @@ -from mocket.mocket import Mocket -from mocket.mode import MocketMode -from mocket.utils import get_mocketize +from mocket.core.mocketizer import Mocketizer, mocketize - -class Mocketizer: - def __init__( - self, - instance=None, - namespace=None, - truesocket_recording_dir=None, - strict_mode=False, - strict_mode_allowed=None, - ): - self.instance = instance - self.truesocket_recording_dir = truesocket_recording_dir - self.namespace = namespace or str(id(self)) - MocketMode().STRICT = strict_mode - if strict_mode: - MocketMode().STRICT_ALLOWED = strict_mode_allowed or [] - elif strict_mode_allowed: - raise ValueError( - "Allowed locations are only accepted when STRICT mode is active." - ) - - def enter(self): - Mocket.enable( - namespace=self.namespace, - truesocket_recording_dir=self.truesocket_recording_dir, - ) - if self.instance: - self.check_and_call("mocketize_setup") - - def __enter__(self): - self.enter() - return self - - def exit(self): - if self.instance: - self.check_and_call("mocketize_teardown") - - Mocket.disable() - - def __exit__(self, type, value, tb): - self.exit() - - async def __aenter__(self, *args, **kwargs): - self.enter() - return self - - async def __aexit__(self, *args, **kwargs): - self.exit() - - def check_and_call(self, method_name): - method = getattr(self.instance, method_name, None) - if callable(method): - method() - - @staticmethod - def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args): - instance = args[0] if args else None - namespace = None - if truesocket_recording_dir: - namespace = ".".join( - ( - instance.__class__.__module__, - instance.__class__.__name__, - test.__name__, - ) - ) - - return Mocketizer( - instance, - namespace=namespace, - truesocket_recording_dir=truesocket_recording_dir, - strict_mode=strict_mode, - strict_mode_allowed=strict_mode_allowed, - ) - - -def wrapper( - test, - truesocket_recording_dir=None, - strict_mode=False, - strict_mode_allowed=None, - *args, - **kwargs, -): - with Mocketizer.factory( - test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args - ): - return test(*args, **kwargs) - - -mocketize = get_mocketize(wrapper_=wrapper) +# NOTE this is here for backwards-compat to keep old import-paths working +__all__ = [ + "Mocketizer", + "mocketize", +] diff --git a/mocket/mockhttp.py b/mocket/mockhttp.py index 245a11af..b40721fc 100644 --- a/mocket/mockhttp.py +++ b/mocket/mockhttp.py @@ -1,263 +1,8 @@ -import re -import time -from functools import cached_property -from http.server import BaseHTTPRequestHandler -from urllib.parse import parse_qs, unquote, urlsplit - -from h11 import SERVER, Connection, Data -from h11 import Request as H11Request - -from mocket.compat import ENCODING, decode_from_bytes, do_the_magic, encode_to_bytes -from mocket.entry import MocketEntry -from mocket.mocket import Mocket - -STATUS = {k: v[0] for k, v in BaseHTTPRequestHandler.responses.items()} -CRLF = "\r\n" -ASCII = "ascii" - - -class Request: - _parser = None - _event = None - - def __init__(self, data): - self._parser = Connection(SERVER) - self.add_data(data) - - def add_data(self, data): - self._parser.receive_data(data) - - @property - def event(self): - if not self._event: - self._event = self._parser.next_event() - return self._event - - @cached_property - def method(self): - return self.event.method.decode(ASCII) - - @cached_property - def path(self): - return self.event.target.decode(ASCII) - - @cached_property - def headers(self): - return {k.decode(ASCII): v.decode(ASCII) for k, v in self.event.headers} - - @cached_property - def querystring(self): - parts = self.path.split("?", 1) - return ( - parse_qs(unquote(parts[1]), keep_blank_values=True) - if len(parts) == 2 - else {} - ) - - @cached_property - def body(self): - while True: - event = self._parser.next_event() - if isinstance(event, H11Request): - self._event = event - elif isinstance(event, Data): - return event.data.decode(ENCODING) - - def __str__(self): - return f"{self.method} - {self.path} - {self.headers}" - - -class Response: - headers = None - is_file_object = False - - def __init__(self, body="", status=200, headers=None): - headers = headers or {} - try: - # File Objects - self.body = body.read() - self.is_file_object = True - except AttributeError: - self.body = encode_to_bytes(body) - self.status = status - - self.set_base_headers() - - if headers is not None: - self.set_extra_headers(headers) - - self.data = self.get_protocol_data() + self.body - - def get_protocol_data(self, str_format_fun_name="capitalize"): - status_line = f"HTTP/1.1 {self.status} {STATUS[self.status]}" - header_lines = CRLF.join( - ( - f"{getattr(k, str_format_fun_name)()}: {v}" - for k, v in self.headers.items() - ) - ) - return f"{status_line}\r\n{header_lines}\r\n\r\n".encode(ENCODING) - - def set_base_headers(self): - self.headers = { - "Status": str(self.status), - "Date": time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime()), - "Server": "Python/Mocket", - "Connection": "close", - "Content-Length": str(len(self.body)), - } - if not self.is_file_object: - self.headers["Content-Type"] = f"text/plain; charset={ENCODING}" - else: - self.headers["Content-Type"] = do_the_magic(self.body) - - def set_extra_headers(self, headers): - r""" - >>> r = Response(body="") - >>> len(r.headers.keys()) - 6 - >>> r.set_extra_headers({"foo-bar": "Foobar"}) - >>> len(r.headers.keys()) - 7 - >>> encode_to_bytes(r.headers.get("Foo-Bar")) == encode_to_bytes("Foobar") - True - """ - for k, v in headers.items(): - self.headers["-".join(token.capitalize() for token in k.split("-"))] = v - - -class Entry(MocketEntry): - CONNECT = "CONNECT" - DELETE = "DELETE" - GET = "GET" - HEAD = "HEAD" - OPTIONS = "OPTIONS" - PATCH = "PATCH" - POST = "POST" - PUT = "PUT" - TRACE = "TRACE" - - METHODS = (CONNECT, DELETE, GET, HEAD, OPTIONS, PATCH, POST, PUT, TRACE) - - request_cls = Request - response_cls = Response - - def __init__(self, uri, method, responses, match_querystring=True): - uri = urlsplit(uri) - - port = uri.port - if not port: - port = 443 if uri.scheme == "https" else 80 - - super().__init__((uri.hostname, port), responses) - self.schema = uri.scheme - self.path = uri.path - self.query = uri.query - self.method = method.upper() - self._sent_data = b"" - self._match_querystring = match_querystring - - def __repr__(self): - return f"{self.__class__.__name__}(method={self.method!r}, schema={self.schema!r}, location={self.location!r}, path={self.path!r}, query={self.query!r})" - - def collect(self, data): - consume_response = True - - decoded_data = decode_from_bytes(data) - if not decoded_data.startswith(Entry.METHODS): - Mocket.remove_last_request() - self._sent_data += data - consume_response = False - else: - self._sent_data = data - - super().collect(self._sent_data) - - return consume_response - - def can_handle(self, data): - r""" - >>> e = Entry('http://www.github.com/?bar=foo&foobar', Entry.GET, (Response(b''),)) - >>> e.can_handle(b'GET /?bar=foo HTTP/1.1\r\nHost: github.com\r\nAccept-Encoding: gzip, deflate\r\nConnection: keep-alive\r\nUser-Agent: python-requests/2.7.0 CPython/3.4.3 Linux/3.19.0-16-generic\r\nAccept: */*\r\n\r\n') - False - >>> e = Entry('http://www.github.com/?bar=foo&foobar', Entry.GET, (Response(b''),)) - >>> e.can_handle(b'GET /?bar=foo&foobar HTTP/1.1\r\nHost: github.com\r\nAccept-Encoding: gzip, deflate\r\nConnection: keep-alive\r\nUser-Agent: python-requests/2.7.0 CPython/3.4.3 Linux/3.19.0-16-generic\r\nAccept: */*\r\n\r\n') - True - """ - try: - requestline, _ = decode_from_bytes(data).split(CRLF, 1) - method, path, _ = self._parse_requestline(requestline) - except ValueError: - return self is getattr(Mocket, "_last_entry", None) - - uri = urlsplit(path) - can_handle = uri.path == self.path and method == self.method - if self._match_querystring: - kw = dict(keep_blank_values=True) - can_handle = can_handle and parse_qs(uri.query, **kw) == parse_qs( - self.query, **kw - ) - if can_handle: - Mocket._last_entry = self - return can_handle - - @staticmethod - def _parse_requestline(line): - """ - http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5 - - >>> Entry._parse_requestline('GET / HTTP/1.0') == ('GET', '/', '1.0') - True - >>> Entry._parse_requestline('post /testurl htTP/1.1') == ('POST', '/testurl', '1.1') - True - >>> Entry._parse_requestline('Im not a RequestLine') - Traceback (most recent call last): - ... - ValueError: Not a Request-Line - """ - m = re.match( - r"({})\s+(.*)\s+HTTP/(1.[0|1])".format("|".join(Entry.METHODS)), line, re.I - ) - if m: - return m.group(1).upper(), m.group(2), m.group(3) - raise ValueError("Not a Request-Line") - - @classmethod - def register(cls, method, uri, *responses, **config): - if "body" in config or "status" in config: - raise AttributeError("Did you mean `Entry.single_register(...)`?") - - default_config = dict(match_querystring=True, add_trailing_slash=True) - default_config.update(config) - config = default_config - - if config["add_trailing_slash"] and not urlsplit(uri).path: - uri += "/" - - Mocket.register( - cls(uri, method, responses, match_querystring=config["match_querystring"]) - ) - - @classmethod - def single_register( - cls, - method, - uri, - body="", - status=200, - headers=None, - match_querystring=True, - exception=None, - ): - response = ( - exception - if exception - else cls.response_cls(body=body, status=status, headers=headers) - ) - - cls.register( - method, - uri, - response, - match_querystring=match_querystring, - ) +from mocket.compat.mockhttp import Entry, Request, Response + +# NOTE this is here for backwards-compat to keep old import-paths working +__all__ = [ + "Entry", + "Request", + "Response", +] diff --git a/mocket/mockredis.py b/mocket/mockredis.py index fc386e2d..6c71f116 100644 --- a/mocket/mockredis.py +++ b/mocket/mockredis.py @@ -1,91 +1,15 @@ -from itertools import chain - -from mocket.compat import ( - decode_from_bytes, - encode_to_bytes, - shsplit, -) -from mocket.entry import MocketEntry -from mocket.mocket import Mocket - - -class Request: - def __init__(self, data): - self.data = data - - -class Response: - def __init__(self, data=None): - self.data = Redisizer.redisize(data or OK) - - -class Redisizer(bytes): - @staticmethod - def tokens(iterable): - iterable = [encode_to_bytes(x) for x in iterable] - return [f"*{len(iterable)}".encode()] + list( - chain(*zip([f"${len(x)}".encode() for x in iterable], iterable)) - ) - - @staticmethod - def redisize(data): - def get_conversion(t): - return { - dict: lambda x: b"\r\n".join( - Redisizer.tokens(list(chain(*tuple(x.items())))) - ), - int: lambda x: f":{x}".encode(), - str: lambda x: "${}\r\n{}".format(len(x.encode("utf-8")), x).encode( - "utf-8" - ), - list: lambda x: b"\r\n".join(Redisizer.tokens(x)), - }[t] - - if isinstance(data, Redisizer): - return data - if isinstance(data, bytes): - data = decode_from_bytes(data) - return Redisizer(get_conversion(data.__class__)(data) + b"\r\n") - - @staticmethod - def command(description, _type="+"): - return Redisizer("{}{}{}".format(_type, description, "\r\n").encode("utf-8")) - - @staticmethod - def error(description): - return Redisizer.command(description, _type="-") - - -OK = Redisizer.command("OK") -QUEUED = Redisizer.command("QUEUED") -ERROR = Redisizer.error - - -class Entry(MocketEntry): - request_cls = Request - response_cls = Response - - def __init__(self, addr, command, responses): - super().__init__(addr or ("localhost", 6379), responses) - d = shsplit(command) - d[0] = d[0].upper() - self.command = Redisizer.tokens(d) - - def can_handle(self, data): - return data.splitlines() == self.command - - @classmethod - def register(cls, addr, command, *responses): - responses = [ - r if isinstance(r, BaseException) else cls.response_cls(r) - for r in responses - ] - Mocket.register(cls(addr, command, responses)) - - @classmethod - def register_response(cls, command, response, addr=None): - cls.register(addr, command, response) - - @classmethod - def register_responses(cls, command, responses, addr=None): - cls.register(addr, command, *responses) +from mocket.compat.mockredis import Entry +from mocket.redis import ERROR, OK, QUEUED, Redisizer +from mocket.redis import MocketRedisRequest as Request +from mocket.redis import MocketRedisResponse as Response + +# NOTE this is here for backwards-compat to keep old import-paths working +__all__ = [ + "ERROR", + "Entry", + "OK", + "QUEUED", + "Redisizer", + "Request", + "Response", +] diff --git a/mocket/mode.py b/mocket/mode.py index e1da7955..c609023a 100644 --- a/mocket/mode.py +++ b/mocket/mode.py @@ -1,45 +1,6 @@ -from __future__ import annotations +from mocket.core.mode import MocketMode -from typing import TYPE_CHECKING, Any, ClassVar - -from mocket.exceptions import StrictMocketException -from mocket.mocket import Mocket - -if TYPE_CHECKING: # pragma: no cover - from typing import NoReturn - - -class MocketMode: - __shared_state: ClassVar[dict[str, Any]] = {} - STRICT: ClassVar = None - STRICT_ALLOWED: ClassVar = None - - def __init__(self) -> None: - self.__dict__ = self.__shared_state - - def is_allowed(self, location: str | tuple[str, int]) -> bool: - """ - Checks if (`host`, `port`) or at least `host` - are allowed locations to perform real `socket` calls - """ - if not self.STRICT: - return True - - host_allowed = False - if isinstance(location, tuple): - host_allowed = location[0] in self.STRICT_ALLOWED - return host_allowed or location in self.STRICT_ALLOWED - - @staticmethod - def raise_not_allowed() -> NoReturn: - current_entries = [ - (location, "\n ".join(map(str, entries))) - for location, entries in Mocket._entries.items() - ] - formatted_entries = "\n".join( - [f" {location}:\n {entries}" for location, entries in current_entries] - ) - raise StrictMocketException( - "Mocket tried to use the real `socket` module while STRICT mode was active.\n" - f"Registered entries:\n{formatted_entries}" - ) +# NOTE this is here for backwards-compat to keep old import-paths working +__all__ = [ + "MocketMode", +] diff --git a/mocket/plugins/httpretty.py b/mocket/plugins/httpretty.py new file mode 100644 index 00000000..90201c94 --- /dev/null +++ b/mocket/plugins/httpretty.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +from mocket.core.async_mocket import async_mocketize +from mocket.core.mocket import Mocket +from mocket.core.mocketizer import mocketize +from mocket.http import ( + MocketHttpEntry, + MocketHttpMethod, + MocketHttpRequest, + MocketHttpResponse, +) + + +class MocketHttprettyResponse(MocketHttpResponse): + server = "Python/HTTPretty" + + def __init__( + self, + body: str | bytes = "", + status: int = 200, + headers: dict[str, str] | None = None, + ) -> None: + super().__init__( + status_code=status, + headers=headers, + body=body, + ) + + @property + def status(self) -> int: + return self.status_code + + @classmethod + def _format_header_key(cls, key: str) -> str: + return key.lower().replace("_", "-") + + +class MocketHttprettyEntry(MocketHttpEntry): + response_cls = MocketHttprettyResponse + + +class MocketHTTPretty: + Response = MocketHttprettyResponse + + CONNECT = MocketHttpMethod.CONNECT + DELETE = MocketHttpMethod.DELETE + GET = MocketHttpMethod.GET + HEAD = MocketHttpMethod.HEAD + OPTIONS = MocketHttpMethod.OPTIONS + PATCH = MocketHttpMethod.PATCH + POST = MocketHttpMethod.POST + PUT = MocketHttpMethod.PUT + TRACE = MocketHttpMethod.TRACE + + @property + def latest_requests(self) -> list[MocketHttpRequest]: + return Mocket.request_list() + + @property + def last_request(self) -> MocketHttpRequest: + return Mocket.last_request() + + def register_uri( + self, + method: MocketHttpMethod, + uri: str, + body: str | bytes = "HTTPretty :)", + adding_headers: dict[str, str] | None = None, + forcing_headers: dict[str, str] | None = None, + status: int = 200, + responses: list[MocketHttpResponse] | None = None, + match_querystring: bool = False, + priority: int = 0, + **headers: str, + ) -> None: + if adding_headers is not None: + headers.update(adding_headers) + + if responses is None: + response = MocketHttprettyResponse( + body=body, + status=status, + headers=headers, + ) + responses = [response] + + if forcing_headers is not None: + for r in responses: + r.set_headers(forcing_headers) + + MocketHttpEntry.register_responses( + method=method, + uri=uri, + responses=responses, + match_querystring=match_querystring, + ) + + +HTTPretty = MocketHTTPretty() +httpretty = HTTPretty + +Response = HTTPretty.Response + +CONNECT = HTTPretty.CONNECT +DELETE = HTTPretty.DELETE +GET = HTTPretty.GET +HEAD = HTTPretty.HEAD +OPTIONS = HTTPretty.OPTIONS +PATCH = HTTPretty.PATCH +POST = HTTPretty.POST +PUT = HTTPretty.PUT +TRACE = HTTPretty.TRACE + +activate = mocketize +httprettified = mocketize +async_httprettified = async_mocketize +register_uri = HTTPretty.register_uri + +enable = Mocket.enable +disable = Mocket.disable +reset = Mocket.reset + + +__all__ = [ + "HTTPretty", + "httpretty", + "activate", + "httprettified", + "async_httprettified", + "register_uri", + "enable", + "disable", + "reset", + "CONNECT", + "DELETE", + "GET", + "HEAD", + "OPTIONS", + "PATCH", + "POST", + "PUT", + "TRACE", + "Response", +] diff --git a/mocket/plugins/httpretty/__init__.py b/mocket/plugins/httpretty/__init__.py deleted file mode 100644 index fac61840..00000000 --- a/mocket/plugins/httpretty/__init__.py +++ /dev/null @@ -1,135 +0,0 @@ -from mocket import mocketize -from mocket.async_mocket import async_mocketize -from mocket.compat import ENCODING -from mocket.mocket import Mocket -from mocket.mockhttp import Entry as MocketHttpEntry -from mocket.mockhttp import Request as MocketHttpRequest -from mocket.mockhttp import Response as MocketHttpResponse - - -def httprettifier_headers(headers): - return {k.lower().replace("_", "-"): v for k, v in headers.items()} - - -class Request(MocketHttpRequest): - @property - def body(self): - return super().body.encode(ENCODING) - - @property - def headers(self): - return httprettifier_headers(super().headers) - - -class Response(MocketHttpResponse): - def get_protocol_data(self, str_format_fun_name="lower"): - if "server" in self.headers and self.headers["server"] == "Python/Mocket": - self.headers["server"] = "Python/HTTPretty" - return super().get_protocol_data(str_format_fun_name=str_format_fun_name) - - def set_base_headers(self): - super().set_base_headers() - self.headers = httprettifier_headers(self.headers) - - original_set_base_headers = set_base_headers - - def set_extra_headers(self, headers): - self.headers.update(headers) - - -class Entry(MocketHttpEntry): - request_cls = Request - response_cls = Response - - -activate = mocketize -httprettified = mocketize -async_httprettified = async_mocketize - -enable = Mocket.enable -disable = Mocket.disable -reset = Mocket.reset - -GET = Entry.GET -PUT = Entry.PUT -POST = Entry.POST -DELETE = Entry.DELETE -HEAD = Entry.HEAD -PATCH = Entry.PATCH -OPTIONS = Entry.OPTIONS - - -def register_uri( - method, - uri, - body="HTTPretty :)", - adding_headers=None, - forcing_headers=None, - status=200, - responses=None, - match_querystring=False, - priority=0, - **headers, -): - headers = httprettifier_headers(headers) - - if adding_headers is not None: - headers.update(httprettifier_headers(adding_headers)) - - if forcing_headers is not None: - - def force_headers(self): - self.headers = httprettifier_headers(forcing_headers) - - Response.set_base_headers = force_headers - else: - Response.set_base_headers = Response.original_set_base_headers - - if responses: - Entry.register(method, uri, *responses) - else: - Entry.single_register( - method, - uri, - body=body, - status=status, - headers=headers, - match_querystring=match_querystring, - ) - - -class MocketHTTPretty: - Response = Response - - def __getattr__(self, name): - if name == "last_request": - return Mocket.last_request() - if name == "latest_requests": - return Mocket.request_list() - return getattr(Entry, name) - - -HTTPretty = MocketHTTPretty() -HTTPretty.register_uri = register_uri -httpretty = HTTPretty - -__all__ = ( - "HTTPretty", - "httpretty", - "activate", - "async_httprettified", - "httprettified", - "enable", - "disable", - "reset", - "Response", - "GET", - "PUT", - "POST", - "DELETE", - "HEAD", - "PATCH", - "register_uri", - "str", - "bytes", -) diff --git a/mocket/plugins/pook_mock_engine.py b/mocket/plugins/pook_mock_engine.py index 549f5509..d519fe35 100644 --- a/mocket/plugins/pook_mock_engine.py +++ b/mocket/plugins/pook_mock_engine.py @@ -1,83 +1,105 @@ +from __future__ import annotations + +from typing import Any, Sequence + +from mocket.core.mocket import Mocket +from mocket.http import MocketHttpEntry, MocketHttpMethod, MocketHttpResponse + try: - from pook.engine import MockEngine + from pook import Engine as PookEngine + from pook import Mock as PookMock + from pook import MockEngine as PookMockEngine + from pook import Request as PookRequest + from pook.interceptors.base import BaseInterceptor as PookBaseInterceptor except ModuleNotFoundError: - MockEngine = object + PookEngine = object + PookMock = object + PookMockEngine = object + PookRequest = object + PookBaseInterceptor = object -from mocket.mocket import Mocket -from mocket.mockhttp import Entry, Response - -class MocketPookEntry(Entry): +class MocketPookEntry(MocketHttpEntry): pook_request = None pook_engine = None - def can_handle(self, data): - can_handle = super().can_handle(data) - - if can_handle: - self.pook_engine.match(self.pook_request) - return can_handle - - @classmethod - def single_register( - cls, - method, - uri, - body="", - status=200, - headers=None, - match_querystring=True, - exception=None, - ): - entry = cls( - uri, - method, - [Response(body=body, status=status, headers=headers)], + def __init__( + self, + method: MocketHttpMethod, + uri: str, + responses: Sequence[MocketHttpResponse | Exception], + pook_engine: PookEngine, + pook_request: PookRequest, + match_querystring: bool = True, + add_trailing_slash: bool = True, + ) -> None: + super().__init__( + method=method, + uri=uri, + responses=responses, match_querystring=match_querystring, + add_trailing_slash=add_trailing_slash, ) - Mocket.register(entry) - return entry + self._pook_engine = pook_engine + self._pook_request = pook_request + def can_handle(self, data: bytes) -> bool: + can_handle = super().can_handle(data) -class MocketEngine(MockEngine): - def __init__(self, engine): - def mocket_mock_fun(*args, **kwargs): - mock = self.pook_mock_fun(*args, **kwargs) - - request = mock._request - method = request.method - url = request.rawurl - - response = mock._response - body = response._body - status = response._status - headers = response._headers - - entry = MocketPookEntry.single_register(method, url, body, status, headers) - entry.pook_engine = self.engine - entry.pook_request = request + if can_handle: + self._pook_engine.match(self._pook_request) + return can_handle - return mock - from pook.interceptors.base import BaseInterceptor +class MocketInterceptor(PookBaseInterceptor): # type: ignore[misc] + @staticmethod + def activate() -> None: + Mocket.disable() + Mocket.enable() - class MocketInterceptor(BaseInterceptor): - @staticmethod - def activate(): - Mocket.disable() - Mocket.enable() + @staticmethod + def disable() -> None: + Mocket.disable() - @staticmethod - def disable(): - Mocket.disable() +class MocketEngine(PookMockEngine): # type: ignore[misc] + def __init__(self, engine: PookEngine) -> None: # Store plugins engine self.engine = engine # Store HTTP client interceptors - self.interceptors = [] + self.interceptors: list[PookBaseInterceptor] = [] # Self-register MocketInterceptor self.add_interceptor(MocketInterceptor) # mocking pook.mock() self.pook_mock_fun = self.engine.mock - self.engine.mock = mocket_mock_fun + self.engine.mock = self.mocket_mock_fun + + def mocket_mock_fun(self, *args: Any, **kwargs: Any) -> PookMock: + mock = self.pook_mock_fun(*args, **kwargs) + + request = mock._request + method = request.method + url = request.rawurl + + response = mock._response + body = response._body + status = response._status + headers = response._headers + + entry = MocketPookEntry( + method=method, + uri=url, + responses=[ + MocketHttpResponse( + status_code=status, + headers=headers, + body=body, + ) + ], + pook_engine=self.engine, + pook_request=request, + ) + Mocket.register(entry) + + return mock diff --git a/mocket/redis.py b/mocket/redis.py new file mode 100644 index 00000000..48b3c864 --- /dev/null +++ b/mocket/redis.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +from itertools import chain +from typing import Sequence + +from mocket.bytes import MocketBytesRequest, MocketBytesResponse +from mocket.core.compat import encode_to_bytes, shsplit +from mocket.core.entry import MocketBaseEntry +from mocket.core.mocket import Mocket +from mocket.core.types import Address + +CRLF = "\r\n" + + +class MocketRedisCommand(bytes): ... + + +class Redisizer(bytes): + @staticmethod + def tokens(iterable: Sequence[str | bytes]) -> list[bytes]: + _iterable = [encode_to_bytes(x) for x in iterable] + return [f"*{len(iterable)}".encode()] + list( + chain(*zip([f"${len(x)}".encode() for x in _iterable], _iterable)) + ) + + @staticmethod + def redisize( + data: str + | bytes + | int + | list[str] + | list[bytes] + | dict[str, str] + | dict[bytes, bytes] + | MocketRedisCommand, + ) -> bytes: + if isinstance(data, MocketRedisCommand): + return data + + if isinstance(data, bytes): + data = data.decode() + + if isinstance(data, str): + data_len = len(data.encode()) + data = f"${data_len}{CRLF}{data}".encode() + + elif isinstance(data, int): + data = f":{data}".encode() + + elif isinstance(data, list): + tokens = Redisizer.tokens(data) + data = CRLF.encode().join(tokens) + + elif isinstance(data, dict): + tokens = Redisizer.tokens(list(chain(*tuple(data.items())))) # type: ignore[arg-type] + data = CRLF.encode().join(tokens) + + return data + CRLF.encode() + + @staticmethod + def command(description: str, _type: str = "+") -> MocketRedisCommand: + return MocketRedisCommand(f"{_type}{description}{CRLF}".encode()) + + @staticmethod + def error(description: str) -> MocketRedisCommand: + return Redisizer.command(description, _type="-") + + +OK = Redisizer.command("OK") +QUEUED = Redisizer.command("QUEUED") +ERROR = Redisizer.error + + +class MocketRedisRequest(MocketBytesRequest): ... + + +class MocketRedisResponse(MocketBytesResponse): + def __init__( + self, + data: str + | bytes + | int + | list[str] + | list[bytes] + | dict[str, str] + | dict[bytes, bytes] + | MocketRedisCommand = OK, + ) -> None: + data = Redisizer.redisize(data) + super().__init__(data=data) + + +class MocketRedisEntry(MocketBaseEntry): + request_cls = MocketRedisRequest + response_cls = MocketRedisResponse + + def __init__( + self, + address: Address, + command: str | bytes, + responses: Sequence[MocketRedisResponse | Exception], + ) -> None: + self._command = command + self._command_tokens = MocketRedisEntry._tokenize_command(command) + + super().__init__(address=address, responses=responses) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"address={self.address}, " + f"command='{self.command!r}" + ")" + ) + + # TODO should this always be str? + @property + def command(self) -> str | bytes: + return self._command + + def can_handle(self, data: bytes) -> bool: + return data.splitlines() == self._command_tokens + + @staticmethod + def _tokenize_command(command: str | bytes) -> list[bytes]: + parts = shsplit(command) + parts[0] = parts[0].upper() + return Redisizer.tokens(parts) + + @classmethod + def register_response( + cls, + address: Address, + command: str | bytes, + response: MocketRedisResponse | Exception, + ) -> None: + entry = cls( + address=address, + command=command, + responses=[response], + ) + Mocket.register(entry) + + @classmethod + def register_responses( + cls, + address: Address, + command: str | bytes, + responses: Sequence[MocketRedisResponse | Exception], + ) -> None: + entry = cls( + address=address, + command=command, + responses=responses, + ) + Mocket.register(entry) diff --git a/pyproject.toml b/pyproject.toml index 77d1f5d4..f167bf54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,9 +118,9 @@ max-complexity = 8 [tool.mypy] python_version = "3.8" files = [ - "mocket/exceptions.py", - "mocket/compat.py", - "mocket/utils.py", + "mocket/core/compat.py", + "mocket/core/exceptions.py", + "mocket/core/utils.py", # "tests/" ] strict = true diff --git a/tests/test_bytes.py b/tests/test_bytes.py new file mode 100644 index 00000000..6ff26ed8 --- /dev/null +++ b/tests/test_bytes.py @@ -0,0 +1,75 @@ +import socket + +from mocket import ( + Mocket, + MocketBytesEntry, + MocketBytesRequest, + MocketBytesResponse, + mocketize, +) + + +@mocketize +def test_bytes_register_response() -> None: + # arrange + address = ("example.com", 5000) + + MocketBytesEntry.register_response( + address=address, + response=MocketBytesResponse(b"test-response"), + ) + + # act + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect(address) + sock.sendall(b"test-request") + response_data = sock.recv(4096) + sock.close() + + # assert + assert response_data == b"test-response" + + requests = Mocket.request_list() + assert len(requests) == 1 + assert type(requests[0]) is MocketBytesRequest + assert requests[0].data == b"test-request" + + +@mocketize +def test_bytes_register_responses() -> None: + # arrange + address = ("example.com", 5000) + + MocketBytesEntry.register_responses( + address=address, + responses=[ + MocketBytesResponse(b"test-response-1"), + MocketBytesResponse(b"test-response-2"), + MocketBytesResponse(b"test-response-3"), + ], + ) + + # act + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect(address) + sock.sendall(b"test-request-1") + response_data_1 = sock.recv(4096) + sock.sendall(b"test-request-2") + response_data_2 = sock.recv(4096) + sock.sendall(b"test-request-3") + response_data_3 = sock.recv(4096) + sock.close() + + # assert + assert response_data_1 == b"test-response-1" + assert response_data_2 == b"test-response-2" + assert response_data_3 == b"test-response-3" + + requests = Mocket.request_list() + assert len(requests) == 3 + assert type(requests[0]) is MocketBytesRequest + assert type(requests[1]) is MocketBytesRequest + assert type(requests[2]) is MocketBytesRequest + assert requests[0].data == b"test-request-1" + assert requests[1].data == b"test-request-2" + assert requests[2].data == b"test-request-3" diff --git a/tests/test_compat.py b/tests/test_compat.py index 49b62ec7..30db14b4 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -1,4 +1,4 @@ -from mocket.compat import do_the_magic +from mocket.core.compat import do_the_magic def test_unknown_binary(): diff --git a/tests/test_mocket.py b/tests/test_mocket.py index 8d09f170..53b372f9 100644 --- a/tests/test_mocket.py +++ b/tests/test_mocket.py @@ -9,7 +9,7 @@ import pytest from mocket import Mocket, MocketEntry, Mocketizer, mocketize -from mocket.compat import encode_to_bytes +from mocket.core.compat import encode_to_bytes class MocketTestCase(TestCase): diff --git a/tests/test_mode.py b/tests/test_mode.py index ea5905b0..a2f83d69 100644 --- a/tests/test_mode.py +++ b/tests/test_mode.py @@ -2,9 +2,9 @@ import requests from mocket import Mocketizer, mocketize +from mocket.core.mode import MocketMode from mocket.exceptions import StrictMocketException from mocket.mockhttp import Entry, Response -from mocket.mode import MocketMode @mocketize(strict_mode=True) diff --git a/tests/test_socket.py b/tests/test_socket.py index 112a9089..b3077f6f 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -2,7 +2,7 @@ import pytest -from mocket.socket import MocketSocket +from mocket import MocketSocket @pytest.mark.parametrize("blocking", (False, True))