diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index b1f7f92be9..4049f90363 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -20,7 +20,7 @@ transformation as dace_transformation, ) from dace.codegen.targets import cpp as dace_cpp -from dace.sdfg import nodes as dace_nodes +from dace.sdfg import memlet_utils as dace_mutils, nodes as dace_nodes from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations @@ -262,7 +262,7 @@ def _gt_expand_non_standard_memlets_sdfg( new_maps: set[dace_nodes.MapEntry] = set() # The implementation is based on DaCe's code generator, see `dace/codegen/targets/cuda.py` # in the function `preprocess()` - # NOTE: This implementation needs a DaCe version that includes https://github.com/spcl/dace/pull/1976 + # NOTE: This implementation needs a DaCe version that includes https://github.com/spcl/dace/pull/2033 for state in sdfg.states(): for e in state.edges(): # We are only interested in edges that connects two access nodes of GPU memory. @@ -289,6 +289,20 @@ def _gt_expand_non_standard_memlets_sdfg( is_c_order = src_strides[-1] == 1 and dst_strides[-1] == 1 if is_c_order or is_fortran_order: continue + + # NOTE: Special case of continuous copy + # Example: dcol[0:I, 0:J, k] -> datacol[0:I, 0:J] + # with copy shape [I, J] and strides [J*K, K], [J, 1] + if src_strides[-1] != 1 or dst_strides[-1] != 1: + try: + is_src_cont = src_strides[0] / src_strides[1] == copy_shape[1] + is_dst_cont = dst_strides[0] / dst_strides[1] == copy_shape[1] + except (TypeError, ValueError): + is_src_cont = False + is_dst_cont = False + if is_src_cont and is_dst_cont: + continue + elif dims > 2: if not (src_strides[-1] != 1 or dst_strides[-1] != 1): continue @@ -298,27 +312,27 @@ def _gt_expand_non_standard_memlets_sdfg( edge.dst for edge in state.out_edges(a) ] - # Turn unsupported copy to a map - try: - dace_transformation.dataflow.CopyToMap.apply_to( - sdfg, - save=False, - annotate=False, - a=a, - b=b, - options={ - "ignore_strides": True - }, # apply 'CopyToMap' even if src/dst strides are different - ) - except ValueError: # If transformation doesn't match, continue normally - continue + if not dace_mutils.can_memlet_be_turned_into_a_map( + edge=e, state=state, sdfg=sdfg, ignore_strides=True + ): + # NOTE: In DaCe, they simply ignore that case and continue to the + # code generator. In GT4Py we generate an error. + raise RuntimeError(f"Unable to turn the not supported edge '{e}' into a copy Map.") + + # Turn the not supported Memlet into a copy Map. We have to do it here, + # such that we can then set their iteration order correctly. + dace_mutils.memlet_to_map( + edge=e, + state=state, + sdfg=sdfg, + ignore_strides=True, + ) # We find the new map by comparing the new neighborhood of `a` with the old one. new_nodes: set[dace_nodes.MapEntry] = { edge.dst for edge in state.out_edges(a) if edge.dst not in old_neighbors_of_a } assert any(isinstance(new_node, dace_nodes.MapEntry) for new_node in new_nodes) - assert len(new_nodes) == 1 new_maps.update(new_nodes) return new_maps diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/common.py b/src/gt4py/next/program_processors/runners/dace/workflow/common.py index 95d77a508a..a7de1a1594 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/common.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/common.py @@ -34,6 +34,10 @@ def set_dace_config( dace.config.Config.set("default_build_folder", value=str(config.BUILD_CACHE_DIR / "dace_cache")) dace.config.Config.set("compiler.use_cache", value=True) + # Prevents the implicit change of Memlets to Maps. Instead they should be handled by + # `gt4py.next.program_processors.runners.dace.transfromations.gpu_utils.gt_gpu_transform_non_standard_memlet()`. + dace.Config.set("compiler.cuda.allow_implicit_memlet_to_map", value=False) + if cmake_build_type is not None: dace.config.Config.set("compiler.build_type", value=cmake_build_type.value) diff --git a/uv.lock b/uv.lock index 51c7efc1aa..94b0f4f507 100644 --- a/uv.lock +++ b/uv.lock @@ -663,7 +663,7 @@ wheels = [ [[package]] name = "dace" version = "1.0.0" -source = { git = "https://github.com/GridTools/dace?branch=gt4py-next-integration#1c56acc04b9d544182cc79b4778fb0c0a201b77d" } +source = { git = "https://github.com/GridTools/dace?branch=gt4py-next-integration#d779cd1e91e6b519426f463184e3ffd36e7ceaf5" } resolution-markers = [ "python_full_version >= '3.11'", "python_full_version < '3.11'", @@ -1105,7 +1105,7 @@ dace = [ { name = "dace", version = "1.0.2", source = { registry = "https://pypi.org/simple" } }, ] dace-next = [ - { name = "dace", version = "1.0.0", source = { git = "https://github.com/GridTools/dace?branch=gt4py-next-integration#1c56acc04b9d544182cc79b4778fb0c0a201b77d" } }, + { name = "dace", version = "1.0.0", source = { git = "https://github.com/GridTools/dace?branch=gt4py-next-integration#d779cd1e91e6b519426f463184e3ffd36e7ceaf5" } }, ] formatting = [ { name = "clang-format" },