-
Notifications
You must be signed in to change notification settings - Fork 42
/
Copy pathinject.py
127 lines (115 loc) · 4.71 KB
/
inject.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from __future__ import annotations
import os
import socket
import ssl
import urllib3
try: # pragma: no cover
from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3
pyopenssl_override = True
except ImportError:
pyopenssl_override = False
def enable(
namespace: str | None = None,
truesocket_recording_dir: str | None = None,
use_hex_encoding=True,
) -> None:
from mocket.mocket import Mocket
from mocket.socket import (
MocketSocket,
mock_create_connection,
mock_getaddrinfo,
mock_gethostbyname,
mock_gethostname,
mock_inet_pton,
mock_socketpair,
mock_urllib3_match_hostname,
)
from mocket.ssl.context import MocketSSLContext
Mocket._namespace = namespace
Mocket._truesocket_recording_dir = truesocket_recording_dir
Mocket._use_hex_encoding = use_hex_encoding
if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
# JSON dumps will be saved here
raise AssertionError
socket.socket = socket.__dict__["socket"] = MocketSocket
socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket
socket.SocketType = socket.__dict__["SocketType"] = MocketSocket
socket.create_connection = socket.__dict__["create_connection"] = (
mock_create_connection
)
socket.gethostname = socket.__dict__["gethostname"] = mock_gethostname
socket.gethostbyname = socket.__dict__["gethostbyname"] = mock_gethostbyname
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = mock_getaddrinfo
socket.socketpair = socket.__dict__["socketpair"] = mock_socketpair
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = MocketSSLContext.wrap_socket
ssl.SSLContext = ssl.__dict__["SSLContext"] = MocketSSLContext
socket.inet_pton = socket.__dict__["inet_pton"] = mock_inet_pton
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
MocketSSLContext.wrap_socket
)
urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
"ssl_wrap_socket"
] = MocketSSLContext.wrap_socket
urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
MocketSSLContext.wrap_socket
)
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
"ssl_wrap_socket"
] = MocketSSLContext.wrap_socket
urllib3.connection.match_hostname = urllib3.connection.__dict__[
"match_hostname"
] = mock_urllib3_match_hostname
if pyopenssl_override: # pragma: no cover
# Take out the pyopenssl version - use the default implementation
extract_from_urllib3()
def disable() -> None:
from mocket.mocket import Mocket
from mocket.socket import (
true_create_connection,
true_getaddrinfo,
true_gethostbyname,
true_gethostname,
true_inet_pton,
true_socket,
true_socketpair,
true_urllib3_match_hostname,
)
from mocket.ssl.context import (
true_ssl_context,
true_ssl_wrap_socket,
true_urllib3_ssl_wrap_socket,
true_urllib3_wrap_socket,
)
socket.socket = socket.__dict__["socket"] = true_socket
socket._socketobject = socket.__dict__["_socketobject"] = true_socket
socket.SocketType = socket.__dict__["SocketType"] = true_socket
socket.create_connection = socket.__dict__["create_connection"] = (
true_create_connection
)
socket.gethostname = socket.__dict__["gethostname"] = true_gethostname
socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo
socket.socketpair = socket.__dict__["socketpair"] = true_socketpair
if true_ssl_wrap_socket:
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket
ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context
socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
true_urllib3_wrap_socket
)
urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
"ssl_wrap_socket"
] = true_urllib3_ssl_wrap_socket
urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
true_urllib3_ssl_wrap_socket
)
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
"ssl_wrap_socket"
] = true_urllib3_ssl_wrap_socket
urllib3.connection.match_hostname = urllib3.connection.__dict__[
"match_hostname"
] = true_urllib3_match_hostname
Mocket.reset()
if pyopenssl_override: # pragma: no cover
# Put the pyopenssl version back in place
inject_into_urllib3()