Skip to content
23 changes: 23 additions & 0 deletions HeterogeneousCore/AlpakaCore/python/functions.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@makortel what do you think ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is much better than spreading around the direct modification copy._TypedParameterizable__type, and a reasonable way for immediate needs. For the longer term I'd still like to solve the problem in different way, with which I'll follow up in #43780

Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
def makeSerialClone(module, **kwargs):
type = module._TypedParameterizable__type
if type.endswith('@alpaka'):
# alpaka module with automatic backend selection
base = type.removesuffix('@alpaka')
elif type.startswith('alpaka_serial_sync::'):
# alpaka module with explicit serial_sync backend
base = type.removeprefix('alpaka_serial_sync::')
Comment on lines +6 to +8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alpaka_serial_sync case could just return module.clone(**kwargs).

elif type.startswith('alpaka_cuda_async::'):
# alpaka module with explicit cuda_async backend
base = type.removeprefix('alpaka_cuda_async::')
elif type.startswith('alpaka_rocm_async::'):
# alpaka module with explicit rocm_async backend
base = type.removeprefix('alpaka_rocm_async::')
Comment on lines +6 to +14
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if-else chain could be replaced with e.g. a regex to remove the alpaka_[^:]*:: part.

else:
# non-alpaka module
raise TypeError('%s is not an alpaka-based module, and cannot be used with makeSerialClone()' % str(module))

copy = module.clone(**kwargs)
copy._TypedParameterizable__type = 'alpaka_serial_sync::' + base
if 'alpaka' in copy.parameterNames_():
del copy.alpaka
return copy