Skip to content

Commit

Permalink
refactor: split ssl-functionality of MocketSocket into MocketSSLSocket
Browse files Browse the repository at this point in the history
  • Loading branch information
betaboon committed Nov 18, 2024
1 parent 90eb5db commit d6c1701
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 58 deletions.
52 changes: 0 additions & 52 deletions mocket/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,6 @@ def _hash_request(h, req):


class MocketSocket:
cipher = lambda s: ("ADH", "AES256", "SHA")
compression = lambda s: ssl.OP_NO_COMPRESSION

def __init__(
self,
family: socket.AddressFamily | int = socket.AF_INET,
Expand All @@ -117,10 +114,6 @@ def __init__(
self._buflen = 65536
self._timeout: float | None = None

self._secure_socket = False
self._did_handshake = False
self._sent_non_empty_bytes = False

self._host = None
self._port = None
self._address = None
Expand Down Expand Up @@ -187,9 +180,6 @@ def settimeout(self, timeout: float | None) -> None:
def getsockopt(level: int, optname: int, buflen: int | None = None) -> int:
return socket.SOCK_STREAM

def do_handshake(self) -> None:
self._did_handshake = True

def getpeername(self) -> _RetAddress:
return self._address

Expand All @@ -202,32 +192,6 @@ def getblocking(self) -> bool:
def getsockname(self) -> _RetAddress:
return true_gethostbyname(self._address[0]), self._address[1]

def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType:
if not (self._host and self._port):
self._address = self._host, self._port = Mocket._address

now = datetime.now()
shift = now + timedelta(days=30 * 12)
return {
"notAfter": shift.strftime("%b %d %H:%M:%S GMT"),
"subjectAltName": (
("DNS", f"*.{self._host}"),
("DNS", self._host),
("DNS", "*"),
),
"subject": (
(("organizationName", f"*.{self._host}"),),
(("organizationalUnitName", "Domain Control Validated"),),
(("commonName", f"*.{self._host}"),),
),
}

def unwrap(self) -> MocketSocket:
return self

def write(self, data: bytes) -> int | None:
return self.send(encode_to_bytes(data))

def connect(self, address: Address) -> None:
self._address = self._host, self._port = address
Mocket._address = address
Expand All @@ -254,14 +218,6 @@ def sendall(self, data, entry=None, *args, **kwargs):
self.io.truncate()
self.io.seek(0)

def read(self, buffersize: int | None = None) -> bytes:
rv = self.io.read(buffersize)
if rv:
self._sent_non_empty_bytes = True
if self._did_handshake and not self._sent_non_empty_bytes:
raise ssl.SSLWantReadError("The operation did not complete (read)")
return rv

def recv_into(
self,
buffer: WriteableBuffer,
Expand Down Expand Up @@ -344,14 +300,6 @@ def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int:
host, port = self._host, self._port
host = true_gethostbyname(host)

if isinstance(self._true_socket, true_socket) and self._secure_socket:
from mocket.ssl.context import true_urllib3_ssl_wrap_socket

self._true_socket = true_urllib3_ssl_wrap_socket(
self._true_socket,
**self._kwargs,
)

with contextlib.suppress(OSError, ValueError):
# already connected
self._true_socket.connect((host, port))
Expand Down
29 changes: 23 additions & 6 deletions mocket/ssl/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import urllib3.util.ssl_

from mocket.socket import MocketSocket
from mocket.ssl.socket import MocketSSLSocket

true_ssl_context = ssl.SSLContext

Expand Down Expand Up @@ -70,18 +71,34 @@ def dummy_method(*args: Any, **kwargs: Any) -> Any:
setattr(self, m, dummy_method)

@staticmethod
def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSocket:
sock._kwargs = kwargs
sock._secure_socket = True
return sock
def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSSLSocket:
ssl_socket = MocketSSLSocket()
ssl_socket._original_socket = sock

ssl_socket._true_socket = true_urllib3_ssl_wrap_socket(
sock._true_socket,
**kwargs,
)
ssl_socket._kwargs = kwargs

ssl_socket._timeout = sock._timeout

ssl_socket._host = sock._host
ssl_socket._port = sock._port
ssl_socket._address = sock._address

ssl_socket._io = sock._io
ssl_socket._entry = sock._entry

return ssl_socket

@staticmethod
def wrap_bio(
incoming: Any, # _ssl.MemoryBIO
outgoing: Any, # _ssl.MemoryBIO
server_side: bool = False,
server_hostname: str | bytes | None = None,
) -> MocketSocket:
ssl_obj = MocketSocket()
) -> MocketSSLSocket:
ssl_obj = MocketSSLSocket()
ssl_obj._host = server_hostname
return ssl_obj
61 changes: 61 additions & 0 deletions mocket/ssl/socket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations

import ssl
from datetime import datetime, timedelta
from typing import Any

from mocket.mocket import Mocket
from mocket.socket import MocketSocket
from mocket.types import _PeerCertRetDictType


class MocketSSLSocket(MocketSocket):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

self._original_socket: MocketSocket = self
self._did_handshake = False
self._sent_non_empty_bytes = False

def read(self, buffersize: int | None = None) -> bytes:
rv = self.io.read(buffersize)
if rv:
self._sent_non_empty_bytes = True
if self._did_handshake and not self._sent_non_empty_bytes:
raise ssl.SSLWantReadError("The operation did not complete (read)")
return rv

def write(self, data: bytes) -> int | None:
return self.send(data)

def do_handshake(self) -> None:
self._did_handshake = True

def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType:
if not (self._host and self._port):
self._address = self._host, self._port = Mocket._address

now = datetime.now()
shift = now + timedelta(days=30 * 12)
return {
"notAfter": shift.strftime("%b %d %H:%M:%S GMT"),
"subjectAltName": (
("DNS", f"*.{self._host}"),
("DNS", self._host),
("DNS", "*"),
),
"subject": (
(("organizationName", f"*.{self._host}"),),
(("organizationalUnitName", "Domain Control Validated"),),
(("commonName", f"*.{self._host}"),),
),
}

def ciper(self) -> tuple[str, str, str]:
return ("ADH", "AES256", "SHA")

def compression(self) -> str | None:
return ssl.OP_NO_COMPRESSION

def unwrap(self) -> MocketSocket:
return self._original_socket

0 comments on commit d6c1701

Please sign in to comment.