Skip to content

Commit

Permalink
Update kernel_interface.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Jul 29, 2022
1 parent 2c2317d commit a5fc60b
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions kernex/interface/kernel_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __setitem__(self, index, func):
# append slice/index to func key list
self.container[func] = [*self.container.get(func, []), index]

def __mesh_call__(self, array, *args, **kwargs):
def _wrap_mesh(self, array, *args, **kwargs):
# TODO : run once resolve_kernel_size/resolve_strides

self.shape = array.shape
Expand All @@ -67,10 +67,9 @@ def __mesh_call__(self, array, *args, **kwargs):
for (func, index) in self.container.items():

if func is not None and self.named_axis is not None:
resolved_func = named_axis_wrapper(self.kernel_size, self.named_axis)(
func
)
self.resolved_container[resolved_func] = index
self.resolved_container[
named_axis_wrapper(self.kernel_size, self.named_axis)(func)
] = index

else:
self.resolved_container[func] = index
Expand All @@ -84,7 +83,7 @@ def __mesh_call__(self, array, *args, **kwargs):
self.relative,
)(array, *args, **kwargs)

def __decorator_call__(self, func):
def _wrap_decorator(self, func):
def call(array, *args, **kwargs):

# TODO : run once resolve_kernel_size/resolve_strides
Expand Down Expand Up @@ -112,16 +111,15 @@ def call(array, *args, **kwargs):
def __call__(self, *args, **kwargs):

if len(args) == 1 and callable(args[0]) and len(kwargs) == 0:
return functools.wraps(args[0])(self.__decorator_call__(args[0]))
return functools.wraps(args[0])(self._wrap_decorator(args[0]))

elif len(args) > 0 and isinstance(args[0], jnp.ndarray):
return self.__mesh_call__(*args, **kwargs)
return self._wrap_mesh(*args, **kwargs)

else:
raise ValueError(
(
"Expected `jnp.ndarray` or `Callable` for the first argument.",
f" Found {tuple(*args,**kwargs)}",
f"Expected `jnp.ndarray` or `Callable` for the first argument. Found {tuple(*args,**kwargs)}"
)
)

Expand Down

0 comments on commit a5fc60b

Please sign in to comment.