Skip to content

Commit

Permalink
vali updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed May 24, 2024
1 parent 6c5631d commit e898bba
Show file tree
Hide file tree
Showing 13 changed files with 1,014 additions and 1,258 deletions.
77 changes: 42 additions & 35 deletions commune/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def prepare_request(self, args: list = None, kwargs: dict = None, params=None, m

if isinstance(args, dict):
kwargs = args
args = None
argsf = None

if params != None:
assert type(params) in [list, dict], f'params must be a list or dict, not {type(params)}'
Expand Down Expand Up @@ -86,15 +86,14 @@ def prepare_request(self, args: list = None, kwargs: dict = None, params=None, m
return request

def iter_over_async(self, ait):
loop = asyncio.get_event_loop()
# helper async fn that just gets the next element
# from the async iterator
def get_next():
try:
obj = loop.run_until_complete(ait.__anext__())
obj = self.loop.run_until_complete(ait.__anext__())
return obj
except StopAsyncIteration:
loop.run_until_complete(self.session.close())
self.loop.run_until_complete(self.session.close())
return 'done'
# actual sync iterator (implemented using a generator)
while True:
Expand Down Expand Up @@ -194,10 +193,10 @@ def prepare_url(self, address, fn):
return url


async def async_forward(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def forward(self, *args, **kwargs):
return self.loop.run_until_complete(self.aysnc_forward(*args, **kwargs))

def forward(self,
async def async_forward(self,
fn: str,
args: list = None,
kwargs: dict = None,
Expand All @@ -211,29 +210,35 @@ def forward(self,
stream = False,
**extra_kwargs
):
key = self.resolve_key(key)
url = self.prepare_url(address, fn)
# resolve the kwargs at least
kwargs =kwargs or {}
kwargs.update(extra_kwargs)
timestamp = c.time()
request = self.prepare_request(args=args, kwargs=kwargs, params=params, message_type=message_type)
future = asyncio.wait_for(self.send_request(url=url, request=request, headers=headers, verbose=verbose, stream=stream), timeout=timeout)
result = asyncio.run(future)

if type(result) in [str, dict, int, float, list, tuple]:
result = self.serializer.deserialize(result)
if isinstance(result, dict) and 'data' in result:
result = result['data']
latency = c.time() - timestamp
if self.save_history:
output = { 'input': request, 'output': result, 'latency': latency}
path = self.history_path+ '/' + self.key.ss58_address + '/' + self.address+ '/'+ str(timestamp)
self.put(path, output)
else:
result = self.iter_over_async(result)

try:
key = self.resolve_key(key)
url = self.prepare_url(address, fn)
# resolve the kwargs at least
kwargs =kwargs or {}
kwargs.update(extra_kwargs)
timestamp = c.time()
request = self.prepare_request(args=args, kwargs=kwargs, params=params, message_type=message_type)
future = await self.send_request(url=url, request=request, headers=headers, verbose=verbose, stream=stream)

if type(result) in [str, dict, int, float, list, tuple]:
result = self.serializer.deserialize(result)
if isinstance(result, dict) and 'data' in result:
result = result['data']
latency = c.time() - timestamp
if self.save_history:
output = { 'input': request, 'output': result, 'latency': latency}
path = self.history_path+ '/' + self.key.ss58_address + '/' + self.address+ '/'+ str(timestamp)
self.put(path, output)
else:
result = self.iter_over_async(result)

except Exception as e:
result = c.detailed_error(e)
return result


def __del__(self):
self.loop.run_until_complete(self.session.close())


def age(self):
Expand Down Expand Up @@ -270,7 +275,8 @@ def history(cls, key=None, history_path='history'):


@classmethod
def call(cls, module : str,
def call(cls,
module : str,
fn:str = None,
*args,
kwargs = None,
Expand All @@ -282,10 +288,11 @@ def call(cls, module : str,
timeout=40,
**extra_kwargs) -> None:

# if '
if '//' in module:
module = module.split('//')[-1]
mode = module.split('//')[0]
if '/' in module:
# adjust the split
if fn != None:
args = [fn] + list(args)
module , fn = module.split('/')
Expand All @@ -296,16 +303,16 @@ def call(cls, module : str,
virtual=False,
key=key)

# if isinstance(kwargs, str):
# kwargs = c.str2dict(kwargs)
if params != None:
kwargs = params

if kwargs == None:
kwargs = {}

kwargs.update(extra_kwargs)

return module.forward(fn=fn, args=args, kwargs=kwargs, stream=stream, timeout=timeout)



@classmethod
def call_search(cls,
search : str,
Expand Down
Loading

0 comments on commit e898bba

Please sign in to comment.