Skip to content

Commit

Permalink
wip cleanly split socket and SSLSocket
Browse files Browse the repository at this point in the history
  • Loading branch information
betaboon committed Nov 17, 2024
1 parent 1206e2e commit 0eb6474
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 87 deletions.
2 changes: 1 addition & 1 deletion mocket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from mocket.async_mocket import async_mocketize
from mocket.entry import MocketEntry
from mocket.mocketizer import Mocketizer, mocketize
from mocket.ssl import MocketSSLContext
from mocket.ssl.context import MocketSSLContext
from mocket.state import MocketState


Expand Down
8 changes: 5 additions & 3 deletions mocket/inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@
true_socket,
true_socketpair,
true_urllib3_match_hostname,
true_urllib3_ssl_wrap_socket,
true_urllib3_wrap_socket,
)
from mocket.ssl import (
from mocket.ssl.context import (
MocketSSLContext,
true_ssl_context,
true_ssl_wrap_socket,
)
from mocket.ssl.socket import (
true_urllib3_ssl_wrap_socket,
true_urllib3_wrap_socket,
)

try: # pragma: no cover
from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3
Expand Down
105 changes: 26 additions & 79 deletions mocket/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,12 @@
import os
import select
import socket
import ssl
from datetime import datetime, timedelta
from json.decoder import JSONDecodeError
from types import TracebackType
from typing import Any, Type

from typing_extensions import Buffer, Self
from urllib3.connection import match_hostname as urllib3_match_hostname
from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket

import mocket.state
from mocket.compat import decode_from_bytes, encode_to_bytes
Expand All @@ -27,16 +24,10 @@
ReadableBuffer,
WriteableBuffer,
_Address,
_PeerCertRetDictType,
_RetAddress,
)
from mocket.utils import hexdump, hexload

try:
from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
except ImportError:
urllib3_wrap_socket = None

xxh32 = None
try:
from xxhash import xxh32
Expand All @@ -55,8 +46,6 @@
true_socketpair = socket.socketpair

true_urllib3_match_hostname = urllib3_match_hostname
true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket
true_urllib3_wrap_socket = urllib3_wrap_socket


def create_connection(
Expand Down Expand Up @@ -108,6 +97,7 @@ def _hash_request(h, req):
return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest()


# TODO rename to MocketSocketIO
class MocketSocketCore(io.BytesIO):
def __init__(self, address: Address) -> None:
self._address = address
Expand All @@ -124,20 +114,8 @@ def write(self, content: Buffer) -> int:


class MocketSocket:
timeout = None
_fd = None
family = None
type = None
proto = None
_host = None
_port = None
_address = None
cipher = lambda s: ("ADH", "AES256", "SHA")
compression = lambda s: ssl.OP_NO_COMPRESSION
_mode = None
_bufsize = None
_secure_socket = False
_did_handshake = False
_sent_non_empty_bytes = False
_io = None

Expand All @@ -149,14 +127,22 @@ def __init__(
fileno: int | None = None,
**kwargs: Any,
):
self.true_socket = true_socket(family, type, proto)
self._buflen = 65536
self._entry = None
self.family = int(family)
self.type = int(type)
self.proto = int(proto)

self._kwargs = kwargs
self._true_socket = true_socket(family, type, proto)
self._truesocket_recording_dir = None
self.kwargs = kwargs

self._timeout: float | None = None
self._buflen = 65536
self._entry = None

# TODO remove host and port with address everywhere
self._host = None
self._port = None
self._address = None

def __str__(self) -> str:
return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})"
Expand Down Expand Up @@ -187,27 +173,24 @@ def fileno(self) -> int:
return r_fd

def gettimeout(self) -> float | None:
return self.timeout
return self._timeout

# FIXME the arguments here seem wrong. they should be `level: int, optname: int, value: int | ReadableBuffer | None`
def setsockopt(self, family: int, type: int, proto: int) -> None:
self.family = family
self.type = type
self.proto = proto

if self.true_socket:
self.true_socket.setsockopt(family, type, proto)
if self._true_socket:
self._true_socket.setsockopt(family, type, proto)

def settimeout(self, timeout: float | None) -> None:
self.timeout = timeout
self._timeout = timeout

@staticmethod
def getsockopt(level: int, optname: int, buflen: int | None = None) -> int:
return socket.SOCK_STREAM

def do_handshake(self) -> None:
self._did_handshake = True

def getpeername(self) -> _RetAddress:
return self._address

Expand All @@ -220,29 +203,6 @@ def getblocking(self) -> bool:
def getsockname(self) -> _RetAddress:
return socket.gethostbyname(self._address[0]), self._address[1]

def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType:
if not (self._host and self._port):
self._address = self._host, self._port = mocket.state.state._address

now = datetime.now()
shift = now + timedelta(days=30 * 12)
return {
"notAfter": shift.strftime("%b %d %H:%M:%S GMT"),
"subjectAltName": (
("DNS", f"*.{self._host}"),
("DNS", self._host),
("DNS", "*"),
),
"subject": (
(("organizationName", f"*.{self._host}"),),
(("organizationalUnitName", "Domain Control Validated"),),
(("commonName", f"*.{self._host}"),),
),
}

def unwrap(self) -> Self:
return self

def write(self, data: ReadableBuffer) -> int | None:
return self.send(encode_to_bytes(data))

Expand All @@ -255,6 +215,7 @@ def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketCore:
self._bufsize = bufsize
return self.io

# TODO
def get_entry(self, data):
return mocket.state.state.get_entry(self._host, self._port, data)

Expand All @@ -274,14 +235,6 @@ def sendall(self, data, entry=None, *args, **kwargs):
self.io.truncate()
self.io.seek(0)

def read(self, buffersize: int | None = None) -> bytes:
rv = self.io.read(buffersize)
if rv:
self._sent_non_empty_bytes = True
if self._did_handshake and not self._sent_non_empty_bytes:
raise ssl.SSLWantReadError("The operation did not complete (read)")
return rv

def recv_into(
self,
buffer: WriteableBuffer,
Expand Down Expand Up @@ -309,6 +262,7 @@ def recv(self, buffersize: int, flags: int | None = None) -> bytes:
exc.args = (0,)
raise exc

# TODO
def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int:
if not MocketMode().is_allowed((self._host, self._port)):
MocketMode.raise_not_allowed()
Expand Down Expand Up @@ -360,23 +314,17 @@ def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int:
host, port = self._host, self._port
host = true_gethostbyname(host)

if isinstance(self.true_socket, true_socket) and self._secure_socket:
self.true_socket = true_urllib3_ssl_wrap_socket(
self.true_socket,
**self.kwargs,
)

with contextlib.suppress(OSError, ValueError):
# already connected
self.true_socket.connect((host, port))
self.true_socket.sendall(data, *args, **kwargs)
self._true_socket.connect((host, port))
self._true_socket.sendall(data, *args, **kwargs)
encoded_response = b""
# https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12
while True:
more_to_read = select.select([self.true_socket], [], [], 0.1)[0]
more_to_read = select.select([self._true_socket], [], [], 0.1)[0]
if not more_to_read and encoded_response:
break
new_content = self.true_socket.recv(self._buflen)
new_content = self._true_socket.recv(self._buflen)
if not new_content:
break
encoded_response += new_content
Expand Down Expand Up @@ -415,10 +363,9 @@ def send(
return len(data)

def close(self) -> None:
# TODO might be better to use self.true_socket.fileno() instead of internal api.
if self.true_socket and not self.true_socket._closed:
self.true_socket.close()
self._fd = None
# TODO might be better to use self._true_socket.fileno() instead of internal api.
if self._true_socket and not self._true_socket._closed:
self._true_socket.close()

def __getattr__(self, name: str) -> Any:
"""Do nothing catchall function, for methods like shutdown()"""
Expand Down
Empty file added mocket/ssl/__init__.py
Empty file.
8 changes: 4 additions & 4 deletions mocket/ssl.py → mocket/ssl/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any

from mocket.socket import MocketSocket
from mocket.ssl.socket import MocketSSLSocket
from mocket.types import ReadableBuffer, StrOrBytesPath

true_ssl_wrap_socket = getattr(
Expand Down Expand Up @@ -84,11 +85,10 @@ def check_hostname(self, _: bool) -> None:
self._check_hostname = False

@staticmethod
def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSocket:
sock.kwargs = kwargs
sock._secure_socket = True
return sock
def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSSLSocket:
return MocketSSLSocket._create(sock=sock)

# TODO this should actually return a SSLObject, not a socket
@staticmethod
def wrap_bio(
incoming: Any, # _ssl.MemoryBIO
Expand Down
93 changes: 93 additions & 0 deletions mocket/ssl/socket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from __future__ import annotations

import ssl
from datetime import datetime, timedelta
from typing import Any

from devtools import debug
from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket

import mocket.state
from mocket.socket import MocketSocket
from mocket.types import _PeerCertRetDictType

try:
from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
except ImportError:
urllib3_wrap_socket = None

true_urllib3_wrap_socket = urllib3_wrap_socket
true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket


class MocketSSLSocket(MocketSocket):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

self._did_handshake = False

def read(self, buffersize: int | None = None) -> bytes:
rv = self.io.read(buffersize)
if rv:
self._sent_non_empty_bytes = True
if self._did_handshake and not self._sent_non_empty_bytes:
raise ssl.SSLWantReadError("The operation did not complete (read)")
return rv

def do_handshake(self) -> None:
self._did_handshake = True

def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType:
if not (self._host and self._port):
self._address = self._host, self._port = mocket.state.state._address

now = datetime.now()
shift = now + timedelta(days=30 * 12)
return {
"notAfter": shift.strftime("%b %d %H:%M:%S GMT"),
"subjectAltName": (
("DNS", f"*.{self._host}"),
("DNS", self._host),
("DNS", "*"),
),
"subject": (
(("organizationName", f"*.{self._host}"),),
(("organizationalUnitName", "Domain Control Validated"),),
(("commonName", f"*.{self._host}"),),
),
}

def ciper(self) -> tuple[str, str, str]:
return ("ADH", "AES256", "SHA")

def compression(self) -> str | None:
return ssl.OP_NO_COMPRESSION

def unwrap(self) -> MocketSocket:
return self

@classmethod
def _create(cls, sock: MocketSocket) -> MocketSSLSocket:
kwargs = dict(
family=sock.family,
type=sock.type,
proto=sock.proto,
)

ssl_socket = MocketSSLSocket(**kwargs)

ssl_socket._kwargs = sock._kwargs
ssl_socket._true_socket = true_urllib3_ssl_wrap_socket(
sock._true_socket,
**sock._kwargs,
)

ssl_socket._host = sock._host
ssl_socket._port = sock._port
ssl_socket._address = sock._address

ssl_socket._timeout = sock._timeout

# ssl_socket._entry = sock._entry

return ssl_socket

0 comments on commit 0eb6474

Please sign in to comment.