Skip to content

Commit 95cffcf

Browse files
feat[dace][next]: Adjusted the handling of non standard GPU Memlets. (#1913)
Due to the update to DaCe in [PR#1976](spcl/dace#1976) the code generator is now able to handle more Memlets directly as Cuda `memcpy()` calls. Thus we have to change this function to reflect this case. It even stands to reason that these functions can be removed altogether. Note that the merge has become possible because of our switch to our DaCe fork in [PR#2012](#2012).
1 parent 7697015 commit 95cffcf

File tree

1 file changed

+6
-11
lines changed
  • src/gt4py/next/program_processors/runners/dace/transformations

1 file changed

+6
-11
lines changed

src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,9 @@ def _gt_expand_non_standard_memlets_sdfg(
258258
) -> set[dace_nodes.MapEntry]:
259259
"""Implementation of `_gt_expand_non_standard_memlets()` that process a single SDFG."""
260260
new_maps: set[dace_nodes.MapEntry] = set()
261-
# The implementation is based on DaCe's code generator.
261+
# The implementation is based on DaCe's code generator, see `dace/codegen/targets/cuda.py`
262+
# in the function `preprocess()`
263+
# NOTE: This implementation needs a DaCe version that includes https://github.com/spcl/dace/pull/1976
262264
for state in sdfg.states():
263265
for e in state.edges():
264266
# We are only interested in edges that connects two access nodes of GPU memory.
@@ -279,16 +281,9 @@ def _gt_expand_non_standard_memlets_sdfg(
279281
if dims == 1:
280282
continue
281283
elif dims == 2:
282-
if src_strides[-1] != 1 or dst_strides[-1] != 1:
283-
try:
284-
is_src_cont = src_strides[0] / src_strides[1] == copy_shape[1]
285-
is_dst_cont = dst_strides[0] / dst_strides[1] == copy_shape[1]
286-
except (TypeError, ValueError):
287-
is_src_cont = False
288-
is_dst_cont = False
289-
if is_src_cont and is_dst_cont:
290-
continue
291-
else:
284+
is_fortran_order = src_strides[0] == 1 and dst_strides[0] == 1
285+
is_c_order = src_strides[-1] == 1 and dst_strides[-1] == 1
286+
if is_c_order or is_fortran_order:
292287
continue
293288
elif dims > 2:
294289
if not (src_strides[-1] != 1 or dst_strides[-1] != 1):

0 commit comments

Comments
 (0)