Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor split socket and ssl socket #265

Merged
2 changes: 1 addition & 1 deletion mocket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from mocket.entry import MocketEntry
from mocket.mocket import Mocket
from mocket.mocketizer import Mocketizer, mocketize
from mocket.ssl import FakeSSLContext
from mocket.ssl.context import FakeSSLContext

__all__ = (
"async_mocketize",
Expand Down
73 changes: 35 additions & 38 deletions mocket/inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,6 @@
import ssl

import urllib3
from urllib3.connection import match_hostname as urllib3_match_hostname
from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket

try:
from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
except ImportError:
urllib3_wrap_socket = None


try: # pragma: no cover
from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3
Expand All @@ -21,30 +13,23 @@
except ImportError:
pyopenssl_override = False

true_socket = socket.socket
true_create_connection = socket.create_connection
true_gethostbyname = socket.gethostbyname
true_gethostname = socket.gethostname
true_getaddrinfo = socket.getaddrinfo
true_socketpair = socket.socketpair
true_ssl_wrap_socket = getattr(
ssl, "wrap_socket", None
) # from Py3.12 it's only under SSLContext
true_ssl_socket = ssl.SSLSocket
true_ssl_context = ssl.SSLContext
true_inet_pton = socket.inet_pton
true_urllib3_wrap_socket = urllib3_wrap_socket
true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket
true_urllib3_match_hostname = urllib3_match_hostname


def enable(
namespace: str | None = None,
truesocket_recording_dir: str | None = None,
) -> None:
from mocket.mocket import Mocket
from mocket.socket import MocketSocket, create_connection, socketpair
from mocket.ssl import FakeSSLContext
from mocket.socket import (
MocketSocket,
mock_create_connection,
mock_getaddrinfo,
mock_gethostbyname,
mock_gethostname,
mock_inet_pton,
mock_socketpair,
mock_urllib3_match_hostname,
)
from mocket.ssl.context import FakeSSLContext

Mocket._namespace = namespace
Mocket._truesocket_recording_dir = truesocket_recording_dir
Expand All @@ -56,20 +41,16 @@ def enable(
socket.socket = socket.__dict__["socket"] = MocketSocket
socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket
socket.SocketType = socket.__dict__["SocketType"] = MocketSocket
socket.create_connection = socket.__dict__["create_connection"] = create_connection
socket.gethostname = socket.__dict__["gethostname"] = lambda: "localhost"
socket.gethostbyname = socket.__dict__["gethostbyname"] = lambda host: "127.0.0.1"
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = (
lambda host, port, family=None, socktype=None, proto=None, flags=None: [
(2, 1, 6, "", (host, port))
]
socket.create_connection = socket.__dict__["create_connection"] = (
mock_create_connection
)
socket.socketpair = socket.__dict__["socketpair"] = socketpair
socket.gethostname = socket.__dict__["gethostname"] = mock_gethostname
socket.gethostbyname = socket.__dict__["gethostbyname"] = mock_gethostbyname
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = mock_getaddrinfo
socket.socketpair = socket.__dict__["socketpair"] = mock_socketpair
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket
ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext
socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: bytes(
"\x7f\x00\x00\x01", "utf-8"
)
socket.inet_pton = socket.__dict__["inet_pton"] = mock_inet_pton
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
FakeSSLContext.wrap_socket
)
Expand All @@ -84,14 +65,30 @@ def enable(
] = FakeSSLContext.wrap_socket
urllib3.connection.match_hostname = urllib3.connection.__dict__[
"match_hostname"
] = lambda *args: None
] = mock_urllib3_match_hostname
if pyopenssl_override: # pragma: no cover
# Take out the pyopenssl version - use the default implementation
extract_from_urllib3()


def disable() -> None:
from mocket.mocket import Mocket
from mocket.socket import (
true_create_connection,
true_getaddrinfo,
true_gethostbyname,
true_gethostname,
true_inet_pton,
true_socket,
true_socketpair,
true_urllib3_match_hostname,
)
from mocket.ssl.context import (
true_ssl_context,
true_ssl_wrap_socket,
true_urllib3_ssl_wrap_socket,
true_urllib3_wrap_socket,
)

socket.socket = socket.__dict__["socket"] = true_socket
socket._socketobject = socket.__dict__["_socketobject"] = true_socket
Expand Down
Loading