diff --git a/mocket/__init__.py b/mocket/__init__.py index 8678d4fd..550a1f5b 100644 --- a/mocket/__init__.py +++ b/mocket/__init__.py @@ -4,7 +4,7 @@ from mocket.async_mocket import async_mocketize from mocket.entry import MocketEntry from mocket.mocketizer import Mocketizer, mocketize -from mocket.ssl import MocketSSLContext +from mocket.ssl.context import MocketSSLContext from mocket.state import MocketState diff --git a/mocket/inject.py b/mocket/inject.py index d97a1b6a..e5274506 100644 --- a/mocket/inject.py +++ b/mocket/inject.py @@ -25,14 +25,16 @@ true_socket, true_socketpair, true_urllib3_match_hostname, - true_urllib3_ssl_wrap_socket, - true_urllib3_wrap_socket, ) -from mocket.ssl import ( +from mocket.ssl.context import ( MocketSSLContext, true_ssl_context, true_ssl_wrap_socket, ) +from mocket.ssl.socket import ( + true_urllib3_ssl_wrap_socket, + true_urllib3_wrap_socket, +) try: # pragma: no cover from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3 diff --git a/mocket/socket.py b/mocket/socket.py index f1dab5ea..100df41e 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -9,15 +9,12 @@ import os import select import socket -import ssl -from datetime import datetime, timedelta from json.decoder import JSONDecodeError from types import TracebackType from typing import Any, Type from typing_extensions import Buffer, Self from urllib3.connection import match_hostname as urllib3_match_hostname -from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket import mocket.state from mocket.compat import decode_from_bytes, encode_to_bytes @@ -27,16 +24,10 @@ ReadableBuffer, WriteableBuffer, _Address, - _PeerCertRetDictType, _RetAddress, ) from mocket.utils import hexdump, hexload -try: - from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket -except ImportError: - urllib3_wrap_socket = None - xxh32 = None try: from xxhash import xxh32 @@ -55,8 +46,6 @@ true_socketpair = socket.socketpair true_urllib3_match_hostname = urllib3_match_hostname -true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket -true_urllib3_wrap_socket = urllib3_wrap_socket def create_connection( @@ -108,6 +97,7 @@ def _hash_request(h, req): return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest() +# TODO rename to MocketSocketIO class MocketSocketCore(io.BytesIO): def __init__(self, address: Address) -> None: self._address = address @@ -124,20 +114,8 @@ def write(self, content: Buffer) -> int: class MocketSocket: - timeout = None - _fd = None - family = None - type = None - proto = None - _host = None - _port = None - _address = None - cipher = lambda s: ("ADH", "AES256", "SHA") - compression = lambda s: ssl.OP_NO_COMPRESSION _mode = None _bufsize = None - _secure_socket = False - _did_handshake = False _sent_non_empty_bytes = False _io = None @@ -149,14 +127,22 @@ def __init__( fileno: int | None = None, **kwargs: Any, ): - self.true_socket = true_socket(family, type, proto) - self._buflen = 65536 - self._entry = None self.family = int(family) self.type = int(type) self.proto = int(proto) + + self._kwargs = kwargs + self._true_socket = true_socket(family, type, proto) self._truesocket_recording_dir = None - self.kwargs = kwargs + + self._timeout: float | None = None + self._buflen = 65536 + self._entry = None + + # TODO remove host and port with address everywhere + self._host = None + self._port = None + self._address = None def __str__(self) -> str: return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})" @@ -187,7 +173,7 @@ def fileno(self) -> int: return r_fd def gettimeout(self) -> float | None: - return self.timeout + return self._timeout # FIXME the arguments here seem wrong. they should be `level: int, optname: int, value: int | ReadableBuffer | None` def setsockopt(self, family: int, type: int, proto: int) -> None: @@ -195,19 +181,16 @@ def setsockopt(self, family: int, type: int, proto: int) -> None: self.type = type self.proto = proto - if self.true_socket: - self.true_socket.setsockopt(family, type, proto) + if self._true_socket: + self._true_socket.setsockopt(family, type, proto) def settimeout(self, timeout: float | None) -> None: - self.timeout = timeout + self._timeout = timeout @staticmethod def getsockopt(level: int, optname: int, buflen: int | None = None) -> int: return socket.SOCK_STREAM - def do_handshake(self) -> None: - self._did_handshake = True - def getpeername(self) -> _RetAddress: return self._address @@ -220,29 +203,6 @@ def getblocking(self) -> bool: def getsockname(self) -> _RetAddress: return socket.gethostbyname(self._address[0]), self._address[1] - def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType: - if not (self._host and self._port): - self._address = self._host, self._port = mocket.state.state._address - - now = datetime.now() - shift = now + timedelta(days=30 * 12) - return { - "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), - "subjectAltName": ( - ("DNS", f"*.{self._host}"), - ("DNS", self._host), - ("DNS", "*"), - ), - "subject": ( - (("organizationName", f"*.{self._host}"),), - (("organizationalUnitName", "Domain Control Validated"),), - (("commonName", f"*.{self._host}"),), - ), - } - - def unwrap(self) -> Self: - return self - def write(self, data: ReadableBuffer) -> int | None: return self.send(encode_to_bytes(data)) @@ -255,6 +215,7 @@ def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketCore: self._bufsize = bufsize return self.io + # TODO def get_entry(self, data): return mocket.state.state.get_entry(self._host, self._port, data) @@ -274,14 +235,6 @@ def sendall(self, data, entry=None, *args, **kwargs): self.io.truncate() self.io.seek(0) - def read(self, buffersize: int | None = None) -> bytes: - rv = self.io.read(buffersize) - if rv: - self._sent_non_empty_bytes = True - if self._did_handshake and not self._sent_non_empty_bytes: - raise ssl.SSLWantReadError("The operation did not complete (read)") - return rv - def recv_into( self, buffer: WriteableBuffer, @@ -309,6 +262,7 @@ def recv(self, buffersize: int, flags: int | None = None) -> bytes: exc.args = (0,) raise exc + # TODO def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int: if not MocketMode().is_allowed((self._host, self._port)): MocketMode.raise_not_allowed() @@ -360,23 +314,17 @@ def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int: host, port = self._host, self._port host = true_gethostbyname(host) - if isinstance(self.true_socket, true_socket) and self._secure_socket: - self.true_socket = true_urllib3_ssl_wrap_socket( - self.true_socket, - **self.kwargs, - ) - with contextlib.suppress(OSError, ValueError): # already connected - self.true_socket.connect((host, port)) - self.true_socket.sendall(data, *args, **kwargs) + self._true_socket.connect((host, port)) + self._true_socket.sendall(data, *args, **kwargs) encoded_response = b"" # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12 while True: - more_to_read = select.select([self.true_socket], [], [], 0.1)[0] + more_to_read = select.select([self._true_socket], [], [], 0.1)[0] if not more_to_read and encoded_response: break - new_content = self.true_socket.recv(self._buflen) + new_content = self._true_socket.recv(self._buflen) if not new_content: break encoded_response += new_content @@ -415,10 +363,9 @@ def send( return len(data) def close(self) -> None: - # TODO might be better to use self.true_socket.fileno() instead of internal api. - if self.true_socket and not self.true_socket._closed: - self.true_socket.close() - self._fd = None + # TODO might be better to use self._true_socket.fileno() instead of internal api. + if self._true_socket and not self._true_socket._closed: + self._true_socket.close() def __getattr__(self, name: str) -> Any: """Do nothing catchall function, for methods like shutdown()""" diff --git a/mocket/ssl/__init__.py b/mocket/ssl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mocket/ssl.py b/mocket/ssl/context.py similarity index 92% rename from mocket/ssl.py rename to mocket/ssl/context.py index 7fda12a2..bfda222a 100644 --- a/mocket/ssl.py +++ b/mocket/ssl/context.py @@ -5,6 +5,7 @@ from typing import Any from mocket.socket import MocketSocket +from mocket.ssl.socket import MocketSSLSocket from mocket.types import ReadableBuffer, StrOrBytesPath true_ssl_wrap_socket = getattr( @@ -84,11 +85,10 @@ def check_hostname(self, _: bool) -> None: self._check_hostname = False @staticmethod - def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSocket: - sock.kwargs = kwargs - sock._secure_socket = True - return sock + def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSSLSocket: + return MocketSSLSocket._create(sock=sock) + # TODO this should actually return a SSLObject, not a socket @staticmethod def wrap_bio( incoming: Any, # _ssl.MemoryBIO diff --git a/mocket/ssl/socket.py b/mocket/ssl/socket.py new file mode 100644 index 00000000..55b6488b --- /dev/null +++ b/mocket/ssl/socket.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import ssl +from datetime import datetime, timedelta +from typing import Any + +from devtools import debug +from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket + +import mocket.state +from mocket.socket import MocketSocket +from mocket.types import _PeerCertRetDictType + +try: + from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket +except ImportError: + urllib3_wrap_socket = None + +true_urllib3_wrap_socket = urllib3_wrap_socket +true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket + + +class MocketSSLSocket(MocketSocket): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + self._did_handshake = False + + def read(self, buffersize: int | None = None) -> bytes: + rv = self.io.read(buffersize) + if rv: + self._sent_non_empty_bytes = True + if self._did_handshake and not self._sent_non_empty_bytes: + raise ssl.SSLWantReadError("The operation did not complete (read)") + return rv + + def do_handshake(self) -> None: + self._did_handshake = True + + def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType: + if not (self._host and self._port): + self._address = self._host, self._port = mocket.state.state._address + + now = datetime.now() + shift = now + timedelta(days=30 * 12) + return { + "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), + "subjectAltName": ( + ("DNS", f"*.{self._host}"), + ("DNS", self._host), + ("DNS", "*"), + ), + "subject": ( + (("organizationName", f"*.{self._host}"),), + (("organizationalUnitName", "Domain Control Validated"),), + (("commonName", f"*.{self._host}"),), + ), + } + + def ciper(self) -> tuple[str, str, str]: + return ("ADH", "AES256", "SHA") + + def compression(self) -> str | None: + return ssl.OP_NO_COMPRESSION + + def unwrap(self) -> MocketSocket: + return self + + @classmethod + def _create(cls, sock: MocketSocket) -> MocketSSLSocket: + kwargs = dict( + family=sock.family, + type=sock.type, + proto=sock.proto, + ) + + ssl_socket = MocketSSLSocket(**kwargs) + + ssl_socket._kwargs = sock._kwargs + ssl_socket._true_socket = true_urllib3_ssl_wrap_socket( + sock._true_socket, + **sock._kwargs, + ) + + ssl_socket._host = sock._host + ssl_socket._port = sock._port + ssl_socket._address = sock._address + + ssl_socket._timeout = sock._timeout + + # ssl_socket._entry = sock._entry + + return ssl_socket