9
9
import os
10
10
import select
11
11
import socket
12
- import ssl
13
- from datetime import datetime , timedelta
14
12
from json .decoder import JSONDecodeError
15
13
from types import TracebackType
16
14
from typing import Any , Type
17
15
18
16
from typing_extensions import Buffer , Self
19
17
from urllib3 .connection import match_hostname as urllib3_match_hostname
20
- from urllib3 .util .ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket
21
18
22
19
import mocket .state
23
20
from mocket .compat import decode_from_bytes , encode_to_bytes
27
24
ReadableBuffer ,
28
25
WriteableBuffer ,
29
26
_Address ,
30
- _PeerCertRetDictType ,
31
27
_RetAddress ,
32
28
)
33
29
from mocket .utils import hexdump , hexload
34
30
35
- try :
36
- from urllib3 .util .ssl_ import wrap_socket as urllib3_wrap_socket
37
- except ImportError :
38
- urllib3_wrap_socket = None
39
-
40
31
xxh32 = None
41
32
try :
42
33
from xxhash import xxh32
55
46
true_socketpair = socket .socketpair
56
47
57
48
true_urllib3_match_hostname = urllib3_match_hostname
58
- true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket
59
- true_urllib3_wrap_socket = urllib3_wrap_socket
60
49
61
50
62
51
def create_connection (
@@ -108,6 +97,7 @@ def _hash_request(h, req):
108
97
return h (encode_to_bytes ("" .join (sorted (req .split ("\r \n " ))))).hexdigest ()
109
98
110
99
100
+ # TODO rename to MocketSocketIO
111
101
class MocketSocketCore (io .BytesIO ):
112
102
def __init__ (self , address : Address ) -> None :
113
103
self ._address = address
@@ -124,20 +114,8 @@ def write(self, content: Buffer) -> int:
124
114
125
115
126
116
class MocketSocket :
127
- timeout = None
128
- _fd = None
129
- family = None
130
- type = None
131
- proto = None
132
- _host = None
133
- _port = None
134
- _address = None
135
- cipher = lambda s : ("ADH" , "AES256" , "SHA" )
136
- compression = lambda s : ssl .OP_NO_COMPRESSION
137
117
_mode = None
138
118
_bufsize = None
139
- _secure_socket = False
140
- _did_handshake = False
141
119
_sent_non_empty_bytes = False
142
120
_io = None
143
121
@@ -149,14 +127,22 @@ def __init__(
149
127
fileno : int | None = None ,
150
128
** kwargs : Any ,
151
129
):
152
- self .true_socket = true_socket (family , type , proto )
153
- self ._buflen = 65536
154
- self ._entry = None
155
130
self .family = int (family )
156
131
self .type = int (type )
157
132
self .proto = int (proto )
133
+
134
+ self ._kwargs = kwargs
135
+ self ._true_socket = true_socket (family , type , proto )
158
136
self ._truesocket_recording_dir = None
159
- self .kwargs = kwargs
137
+
138
+ self ._timeout : float | None = None
139
+ self ._buflen = 65536
140
+ self ._entry = None
141
+
142
+ # TODO remove host and port with address everywhere
143
+ self ._host = None
144
+ self ._port = None
145
+ self ._address = None
160
146
161
147
def __str__ (self ) -> str :
162
148
return f"({ self .__class__ .__name__ } )(family={ self .family } type={ self .type } protocol={ self .proto } )"
@@ -187,27 +173,24 @@ def fileno(self) -> int:
187
173
return r_fd
188
174
189
175
def gettimeout (self ) -> float | None :
190
- return self .timeout
176
+ return self ._timeout
191
177
192
178
# FIXME the arguments here seem wrong. they should be `level: int, optname: int, value: int | ReadableBuffer | None`
193
179
def setsockopt (self , family : int , type : int , proto : int ) -> None :
194
180
self .family = family
195
181
self .type = type
196
182
self .proto = proto
197
183
198
- if self .true_socket :
199
- self .true_socket .setsockopt (family , type , proto )
184
+ if self ._true_socket :
185
+ self ._true_socket .setsockopt (family , type , proto )
200
186
201
187
def settimeout (self , timeout : float | None ) -> None :
202
- self .timeout = timeout
188
+ self ._timeout = timeout
203
189
204
190
@staticmethod
205
191
def getsockopt (level : int , optname : int , buflen : int | None = None ) -> int :
206
192
return socket .SOCK_STREAM
207
193
208
- def do_handshake (self ) -> None :
209
- self ._did_handshake = True
210
-
211
194
def getpeername (self ) -> _RetAddress :
212
195
return self ._address
213
196
@@ -220,29 +203,6 @@ def getblocking(self) -> bool:
220
203
def getsockname (self ) -> _RetAddress :
221
204
return socket .gethostbyname (self ._address [0 ]), self ._address [1 ]
222
205
223
- def getpeercert (self , binary_form : bool = False ) -> _PeerCertRetDictType :
224
- if not (self ._host and self ._port ):
225
- self ._address = self ._host , self ._port = mocket .state .state ._address
226
-
227
- now = datetime .now ()
228
- shift = now + timedelta (days = 30 * 12 )
229
- return {
230
- "notAfter" : shift .strftime ("%b %d %H:%M:%S GMT" ),
231
- "subjectAltName" : (
232
- ("DNS" , f"*.{ self ._host } " ),
233
- ("DNS" , self ._host ),
234
- ("DNS" , "*" ),
235
- ),
236
- "subject" : (
237
- (("organizationName" , f"*.{ self ._host } " ),),
238
- (("organizationalUnitName" , "Domain Control Validated" ),),
239
- (("commonName" , f"*.{ self ._host } " ),),
240
- ),
241
- }
242
-
243
- def unwrap (self ) -> Self :
244
- return self
245
-
246
206
def write (self , data : ReadableBuffer ) -> int | None :
247
207
return self .send (encode_to_bytes (data ))
248
208
@@ -255,6 +215,7 @@ def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketCore:
255
215
self ._bufsize = bufsize
256
216
return self .io
257
217
218
+ # TODO
258
219
def get_entry (self , data ):
259
220
return mocket .state .state .get_entry (self ._host , self ._port , data )
260
221
@@ -274,14 +235,6 @@ def sendall(self, data, entry=None, *args, **kwargs):
274
235
self .io .truncate ()
275
236
self .io .seek (0 )
276
237
277
- def read (self , buffersize : int | None = None ) -> bytes :
278
- rv = self .io .read (buffersize )
279
- if rv :
280
- self ._sent_non_empty_bytes = True
281
- if self ._did_handshake and not self ._sent_non_empty_bytes :
282
- raise ssl .SSLWantReadError ("The operation did not complete (read)" )
283
- return rv
284
-
285
238
def recv_into (
286
239
self ,
287
240
buffer : WriteableBuffer ,
@@ -309,6 +262,7 @@ def recv(self, buffersize: int, flags: int | None = None) -> bytes:
309
262
exc .args = (0 ,)
310
263
raise exc
311
264
265
+ # TODO
312
266
def true_sendall (self , data : ReadableBuffer , * args : Any , ** kwargs : Any ) -> int :
313
267
if not MocketMode ().is_allowed ((self ._host , self ._port )):
314
268
MocketMode .raise_not_allowed ()
@@ -360,23 +314,17 @@ def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int:
360
314
host , port = self ._host , self ._port
361
315
host = true_gethostbyname (host )
362
316
363
- if isinstance (self .true_socket , true_socket ) and self ._secure_socket :
364
- self .true_socket = true_urllib3_ssl_wrap_socket (
365
- self .true_socket ,
366
- ** self .kwargs ,
367
- )
368
-
369
317
with contextlib .suppress (OSError , ValueError ):
370
318
# already connected
371
- self .true_socket .connect ((host , port ))
372
- self .true_socket .sendall (data , * args , ** kwargs )
319
+ self ._true_socket .connect ((host , port ))
320
+ self ._true_socket .sendall (data , * args , ** kwargs )
373
321
encoded_response = b""
374
322
# https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12
375
323
while True :
376
- more_to_read = select .select ([self .true_socket ], [], [], 0.1 )[0 ]
324
+ more_to_read = select .select ([self ._true_socket ], [], [], 0.1 )[0 ]
377
325
if not more_to_read and encoded_response :
378
326
break
379
- new_content = self .true_socket .recv (self ._buflen )
327
+ new_content = self ._true_socket .recv (self ._buflen )
380
328
if not new_content :
381
329
break
382
330
encoded_response += new_content
@@ -415,10 +363,9 @@ def send(
415
363
return len (data )
416
364
417
365
def close (self ) -> None :
418
- # TODO might be better to use self.true_socket.fileno() instead of internal api.
419
- if self .true_socket and not self .true_socket ._closed :
420
- self .true_socket .close ()
421
- self ._fd = None
366
+ # TODO might be better to use self._true_socket.fileno() instead of internal api.
367
+ if self ._true_socket and not self ._true_socket ._closed :
368
+ self ._true_socket .close ()
422
369
423
370
def __getattr__ (self , name : str ) -> Any :
424
371
"""Do nothing catchall function, for methods like shutdown()"""
0 commit comments