@@ -526,6 +526,7 @@ def __init__(
526526 options : Optional [dict ] = None ,
527527 _log_raw_websockets : bool = False ,
528528 retry_timeout : float = 60.0 ,
529+ max_retries : int = 5 ,
529530 ):
530531 """
531532 Websocket manager object. Allows for the use of a single websocket connection by multiple
@@ -536,6 +537,10 @@ def __init__(
536537 max_subscriptions: Maximum number of subscriptions per websocket connection
537538 max_connections: Maximum number of connections total
538539 shutdown_timer: Number of seconds to shut down websocket connection after last use
540+ options: Options to pass to the websocket connection
541+ _log_raw_websockets: Whether to log raw websockets in the "raw_websocket" logger
542+ retry_timeout: Timeout in seconds to retry websocket connection
543+ max_retries: Maximum number of retries following a timeout
539544 """
540545 # TODO allow setting max concurrent connections and rpc subscriptions per connection
541546 self .ws_url = ws_url
@@ -555,6 +560,7 @@ def __init__(
555560 self ._options = options if options else {}
556561 self ._log_raw_websockets = _log_raw_websockets
557562 self ._in_use_ids = set ()
563+ self ._max_retries = max_retries
558564
559565 @property
560566 def state (self ):
@@ -575,7 +581,6 @@ async def loop_time() -> float:
575581 async def _cancel (self ):
576582 try :
577583 self ._send_recv_task .cancel ()
578- await self ._send_recv_task
579584 await self .ws .close ()
580585 except (
581586 AttributeError ,
@@ -616,19 +621,30 @@ async def _handler(self, ws: ClientConnection) -> None:
616621 )
617622 loop = asyncio .get_running_loop ()
618623 should_reconnect = False
624+ is_retry = False
619625 for task in pending :
620626 task .cancel ()
621627 for task in done :
622- if isinstance (task .result (), (asyncio .TimeoutError , ConnectionClosed )):
628+ task_res = task .result ()
629+ if isinstance (
630+ task_res , (asyncio .TimeoutError , ConnectionClosed , TimeoutError )
631+ ):
623632 should_reconnect = True
633+ if isinstance (task_res , (asyncio .TimeoutError , TimeoutError )):
634+ self ._attempts += 1
635+ is_retry = True
624636 if should_reconnect is True :
625637 for original_id , payload in list (self ._inflight .items ()):
626638 self ._received [original_id ] = loop .create_future ()
627639 to_send = json .loads (payload )
628640 await self ._sending .put (to_send )
629- logger .info ("Timeout occurred. Reconnecting." )
641+ if is_retry :
642+ # Otherwise the connection was just closed due to no activity, which should not count against retries
643+ logger .info (
644+ f"Timeout occurred. Reconnecting. Attempt { self ._attempts } of { self ._max_retries } "
645+ )
630646 await self .connect (True )
631- await self ._handler (ws = ws )
647+ await self ._handler (ws = self . ws )
632648 elif isinstance (e := recv_task .result (), Exception ):
633649 return e
634650 elif isinstance (e := send_task .result (), Exception ):
@@ -689,15 +705,22 @@ async def _start_receiving(self, ws: ClientConnection) -> Exception:
689705 recd = await asyncio .wait_for (
690706 ws .recv (decode = False ), timeout = self .retry_timeout
691707 )
708+ # reset the counter once we successfully receive something back
709+ self ._attempts = 0
692710 await self ._recv (recd )
693711 except Exception as e :
694- logger .exception ("Start receiving exception" , exc_info = e )
695712 if isinstance (e , ssl .SSLError ):
696713 e = ConnectionClosed
697- for fut in self ._received .values ():
698- if not fut .done ():
699- fut .set_exception (e )
700- fut .cancel ()
714+ if not isinstance (
715+ e , (asyncio .TimeoutError , TimeoutError , ConnectionClosed )
716+ ):
717+ logger .exception ("Websocket receiving exception" , exc_info = e )
718+ for fut in self ._received .values ():
719+ if not fut .done ():
720+ fut .set_exception (e )
721+ fut .cancel ()
722+ else :
723+ logger .debug ("Timeout occurred. Reconnecting." )
701724 return e
702725
703726 async def _start_sending (self , ws ) -> Exception :
@@ -713,14 +736,21 @@ async def _start_sending(self, ws) -> Exception:
713736 raw_websocket_logger .debug (f"WEBSOCKET_SEND> { to_send } " )
714737 await ws .send (to_send )
715738 except Exception as e :
716- logger .exception ("Start sending exception" , exc_info = e )
717- if to_send is not None :
718- self ._received [to_send ["id" ]].set_exception (e )
719- self ._received [to_send ["id" ]].cancel ()
739+ if isinstance (e , ssl .SSLError ):
740+ e = ConnectionClosed
741+ if not isinstance (
742+ e , (asyncio .TimeoutError , TimeoutError , ConnectionClosed )
743+ ):
744+ logger .exception ("Websocket sending exception" , exc_info = e )
745+ if to_send is not None :
746+ self ._received [to_send ["id" ]].set_exception (e )
747+ self ._received [to_send ["id" ]].cancel ()
748+ else :
749+ for i in self ._received .keys ():
750+ self ._received [i ].set_exception (e )
751+ self ._received [i ].cancel ()
720752 else :
721- for i in self ._received .keys ():
722- self ._received [i ].set_exception (e )
723- self ._received [i ].cancel ()
753+ logger .debug ("Timeout occurred. Reconnecting." )
724754 return e
725755
726756 async def send (self , payload : dict ) -> str :
@@ -784,9 +814,9 @@ async def retrieve(self, item_id: str) -> Optional[dict]:
784814 if item is not None :
785815 if item .done ():
786816 self .max_subscriptions .release ()
817+ res = item .result ()
787818 del self ._received [item_id ]
788-
789- return item .result ()
819+ return res
790820 else :
791821 try :
792822 return self ._received_subscriptions [item_id ].get_nowait ()
@@ -860,6 +890,7 @@ def __init__(
860890 },
861891 shutdown_timer = ws_shutdown_timer ,
862892 retry_timeout = self .retry_timeout ,
893+ max_retries = max_retries ,
863894 )
864895 else :
865896 self .ws = AsyncMock (spec = Websocket )
@@ -1165,7 +1196,7 @@ async def get_runtime_for_version(
11651196 async def _get_runtime_for_version (
11661197 self , runtime_version : int , block_hash : Optional [str ] = None
11671198 ) -> Runtime :
1168- runtime_config = RuntimeConfigurationObject ()
1199+ runtime_config = RuntimeConfigurationObject (ss58_format = self . ss58_format )
11691200 runtime_config .clear_type_registry ()
11701201 runtime_config .update_type_registry (load_type_registry_preset (name = "core" ))
11711202
@@ -2337,7 +2368,7 @@ async def _make_rpc_request(
23372368 request_manager .add_request (item_id , payload ["id" ])
23382369
23392370 while True :
2340- for item_id in list ( request_manager .response_map . keys () ):
2371+ for item_id in request_manager .unresponded ( ):
23412372 if (
23422373 item_id not in request_manager .responses
23432374 or asyncio .iscoroutinefunction (result_handler )
@@ -2368,7 +2399,6 @@ async def _make_rpc_request(
23682399 runtime = runtime ,
23692400 force_legacy_decode = force_legacy_decode ,
23702401 )
2371-
23722402 request_manager .add_response (
23732403 item_id , decoded_response , complete
23742404 )
0 commit comments