Skip to content

Commit 6b7143f

Browse files
committed
refactor: split ssl-functionality of MocketSocket into MocketSSLSocket
1 parent 3d8241e commit 6b7143f

File tree

3 files changed

+84
-58
lines changed

3 files changed

+84
-58
lines changed

mocket/socket.py

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,6 @@ def _hash_request(h, req):
9696

9797

9898
class MocketSocket:
99-
cipher = lambda s: ("ADH", "AES256", "SHA")
100-
compression = lambda s: ssl.OP_NO_COMPRESSION
101-
10299
def __init__(
103100
self,
104101
family: socket.AddressFamily | int = socket.AF_INET,
@@ -117,10 +114,6 @@ def __init__(
117114
self._buflen = 65536
118115
self._timeout: float | None = None
119116

120-
self._secure_socket = False
121-
self._did_handshake = False
122-
self._sent_non_empty_bytes = False
123-
124117
self._host = None
125118
self._port = None
126119
self._address = None
@@ -187,9 +180,6 @@ def settimeout(self, timeout: float | None) -> None:
187180
def getsockopt(level: int, optname: int, buflen: int | None = None) -> int:
188181
return socket.SOCK_STREAM
189182

190-
def do_handshake(self) -> None:
191-
self._did_handshake = True
192-
193183
def getpeername(self) -> _RetAddress:
194184
return self._address
195185

@@ -202,32 +192,6 @@ def getblocking(self) -> bool:
202192
def getsockname(self) -> _RetAddress:
203193
return true_gethostbyname(self._address[0]), self._address[1]
204194

205-
def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType:
206-
if not (self._host and self._port):
207-
self._address = self._host, self._port = Mocket._address
208-
209-
now = datetime.now()
210-
shift = now + timedelta(days=30 * 12)
211-
return {
212-
"notAfter": shift.strftime("%b %d %H:%M:%S GMT"),
213-
"subjectAltName": (
214-
("DNS", f"*.{self._host}"),
215-
("DNS", self._host),
216-
("DNS", "*"),
217-
),
218-
"subject": (
219-
(("organizationName", f"*.{self._host}"),),
220-
(("organizationalUnitName", "Domain Control Validated"),),
221-
(("commonName", f"*.{self._host}"),),
222-
),
223-
}
224-
225-
def unwrap(self) -> MocketSocket:
226-
return self
227-
228-
def write(self, data: bytes) -> int | None:
229-
return self.send(encode_to_bytes(data))
230-
231195
def connect(self, address: Address) -> None:
232196
self._address = self._host, self._port = address
233197
Mocket._address = address
@@ -254,14 +218,6 @@ def sendall(self, data, entry=None, *args, **kwargs):
254218
self.io.truncate()
255219
self.io.seek(0)
256220

257-
def read(self, buffersize: int | None = None) -> bytes:
258-
rv = self.io.read(buffersize)
259-
if rv:
260-
self._sent_non_empty_bytes = True
261-
if self._did_handshake and not self._sent_non_empty_bytes:
262-
raise ssl.SSLWantReadError("The operation did not complete (read)")
263-
return rv
264-
265221
def recv_into(
266222
self,
267223
buffer: WriteableBuffer,
@@ -344,14 +300,6 @@ def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int:
344300
host, port = self._host, self._port
345301
host = true_gethostbyname(host)
346302

347-
if isinstance(self._true_socket, true_socket) and self._secure_socket:
348-
from mocket.ssl.context import true_urllib3_ssl_wrap_socket
349-
350-
self._true_socket = true_urllib3_ssl_wrap_socket(
351-
self._true_socket,
352-
**self._kwargs,
353-
)
354-
355303
with contextlib.suppress(OSError, ValueError):
356304
# already connected
357305
self._true_socket.connect((host, port))

mocket/ssl/context.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import urllib3.util.ssl_
88

99
from mocket.socket import MocketSocket
10+
from mocket.ssl.socket import MocketSSLSocket
1011

1112
true_ssl_context = ssl.SSLContext
1213

@@ -70,18 +71,34 @@ def dummy_method(*args: Any, **kwargs: Any) -> Any:
7071
setattr(self, m, dummy_method)
7172

7273
@staticmethod
73-
def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSocket:
74-
sock._kwargs = kwargs
75-
sock._secure_socket = True
76-
return sock
74+
def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSSLSocket:
75+
ssl_socket = MocketSSLSocket()
76+
ssl_socket._original_socket = sock
77+
78+
ssl_socket._true_socket = true_urllib3_ssl_wrap_socket(
79+
sock._true_socket,
80+
**kwargs,
81+
)
82+
ssl_socket._kwargs = kwargs
83+
84+
ssl_socket._timeout = sock._timeout
85+
86+
ssl_socket._host = sock._host
87+
ssl_socket._port = sock._port
88+
ssl_socket._address = sock._address
89+
90+
ssl_socket._io = sock._io
91+
ssl_socket._entry = sock._entry
92+
93+
return ssl_socket
7794

7895
@staticmethod
7996
def wrap_bio(
8097
incoming: Any, # _ssl.MemoryBIO
8198
outgoing: Any, # _ssl.MemoryBIO
8299
server_side: bool = False,
83100
server_hostname: str | bytes | None = None,
84-
) -> MocketSocket:
85-
ssl_obj = MocketSocket()
101+
) -> MocketSSLSocket:
102+
ssl_obj = MocketSSLSocket()
86103
ssl_obj._host = server_hostname
87104
return ssl_obj

mocket/ssl/socket.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from __future__ import annotations
2+
3+
import ssl
4+
from datetime import datetime, timedelta
5+
from typing import Any
6+
7+
from mocket.mocket import Mocket
8+
from mocket.socket import MocketSocket
9+
from mocket.types import _PeerCertRetDictType
10+
11+
12+
class MocketSSLSocket(MocketSocket):
13+
def __init__(self, *args: Any, **kwargs: Any) -> None:
14+
super().__init__(*args, **kwargs)
15+
16+
self._original_socket: MocketSocket = self
17+
self._did_handshake = False
18+
self._sent_non_empty_bytes = False
19+
20+
def read(self, buffersize: int | None = None) -> bytes:
21+
rv = self.io.read(buffersize)
22+
if rv:
23+
self._sent_non_empty_bytes = True
24+
if self._did_handshake and not self._sent_non_empty_bytes:
25+
raise ssl.SSLWantReadError("The operation did not complete (read)")
26+
return rv
27+
28+
def write(self, data: bytes) -> int | None:
29+
return self.send(data)
30+
31+
def do_handshake(self) -> None:
32+
self._did_handshake = True
33+
34+
def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType:
35+
if not (self._host and self._port):
36+
self._address = self._host, self._port = Mocket._address
37+
38+
now = datetime.now()
39+
shift = now + timedelta(days=30 * 12)
40+
return {
41+
"notAfter": shift.strftime("%b %d %H:%M:%S GMT"),
42+
"subjectAltName": (
43+
("DNS", f"*.{self._host}"),
44+
("DNS", self._host),
45+
("DNS", "*"),
46+
),
47+
"subject": (
48+
(("organizationName", f"*.{self._host}"),),
49+
(("organizationalUnitName", "Domain Control Validated"),),
50+
(("commonName", f"*.{self._host}"),),
51+
),
52+
}
53+
54+
def ciper(self) -> tuple[str, str, str]:
55+
return ("ADH", "AES256", "SHA")
56+
57+
def compression(self) -> str | None:
58+
return ssl.OP_NO_COMPRESSION
59+
60+
def unwrap(self) -> MocketSocket:
61+
return self._original_socket

0 commit comments

Comments
 (0)