Skip to content

Commit

Permalink
Merge pull request #265 from betaboon/refactor-split-socket-and-ssl-s…
Browse files Browse the repository at this point in the history
…ocket

Refactor split socket and ssl socket
  • Loading branch information
mindflayer authored Nov 20, 2024
2 parents 89055e8 + 636951f commit 8e7b3b6
Show file tree
Hide file tree
Showing 9 changed files with 371 additions and 220 deletions.
2 changes: 1 addition & 1 deletion mocket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from mocket.entry import MocketEntry
from mocket.mocket import Mocket
from mocket.mocketizer import Mocketizer, mocketize
from mocket.ssl import FakeSSLContext
from mocket.ssl.context import FakeSSLContext

__all__ = (
"async_mocketize",
Expand Down
73 changes: 35 additions & 38 deletions mocket/inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,6 @@
import ssl

import urllib3
from urllib3.connection import match_hostname as urllib3_match_hostname
from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket

try:
from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
except ImportError:
urllib3_wrap_socket = None


try: # pragma: no cover
from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3
Expand All @@ -21,30 +13,23 @@
except ImportError:
pyopenssl_override = False

true_socket = socket.socket
true_create_connection = socket.create_connection
true_gethostbyname = socket.gethostbyname
true_gethostname = socket.gethostname
true_getaddrinfo = socket.getaddrinfo
true_socketpair = socket.socketpair
true_ssl_wrap_socket = getattr(
ssl, "wrap_socket", None
) # from Py3.12 it's only under SSLContext
true_ssl_socket = ssl.SSLSocket
true_ssl_context = ssl.SSLContext
true_inet_pton = socket.inet_pton
true_urllib3_wrap_socket = urllib3_wrap_socket
true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket
true_urllib3_match_hostname = urllib3_match_hostname


def enable(
namespace: str | None = None,
truesocket_recording_dir: str | None = None,
) -> None:
from mocket.mocket import Mocket
from mocket.socket import MocketSocket, create_connection, socketpair
from mocket.ssl import FakeSSLContext
from mocket.socket import (
MocketSocket,
mock_create_connection,
mock_getaddrinfo,
mock_gethostbyname,
mock_gethostname,
mock_inet_pton,
mock_socketpair,
mock_urllib3_match_hostname,
)
from mocket.ssl.context import FakeSSLContext

Mocket._namespace = namespace
Mocket._truesocket_recording_dir = truesocket_recording_dir
Expand All @@ -56,20 +41,16 @@ def enable(
socket.socket = socket.__dict__["socket"] = MocketSocket
socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket
socket.SocketType = socket.__dict__["SocketType"] = MocketSocket
socket.create_connection = socket.__dict__["create_connection"] = create_connection
socket.gethostname = socket.__dict__["gethostname"] = lambda: "localhost"
socket.gethostbyname = socket.__dict__["gethostbyname"] = lambda host: "127.0.0.1"
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = (
lambda host, port, family=None, socktype=None, proto=None, flags=None: [
(2, 1, 6, "", (host, port))
]
socket.create_connection = socket.__dict__["create_connection"] = (
mock_create_connection
)
socket.socketpair = socket.__dict__["socketpair"] = socketpair
socket.gethostname = socket.__dict__["gethostname"] = mock_gethostname
socket.gethostbyname = socket.__dict__["gethostbyname"] = mock_gethostbyname
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = mock_getaddrinfo
socket.socketpair = socket.__dict__["socketpair"] = mock_socketpair
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket
ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext
socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: bytes(
"\x7f\x00\x00\x01", "utf-8"
)
socket.inet_pton = socket.__dict__["inet_pton"] = mock_inet_pton
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
FakeSSLContext.wrap_socket
)
Expand All @@ -84,14 +65,30 @@ def enable(
] = FakeSSLContext.wrap_socket
urllib3.connection.match_hostname = urllib3.connection.__dict__[
"match_hostname"
] = lambda *args: None
] = mock_urllib3_match_hostname
if pyopenssl_override: # pragma: no cover
# Take out the pyopenssl version - use the default implementation
extract_from_urllib3()


def disable() -> None:
from mocket.mocket import Mocket
from mocket.socket import (
true_create_connection,
true_getaddrinfo,
true_gethostbyname,
true_gethostname,
true_inet_pton,
true_socket,
true_socketpair,
true_urllib3_match_hostname,
)
from mocket.ssl.context import (
true_ssl_context,
true_ssl_wrap_socket,
true_urllib3_ssl_wrap_socket,
true_urllib3_wrap_socket,
)

socket.socket = socket.__dict__["socket"] = true_socket
socket._socketobject = socket.__dict__["_socketobject"] = true_socket
Expand Down
Loading

0 comments on commit 8e7b3b6

Please sign in to comment.