Skip to content

Commit

Permalink
Update kernel_interface.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 3, 2022
1 parent 0286937 commit 1247e5a
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 19 deletions.
39 changes: 24 additions & 15 deletions kernex/interface/kernel_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from kernex.src.scan import kernelScan, offsetKernelScan


@pytc.treeclass(op=False, field_only=True)
@pytc.treeclass
class kernelInterface:

kernel_size: tuple[int, ...] | int = pytc.static_field()
Expand All @@ -37,14 +37,11 @@ class kernelInterface:

def __post_init__(self):
"""resolve the border values and the kernel operation"""

if self.use_offset:
self.border = _resolve_offset_argument(self.border, self.kernel_size)
self.kernel_op = offsetKernelScan if self.inplace else offsetKernelMap

else:
self.border = _resolve_padding_argument(self.border, self.kernel_size)
self.kernel_op = kernelScan if self.inplace else kernelMap
self.border = (
_resolve_offset_argument(self.border, self.kernel_size)
if self.use_offset
else _resolve_padding_argument(self.border, self.kernel_size)
)

def __setitem__(self, index, func):

Expand Down Expand Up @@ -73,7 +70,13 @@ def _wrap_mesh(self, array, *args, **kwargs):
else:
self.resolved_container[func] = index

return self.kernel_op(
kernel_op = (
(offsetKernelScan if self.inplace else offsetKernelMap)
if self.use_offset
else (kernelScan if self.inplace else kernelMap)
)

return kernel_op(
self.resolved_container,
self.shape,
self.kernel_size,
Expand All @@ -95,7 +98,13 @@ def call(array, *args, **kwargs):
else func: ()
}

return self.kernel_op(
kernel_op = (
(offsetKernelScan if self.inplace else offsetKernelMap)
if self.use_offset
else (kernelScan if self.inplace else kernelMap)
)

return kernel_op(
self.resolved_container,
self.shape,
self.kernel_size,
Expand All @@ -122,7 +131,7 @@ def __call__(self, *args, **kwargs):
)


@pytc.treeclass(op=False)
@pytc.treeclass
class sscan(kernelInterface):
def __init__(
self, kernel_size=1, strides=1, offset=0, relative=False, named_axis=None
Expand All @@ -139,7 +148,7 @@ def __init__(
)


@pytc.treeclass(op=False)
@pytc.treeclass
class smap(kernelInterface):
def __init__(
self, kernel_size=1, strides=1, offset=0, relative=False, named_axis=None
Expand All @@ -156,7 +165,7 @@ def __init__(
)


@pytc.treeclass(op=False)
@pytc.treeclass
class kscan(kernelInterface):
def __init__(
self, kernel_size=1, strides=1, padding=0, relative=False, named_axis=None
Expand All @@ -173,7 +182,7 @@ def __init__(
)


@pytc.treeclass(op=False)
@pytc.treeclass
class kmap(kernelInterface):
def __init__(
self, kernel_size=1, strides=1, padding=0, relative=False, named_axis=None
Expand Down
2 changes: 1 addition & 1 deletion kernex/src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from kernex.src.utils import ZIP, _key_search, general_arange, general_product


@pytc.treeclass(op=False)
@pytc.treeclass
class kernelOperation:
"""base class for all kernel operations"""

Expand Down
2 changes: 1 addition & 1 deletion kernex/src/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from kernex.src.utils import ZIP, _offset_to_padding, ix_, roll_view


@pytc.treeclass(op=False)
@pytc.treeclass
class baseKernelMap(kernelOperation):
def __post_init__(self):
self.__call__ = (
Expand Down
2 changes: 1 addition & 1 deletion kernex/src/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from kernex.src.utils import ZIP, _offset_to_padding, ix_, roll_view


@pytc.treeclass(op=False)
@pytc.treeclass
class baseKernelScan(kernelOperation):
def __post_init__(self):
self.__call__ = (
Expand Down
2 changes: 1 addition & 1 deletion tests_and_benchmarks/test_benchmark_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_and_time_conv2d():
print()
print("backend name = ", jax.devices())

iters = 1000
iters = 50

dims = list(
sorted(itertools.product([4, 8, 16, 32, 64], [16, 32, 64, 128, 256, 512, 1024]))
Expand Down

0 comments on commit 1247e5a

Please sign in to comment.