From 4dc38cafa50b499bf783998e55a87e8446b2dce5 Mon Sep 17 00:00:00 2001 From: betaboon Date: Mon, 18 Nov 2024 10:16:24 +0100 Subject: [PATCH 1/9] refactor: type SuperFakeSSLContext and FakeSSLContext --- mocket/ssl.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/mocket/ssl.py b/mocket/ssl.py index e4ae44cf..9d9d5d3b 100644 --- a/mocket/ssl.py +++ b/mocket/ssl.py @@ -1,8 +1,15 @@ +from __future__ import annotations + +from typing import Any + +from mocket.socket import MocketSocket + + class SuperFakeSSLContext: """For Python 3.6 and newer.""" class FakeSetter(int): - def __set__(self, *args): + def __set__(self, *args: Any) -> None: pass minimum_version = FakeSetter() @@ -24,33 +31,36 @@ class FakeSSLContext(SuperFakeSSLContext): _check_hostname = False @property - def check_hostname(self): + def check_hostname(self) -> bool: return self._check_hostname @check_hostname.setter - def check_hostname(self, _): + def check_hostname(self, _: bool) -> None: self._check_hostname = False - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: self._set_dummy_methods() - def _set_dummy_methods(self): - def dummy_method(*args, **kwargs): + def _set_dummy_methods(self) -> None: + def dummy_method(*args: Any, **kwargs: Any) -> Any: pass for m in self.DUMMY_METHODS: setattr(self, m, dummy_method) @staticmethod - def wrap_socket(sock, *args, **kwargs): + def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSocket: sock.kwargs = kwargs sock._secure_socket = True return sock @staticmethod - def wrap_bio(incoming, outcoming, *args, **kwargs): - from mocket.socket import MocketSocket - + def wrap_bio( + incoming: Any, # _ssl.MemoryBIO + outgoing: Any, # _ssl.MemoryBIO + server_side: bool = False, + server_hostname: str | bytes | None = None, + ) -> MocketSocket: ssl_obj = MocketSocket() - ssl_obj._host = kwargs["server_hostname"] + ssl_obj._host = server_hostname return ssl_obj From ba68b9cd4ac1941b3c333a79d32936e9f2193aef Mon Sep 17 00:00:00 2001 From: betaboon Date: Sun, 17 Nov 2024 17:40:19 +0100 Subject: [PATCH 2/9] refactor: move FakeSSLContext from mocket.ssl to mocket.ssl.context --- mocket/__init__.py | 2 +- mocket/inject.py | 2 +- mocket/ssl/__init__.py | 0 mocket/{ssl.py => ssl/context.py} | 0 4 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 mocket/ssl/__init__.py rename mocket/{ssl.py => ssl/context.py} (100%) diff --git a/mocket/__init__.py b/mocket/__init__.py index d64cb11d..58993a24 100644 --- a/mocket/__init__.py +++ b/mocket/__init__.py @@ -2,7 +2,7 @@ from mocket.entry import MocketEntry from mocket.mocket import Mocket from mocket.mocketizer import Mocketizer, mocketize -from mocket.ssl import FakeSSLContext +from mocket.ssl.context import FakeSSLContext __all__ = ( "async_mocketize", diff --git a/mocket/inject.py b/mocket/inject.py index cba0b40b..b39503ed 100644 --- a/mocket/inject.py +++ b/mocket/inject.py @@ -44,7 +44,7 @@ def enable( ) -> None: from mocket.mocket import Mocket from mocket.socket import MocketSocket, create_connection, socketpair - from mocket.ssl import FakeSSLContext + from mocket.ssl.context import FakeSSLContext Mocket._namespace = namespace Mocket._truesocket_recording_dir = truesocket_recording_dir diff --git a/mocket/ssl/__init__.py b/mocket/ssl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mocket/ssl.py b/mocket/ssl/context.py similarity index 100% rename from mocket/ssl.py rename to mocket/ssl/context.py From 942c33f379a1e0fc19122ecc9424ceeb6d270fef Mon Sep 17 00:00:00 2001 From: betaboon Date: Sun, 17 Nov 2024 18:15:48 +0100 Subject: [PATCH 3/9] refactor: move true_* from mocket.inject to mocket.socket and mocket.ssl.context --- mocket/inject.py | 71 +++++++++++++++++++++---------------------- mocket/socket.py | 69 ++++++++++++++++++++++++++++++++++++----- mocket/ssl/context.py | 3 ++ 3 files changed, 98 insertions(+), 45 deletions(-) diff --git a/mocket/inject.py b/mocket/inject.py index b39503ed..5909cb93 100644 --- a/mocket/inject.py +++ b/mocket/inject.py @@ -5,14 +5,6 @@ import ssl import urllib3 -from urllib3.connection import match_hostname as urllib3_match_hostname -from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket - -try: - from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket -except ImportError: - urllib3_wrap_socket = None - try: # pragma: no cover from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3 @@ -21,29 +13,22 @@ except ImportError: pyopenssl_override = False -true_socket = socket.socket -true_create_connection = socket.create_connection -true_gethostbyname = socket.gethostbyname -true_gethostname = socket.gethostname -true_getaddrinfo = socket.getaddrinfo -true_socketpair = socket.socketpair -true_ssl_wrap_socket = getattr( - ssl, "wrap_socket", None -) # from Py3.12 it's only under SSLContext -true_ssl_socket = ssl.SSLSocket -true_ssl_context = ssl.SSLContext -true_inet_pton = socket.inet_pton -true_urllib3_wrap_socket = urllib3_wrap_socket -true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket -true_urllib3_match_hostname = urllib3_match_hostname - def enable( namespace: str | None = None, truesocket_recording_dir: str | None = None, ) -> None: from mocket.mocket import Mocket - from mocket.socket import MocketSocket, create_connection, socketpair + from mocket.socket import ( + MocketSocket, + mock_create_connection, + mock_getaddrinfo, + mock_gethostbyname, + mock_gethostname, + mock_inet_pton, + mock_socketpair, + mock_urllib3_match_hostname, + ) from mocket.ssl.context import FakeSSLContext Mocket._namespace = namespace @@ -56,20 +41,16 @@ def enable( socket.socket = socket.__dict__["socket"] = MocketSocket socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket socket.SocketType = socket.__dict__["SocketType"] = MocketSocket - socket.create_connection = socket.__dict__["create_connection"] = create_connection - socket.gethostname = socket.__dict__["gethostname"] = lambda: "localhost" - socket.gethostbyname = socket.__dict__["gethostbyname"] = lambda host: "127.0.0.1" - socket.getaddrinfo = socket.__dict__["getaddrinfo"] = ( - lambda host, port, family=None, socktype=None, proto=None, flags=None: [ - (2, 1, 6, "", (host, port)) - ] + socket.create_connection = socket.__dict__["create_connection"] = ( + mock_create_connection ) - socket.socketpair = socket.__dict__["socketpair"] = socketpair + socket.gethostname = socket.__dict__["gethostname"] = mock_gethostname + socket.gethostbyname = socket.__dict__["gethostbyname"] = mock_gethostbyname + socket.getaddrinfo = socket.__dict__["getaddrinfo"] = mock_getaddrinfo + socket.socketpair = socket.__dict__["socketpair"] = mock_socketpair ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext - socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: bytes( - "\x7f\x00\x00\x01", "utf-8" - ) + socket.inet_pton = socket.__dict__["inet_pton"] = mock_inet_pton urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( FakeSSLContext.wrap_socket ) @@ -84,7 +65,7 @@ def enable( ] = FakeSSLContext.wrap_socket urllib3.connection.match_hostname = urllib3.connection.__dict__[ "match_hostname" - ] = lambda *args: None + ] = mock_urllib3_match_hostname if pyopenssl_override: # pragma: no cover # Take out the pyopenssl version - use the default implementation extract_from_urllib3() @@ -92,6 +73,22 @@ def enable( def disable() -> None: from mocket.mocket import Mocket + from mocket.socket import ( + true_create_connection, + true_getaddrinfo, + true_gethostbyname, + true_gethostname, + true_inet_pton, + true_socket, + true_socketpair, + true_ssl_wrap_socket, + true_urllib3_match_hostname, + true_urllib3_ssl_wrap_socket, + true_urllib3_wrap_socket, + ) + from mocket.ssl.context import ( + true_ssl_context, + ) socket.socket = socket.__dict__["socket"] = true_socket socket._socketobject = socket.__dict__["_socketobject"] = true_socket diff --git a/mocket/socket.py b/mocket/socket.py index e4be00b6..ab711f06 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import errno import hashlib @@ -8,18 +10,42 @@ import ssl from datetime import datetime, timedelta from json.decoder import JSONDecodeError +from typing import Any + +import urllib3.connection +import urllib3.util.ssl_ from mocket.compat import decode_from_bytes, encode_to_bytes -from mocket.inject import ( - true_gethostbyname, - true_socket, - true_urllib3_ssl_wrap_socket, -) from mocket.io import MocketSocketCore from mocket.mocket import Mocket from mocket.mode import MocketMode from mocket.utils import hexdump, hexload +true_create_connection = socket.create_connection +true_getaddrinfo = socket.getaddrinfo +true_gethostbyname = socket.gethostbyname +true_gethostname = socket.gethostname +true_inet_pton = socket.inet_pton +true_socket = socket.socket +true_socketpair = socket.socketpair +true_ssl_wrap_socket = None + +true_urllib3_match_hostname = urllib3.connection.match_hostname +true_urllib3_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket +true_urllib3_wrap_socket = None + +with contextlib.suppress(ImportError): + # from Py3.12 it's only under SSLContext + from ssl import wrap_socket as ssl_wrap_socket + + true_ssl_wrap_socket = ssl_wrap_socket + +with contextlib.suppress(ImportError): + from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket + + true_urllib3_wrap_socket = urllib3_wrap_socket + + xxh32 = None try: from xxhash import xxh32 @@ -29,7 +55,7 @@ hasher = xxh32 or hashlib.md5 -def create_connection(address, timeout=None, source_address=None): +def mock_create_connection(address, timeout=None, source_address=None): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) if timeout: s.settimeout(timeout) @@ -37,13 +63,40 @@ def create_connection(address, timeout=None, source_address=None): return s -def socketpair(*args, **kwargs): +def mock_getaddrinfo( + host: str, + port: int, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, +) -> list[tuple[int, int, int, str, tuple[str, int]]]: + return [(2, 1, 6, "", (host, port))] + + +def mock_gethostbyname(hostname: str) -> str: + return "127.0.0.1" + + +def mock_gethostname() -> str: + return "localhost" + + +def mock_inet_pton(address_family: int, ip_string: str) -> bytes: + return bytes("\x7f\x00\x00\x01", "utf-8") + + +def mock_socketpair(*args, **kwargs): """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services.""" import _socket return _socket.socketpair(*args, **kwargs) +def mock_urllib3_match_hostname(*args: Any) -> None: + return None + + def _hash_request(h, req): return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest() @@ -132,7 +185,7 @@ def getblocking(self): return self.gettimeout() is None def getsockname(self): - return socket.gethostbyname(self._address[0]), self._address[1] + return true_gethostbyname(self._address[0]), self._address[1] def getpeercert(self, *args, **kwargs): if not (self._host and self._port): diff --git a/mocket/ssl/context.py b/mocket/ssl/context.py index 9d9d5d3b..a327fbef 100644 --- a/mocket/ssl/context.py +++ b/mocket/ssl/context.py @@ -1,9 +1,12 @@ from __future__ import annotations +import ssl from typing import Any from mocket.socket import MocketSocket +true_ssl_context = ssl.SSLContext + class SuperFakeSSLContext: """For Python 3.6 and newer.""" From cfcd85c642cfa3847a7af1b5b81c9052846aa146 Mon Sep 17 00:00:00 2001 From: betaboon Date: Sun, 17 Nov 2024 18:59:58 +0100 Subject: [PATCH 4/9] refactor: type MocketSocket --- mocket/socket.py | 93 ++++++++++++++++++++++++++++++++---------------- mocket/types.py | 17 ++++++++- 2 files changed, 78 insertions(+), 32 deletions(-) diff --git a/mocket/socket.py b/mocket/socket.py index ab711f06..3743e6f2 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -10,15 +10,25 @@ import ssl from datetime import datetime, timedelta from json.decoder import JSONDecodeError -from typing import Any +from types import TracebackType +from typing import Any, Type import urllib3.connection import urllib3.util.ssl_ +from typing_extensions import Self from mocket.compat import decode_from_bytes, encode_to_bytes +from mocket.entry import MocketEntry from mocket.io import MocketSocketCore from mocket.mocket import Mocket from mocket.mode import MocketMode +from mocket.types import ( + Address, + ReadableBuffer, + WriteableBuffer, + _PeerCertRetDictType, + _RetAddress, +) from mocket.utils import hexdump, hexload true_create_connection = socket.create_connection @@ -120,8 +130,13 @@ class MocketSocket: _io = None def __init__( - self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs - ): + self, + family: socket.AddressFamily | int = socket.AF_INET, + type: socket.SocketKind | int = socket.SOCK_STREAM, + proto: int = 0, + fileno: int | None = None, + **kwargs: Any, + ) -> None: self.true_socket = true_socket(family, type, proto) self._buflen = 65536 self._entry = None @@ -131,22 +146,27 @@ def __init__( self._truesocket_recording_dir = None self.kwargs = kwargs - def __str__(self): + def __str__(self) -> str: return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})" - def __enter__(self): + def __enter__(self) -> Self: return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + type_: Type[BaseException] | None, # noqa: UP006 + value: BaseException | None, + traceback: TracebackType | None, + ) -> None: self.close() @property - def io(self): + def io(self) -> MocketSocketCore: if self._io is None: self._io = MocketSocketCore((self._host, self._port)) return self._io - def fileno(self): + def fileno(self) -> int: address = (self._host, self._port) r_fd, _ = Mocket.get_pair(address) if not r_fd: @@ -154,10 +174,11 @@ def fileno(self): Mocket.set_pair(address, (r_fd, w_fd)) return r_fd - def gettimeout(self): + def gettimeout(self) -> float | None: return self.timeout - def setsockopt(self, family, type, proto): + # FIXME the arguments here seem wrong. they should be `level: int, optname: int, value: int | ReadableBuffer | None` + def setsockopt(self, family: int, type: int, proto: int) -> None: self.family = family self.type = type self.proto = proto @@ -165,29 +186,29 @@ def setsockopt(self, family, type, proto): if self.true_socket: self.true_socket.setsockopt(family, type, proto) - def settimeout(self, timeout): + def settimeout(self, timeout: float | None) -> None: self.timeout = timeout @staticmethod - def getsockopt(level, optname, buflen=None): + def getsockopt(level: int, optname: int, buflen: int | None = None) -> int: return socket.SOCK_STREAM - def do_handshake(self): + def do_handshake(self) -> None: self._did_handshake = True - def getpeername(self): + def getpeername(self) -> _RetAddress: return self._address - def setblocking(self, block): + def setblocking(self, block: bool) -> None: self.settimeout(None) if block else self.settimeout(0.0) - def getblocking(self): + def getblocking(self) -> bool: return self.gettimeout() is None - def getsockname(self): + def getsockname(self) -> _RetAddress: return true_gethostbyname(self._address[0]), self._address[1] - def getpeercert(self, *args, **kwargs): + def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType: if not (self._host and self._port): self._address = self._host, self._port = Mocket._address @@ -207,22 +228,22 @@ def getpeercert(self, *args, **kwargs): ), } - def unwrap(self): + def unwrap(self) -> MocketSocket: return self - def write(self, data): + def write(self, data: bytes) -> int | None: return self.send(encode_to_bytes(data)) - def connect(self, address): + def connect(self, address: Address) -> None: self._address = self._host, self._port = address Mocket._address = address - def makefile(self, mode="r", bufsize=-1): + def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketCore: self._mode = mode self._bufsize = bufsize return self.io - def get_entry(self, data): + def get_entry(self, data: bytes) -> MocketEntry | None: return Mocket.get_entry(self._host, self._port, data) def sendall(self, data, entry=None, *args, **kwargs): @@ -241,7 +262,7 @@ def sendall(self, data, entry=None, *args, **kwargs): self.io.truncate() self.io.seek(0) - def read(self, buffersize): + def read(self, buffersize: int | None = None) -> bytes: rv = self.io.read(buffersize) if rv: self._sent_non_empty_bytes = True @@ -249,7 +270,12 @@ def read(self, buffersize): raise ssl.SSLWantReadError("The operation did not complete (read)") return rv - def recv_into(self, buffer, buffersize=None, flags=None): + def recv_into( + self, + buffer: WriteableBuffer, + buffersize: int | None = None, + flags: int | None = None, + ) -> int: if hasattr(buffer, "write"): return buffer.write(self.read(buffersize)) # buffer is a memoryview @@ -258,7 +284,7 @@ def recv_into(self, buffer, buffersize=None, flags=None): buffer[: len(data)] = data return len(data) - def recv(self, buffersize, flags=None): + def recv(self, buffersize: int, flags: int | None = None) -> bytes: r_fd, _ = Mocket.get_pair((self._host, self._port)) if r_fd: return os.read(r_fd, buffersize) @@ -271,7 +297,7 @@ def recv(self, buffersize, flags=None): exc.args = (0,) raise exc - def true_sendall(self, data, *args, **kwargs): + def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int: if not MocketMode().is_allowed((self._host, self._port)): MocketMode.raise_not_allowed() @@ -359,7 +385,12 @@ def true_sendall(self, data, *args, **kwargs): # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO return encoded_response - def send(self, data, *args, **kwargs): # pragma: no cover + def send( + self, + data: ReadableBuffer, + *args: Any, + **kwargs: Any, + ) -> int: # pragma: no cover entry = self.get_entry(data) if not entry or (entry and self._entry != entry): kwargs["entry"] = entry @@ -371,15 +402,15 @@ def send(self, data, *args, **kwargs): # pragma: no cover self._entry = entry return len(data) - def close(self): + def close(self) -> None: if self.true_socket and not self.true_socket._closed: self.true_socket.close() self._fd = None - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: """Do nothing catchall function, for methods like shutdown()""" - def do_nothing(*args, **kwargs): + def do_nothing(*args: Any, **kwargs: Any) -> Any: pass return do_nothing diff --git a/mocket/types.py b/mocket/types.py index 61b7a4d5..562648c7 100644 --- a/mocket/types.py +++ b/mocket/types.py @@ -1,5 +1,20 @@ from __future__ import annotations -from typing import Tuple +from typing import Any, Dict, Tuple, Union + +from typing_extensions import Buffer, TypeAlias Address = Tuple[str, int] + +# adapted from typeshed/stdlib/_typeshed/__init__.pyi +WriteableBuffer: TypeAlias = Buffer +ReadableBuffer: TypeAlias = Buffer + +# from typeshed/stdlib/_socket.pyi +_Address: TypeAlias = Union[Tuple[Any, ...], str, ReadableBuffer] +_RetAddress: TypeAlias = Any + +# from typeshed/stdlib/ssl.pyi +_PCTRTT: TypeAlias = Tuple[Tuple[str, str], ...] +_PCTRTTT: TypeAlias = Tuple[_PCTRTT, ...] +_PeerCertRetDictType: TypeAlias = Dict[str, Union[str, _PCTRTTT, _PCTRTT]] From 9050127e34dcd121086e68ae657d05e51d414425 Mon Sep 17 00:00:00 2001 From: betaboon Date: Sun, 17 Nov 2024 19:02:08 +0100 Subject: [PATCH 5/9] refactor: remove unused instance-variables from MocketSocket --- mocket/socket.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/mocket/socket.py b/mocket/socket.py index 3743e6f2..c4b6a9a8 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -113,7 +113,6 @@ def _hash_request(h, req): class MocketSocket: timeout = None - _fd = None family = None type = None proto = None @@ -122,8 +121,6 @@ class MocketSocket: _address = None cipher = lambda s: ("ADH", "AES256", "SHA") compression = lambda s: ssl.OP_NO_COMPRESSION - _mode = None - _bufsize = None _secure_socket = False _did_handshake = False _sent_non_empty_bytes = False @@ -239,8 +236,6 @@ def connect(self, address: Address) -> None: Mocket._address = address def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketCore: - self._mode = mode - self._bufsize = bufsize return self.io def get_entry(self, data: bytes) -> MocketEntry | None: @@ -405,7 +400,6 @@ def send( def close(self) -> None: if self.true_socket and not self.true_socket._closed: self.true_socket.close() - self._fd = None def __getattr__(self, name: str) -> Any: """Do nothing catchall function, for methods like shutdown()""" From 1eb61cf55ea7a0445cfd7eee33a87b8fc936858c Mon Sep 17 00:00:00 2001 From: betaboon Date: Sun, 17 Nov 2024 19:13:12 +0100 Subject: [PATCH 6/9] refactor: MocketSocket - make instance-variables private and move into constructor --- mocket/socket.py | 81 +++++++++++++++++++++++++------------------ mocket/ssl/context.py | 2 +- 2 files changed, 48 insertions(+), 35 deletions(-) diff --git a/mocket/socket.py b/mocket/socket.py index c4b6a9a8..0b345572 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -112,19 +112,8 @@ def _hash_request(h, req): class MocketSocket: - timeout = None - family = None - type = None - proto = None - _host = None - _port = None - _address = None cipher = lambda s: ("ADH", "AES256", "SHA") compression = lambda s: ssl.OP_NO_COMPRESSION - _secure_socket = False - _did_handshake = False - _sent_non_empty_bytes = False - _io = None def __init__( self, @@ -134,14 +123,26 @@ def __init__( fileno: int | None = None, **kwargs: Any, ) -> None: - self.true_socket = true_socket(family, type, proto) + self._family = family + self._type = type + self._proto = proto + + self._kwargs = kwargs + self._true_socket = true_socket(family, type, proto) + self._buflen = 65536 + self._timeout: float | None = None + + self._secure_socket = False + self._did_handshake = False + self._sent_non_empty_bytes = False + + self._host = None + self._port = None + self._address = None + + self._io = None self._entry = None - self.family = int(family) - self.type = int(type) - self.proto = int(proto) - self._truesocket_recording_dir = None - self.kwargs = kwargs def __str__(self) -> str: return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})" @@ -157,6 +158,18 @@ def __exit__( ) -> None: self.close() + @property + def family(self) -> int: + return self._family + + @property + def type(self) -> int: + return self._type + + @property + def proto(self) -> int: + return self._proto + @property def io(self) -> MocketSocketCore: if self._io is None: @@ -172,19 +185,19 @@ def fileno(self) -> int: return r_fd def gettimeout(self) -> float | None: - return self.timeout + return self._timeout # FIXME the arguments here seem wrong. they should be `level: int, optname: int, value: int | ReadableBuffer | None` def setsockopt(self, family: int, type: int, proto: int) -> None: - self.family = family - self.type = type - self.proto = proto + self._family = family + self._type = type + self._proto = proto - if self.true_socket: - self.true_socket.setsockopt(family, type, proto) + if self._true_socket: + self._true_socket.setsockopt(family, type, proto) def settimeout(self, timeout: float | None) -> None: - self.timeout = timeout + self._timeout = timeout @staticmethod def getsockopt(level: int, optname: int, buflen: int | None = None) -> int: @@ -343,23 +356,23 @@ def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int: host, port = self._host, self._port host = true_gethostbyname(host) - if isinstance(self.true_socket, true_socket) and self._secure_socket: - self.true_socket = true_urllib3_ssl_wrap_socket( - self.true_socket, - **self.kwargs, + if isinstance(self._true_socket, true_socket) and self._secure_socket: + self._true_socket = true_urllib3_ssl_wrap_socket( + self._true_socket, + **self._kwargs, ) with contextlib.suppress(OSError, ValueError): # already connected - self.true_socket.connect((host, port)) - self.true_socket.sendall(data, *args, **kwargs) + self._true_socket.connect((host, port)) + self._true_socket.sendall(data, *args, **kwargs) encoded_response = b"" # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12 while True: - more_to_read = select.select([self.true_socket], [], [], 0.1)[0] + more_to_read = select.select([self._true_socket], [], [], 0.1)[0] if not more_to_read and encoded_response: break - new_content = self.true_socket.recv(self._buflen) + new_content = self._true_socket.recv(self._buflen) if not new_content: break encoded_response += new_content @@ -398,8 +411,8 @@ def send( return len(data) def close(self) -> None: - if self.true_socket and not self.true_socket._closed: - self.true_socket.close() + if self._true_socket and not self._true_socket._closed: + self._true_socket.close() def __getattr__(self, name: str) -> Any: """Do nothing catchall function, for methods like shutdown()""" diff --git a/mocket/ssl/context.py b/mocket/ssl/context.py index a327fbef..a830c1e7 100644 --- a/mocket/ssl/context.py +++ b/mocket/ssl/context.py @@ -53,7 +53,7 @@ def dummy_method(*args: Any, **kwargs: Any) -> Any: @staticmethod def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSocket: - sock.kwargs = kwargs + sock._kwargs = kwargs sock._secure_socket = True return sock From 0eff8f1ec935124b0d6097ecc366d8e758220eda Mon Sep 17 00:00:00 2001 From: betaboon Date: Mon, 18 Nov 2024 09:29:23 +0100 Subject: [PATCH 7/9] refactor: move true-ssl-methods to mocket.ssl.context --- mocket/inject.py | 6 +++--- mocket/socket.py | 18 ++---------------- mocket/ssl/context.py | 18 ++++++++++++++++++ 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/mocket/inject.py b/mocket/inject.py index 5909cb93..b733dd3c 100644 --- a/mocket/inject.py +++ b/mocket/inject.py @@ -81,13 +81,13 @@ def disable() -> None: true_inet_pton, true_socket, true_socketpair, - true_ssl_wrap_socket, true_urllib3_match_hostname, - true_urllib3_ssl_wrap_socket, - true_urllib3_wrap_socket, ) from mocket.ssl.context import ( true_ssl_context, + true_ssl_wrap_socket, + true_urllib3_ssl_wrap_socket, + true_urllib3_wrap_socket, ) socket.socket = socket.__dict__["socket"] = true_socket diff --git a/mocket/socket.py b/mocket/socket.py index 0b345572..c3bed15f 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -14,7 +14,6 @@ from typing import Any, Type import urllib3.connection -import urllib3.util.ssl_ from typing_extensions import Self from mocket.compat import decode_from_bytes, encode_to_bytes @@ -38,22 +37,7 @@ true_inet_pton = socket.inet_pton true_socket = socket.socket true_socketpair = socket.socketpair -true_ssl_wrap_socket = None - true_urllib3_match_hostname = urllib3.connection.match_hostname -true_urllib3_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket -true_urllib3_wrap_socket = None - -with contextlib.suppress(ImportError): - # from Py3.12 it's only under SSLContext - from ssl import wrap_socket as ssl_wrap_socket - - true_ssl_wrap_socket = ssl_wrap_socket - -with contextlib.suppress(ImportError): - from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket - - true_urllib3_wrap_socket = urllib3_wrap_socket xxh32 = None @@ -357,6 +341,8 @@ def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int: host = true_gethostbyname(host) if isinstance(self._true_socket, true_socket) and self._secure_socket: + from mocket.ssl.context import true_urllib3_ssl_wrap_socket + self._true_socket = true_urllib3_ssl_wrap_socket( self._true_socket, **self._kwargs, diff --git a/mocket/ssl/context.py b/mocket/ssl/context.py index a830c1e7..fccf5db4 100644 --- a/mocket/ssl/context.py +++ b/mocket/ssl/context.py @@ -1,12 +1,30 @@ from __future__ import annotations +import contextlib import ssl from typing import Any +import urllib3.util.ssl_ + from mocket.socket import MocketSocket true_ssl_context = ssl.SSLContext +true_ssl_wrap_socket = None +true_urllib3_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket +true_urllib3_wrap_socket = None + +with contextlib.suppress(ImportError): + # from Py3.12 it's only under SSLContext + from ssl import wrap_socket as ssl_wrap_socket + + true_ssl_wrap_socket = ssl_wrap_socket + +with contextlib.suppress(ImportError): + from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket + + true_urllib3_wrap_socket = urllib3_wrap_socket + class SuperFakeSSLContext: """For Python 3.6 and newer.""" From 90eb5db6929f12793413ac3894b53fc175b269c2 Mon Sep 17 00:00:00 2001 From: betaboon Date: Mon, 18 Nov 2024 09:38:27 +0100 Subject: [PATCH 8/9] refactor: prepare for removal of read and write from MocketSocket --- mocket/socket.py | 10 +++++++--- tests/test_http.py | 12 ++++++------ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/mocket/socket.py b/mocket/socket.py index c3bed15f..3cd68fe5 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -269,9 +269,13 @@ def recv_into( flags: int | None = None, ) -> int: if hasattr(buffer, "write"): - return buffer.write(self.read(buffersize)) + return buffer.write(self.recv(buffersize)) + # buffer is a memoryview - data = self.read(buffersize) + if buffersize is None: + buffersize = len(buffer) + + data = self.recv(buffersize) if data: buffer[: len(data)] = data return len(data) @@ -280,7 +284,7 @@ def recv(self, buffersize: int, flags: int | None = None) -> bytes: r_fd, _ = Mocket.get_pair((self._host, self._port)) if r_fd: return os.read(r_fd, buffersize) - data = self.read(buffersize) + data = self.io.read(buffersize) if data: return data # used by Redis mock diff --git a/tests/test_http.py b/tests/test_http.py index d516068b..afa31185 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -359,12 +359,12 @@ def test_sockets(self): sock = socket.socket(address[0], address[1], address[2]) sock.connect(address[-1]) - sock.write(f"{method} {path} HTTP/1.0\r\n") - sock.write(f"Host: {host}\r\n") - sock.write("Content-Type: application/json\r\n") - sock.write("Content-Length: %d\r\n" % len(data)) - sock.write("Connection: close\r\n\r\n") - sock.write(data) + sock.send(f"{method} {path} HTTP/1.0\r\n".encode()) + sock.send(f"Host: {host}\r\n".encode()) + sock.send(b"Content-Type: application/json\r\n") + sock.send(b"Content-Length: %d\r\n" % len(data)) + sock.send(b"Connection: close\r\n\r\n") + sock.send(data.encode()) sock.close() # Proof that worked. From 636951f2f9ea47139539b346c0e3bbc9067e86f0 Mon Sep 17 00:00:00 2001 From: betaboon Date: Mon, 18 Nov 2024 10:10:41 +0100 Subject: [PATCH 9/9] refactor: split ssl-functionality of MocketSocket into MocketSSLSocket --- mocket/socket.py | 52 ------------------------------------ mocket/ssl/context.py | 29 +++++++++++++++----- mocket/ssl/socket.py | 62 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+), 58 deletions(-) create mode 100644 mocket/ssl/socket.py diff --git a/mocket/socket.py b/mocket/socket.py index 3cd68fe5..2ce74c09 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -96,9 +96,6 @@ def _hash_request(h, req): class MocketSocket: - cipher = lambda s: ("ADH", "AES256", "SHA") - compression = lambda s: ssl.OP_NO_COMPRESSION - def __init__( self, family: socket.AddressFamily | int = socket.AF_INET, @@ -117,10 +114,6 @@ def __init__( self._buflen = 65536 self._timeout: float | None = None - self._secure_socket = False - self._did_handshake = False - self._sent_non_empty_bytes = False - self._host = None self._port = None self._address = None @@ -187,9 +180,6 @@ def settimeout(self, timeout: float | None) -> None: def getsockopt(level: int, optname: int, buflen: int | None = None) -> int: return socket.SOCK_STREAM - def do_handshake(self) -> None: - self._did_handshake = True - def getpeername(self) -> _RetAddress: return self._address @@ -202,32 +192,6 @@ def getblocking(self) -> bool: def getsockname(self) -> _RetAddress: return true_gethostbyname(self._address[0]), self._address[1] - def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType: - if not (self._host and self._port): - self._address = self._host, self._port = Mocket._address - - now = datetime.now() - shift = now + timedelta(days=30 * 12) - return { - "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), - "subjectAltName": ( - ("DNS", f"*.{self._host}"), - ("DNS", self._host), - ("DNS", "*"), - ), - "subject": ( - (("organizationName", f"*.{self._host}"),), - (("organizationalUnitName", "Domain Control Validated"),), - (("commonName", f"*.{self._host}"),), - ), - } - - def unwrap(self) -> MocketSocket: - return self - - def write(self, data: bytes) -> int | None: - return self.send(encode_to_bytes(data)) - def connect(self, address: Address) -> None: self._address = self._host, self._port = address Mocket._address = address @@ -254,14 +218,6 @@ def sendall(self, data, entry=None, *args, **kwargs): self.io.truncate() self.io.seek(0) - def read(self, buffersize: int | None = None) -> bytes: - rv = self.io.read(buffersize) - if rv: - self._sent_non_empty_bytes = True - if self._did_handshake and not self._sent_non_empty_bytes: - raise ssl.SSLWantReadError("The operation did not complete (read)") - return rv - def recv_into( self, buffer: WriteableBuffer, @@ -344,14 +300,6 @@ def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int: host, port = self._host, self._port host = true_gethostbyname(host) - if isinstance(self._true_socket, true_socket) and self._secure_socket: - from mocket.ssl.context import true_urllib3_ssl_wrap_socket - - self._true_socket = true_urllib3_ssl_wrap_socket( - self._true_socket, - **self._kwargs, - ) - with contextlib.suppress(OSError, ValueError): # already connected self._true_socket.connect((host, port)) diff --git a/mocket/ssl/context.py b/mocket/ssl/context.py index fccf5db4..e5f60c0a 100644 --- a/mocket/ssl/context.py +++ b/mocket/ssl/context.py @@ -7,6 +7,7 @@ import urllib3.util.ssl_ from mocket.socket import MocketSocket +from mocket.ssl.socket import MocketSSLSocket true_ssl_context = ssl.SSLContext @@ -70,10 +71,26 @@ def dummy_method(*args: Any, **kwargs: Any) -> Any: setattr(self, m, dummy_method) @staticmethod - def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSocket: - sock._kwargs = kwargs - sock._secure_socket = True - return sock + def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSSLSocket: + ssl_socket = MocketSSLSocket() + ssl_socket._original_socket = sock + + ssl_socket._true_socket = true_urllib3_ssl_wrap_socket( + sock._true_socket, + **kwargs, + ) + ssl_socket._kwargs = kwargs + + ssl_socket._timeout = sock._timeout + + ssl_socket._host = sock._host + ssl_socket._port = sock._port + ssl_socket._address = sock._address + + ssl_socket._io = sock._io + ssl_socket._entry = sock._entry + + return ssl_socket @staticmethod def wrap_bio( @@ -81,7 +98,7 @@ def wrap_bio( outgoing: Any, # _ssl.MemoryBIO server_side: bool = False, server_hostname: str | bytes | None = None, - ) -> MocketSocket: - ssl_obj = MocketSocket() + ) -> MocketSSLSocket: + ssl_obj = MocketSSLSocket() ssl_obj._host = server_hostname return ssl_obj diff --git a/mocket/ssl/socket.py b/mocket/ssl/socket.py new file mode 100644 index 00000000..e50b7320 --- /dev/null +++ b/mocket/ssl/socket.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import ssl +from datetime import datetime, timedelta +from typing import Any + +from mocket.compat import encode_to_bytes +from mocket.mocket import Mocket +from mocket.socket import MocketSocket +from mocket.types import _PeerCertRetDictType + + +class MocketSSLSocket(MocketSocket): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + self._did_handshake = False + self._sent_non_empty_bytes = False + self._original_socket: MocketSocket = self + + def read(self, buffersize: int | None = None) -> bytes: + rv = self.io.read(buffersize) + if rv: + self._sent_non_empty_bytes = True + if self._did_handshake and not self._sent_non_empty_bytes: + raise ssl.SSLWantReadError("The operation did not complete (read)") + return rv + + def write(self, data: bytes) -> int | None: + return self.send(encode_to_bytes(data)) + + def do_handshake(self) -> None: + self._did_handshake = True + + def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType: + if not (self._host and self._port): + self._address = self._host, self._port = Mocket._address + + now = datetime.now() + shift = now + timedelta(days=30 * 12) + return { + "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), + "subjectAltName": ( + ("DNS", f"*.{self._host}"), + ("DNS", self._host), + ("DNS", "*"), + ), + "subject": ( + (("organizationName", f"*.{self._host}"),), + (("organizationalUnitName", "Domain Control Validated"),), + (("commonName", f"*.{self._host}"),), + ), + } + + def ciper(self) -> tuple[str, str, str]: + return ("ADH", "AES256", "SHA") + + def compression(self) -> str | None: + return ssl.OP_NO_COMPRESSION + + def unwrap(self) -> MocketSocket: + return self._original_socket