From 8b12e2a8b94465c4c242bb3a66e84063f11d71f3 Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Mon, 12 Sep 2022 00:37:21 +0900 Subject: [PATCH] comments and minor edits --- kernex/__init__.py | 2 +- kernex/_src/base.py | 16 +++++- kernex/_src/map.py | 45 ++++++++++++++--- kernex/_src/scan.py | 6 ++- kernex/_src/utils.py | 8 ++- kernex/interface/named_axis.py | 5 ++ kernex/interface/resolve_utils.py | 84 +++++++++++++++---------------- 7 files changed, 112 insertions(+), 54 deletions(-) diff --git a/kernex/__init__.py b/kernex/__init__.py index 3b42cfa..236a5c2 100644 --- a/kernex/__init__.py +++ b/kernex/__init__.py @@ -17,4 +17,4 @@ "offsetKernelScan", ) -__version__ = "0.0.8" +__version__ = "0.1.0" diff --git a/kernex/_src/base.py b/kernex/_src/base.py index a06fd2a..1b94f31 100644 --- a/kernex/_src/base.py +++ b/kernex/_src/base.py @@ -35,6 +35,10 @@ def pad_width(self): Returns: padding value passed to `pad_width` in `jnp.pad` """ + # this function is cached because it is called multiple times + # and it is expensive to calculate + # if the border is negative, the padding is 0 + # if the border is positive, the padding is the border value return tuple([0, max(0, pi[0]) + max(0, pi[1])] for pi in self.border) @cached_property @@ -45,6 +49,10 @@ def output_shape(self) -> tuple[int, ...]: Returns: tuple[int, ...]: resulting shape of the kernel operation """ + # this function is cached because it is called multiple times + # and it is expensive to calculate + # the output shape is the shape of the array after the kernel operation + # is applied to the input array return tuple( (xi + (li + ri) - ki) // si + 1 for xi, ki, si, (li, ri) in ZIP( @@ -55,13 +63,16 @@ def output_shape(self) -> tuple[int, ...]: @cached_property def views(self) -> tuple[jnp.ndarray, ...]: """Generate absolute sampling matrix""" + # this function is cached because it is called multiple times + # and it is expensive to calculate + # the view is the indices of the array that is used to calculate + # the output value dim_range = tuple( general_arange(di, ki, si, x0, xf) for (di, ki, si, (x0, xf)) in zip( self.shape, self.kernel_size, self.strides, self.border ) ) - matrix = general_product(*dim_range) return tuple(map(lambda xi, wi: xi.reshape(-1, wi), matrix, self.kernel_size)) @@ -86,6 +97,8 @@ def funcs(self) -> tuple[Callable[[Any], jnp.ndarray]]: @property def slices(self): + # this function returns a tuple of slices + # the slices are used to slice the array return tuple(self.func_index_map.values()) def index_from_view(self, view: tuple[jnp.ndarray, ...]) -> tuple[int, ...]: @@ -97,6 +110,7 @@ def index_from_view(self, view: tuple[jnp.ndarray, ...]) -> tuple[int, ...]: Returns: tuple[int, ...]: index as a tuple of int for each dimension """ + # this function returns a tuple of int return tuple( view[i][wi // 2] if wi % 2 == 1 else view[i][(wi - 1) // 2] for i, wi in enumerate(self.kernel_size) diff --git a/kernex/_src/map.py b/kernex/_src/map.py index 6307949..3d4822e 100644 --- a/kernex/_src/map.py +++ b/kernex/_src/map.py @@ -16,11 +16,17 @@ class baseKernelMap(kernelOperation): def __post_init__(self): self.__call__ = ( - self.__single_call__ if len(self.funcs) == 1 else self.__multi_call__ + # if there is only one function, use the single call method + # this is faster than the multi call method + # this is because the multi call method uses lax.switch + self.__single_call__ + if len(self.funcs) == 1 + else self.__multi_call__ ) def reduce_map_func(self, func, *args, **kwargs) -> Callable: if self.relative: + # if the function is relative, the function is applied to the view return lambda view, array: func( roll_view(array[ix_(*view)]), *args, **kwargs ) @@ -28,22 +34,37 @@ def reduce_map_func(self, func, *args, **kwargs) -> Callable: else: return lambda view, array: func(array[ix_(*view)], *args, **kwargs) - def __single_call__(self, array, *args, **kwargs): - + def __single_call__(self, array: jnp.ndarray, *args, **kwargs): padded_array = jnp.pad(array, self.pad_width) + + # convert the function to a callable that takes a view and an array + # and returns the result of the function applied to the view reduced_func = self.reduce_map_func(self.funcs[0], *args, **kwargs) + + # apply the function to each view using vmap + # the result is a 1D array of the same length as the number of views result = vmap(lambda view: reduced_func(view, padded_array))(self.views) - func_shape = result.shape[1:] - return result.reshape(*self.output_shape, *func_shape) + + # reshape the result to the output shape + # for example if the input shape is (3, 3) and the kernel shape is (2, 2) + # and the stride is 1 , and the padding is 0, the output shape is (2, 2) + return result.reshape(*self.output_shape, *result.shape[1:]) def __multi_call__(self, array, *args, **kwargs): padded_array = jnp.pad(array, self.pad_width) - + # convert the functions to a callable that takes a view and an array + # and returns the result of the function applied to the view + # the result is a 1D array of the same length as the number of views reduced_funcs = tuple( self.reduce_map_func(func, *args, **kwargs) for func in self.funcs[::-1] ) + # apply the functions to each view using vmap + # the result is a 1D array of the same length as the number of views + # here, lax.switch is used to apply the functions in order + # the first function is applied to the first view, the second function + # is applied to the second view, and so on result = vmap( lambda view: lax.switch( self.func_index_from_view(view), reduced_funcs, view, padded_array @@ -56,6 +77,8 @@ def __multi_call__(self, array, *args, **kwargs): @pytc.treeclass class kernelMap(baseKernelMap): + """A class for applying a function to a kernel map of an array""" + def __init__(self, func_dict, shape, kernel_size, strides, padding, relative): super().__init__(func_dict, shape, kernel_size, strides, padding, relative) @@ -65,7 +88,11 @@ def __call__(self, array, *args, **kwargs): @pytc.treeclass class offsetKernelMap(kernelMap): + """A class for applying a function to a kernel map of an array""" + def __init__(self, func_dict, shape, kernel_size, strides, offset, relative): + # the offset is converted to padding and the padding is used to pad the array + # the padding is then used to calculate the views self.offset = offset @@ -80,6 +107,9 @@ def __init__(self, func_dict, shape, kernel_size, strides, offset, relative): @cached_property def set_indices(self): + # the indices of the array that are set by the kernel operation + # this is used to set the values of the array after the kernel operation + # is applied return tuple( jnp.arange(x0, di - xf, si) for di, ki, si, (x0, xf) in ZIP( @@ -88,6 +118,9 @@ def set_indices(self): ) def __call__(self, array, *args, **kwargs): + # apply the kernel operation + # the result is a 1D array of the same length as the number of views + # the result is reshaped to the output shape result = self.__call__(array, *args, **kwargs) assert ( result.shape <= array.shape diff --git a/kernex/_src/scan.py b/kernex/_src/scan.py index 44628ba..777e951 100644 --- a/kernex/_src/scan.py +++ b/kernex/_src/scan.py @@ -13,12 +13,17 @@ @pytc.treeclass class baseKernelScan(kernelOperation): def __post_init__(self): + # if there is only one function, use the single call method + # this is faster than the multi call method + # this is because the multi call method uses lax.switch self.__call__ = ( self.__single_call__ if len(self.funcs) == 1 else self.__multi_call__ ) def reduce_scan_func(self, func, *args, **kwargs) -> Callable: if self.relative: + # if the function is relative, the function is applied to the view + # the result is a 1D array of the same length as the number of views return lambda view, array: array.at[self.index_from_view(view)].set( func(roll_view(array[ix_(*view)]), *args, **kwargs) ) @@ -64,7 +69,6 @@ def scan_body(padded_array, view): @pytc.treeclass class kernelScan(baseKernelScan): def __init__(self, func_dict, shape, kernel_size, strides, padding, relative): - super().__init__(func_dict, shape, kernel_size, strides, padding, relative) def __call__(self, array, *args, **kwargs): diff --git a/kernex/_src/utils.py b/kernex/_src/utils.py index b376271..83f695f 100644 --- a/kernex/_src/utils.py +++ b/kernex/_src/utils.py @@ -12,6 +12,8 @@ class cached_property: + """this function is a decorator that caches the result of the function""" + def __init__(self, func): self.name = func.__name__ self.func = func @@ -35,7 +37,8 @@ def ZIP(*args): def _offset_to_padding(input_argument, kernel_size): """convert offset argument to negative border values""" - + # for example for a kernel_size = (3,3) and offset = (1,1) + # the padding will be (-1,-1) for each dimension padding = [[]] * len(kernel_size) # offset = 1 ==> padding= 0 for kernel_size =3 @@ -70,6 +73,7 @@ def roll_view(array: jnp.ndarray) -> jnp.ndarray: [ 3 4 5 1 2] [ 8 9 10 6 7]] """ + # this function is used to roll the view along all axes shape = jnp.array(array.shape) axes = tuple(range(len(shape))) # list all axes shift = tuple( @@ -114,6 +118,7 @@ def general_arange(di: int, ki: int, si: int, x0: int, xf: int) -> jnp.ndarray: [1 2 3] [2 3 4]] """ + # this function is used to calculate the windows indices for a given dimension start, end = -x0 + ((ki - 1) // 2), di + xf - (ki // 2) size = end - start lhs = jax.lax.broadcasted_iota(dtype=jnp.int32, shape=(size, ki), dimension=0) + (start) # fmt: skip @@ -170,7 +175,6 @@ def _index_from_view( Returns: tuple[int, ...]: index as a tuple of int for each dimension """ - return tuple( view[i][wi // 2] if wi % 2 == 1 else view[i][(wi - 1) // 2] for i, wi in enumerate(kernel_size) diff --git a/kernex/interface/named_axis.py b/kernex/interface/named_axis.py index 38b9b45..5dbd5e3 100644 --- a/kernex/interface/named_axis.py +++ b/kernex/interface/named_axis.py @@ -15,6 +15,11 @@ class sortedDict(dict): """a class that sort a key before setting or getting an item""" + # this dict is used to store the kernel values + # the key is a tuple of the axis names + # the value is the kernel values + # for example if the kernel is 3x3 and the axis names are ['x', 'y'] + # the key will be ('x', 'y') and the value will be the kernel values def __getitem__(self, key: tuple[str, ...]): key = (key,) if isinstance(key, str) else tuple(sorted(key)) return super().__getitem__(key) diff --git a/kernex/interface/resolve_utils.py b/kernex/interface/resolve_utils.py index cef968f..34300f1 100644 --- a/kernex/interface/resolve_utils.py +++ b/kernex/interface/resolve_utils.py @@ -99,66 +99,64 @@ def _resolve_dict_argument( return tuple(temp) -@dispatch(argnum=0) def _resolve_offset_argument(input_argument, kernel_size): - raise NotImplementedError( - "input_argument type={} is not implemented".format(type(input_argument)) - ) - - -@_resolve_offset_argument.register(int) -def _(input_argument, kernel_size): - return [(input_argument, input_argument)] * len(kernel_size) - - -@_resolve_offset_argument.register(list) -@_resolve_offset_argument.register(tuple) -def _(input_argument, kernel_size): - offset = [[]] * len(kernel_size) + @dispatch(argnum=0) + def __resolve_offset_argument(input_argument, kernel_size): + raise NotImplementedError( + "input_argument type={} is not implemented".format(type(input_argument)) + ) - for i, item in enumerate(input_argument): - offset[i] = (item, item) if isinstance(item, int) else item + @__resolve_offset_argument.register(int) + def _(input_argument, kernel_size): + return [(input_argument, input_argument)] * len(kernel_size) - return offset + @__resolve_offset_argument.register(list) + @__resolve_offset_argument.register(tuple) + def _(input_argument, kernel_size): + offset = [[]] * len(kernel_size) + for i, item in enumerate(input_argument): + offset[i] = (item, item) if isinstance(item, int) else item -@dispatch(argnum=0) -def __resolve_index_step(index, shape): - raise NotImplementedError(f"index type={type(index)} is not implemented") + return offset + return __resolve_offset_argument(input_argument, kernel_size) -@__resolve_index_step.register(int) -def _(index, shape): - index += shape if index < 0 else 0 - return index +def _resolve_index(index, shape): + """Resolve index to a tuple of int""" -@__resolve_index_step.register(slice) -def _(index, shape): - start, end, step = index.start, index.stop, index.step + @dispatch(argnum=0) + def __resolve_index_step(index, shape): + raise NotImplementedError(f"index type={type(index)} is not implemented") - start = start or 0 - start += shape if start < 0 else 0 + @__resolve_index_step.register(int) + def _(index, shape): + index += shape if index < 0 else 0 + return index - end = end or shape - end += shape if end < 0 else 0 + @__resolve_index_step.register(slice) + def _(index, shape): + start, end, step = index.start, index.stop, index.step - step = step or 1 + start = start or 0 + start += shape if start < 0 else 0 - return (start, end, step) + end = end or shape + end += shape if end < 0 else 0 + step = step or 1 -@__resolve_index_step.register(list) -@__resolve_index_step.register(tuple) -def _(index, shape): - assert all( - isinstance(i, int) for i in jax.tree_util.tree_leaves(index) - ), "All items in tuple must be int" - return index + return (start, end, step) + @__resolve_index_step.register(list) + @__resolve_index_step.register(tuple) + def _(index, shape): + assert all( + isinstance(i, int) for i in jax.tree_util.tree_leaves(index) + ), "All items in tuple must be int" + return index -def _resolve_index(index, shape): - """Resolve index to a tuple of int""" index = [index] if not isinstance(index, tuple) else index resolved_index = [[]] * len(index)