Skip to content

Commit

Permalink
refactor and test
Browse files Browse the repository at this point in the history
  • Loading branch information
latentvector committed Jun 8, 2024
1 parent e01253c commit bd02930
Show file tree
Hide file tree
Showing 16 changed files with 223 additions and 302 deletions.
48 changes: 23 additions & 25 deletions commune/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,44 @@ def __init__(self,
args = None,
module = 'module',
verbose = True,
history_module = 'history',
path = 'history',
save: bool = True):
self.verbose = verbose
self.save = save
self.history_module = c.module(history_module)(folder_path=self.resolve_path(path))
self.base_module = c.module(module)
self.base_module_attributes = list(set(self.base_module.functions() + self.base_module.get_attributes()))
args = args or self.argv()
self.input_str = 'c ' + ' '.join(args)
output = self.get_output(args)
self.process_output(output)

def process_output(self, output):
if c.is_generator(output):
for output_item in output:
if isinstance(c, Munch):
output_item = output_item.toDict()
c.print(output_item, verbose=verbose)
c.print(output_item, verbose=self.verbose)
else:
if isinstance(output, Munch):
output = output.toDict()
c.print(output, verbose=verbose)
c.print(output, verbose=self.verbose)

if save and c.jsonable(output):
self.history_module().add({'input': 'c ' + ' '.join(args), 'output': output})
if self.save and c.jsonable(output):
self.history_module.add({'input': self.input_str, 'output': output})
return output

def get_output(self, args):

args, kwargs = self.parse_args(args)


base_module_attributes = list(set(self.base_module.functions() + self.base_module.get_attributes()))
# is it a fucntion, assume it is for the module
# handle module/function
is_fn = args[0] in base_module_attributes
def get_output(self, args):


is_fn = args[0] in self.base_module_attributes
if '/' in args[0]:
args = args[0].split('/') + args[1:]
is_fn = False

if is_fn:
# is a function
module = self.base_module
Expand All @@ -54,21 +59,19 @@ def get_output(self, args):
if isinstance(module, str):
module = c.module(module)
fn = args.pop(0)


if module.classify_fn(fn) == 'self':
module = module()

module = module()
fn_obj = getattr(module, fn)

args, kwargs = self.parse_args(args)


if callable(fn_obj):
output = fn_obj(*args, **kwargs)
elif c.is_property(fn_obj):
output = getattr(module(), fn)
else:
output = fn_obj
if callable(fn):
output = fn(*args, **kwargs)

return output

Expand All @@ -78,7 +81,6 @@ def get_output(self, args):
def parse_args(cls, argv = None):
if argv is None:
argv = cls.argv()

args = []
kwargs = {}
parsing_kwargs = False
Expand All @@ -89,13 +91,13 @@ def parse_args(cls, argv = None):
# args.append(cls.determine_type(arg))
if '=' in arg:
parsing_kwargs = True
key, value = arg.split('=', 1)
key, value = arg.split('=')
# use determine_type to convert the value to its actual type
kwargs[key] = cls.determine_type(value)

else:
assert parsing_kwargs is False, 'Cannot mix positional and keyword arguments'
args.append(cls.determine_type(arg))

return args, kwargs

@classmethod
Expand Down Expand Up @@ -144,10 +146,6 @@ def determine_type(cls, x):
return x


@classmethod
def history_module(cls, path='history'):
return c.m('history')(folder_path=cls.resolve_path(path))

@classmethod
def history(cls,**kwargs):
history = cls.history_module().history(**kwargs)
Expand Down
109 changes: 37 additions & 72 deletions commune/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1836,7 +1836,7 @@ def rm(cls, path, extension=None, mode = 'json'):
return {'success':False, 'message':f'{path} does not exist'}
if os.path.isdir(path):
c.rmdir(path)
else:
if os.path.isfile(path):
os.remove(path)
assert not os.path.exists(path), f'{path} was not removed'

Expand Down Expand Up @@ -3001,8 +3001,9 @@ def memory_usage(fmt='gb'):
return (process.memory_info().rss // 1024) / scale

@classmethod
def argparse(cls, verbose: bool = False, version=1):
if version == 1:
def argparse(cls, verbose: bool = False, **kwargs):
argv = ' '.join(c.argv())
if ' --' in argv or ' -' in argv:
parser = argparse.ArgumentParser(description='Argparse for the module')
parser.add_argument('-fn', '--fn', dest='function', help='The function of the key', type=str, default="__init__")
parser.add_argument('-kwargs', '--kwargs', dest='kwargs', help='key word arguments to the function', type=str, default="{}")
Expand All @@ -3020,13 +3021,17 @@ def argparse(cls, verbose: bool = False, version=1):
if len(args.params) > len(args.kwargs):
args.kwargs = args.params
args.args = json.loads(args.args.replace("'",'"'))
elif version == 2:
args = c.parseargs()

else:
args = c.parse_args()
return args

@classmethod
def run(cls, name:str = None, verbose:bool = False, version=1) -> Any:
def parse_args(cls, argv = None, **kwargs):
return c.module('cli').parse_args(argv=argv)

@classmethod
def run(cls, name:str = None,
version=1) -> Any:
is_main = name == '__main__' or name == None or name == cls.__name__
if not is_main:
return {'success':False, 'message':f'Not main module {name}'}
Expand Down Expand Up @@ -3605,11 +3610,13 @@ def test_fns(cls, *args, **kwargs):
@classmethod
def test(cls,
module=None,
timeout=60,
timeout=70,
trials=3,
parallel=True,
):
module = module or cls.module_path()
if module == 'module':
return c.cmd('pytest commune', verbose=True)
if c.module_exists(module + '.test'):
c.print('FOUND TEST MODULE', color='yellow')
module = module + '.test'
Expand Down Expand Up @@ -4700,74 +4707,32 @@ def find_lines(self, text:str, search:str) -> List[str]:
@classmethod
def new_module( cls,
module : str ,
repo : str = None,
base_module : str = 'demo',
tree : bool = 'commune',
overwrite : bool = True,
**kwargs):
base_module : str = 'demo',
folder_module : bool = False,
update=1
):

""" Makes directories for path.
"""
if module == None:
assert repo != None, 'repo must be specified if module is not specified'
module = os.path.basename(repo).replace('.git','').replace(' ','_').replace('-','_').lower()
tree_path = c.tree2path().get(tree)

class_name = ''
for m in module.split('.'):
class_name += m[0].upper() + m[1:] # capitalize first letter

if c.module_exists(module):
if overwrite:
module_path = c.module(module).dirpath() if c.is_file_module(module) else c.module(module).filepath()
c.rm(module_path)
else:
return {'success': False,
'path': module_path,
'msg': f' module {module} already exists, set overwrite=True to overwrite'}

# get the code ready from the base module
c.print(f'Getting {base_module}')
base_module = c.module(base_module)
is_folder_module = base_module.is_folder_module()

base_module_class = base_module.class_name()
module_class_name = ''.join([m[0].upper() + m[1:] for m in module.split('.')])

# build the path2text dictionary
if is_folder_module:
dirpath = tree_path + '/'+ module.replace('.','/') + '/'
base_dirpath = base_module.dirpath()
path2text = c.path2text( base_module.dirpath())
path2text = {k.replace(base_dirpath +'/',dirpath ):v for k,v in path2text.items()}
else:
module_path = tree_path + '/'+ module.replace('.','/') + '.py'
code = base_module.code()
path2text = {module_path: code}

og_path2text = c.copy(path2text)
for path, text in og_path2text.items():
file_type = path.split('.')[-1]
is_module_python_file = (file_type == 'py' and 'class ' + base_module_class in text)

if is_folder_module:
if file_type == 'yaml' or is_module_python_file:
path_filename = path.split('/')[-1]
new_filename = module.replace('.', '_') + '.'+ file_type
path = path[:-len(path_filename)] + new_filename


if is_module_python_file:
text = text.replace(base_module_class, module_class_name)

path2text[path] = text
c.put_text(path, text)
c.print(f'Created {path} :: {module}')

assert c.module_exists(module), f'Failed to create module {module}'
module_class_name = ''.join([m[0].capitalize() + m[1:] for m in module.split('.')])
base_module_class_name = base_module.class_name()
base_module_code = base_module.code().replace(base_module_class_name, module_class_name)
pwd = c.pwd()
path = os.path.join(pwd, module.replace('.', '/'))
if folder_module:
dirpath = path
filename = module.replace('.', '_')
path = os.path.join(path, filename)

path = path + '.py'
dirpath = os.path.dirname(path)
if os.path.exists(path) and not update:
return {'success': True, 'msg': f'Module {module} already exists', 'path': path}
if not os.path.exists(dirpath):
os.makedirs(dirpath, exist_ok=True)

return {'success': True, 'msg': f'Created module {module}', 'path': path, 'paths': list(c.path2text(c.module(module).dirpath()).keys())}
c.put_text(path, base_module_code)

return {'success': True, 'msg': f'Created module {module}', 'path': path}

add_module = new_module

Expand Down
38 changes: 21 additions & 17 deletions commune/server/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ def __init__(self,
refresh: bool = False,
stake_from_weight = 1.0, # the weight of the staker
max_age = 30, # max age of the state in seconds
sync_interval: int = 60, # 1000 seconds per sync with the network
max_staleness: int = 60, # 1000 seconds per sync with the network

**kwargs):

self.set_config(locals())
self.user_module = c.module("user")()
self.state_path = state_path
self.state_path = self.resolve_path(state_path)
if refresh:
self.rm_state()
self.last_time_synced = c.time()
Expand All @@ -35,7 +35,6 @@ def __init__(self,
'fn_info': {}}

self.set_module(module)

c.thread(self.run_loop)

def set_module(self, module):
Expand All @@ -52,30 +51,35 @@ def run_loop(self):
except Exception as e:
r = c.detailed_error(e)
c.print(r)
c.sleep(self.config.sync_interval)
c.sleep(self.config.max_staleness)


def sync_network(self, update=False, max_age=None):
state = self.get(self.state_path, {}, max_age=self.config.sync_interval)
time_since_sync = c.time() - state.get('sync_time', 0)
def sync_network(self, update=False, max_age=None, netuid=None, network=None):
state = self.get(self.state_path, {}, max_age=self.config.max_staleness)
netuid = netuid or self.config.netuid
network = network or self.config.network
staleness = c.time() - state.get('sync_time', 0)
self.key2address = c.key2address()
self.address2key = c.address2key()
response = {'msg': f'synced {self.state_path}',
'until_sync': int(self.config.sync_interval - time_since_sync),
'time_since_sync': int(time_since_sync)}
response = {
'path': self.state_path,
'max_staleness': self.config.max_staleness,
'network': network,
'netuid': netuid,
'staleness': int(staleness),
'datetime': c.datetime()}

if time_since_sync < self.config.sync_interval:
if staleness < self.config.max_staleness:
response['msg'] = 'synced too earlly'
return response

self.subspace = c.module('subspace')(network=self.config.network)
else:
response['msg'] = 'Synced with the network'
response['staleness'] = 0
self.subspace = c.module('subspace')(network=network)
max_age = max_age or self.config.max_age
state['stakes'] = self.subspace.stakes(fmt='j', netuid=self.config.netuid, update=update, max_age=max_age)
state['stakes'] = self.subspace.stakes(fmt='j', netuid=netuid, update=update, max_age=max_age)
self.state = state
self.put(self.state_path, self.state)
c.print(f'🔄 Synced {self.state_path} at {c.datetime()} 🔄\033', color='yellow')


return response

def forward(self, fn: str = 'info' , input:dict = None, address=None) -> dict:
Expand Down
8 changes: 4 additions & 4 deletions commune/server/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ def test_basics(cls) -> dict:


@classmethod
def test_serving(cls):
server_name = 'module::test'
def test_serving(cls, server_name = 'module::test'):
if server_name in c.servers():
c.kill(server_name)
module = c.serve(server_name)
c.wait_for_server(server_name)
module = c.connect(server_name)

module.put("hey",1)
r = module.put("hey",1)
v = module.get("hey")
assert v == 1, f"get failed {v}"
c.kill(server_name)
Expand Down
Loading

0 comments on commit bd02930

Please sign in to comment.