From d6c1701973df81de507a8014368284018ba27256 Mon Sep 17 00:00:00 2001 From: betaboon Date: Mon, 18 Nov 2024 10:10:41 +0100 Subject: [PATCH] refactor: split ssl-functionality of MocketSocket into MocketSSLSocket --- mocket/socket.py | 52 ------------------------------------ mocket/ssl/context.py | 29 +++++++++++++++----- mocket/ssl/socket.py | 61 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 84 insertions(+), 58 deletions(-) create mode 100644 mocket/ssl/socket.py diff --git a/mocket/socket.py b/mocket/socket.py index 3cd68fe5..2ce74c09 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -96,9 +96,6 @@ def _hash_request(h, req): class MocketSocket: - cipher = lambda s: ("ADH", "AES256", "SHA") - compression = lambda s: ssl.OP_NO_COMPRESSION - def __init__( self, family: socket.AddressFamily | int = socket.AF_INET, @@ -117,10 +114,6 @@ def __init__( self._buflen = 65536 self._timeout: float | None = None - self._secure_socket = False - self._did_handshake = False - self._sent_non_empty_bytes = False - self._host = None self._port = None self._address = None @@ -187,9 +180,6 @@ def settimeout(self, timeout: float | None) -> None: def getsockopt(level: int, optname: int, buflen: int | None = None) -> int: return socket.SOCK_STREAM - def do_handshake(self) -> None: - self._did_handshake = True - def getpeername(self) -> _RetAddress: return self._address @@ -202,32 +192,6 @@ def getblocking(self) -> bool: def getsockname(self) -> _RetAddress: return true_gethostbyname(self._address[0]), self._address[1] - def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType: - if not (self._host and self._port): - self._address = self._host, self._port = Mocket._address - - now = datetime.now() - shift = now + timedelta(days=30 * 12) - return { - "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), - "subjectAltName": ( - ("DNS", f"*.{self._host}"), - ("DNS", self._host), - ("DNS", "*"), - ), - "subject": ( - (("organizationName", f"*.{self._host}"),), - (("organizationalUnitName", "Domain Control Validated"),), - (("commonName", f"*.{self._host}"),), - ), - } - - def unwrap(self) -> MocketSocket: - return self - - def write(self, data: bytes) -> int | None: - return self.send(encode_to_bytes(data)) - def connect(self, address: Address) -> None: self._address = self._host, self._port = address Mocket._address = address @@ -254,14 +218,6 @@ def sendall(self, data, entry=None, *args, **kwargs): self.io.truncate() self.io.seek(0) - def read(self, buffersize: int | None = None) -> bytes: - rv = self.io.read(buffersize) - if rv: - self._sent_non_empty_bytes = True - if self._did_handshake and not self._sent_non_empty_bytes: - raise ssl.SSLWantReadError("The operation did not complete (read)") - return rv - def recv_into( self, buffer: WriteableBuffer, @@ -344,14 +300,6 @@ def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int: host, port = self._host, self._port host = true_gethostbyname(host) - if isinstance(self._true_socket, true_socket) and self._secure_socket: - from mocket.ssl.context import true_urllib3_ssl_wrap_socket - - self._true_socket = true_urllib3_ssl_wrap_socket( - self._true_socket, - **self._kwargs, - ) - with contextlib.suppress(OSError, ValueError): # already connected self._true_socket.connect((host, port)) diff --git a/mocket/ssl/context.py b/mocket/ssl/context.py index fccf5db4..e5f60c0a 100644 --- a/mocket/ssl/context.py +++ b/mocket/ssl/context.py @@ -7,6 +7,7 @@ import urllib3.util.ssl_ from mocket.socket import MocketSocket +from mocket.ssl.socket import MocketSSLSocket true_ssl_context = ssl.SSLContext @@ -70,10 +71,26 @@ def dummy_method(*args: Any, **kwargs: Any) -> Any: setattr(self, m, dummy_method) @staticmethod - def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSocket: - sock._kwargs = kwargs - sock._secure_socket = True - return sock + def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSSLSocket: + ssl_socket = MocketSSLSocket() + ssl_socket._original_socket = sock + + ssl_socket._true_socket = true_urllib3_ssl_wrap_socket( + sock._true_socket, + **kwargs, + ) + ssl_socket._kwargs = kwargs + + ssl_socket._timeout = sock._timeout + + ssl_socket._host = sock._host + ssl_socket._port = sock._port + ssl_socket._address = sock._address + + ssl_socket._io = sock._io + ssl_socket._entry = sock._entry + + return ssl_socket @staticmethod def wrap_bio( @@ -81,7 +98,7 @@ def wrap_bio( outgoing: Any, # _ssl.MemoryBIO server_side: bool = False, server_hostname: str | bytes | None = None, - ) -> MocketSocket: - ssl_obj = MocketSocket() + ) -> MocketSSLSocket: + ssl_obj = MocketSSLSocket() ssl_obj._host = server_hostname return ssl_obj diff --git a/mocket/ssl/socket.py b/mocket/ssl/socket.py new file mode 100644 index 00000000..3ea928e7 --- /dev/null +++ b/mocket/ssl/socket.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import ssl +from datetime import datetime, timedelta +from typing import Any + +from mocket.mocket import Mocket +from mocket.socket import MocketSocket +from mocket.types import _PeerCertRetDictType + + +class MocketSSLSocket(MocketSocket): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + self._original_socket: MocketSocket = self + self._did_handshake = False + self._sent_non_empty_bytes = False + + def read(self, buffersize: int | None = None) -> bytes: + rv = self.io.read(buffersize) + if rv: + self._sent_non_empty_bytes = True + if self._did_handshake and not self._sent_non_empty_bytes: + raise ssl.SSLWantReadError("The operation did not complete (read)") + return rv + + def write(self, data: bytes) -> int | None: + return self.send(data) + + def do_handshake(self) -> None: + self._did_handshake = True + + def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType: + if not (self._host and self._port): + self._address = self._host, self._port = Mocket._address + + now = datetime.now() + shift = now + timedelta(days=30 * 12) + return { + "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), + "subjectAltName": ( + ("DNS", f"*.{self._host}"), + ("DNS", self._host), + ("DNS", "*"), + ), + "subject": ( + (("organizationName", f"*.{self._host}"),), + (("organizationalUnitName", "Domain Control Validated"),), + (("commonName", f"*.{self._host}"),), + ), + } + + def ciper(self) -> tuple[str, str, str]: + return ("ADH", "AES256", "SHA") + + def compression(self) -> str | None: + return ssl.OP_NO_COMPRESSION + + def unwrap(self) -> MocketSocket: + return self._original_socket