4
4
import logging
5
5
import socket
6
6
import select
7
+ import struct
7
8
import threading
8
9
import subprocess
9
- import typing
10
+ from typing import Callable , ContextManager
10
11
11
12
from tests .containers .cancellation_token import CancellationToken
12
13
@@ -50,7 +51,7 @@ def stop(self):
50
51
class SocketProxy :
51
52
def __init__ (
52
53
self ,
53
- remote_socket_factory : typing . ContextManager [socket .socket ],
54
+ remote_socket_factory : Callable [..., ContextManager [socket .socket ] ],
54
55
local_host : str = "localhost" ,
55
56
local_port : int = 0 ,
56
57
buffer_size : int = 4096
@@ -81,9 +82,14 @@ def listen_and_serve_until_canceled(self):
81
82
Handles at most one client at a time. """
82
83
try :
83
84
while not self .cancellation_token .cancelled :
84
- client_socket , addr = self .server_socket .accept ()
85
- logging .info (f"Accepted connection from { addr [0 ]} :{ addr [1 ]} " )
86
- self ._handle_client (client_socket )
85
+ readable , _ , _ = select .select ([self .server_socket , self .cancellation_token ], [], [])
86
+
87
+ # ISSUE-922: socket.accept() blocks, so if cancel() did not come very fast, we'd loop over and block
88
+ if self .server_socket in readable :
89
+ client_socket , addr = self .server_socket .accept ()
90
+ logging .info (f"Accepted connection from { addr [0 ]} :{ addr [1 ]} " )
91
+ # handle client synchronously, which means that there can be at most one at a time
92
+ self ._handle_client (client_socket )
87
93
except Exception as e :
88
94
logging .exception (f"Proxying failed to listen" , exc_info = e )
89
95
raise
@@ -96,27 +102,39 @@ def get_actual_port(self) -> int:
96
102
return self .server_socket .getsockname ()[1 ]
97
103
98
104
def _handle_client (self , client_socket ):
99
- with client_socket as _ , self .remote_socket_factory as remote_socket :
100
- while True :
105
+ with client_socket as _ , self .remote_socket_factory () as remote_socket :
106
+ while not self . cancellation_token . cancelled :
101
107
readable , _ , _ = select .select ([client_socket , remote_socket , self .cancellation_token ], [], [])
102
108
103
- if self .cancellation_token .cancelled :
104
- break
105
-
106
109
if client_socket in readable :
107
110
data = client_socket .recv (self .buffer_size )
108
111
if not data :
109
112
break
110
113
remote_socket .send (data )
111
114
112
115
if remote_socket in readable :
113
- data = remote_socket .recv (self .buffer_size )
116
+ try :
117
+ data = remote_socket .recv (self .buffer_size )
118
+ except ConnectionResetError :
119
+ # ISSUE-922: it seems best to propagate the error and let the client retry
120
+ # alternatively it would be necessary to resend anything already received from client_socket
121
+ logging .info (f"Reading from remote socket failed, client { client_socket .getpeername ()} has been disconnected" )
122
+ _rst_socket (client_socket )
123
+ break
114
124
if not data :
115
125
break
116
126
client_socket .send (data )
117
127
118
128
119
- if __name__ == "__main__" :
129
+ def _rst_socket (s : socket ):
130
+ """Closing a SO_LINGER socket will RST it
131
+ https://stackoverflow.com/questions/46264404/how-can-i-reset-a-tcp-socket-in-python
132
+ """
133
+ s .setsockopt (socket .SOL_SOCKET , socket .SO_LINGER , struct .pack ('ii' , 1 , 0 ))
134
+ s .close ()
135
+
136
+
137
+ def main () -> None :
120
138
"""Sample application to show how this can work."""
121
139
122
140
@@ -161,13 +179,21 @@ def get_actual_port(self):
161
179
server .join ()
162
180
163
181
164
- proxy = SocketProxy (remote_socket_factory () , "localhost" , 0 )
182
+ proxy = SocketProxy (remote_socket_factory , "localhost" , 0 )
165
183
thread = threading .Thread (target = proxy .listen_and_serve_until_canceled )
166
184
thread .start ()
167
185
168
- client_socket = socket .socket (socket .AF_INET , socket .SOCK_STREAM )
169
- client_socket .connect (("localhost" , proxy .get_actual_port ()))
186
+ for _ in range (2 ):
187
+ client_socket = socket .socket (socket .AF_INET , socket .SOCK_STREAM )
188
+ client_socket .connect (("localhost" , proxy .get_actual_port ()))
170
189
171
- print (client_socket .recv (1024 )) # prints Hello World
190
+ print (client_socket .recv (1024 )) # prints Hello World
191
+ print (client_socket .recv (1024 )) # prints nothing
192
+ client_socket .close ()
193
+ proxy .cancellation_token .cancel ()
172
194
173
195
thread .join ()
196
+
197
+
198
+ if __name__ == "__main__" :
199
+ main ()
0 commit comments