Skip to content

Commit 0eb6474

Browse files
committed
wip cleanly split socket and SSLSocket
1 parent 1206e2e commit 0eb6474

File tree

6 files changed

+129
-87
lines changed

6 files changed

+129
-87
lines changed

mocket/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from mocket.async_mocket import async_mocketize
55
from mocket.entry import MocketEntry
66
from mocket.mocketizer import Mocketizer, mocketize
7-
from mocket.ssl import MocketSSLContext
7+
from mocket.ssl.context import MocketSSLContext
88
from mocket.state import MocketState
99

1010

mocket/inject.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,16 @@
2525
true_socket,
2626
true_socketpair,
2727
true_urllib3_match_hostname,
28-
true_urllib3_ssl_wrap_socket,
29-
true_urllib3_wrap_socket,
3028
)
31-
from mocket.ssl import (
29+
from mocket.ssl.context import (
3230
MocketSSLContext,
3331
true_ssl_context,
3432
true_ssl_wrap_socket,
3533
)
34+
from mocket.ssl.socket import (
35+
true_urllib3_ssl_wrap_socket,
36+
true_urllib3_wrap_socket,
37+
)
3638

3739
try: # pragma: no cover
3840
from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3

mocket/socket.py

Lines changed: 26 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,12 @@
99
import os
1010
import select
1111
import socket
12-
import ssl
13-
from datetime import datetime, timedelta
1412
from json.decoder import JSONDecodeError
1513
from types import TracebackType
1614
from typing import Any, Type
1715

1816
from typing_extensions import Buffer, Self
1917
from urllib3.connection import match_hostname as urllib3_match_hostname
20-
from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket
2118

2219
import mocket.state
2320
from mocket.compat import decode_from_bytes, encode_to_bytes
@@ -27,16 +24,10 @@
2724
ReadableBuffer,
2825
WriteableBuffer,
2926
_Address,
30-
_PeerCertRetDictType,
3127
_RetAddress,
3228
)
3329
from mocket.utils import hexdump, hexload
3430

35-
try:
36-
from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
37-
except ImportError:
38-
urllib3_wrap_socket = None
39-
4031
xxh32 = None
4132
try:
4233
from xxhash import xxh32
@@ -55,8 +46,6 @@
5546
true_socketpair = socket.socketpair
5647

5748
true_urllib3_match_hostname = urllib3_match_hostname
58-
true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket
59-
true_urllib3_wrap_socket = urllib3_wrap_socket
6049

6150

6251
def create_connection(
@@ -108,6 +97,7 @@ def _hash_request(h, req):
10897
return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest()
10998

11099

100+
# TODO rename to MocketSocketIO
111101
class MocketSocketCore(io.BytesIO):
112102
def __init__(self, address: Address) -> None:
113103
self._address = address
@@ -124,20 +114,8 @@ def write(self, content: Buffer) -> int:
124114

125115

126116
class MocketSocket:
127-
timeout = None
128-
_fd = None
129-
family = None
130-
type = None
131-
proto = None
132-
_host = None
133-
_port = None
134-
_address = None
135-
cipher = lambda s: ("ADH", "AES256", "SHA")
136-
compression = lambda s: ssl.OP_NO_COMPRESSION
137117
_mode = None
138118
_bufsize = None
139-
_secure_socket = False
140-
_did_handshake = False
141119
_sent_non_empty_bytes = False
142120
_io = None
143121

@@ -149,14 +127,22 @@ def __init__(
149127
fileno: int | None = None,
150128
**kwargs: Any,
151129
):
152-
self.true_socket = true_socket(family, type, proto)
153-
self._buflen = 65536
154-
self._entry = None
155130
self.family = int(family)
156131
self.type = int(type)
157132
self.proto = int(proto)
133+
134+
self._kwargs = kwargs
135+
self._true_socket = true_socket(family, type, proto)
158136
self._truesocket_recording_dir = None
159-
self.kwargs = kwargs
137+
138+
self._timeout: float | None = None
139+
self._buflen = 65536
140+
self._entry = None
141+
142+
# TODO remove host and port with address everywhere
143+
self._host = None
144+
self._port = None
145+
self._address = None
160146

161147
def __str__(self) -> str:
162148
return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})"
@@ -187,27 +173,24 @@ def fileno(self) -> int:
187173
return r_fd
188174

189175
def gettimeout(self) -> float | None:
190-
return self.timeout
176+
return self._timeout
191177

192178
# FIXME the arguments here seem wrong. they should be `level: int, optname: int, value: int | ReadableBuffer | None`
193179
def setsockopt(self, family: int, type: int, proto: int) -> None:
194180
self.family = family
195181
self.type = type
196182
self.proto = proto
197183

198-
if self.true_socket:
199-
self.true_socket.setsockopt(family, type, proto)
184+
if self._true_socket:
185+
self._true_socket.setsockopt(family, type, proto)
200186

201187
def settimeout(self, timeout: float | None) -> None:
202-
self.timeout = timeout
188+
self._timeout = timeout
203189

204190
@staticmethod
205191
def getsockopt(level: int, optname: int, buflen: int | None = None) -> int:
206192
return socket.SOCK_STREAM
207193

208-
def do_handshake(self) -> None:
209-
self._did_handshake = True
210-
211194
def getpeername(self) -> _RetAddress:
212195
return self._address
213196

@@ -220,29 +203,6 @@ def getblocking(self) -> bool:
220203
def getsockname(self) -> _RetAddress:
221204
return socket.gethostbyname(self._address[0]), self._address[1]
222205

223-
def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType:
224-
if not (self._host and self._port):
225-
self._address = self._host, self._port = mocket.state.state._address
226-
227-
now = datetime.now()
228-
shift = now + timedelta(days=30 * 12)
229-
return {
230-
"notAfter": shift.strftime("%b %d %H:%M:%S GMT"),
231-
"subjectAltName": (
232-
("DNS", f"*.{self._host}"),
233-
("DNS", self._host),
234-
("DNS", "*"),
235-
),
236-
"subject": (
237-
(("organizationName", f"*.{self._host}"),),
238-
(("organizationalUnitName", "Domain Control Validated"),),
239-
(("commonName", f"*.{self._host}"),),
240-
),
241-
}
242-
243-
def unwrap(self) -> Self:
244-
return self
245-
246206
def write(self, data: ReadableBuffer) -> int | None:
247207
return self.send(encode_to_bytes(data))
248208

@@ -255,6 +215,7 @@ def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketCore:
255215
self._bufsize = bufsize
256216
return self.io
257217

218+
# TODO
258219
def get_entry(self, data):
259220
return mocket.state.state.get_entry(self._host, self._port, data)
260221

@@ -274,14 +235,6 @@ def sendall(self, data, entry=None, *args, **kwargs):
274235
self.io.truncate()
275236
self.io.seek(0)
276237

277-
def read(self, buffersize: int | None = None) -> bytes:
278-
rv = self.io.read(buffersize)
279-
if rv:
280-
self._sent_non_empty_bytes = True
281-
if self._did_handshake and not self._sent_non_empty_bytes:
282-
raise ssl.SSLWantReadError("The operation did not complete (read)")
283-
return rv
284-
285238
def recv_into(
286239
self,
287240
buffer: WriteableBuffer,
@@ -309,6 +262,7 @@ def recv(self, buffersize: int, flags: int | None = None) -> bytes:
309262
exc.args = (0,)
310263
raise exc
311264

265+
# TODO
312266
def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int:
313267
if not MocketMode().is_allowed((self._host, self._port)):
314268
MocketMode.raise_not_allowed()
@@ -360,23 +314,17 @@ def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int:
360314
host, port = self._host, self._port
361315
host = true_gethostbyname(host)
362316

363-
if isinstance(self.true_socket, true_socket) and self._secure_socket:
364-
self.true_socket = true_urllib3_ssl_wrap_socket(
365-
self.true_socket,
366-
**self.kwargs,
367-
)
368-
369317
with contextlib.suppress(OSError, ValueError):
370318
# already connected
371-
self.true_socket.connect((host, port))
372-
self.true_socket.sendall(data, *args, **kwargs)
319+
self._true_socket.connect((host, port))
320+
self._true_socket.sendall(data, *args, **kwargs)
373321
encoded_response = b""
374322
# https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12
375323
while True:
376-
more_to_read = select.select([self.true_socket], [], [], 0.1)[0]
324+
more_to_read = select.select([self._true_socket], [], [], 0.1)[0]
377325
if not more_to_read and encoded_response:
378326
break
379-
new_content = self.true_socket.recv(self._buflen)
327+
new_content = self._true_socket.recv(self._buflen)
380328
if not new_content:
381329
break
382330
encoded_response += new_content
@@ -415,10 +363,9 @@ def send(
415363
return len(data)
416364

417365
def close(self) -> None:
418-
# TODO might be better to use self.true_socket.fileno() instead of internal api.
419-
if self.true_socket and not self.true_socket._closed:
420-
self.true_socket.close()
421-
self._fd = None
366+
# TODO might be better to use self._true_socket.fileno() instead of internal api.
367+
if self._true_socket and not self._true_socket._closed:
368+
self._true_socket.close()
422369

423370
def __getattr__(self, name: str) -> Any:
424371
"""Do nothing catchall function, for methods like shutdown()"""

mocket/ssl/__init__.py

Whitespace-only changes.

mocket/ssl.py renamed to mocket/ssl/context.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any
66

77
from mocket.socket import MocketSocket
8+
from mocket.ssl.socket import MocketSSLSocket
89
from mocket.types import ReadableBuffer, StrOrBytesPath
910

1011
true_ssl_wrap_socket = getattr(
@@ -84,11 +85,10 @@ def check_hostname(self, _: bool) -> None:
8485
self._check_hostname = False
8586

8687
@staticmethod
87-
def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSocket:
88-
sock.kwargs = kwargs
89-
sock._secure_socket = True
90-
return sock
88+
def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSSLSocket:
89+
return MocketSSLSocket._create(sock=sock)
9190

91+
# TODO this should actually return a SSLObject, not a socket
9292
@staticmethod
9393
def wrap_bio(
9494
incoming: Any, # _ssl.MemoryBIO

mocket/ssl/socket.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from __future__ import annotations
2+
3+
import ssl
4+
from datetime import datetime, timedelta
5+
from typing import Any
6+
7+
from devtools import debug
8+
from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket
9+
10+
import mocket.state
11+
from mocket.socket import MocketSocket
12+
from mocket.types import _PeerCertRetDictType
13+
14+
try:
15+
from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
16+
except ImportError:
17+
urllib3_wrap_socket = None
18+
19+
true_urllib3_wrap_socket = urllib3_wrap_socket
20+
true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket
21+
22+
23+
class MocketSSLSocket(MocketSocket):
24+
def __init__(self, *args: Any, **kwargs: Any) -> None:
25+
super().__init__(*args, **kwargs)
26+
27+
self._did_handshake = False
28+
29+
def read(self, buffersize: int | None = None) -> bytes:
30+
rv = self.io.read(buffersize)
31+
if rv:
32+
self._sent_non_empty_bytes = True
33+
if self._did_handshake and not self._sent_non_empty_bytes:
34+
raise ssl.SSLWantReadError("The operation did not complete (read)")
35+
return rv
36+
37+
def do_handshake(self) -> None:
38+
self._did_handshake = True
39+
40+
def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType:
41+
if not (self._host and self._port):
42+
self._address = self._host, self._port = mocket.state.state._address
43+
44+
now = datetime.now()
45+
shift = now + timedelta(days=30 * 12)
46+
return {
47+
"notAfter": shift.strftime("%b %d %H:%M:%S GMT"),
48+
"subjectAltName": (
49+
("DNS", f"*.{self._host}"),
50+
("DNS", self._host),
51+
("DNS", "*"),
52+
),
53+
"subject": (
54+
(("organizationName", f"*.{self._host}"),),
55+
(("organizationalUnitName", "Domain Control Validated"),),
56+
(("commonName", f"*.{self._host}"),),
57+
),
58+
}
59+
60+
def ciper(self) -> tuple[str, str, str]:
61+
return ("ADH", "AES256", "SHA")
62+
63+
def compression(self) -> str | None:
64+
return ssl.OP_NO_COMPRESSION
65+
66+
def unwrap(self) -> MocketSocket:
67+
return self
68+
69+
@classmethod
70+
def _create(cls, sock: MocketSocket) -> MocketSSLSocket:
71+
kwargs = dict(
72+
family=sock.family,
73+
type=sock.type,
74+
proto=sock.proto,
75+
)
76+
77+
ssl_socket = MocketSSLSocket(**kwargs)
78+
79+
ssl_socket._kwargs = sock._kwargs
80+
ssl_socket._true_socket = true_urllib3_ssl_wrap_socket(
81+
sock._true_socket,
82+
**sock._kwargs,
83+
)
84+
85+
ssl_socket._host = sock._host
86+
ssl_socket._port = sock._port
87+
ssl_socket._address = sock._address
88+
89+
ssl_socket._timeout = sock._timeout
90+
91+
# ssl_socket._entry = sock._entry
92+
93+
return ssl_socket

0 commit comments

Comments
 (0)