11import asyncio
22import json
33import logging
4+ from contextlib import suppress
45from ssl import SSLContext
56from typing import Any , AsyncGenerator , Dict , Optional , Tuple , Union , cast
67
@@ -94,6 +95,7 @@ def __init__(
9495 connect_timeout : int = 10 ,
9596 close_timeout : int = 10 ,
9697 ack_timeout : int = 10 ,
98+ keep_alive_timeout : Optional [int ] = None ,
9799 connect_args : Dict [str , Any ] = {},
98100 ) -> None :
99101 """Initialize the transport with the given parameters.
@@ -107,6 +109,8 @@ def __init__(
107109 :param close_timeout: Timeout in seconds for the close.
108110 :param ack_timeout: Timeout in seconds to wait for the connection_ack message
109111 from the server.
112+ :param keep_alive_timeout: Optional Timeout in seconds to receive
113+ a sign of liveness from the server.
110114 :param connect_args: Other parameters forwarded to websockets.connect
111115 """
112116 self .url : str = url
@@ -117,6 +121,7 @@ def __init__(
117121 self .connect_timeout : int = connect_timeout
118122 self .close_timeout : int = close_timeout
119123 self .ack_timeout : int = ack_timeout
124+ self .keep_alive_timeout : Optional [int ] = keep_alive_timeout
120125
121126 self .connect_args = connect_args
122127
@@ -125,6 +130,7 @@ def __init__(
125130 self .listeners : Dict [int , ListenerQueue ] = {}
126131
127132 self .receive_data_task : Optional [asyncio .Future ] = None
133+ self .check_keep_alive_task : Optional [asyncio .Future ] = None
128134 self .close_task : Optional [asyncio .Future ] = None
129135
130136 # We need to set an event loop here if there is none
@@ -141,6 +147,10 @@ def __init__(
141147 self ._no_more_listeners : asyncio .Event = asyncio .Event ()
142148 self ._no_more_listeners .set ()
143149
150+ if self .keep_alive_timeout is not None :
151+ self ._next_keep_alive_message : asyncio .Event = asyncio .Event ()
152+ self ._next_keep_alive_message .set ()
153+
144154 self ._connecting : bool = False
145155
146156 self .close_exception : Optional [Exception ] = None
@@ -315,8 +325,9 @@ def _parse_answer(
315325 )
316326
317327 elif answer_type == "ka" :
318- # KeepAlive message
319- pass
328+ # Keep-alive message
329+ if self .check_keep_alive_task is not None :
330+ self ._next_keep_alive_message .set ()
320331 elif answer_type == "connection_ack" :
321332 pass
322333 elif answer_type == "connection_error" :
@@ -332,8 +343,41 @@ def _parse_answer(
332343
333344 return answer_type , answer_id , execution_result
334345
335- async def _receive_data_loop (self ) -> None :
346+ async def _check_ws_liveness (self ) -> None :
347+ """Coroutine which will periodically check the liveness of the connection
348+ through keep-alive messages
349+ """
350+
351+ try :
352+ while True :
353+ await asyncio .wait_for (
354+ self ._next_keep_alive_message .wait (), self .keep_alive_timeout
355+ )
336356
357+ # Reset for the next iteration
358+ self ._next_keep_alive_message .clear ()
359+
360+ except asyncio .TimeoutError :
361+ # No keep-alive message in the appriopriate interval, close with error
362+ # while trying to notify the server of a proper close (in case
363+ # the keep-alive interval of the client or server was not aligned
364+ # the connection still remains)
365+
366+ # If the timeout happens during a close already in progress, do nothing
367+ if self .close_task is None :
368+ await self ._fail (
369+ TransportServerError (
370+ "No keep-alive message has been received within "
371+ "the expected interval ('keep_alive_timeout' parameter)"
372+ ),
373+ clean_close = False ,
374+ )
375+
376+ except asyncio .CancelledError :
377+ # The client is probably closing, handle it properly
378+ pass
379+
380+ async def _receive_data_loop (self ) -> None :
337381 try :
338382 while True :
339383
@@ -549,6 +593,13 @@ async def connect(self) -> None:
549593 await self ._fail (e , clean_close = False )
550594 raise e
551595
596+ # If specified, create a task to check liveness of the connection
597+ # through keep-alive messages
598+ if self .keep_alive_timeout is not None :
599+ self .check_keep_alive_task = asyncio .ensure_future (
600+ self ._check_ws_liveness ()
601+ )
602+
552603 # Create a task to listen to the incoming websocket messages
553604 self .receive_data_task = asyncio .ensure_future (self ._receive_data_loop ())
554605
@@ -597,6 +648,13 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
597648 # We should always have an active websocket connection here
598649 assert self .websocket is not None
599650
651+ # Properly shut down liveness checker if enabled
652+ if self .check_keep_alive_task is not None :
653+ # More info: https://stackoverflow.com/a/43810272/1113207
654+ self .check_keep_alive_task .cancel ()
655+ with suppress (asyncio .CancelledError ):
656+ await self .check_keep_alive_task
657+
600658 # Saving exception to raise it later if trying to use the transport
601659 # after it has already closed.
602660 self .close_exception = e
@@ -629,6 +687,7 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
629687
630688 self .websocket = None
631689 self .close_task = None
690+ self .check_keep_alive_task = None
632691
633692 self ._wait_closed .set ()
634693
0 commit comments