66
77from ...logger import logger
88from ..ipc import ZeroMqQueue
9- from .rpc_common import (RPCCancelled , RPCRequest , RPCResponse ,
9+ from .rpc_common import (RPCCancelled , RPCParams , RPCRequest , RPCResponse ,
1010 RPCStreamingError , RPCTimeout )
1111
1212
@@ -164,33 +164,32 @@ def _start_response_reader_lazily(self):
164164 # Store the concurrent.futures.Future
165165 self ._reader_task = future
166166
167- async def _call_async (self , __rpc_method_name , * args , ** kwargs ):
167+ async def _call_async (self , method_name , * args , ** kwargs ):
168168 """Async version of RPC call.
169169 Args:
170- __rpc_method_name : Method name to call
170+ method_name : Method name to call
171171 *args: Positional arguments
172172 **kwargs: Keyword arguments
173- __rpc_timeout: The timeout (seconds) for the RPC call.
174- __rpc_need_response: Whether the RPC call needs a response.
175- If set to False, the remote call will return immediately.
173+ __rpc_params: RPCParams object containing RPC parameters.
176174
177175 Returns:
178176 The result of the remote method call
179177 """
180178 logger .debug (
181- f"RPC client calling method: { __rpc_method_name } with args: { args } and kwargs: { kwargs } "
179+ f"RPC client calling method: { method_name } with args: { args } and kwargs: { kwargs } "
182180 )
183181 if self ._server_stopped :
184182 raise RPCCancelled ("Server is shutting down, request cancelled" )
185183
186184 self ._start_response_reader_lazily ()
187- need_response = kwargs .pop ("__rpc_need_response" , True )
188- timeout = kwargs .pop ("__rpc_timeout" , self ._timeout )
185+ rpc_params = kwargs .pop ("__rpc_params" , RPCParams ())
186+ need_response = rpc_params .need_response
187+ timeout = rpc_params .timeout if rpc_params .timeout is not None else self ._timeout
189188
190189 request_id = uuid .uuid4 ().hex
191190 logger .debug (f"RPC client sending request: { request_id } " )
192191 request = RPCRequest (request_id ,
193- __rpc_method_name ,
192+ method_name ,
194193 args ,
195194 kwargs ,
196195 need_response ,
@@ -216,7 +215,7 @@ async def _call_async(self, __rpc_method_name, *args, **kwargs):
216215 raise
217216 except asyncio .TimeoutError :
218217 raise RPCTimeout (
219- f"Request '{ __rpc_method_name } ' timed out after { timeout } s" )
218+ f"Request '{ method_name } ' timed out after { timeout } s" )
220219 except Exception as e :
221220 raise e
222221 finally :
@@ -241,11 +240,11 @@ def run_loop():
241240 import time
242241 time .sleep (0.1 )
243242
244- def _call_sync (self , __rpc_method_name , * args , ** kwargs ):
243+ def _call_sync (self , method_name , * args , ** kwargs ):
245244 """Synchronous version of RPC call."""
246245 self ._ensure_event_loop ()
247246 future = asyncio .run_coroutine_threadsafe (
248- self ._call_async (__rpc_method_name , * args , ** kwargs ), self ._loop )
247+ self ._call_async (method_name , * args , ** kwargs ), self ._loop )
249248 return future .result ()
250249
251250 def call_async (self , name : str , * args , ** kwargs ):
@@ -263,7 +262,9 @@ def call_async(self, name: str, *args, **kwargs):
263262 Example:
264263 result = await client.call_async('remote_method', arg1, arg2, key=value)
265264 """
266- return self ._call_async (name , * args , ** kwargs , __rpc_need_response = True )
265+ if "__rpc_params" not in kwargs :
266+ kwargs ["__rpc_params" ] = RPCParams (need_response = True )
267+ return self ._call_async (name , * args , ** kwargs )
267268
268269 def call_future (self , name : str , * args ,
269270 ** kwargs ) -> concurrent .futures .Future :
@@ -331,7 +332,8 @@ async def call_streaming(self, name: str, *args,
331332 raise RPCCancelled ("Server is shutting down, request cancelled" )
332333
333334 self ._start_response_reader_lazily ()
334- timeout = kwargs .pop ("__rpc_timeout" , self ._timeout )
335+ rpc_params = kwargs .pop ("__rpc_params" , RPCParams ())
336+ timeout = rpc_params .timeout if rpc_params .timeout is not None else self ._timeout
335337
336338 request_id = uuid .uuid4 ().hex
337339 queue = asyncio .Queue ()
@@ -379,7 +381,9 @@ async def call_streaming(self, name: str, *args,
379381 def get_server_attr (self , name : str ):
380382 """ Get the attribute of the RPC server.
381383 This is mainly used for testing. """
382- return self ._call_sync ("__rpc_get_attr" , name , __rpc_timeout = 10 )
384+ return self ._call_sync ("__rpc_get_attr" ,
385+ name ,
386+ __rpc_params = RPCParams (timeout = 10 ))
383387
384388 def __getattr__ (self , name ):
385389 """
@@ -395,7 +399,8 @@ def __init__(self, client, method_name):
395399
396400 def __call__ (self , * args , ** kwargs ):
397401 """Default synchronous call"""
398- mode = kwargs .pop ("__rpc_mode" , "sync" )
402+ rpc_params = kwargs .get ("__rpc_params" , RPCParams ())
403+ mode = rpc_params .mode
399404 if mode == "sync" :
400405 return self .client ._call_sync (self .method_name , * args ,
401406 ** kwargs )
0 commit comments