diff --git a/AMD_INTRODUCTION.md b/AMD_INTRODUCTION.md new file mode 100644 index 0000000000..b9f08f476a --- /dev/null +++ b/AMD_INTRODUCTION.md @@ -0,0 +1,44 @@ +# Icon4py performance on MI300 + +## Quickstart + +``` +# Connect to Beverin (CSCS system with MI300A) +ssh beverin.cscs.ch +``` + +In Beverin: + +``` +# Enter scratch directory +cd $SCRATCH + +# Clone icon4py and checkout the correct branch +git clone git@github.com:C2SM/icon4py.git +cd icon4py +git checkout amd_profiling + +# Pull the correct `uenv` image. *!* NECESSARY ONLY ONCE *!* +uenv image pull build::prgenv-gnu/25.12:2333839235 + +# Start the uenv and mount the ROCm 7.1.0 environment. *!* This needs to be executed before running anything everytime *!* +uenv start --view default prgenv-gnu/25.12:2333839235 + +# Install the necessary venv +bash amd_scripts/install_icon4py_venv.sh + +# Source venv +source .venv/bin/activate + +# Source other necessary environment variables +source amd_scripts/setup_env.sh + +# Set GT4Py related environment variables +export GT4PY_UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE="1" +export GT4PY_BUILD_CACHE_LIFETIME=persistent +export GT4PY_BUILD_CACHE_DIR=amd_profiling_granule +export GT4PY_COLLECT_METRICS_LEVEL=10 +export GT4PY_DYCORE_ENABLE_METRICS="1" +export GT4PY_ADD_GPU_TRACE_MARKERS="1" +export HIPFLAGS="-std=c++17 -fPIC -O3 -march=native -Wno-unused-parameter -save-temps -Rpass-analysis=kernel-resource-usage" +``` diff --git a/amd_scripts/install_icon4py_venv.sh b/amd_scripts/install_icon4py_venv.sh new file mode 100755 index 0000000000..4c86fc481d --- /dev/null +++ b/amd_scripts/install_icon4py_venv.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +set -e + +date + +# Go to the root of the icon4py repository to run the installation from there +ICON4PY_GIT_ROOT=$(git rev-parse --show-toplevel) +cd $ICON4PY_GIT_ROOT + +# Set necessasry flags for compilation +source $ICON4PY_GIT_ROOT/amd_scripts/setup_env.sh + +# Install uv locally +export PATH="$PWD/bin:$PATH" +if [ ! -x "$PWD/bin/uv" ]; then + curl -LsSf https://astral.sh/uv/install.sh | UV_UNMANAGED_INSTALL="$PWD/bin" sh +else + echo "# uv already installed at $PWD/bin/uv" +fi + +# Install icon4py, gt4py, DaCe and other basic dependencies using uv +uv sync --extra rocm7_0 --python $(which python3.12) + +# Activate virtual environment +source .venv/bin/activate + +# Install the requirements for rocprofiler-compute so we can run the profiler from the same environment +uv pip install -r /user-environment/linux-zen3/rocprofiler-compute-7.1.0-rjjjgkz67w66bp46jw7bvlfyduzr6vhv/libexec/rocprofiler-compute/requirements.txt + +echo "# install done" +date diff --git a/amd_scripts/setup_env.sh b/amd_scripts/setup_env.sh new file mode 100644 index 0000000000..d8eeb99692 --- /dev/null +++ b/amd_scripts/setup_env.sh @@ -0,0 +1,13 @@ +export CC="$(which gcc)" +export MPICH_CC="$(which gcc)" +export CXX="$(which g++)" +export MPICH_CXX="$(which g++)" +export HUGETLB_ELFMAP="no" +export HUGETLB_MORECORE="no" +export PYTHONOPTIMIZE="2" +export HCC_AMDGPU_TARGET="gfx942" +export ROCM_HOME="/user-environment/env/default" +export HIPCC=$(which hipcc) +export ROCM_VERSION="7.1.0" +export LD_LIBRARY_PATH=/user-environment/linux-zen3/rocprofiler-dev-7.1.0-i7wbbbgrx7jjp4o2xroyj5j263dkzplv/lib:$LD_LIBRARY_PATH +export LD_PRELOAD=/user-environment/env/default/lib/libomp.so:$LD_PRELOAD diff --git a/model/atmosphere/subgrid_scale_physics/muphys/src/icon4py/model/atmosphere/subgrid_scale_physics/muphys/driver/run_graupel_only.py b/model/atmosphere/subgrid_scale_physics/muphys/src/icon4py/model/atmosphere/subgrid_scale_physics/muphys/driver/run_graupel_only.py index e4818f9a23..8a769cfae1 100755 --- a/model/atmosphere/subgrid_scale_physics/muphys/src/icon4py/model/atmosphere/subgrid_scale_physics/muphys/driver/run_graupel_only.py +++ b/model/atmosphere/subgrid_scale_physics/muphys/src/icon4py/model/atmosphere/subgrid_scale_physics/muphys/driver/run_graupel_only.py @@ -10,15 +10,20 @@ from __future__ import annotations import argparse +import copy import pathlib import time from gt4py import next as gtx from gt4py.next import config as gtx_config from gt4py.next.instrumentation import metrics as gtx_metrics +from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations from icon4py.model.atmosphere.subgrid_scale_physics.muphys.driver import common, utils -from icon4py.model.atmosphere.subgrid_scale_physics.muphys.implementations import graupel +from icon4py.model.atmosphere.subgrid_scale_physics.muphys.implementations import ( + graupel, + graupel_dace_hooks, +) from icon4py.model.common import dimension as dims, model_backends, model_options, type_alias as ta from icon4py.model.common.utils import device_utils @@ -60,7 +65,17 @@ def setup_graupel( vertical_start: int, vertical_end: int, enable_masking: bool = True, + enable_dace_hooks: bool = True, ): + if enable_dace_hooks: + assert model_backends.is_backend_descriptor(backend) + backend = copy.deepcopy(backend) + if "optimization_args" not in backend: + backend["optimization_args"] = {} + backend["optimization_args"]["optimization_hooks"] = { + gtx_transformations.GT4PyAutoOptHook.TopLevelDataFlowPre: graupel_dace_hooks.remove_self_copy_inside_scan, + gtx_transformations.GT4PyAutoOptHook.TopLevelDataFlowPost: graupel_dace_hooks.rename_intermediate_access_nodes, + } with utils.recursion_limit(10**4): # TODO(havogt): make an option in gt4py? graupel_run_program = model_options.setup_program( backend=backend, diff --git a/model/atmosphere/subgrid_scale_physics/muphys/src/icon4py/model/atmosphere/subgrid_scale_physics/muphys/implementations/graupel_dace_hooks.py b/model/atmosphere/subgrid_scale_physics/muphys/src/icon4py/model/atmosphere/subgrid_scale_physics/muphys/implementations/graupel_dace_hooks.py new file mode 100644 index 0000000000..6d01465ec2 --- /dev/null +++ b/model/atmosphere/subgrid_scale_physics/muphys/src/icon4py/model/atmosphere/subgrid_scale_physics/muphys/implementations/graupel_dace_hooks.py @@ -0,0 +1,650 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022-2024, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import copy +from collections.abc import Sequence + +import dace +from dace import ( + nodes as dace_nodes, + sdfg as dace_sdfg, + symbolic as dace_sym, + transformation as dace_transformation, +) +from gt4py.next import config as gtx_config +from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations +from gt4py.next.program_processors.runners.dace.transformations import ( + local_double_buffering as gtx_local_double_buffering, +) + + +def _cleanup_local_self_update( + scan_sdfg: dace.SDFG, + if_stmt_node: dace.sdfg.state.ConditionalBlock, + if_stmt_conn: str, + compute_src_node: dace_nodes.AccessNode, + compute_dst_node: dace_nodes.AccessNode, + update_src_node: dace_nodes.AccessNode, + update_dst_node: dace_nodes.AccessNode, + scan_compute_st: dace.SDFGState, + scan_update_st: dace.SDFGState, +) -> None: + temp_data_name = compute_dst_node.data + assert isinstance(scan_sdfg.arrays[temp_data_name], dace.data.Scalar) + assert scan_sdfg.arrays[temp_data_name] == compute_src_node.desc(scan_sdfg) + assert scan_sdfg.arrays[temp_data_name] == update_dst_node.desc(scan_sdfg) + + # reroute the write edge in the compute state + new_compute_dst_node = scan_compute_st.add_access(compute_src_node.data) + scan_compute_st.add_edge( + if_stmt_node, + if_stmt_conn, + new_compute_dst_node, + None, + dace.Memlet(data=new_compute_dst_node.data, subset="0"), + ) + for edge in scan_compute_st.out_edges(compute_dst_node): + scan_compute_st.add_edge( + new_compute_dst_node, + None, + edge.dst, + edge.dst_conn, + dace.Memlet( + data=new_compute_dst_node.data, + subset=edge.data.get_src_subset(edge, scan_compute_st), + other_subset=edge.data.get_dst_subset(edge, scan_compute_st), + ), + ) + scan_compute_st.remove_node(compute_dst_node) + + # reroute the write edge in the update state + scan_update_st.add_nedge( + scan_update_st.add_access(compute_src_node.data), + update_dst_node, + dace.Memlet( + data=compute_src_node.data, + subset="0", + other_subset="0", + ), + ) + scan_update_st.remove_node(update_src_node) + + # now it is safe to remove the data descriptor + scan_sdfg.remove_data(temp_data_name, validate=gtx_config.DEBUG) + print( + f"Removed self-copy in {if_stmt_node.label}: {compute_src_node.data} -> {compute_dst_node.data}" + ) + + +def _replace_scan_input( + sdfg: dace.SDFG, + state: dace.SDFGState, + old_node: dace_nodes.AccessNode, + new_node: dace_nodes.AccessNode, + new_node_offsets: Sequence[dace_sym.SymbolicType], +) -> None: + reconfigured_neighbour: set[tuple[dace_nodes.Node, str | None]] = set() + + for producer_edge in list(state.in_edges(old_node)): + producer: dace_nodes.Node = producer_edge.src + producer_conn = producer_edge.src_conn + new_producer_edge = gtx_transformations.utils.reroute_edge( + is_producer_edge=True, + current_edge=producer_edge, + ss_offset=new_node_offsets, + state=state, + sdfg=sdfg, + old_node=old_node, + new_node=new_node, + ) + if (producer, producer_conn) not in reconfigured_neighbour: + gtx_transformations.utils.reconfigure_dataflow_after_rerouting( + is_producer_edge=True, + new_edge=new_producer_edge, + sdfg=sdfg, + state=state, + ss_offset=new_node_offsets, + old_node=old_node, + new_node=new_node, + ) + reconfigured_neighbour.add((producer, producer_conn)) + + for consumer_edge in list(state.out_edges(old_node)): + consumer: dace_nodes.Node = consumer_edge.dst + consumer_conn = consumer_edge.dst_conn + new_consumer_edge = gtx_transformations.utils.reroute_edge( + is_producer_edge=False, + current_edge=consumer_edge, + ss_offset=new_node_offsets, + state=state, + sdfg=sdfg, + old_node=old_node, + new_node=new_node, + ) + if (consumer, consumer_conn) not in reconfigured_neighbour: + gtx_transformations.utils.reconfigure_dataflow_after_rerouting( + is_producer_edge=False, + new_edge=new_consumer_edge, + sdfg=sdfg, + state=state, + ss_offset=new_node_offsets, + old_node=old_node, + new_node=new_node, + ) + reconfigured_neighbour.add((consumer, consumer_conn)) + + state.remove_node(old_node) + sdfg.remove_data(old_node.data, validate=gtx_config.DEBUG) + + gtx_transformations.gt_propagate_strides_from_access_node( + sdfg=sdfg, + state=state, + outer_node=new_node, + ) + + +def _cleanup_global_self_update( + sdfg: dace.SDFG, + state: dace.SDFGState, + if_stmt_node: dace_nodes.NestedSDFG, + if_stmt_else_state: dace.SDFGState, + if_stmt_output: str, + scan_node: dace_nodes.NestedSDFG, + scan_compute_st: dace.SDFGState, + compute_src_node: dace_nodes.AccessNode, + compute_dst_node: dace_nodes.AccessNode, +): + scan_sdfg = scan_node.sdfg + assert isinstance(compute_dst_node.desc(scan_sdfg), dace.data.Scalar) + + # retrieve the source data outside the scan map scope + assert len(list(state.in_edges_by_connector(scan_node, compute_src_node.data))) == 1 + top_level_input_edge = next(state.in_edges_by_connector(scan_node, compute_src_node.data)) + assert isinstance(top_level_input_edge.src, dace_nodes.MapEntry) + map_entry_in_edge_conn = "IN_" + top_level_input_edge.src_conn[4:] + assert ( + len(list(state.in_edges_by_connector(top_level_input_edge.src, map_entry_in_edge_conn))) + == 1 + ) + top_level_src_node = next( + state.in_edges_by_connector(top_level_input_edge.src, map_entry_in_edge_conn) + ).src + assert isinstance(top_level_src_node, dace_nodes.AccessNode) + assert top_level_src_node.desc(sdfg).transient + + # retrieve the outer write to destination buffer in the compute state + assert scan_compute_st.out_degree(compute_dst_node) == 1 + scan_output_edge = scan_compute_st.out_edges(compute_dst_node)[0] + assert isinstance(scan_output_edge.dst, dace_nodes.AccessNode) + output_node = scan_output_edge.dst + output_desc = output_node.desc(scan_sdfg) + assert not output_desc.transient + output_data = output_node.data + assert output_data.startswith("__gtir_scan_output") + output_subset = scan_output_edge.data.get_dst_subset(scan_output_edge, scan_compute_st) + assert output_subset.num_elements() == 1 + + if_stmt_sdfg = if_stmt_node.sdfg + if_stmt_global_output, _ = if_stmt_sdfg.add_scalar( + output_data, output_desc.dtype, find_new_name=True + ) + if_stmt_node.add_out_connector(if_stmt_global_output) + + for if_stmt_state in if_stmt_sdfg.states(): + if if_stmt_state != if_stmt_else_state: + sink_nodes = [ + node for node in if_stmt_state.sink_nodes() if node.data == if_stmt_output + ] + assert len(sink_nodes) <= 1 + if sink_nodes: + local_output_node = sink_nodes[0] + assert if_stmt_state.in_degree(local_output_node) == 1 + if_stmt_output_edge = if_stmt_state.in_edges(local_output_node)[0] + src_subset = if_stmt_output_edge.data.get_src_subset( + if_stmt_output_edge, if_stmt_state + ) + if_stmt_state.add_edge( + if_stmt_output_edge.src, + if_stmt_output_edge.src_conn, + if_stmt_state.add_access(if_stmt_global_output), + None, + dace.Memlet(data=if_stmt_global_output, subset="0", other_subset=src_subset), + ) + + scan_compute_st.add_edge( + if_stmt_node, + if_stmt_global_output, + output_node, + None, + dace.Memlet(data=output_data, subset=output_subset), + ) + scan_compute_st.remove_edge(scan_output_edge) + + # retrieve the destination data outside the scan map scope + assert len(list(state.out_edges_by_connector(scan_node, output_data))) == 1 + map_exit_in_edge = next(state.out_edges_by_connector(scan_node, output_data)) + assert isinstance(map_exit_in_edge.dst, dace_nodes.MapExit) + map_exit_out_edge_conn = "OUT_" + map_exit_in_edge.dst_conn[3:] + assert ( + len(list(state.out_edges_by_connector(map_exit_in_edge.dst, map_exit_out_edge_conn))) == 1 + ) + map_exit_out_edge = next( + state.out_edges_by_connector(map_exit_in_edge.dst, map_exit_out_edge_conn) + ) + top_level_dst_node = map_exit_out_edge.dst + assert isinstance(top_level_dst_node, dace_nodes.AccessNode) + top_level_dst_node_subset = map_exit_out_edge.data.get_dst_subset(map_exit_out_edge, state) + + # replace the source node outside the scan map with the output node + new_top_level_src_node = state.add_access(top_level_dst_node.data) + _replace_scan_input( + sdfg=sdfg, + state=state, + old_node=top_level_src_node, + new_node=new_top_level_src_node, + new_node_offsets=top_level_dst_node_subset.min_element(), + ) + print( + f"Removed self-copy in {if_stmt_node.label}: {top_level_src_node.data} -> {compute_src_node.data} -> {compute_dst_node.data} -> {output_data} -> {top_level_dst_node.data}" + ) + + +def _graupel_run_self_copy_removal_inside_if_stmt( # noqa: PLR0912, PLR0915 + sdfg: dace.SDFG, + state: dace.SDFGState, + scan_node: dace_nodes.NestedSDFG, + scan_compute_st: dace.SDFGState, + scan_update_st: dace.SDFGState, + if_stmt_node: dace_nodes.NestedSDFG, +) -> None: + scan_sdfg = scan_node.sdfg + nsdfg = if_stmt_node.sdfg + assert len(nsdfg.nodes()) == 1 and isinstance( + nsdfg.nodes()[0], dace_sdfg.state.ConditionalBlock + ) + if_region = nsdfg.nodes()[0] + assert len(list(br[1] for br in if_region.branches if br[0] is None)) == 1 + else_br = next(br[1] for br in if_region.branches if br[0] is None) + assert isinstance(else_br.start_block, dace.SDFGState) + assert len(if_region.out_degree(else_br.start_block)) == 0 + else_st = else_br.start_block + src_nodes = [node for node in else_st.source_nodes() if isinstance(node, dace_nodes.AccessNode)] + + for src_node in src_nodes: + assert not src_node.desc(nsdfg).transient + if else_st.out_degree(src_node) != 1: + continue + self_copy_edge = else_st.out_edges(src_node)[0] + dst_node = self_copy_edge.dst + if else_st.out_degree(dst_node) != 0: + continue + assert not dst_node.desc(nsdfg).transient + + # retrieve the source data to copy in the compute state + assert len(list(scan_compute_st.in_edges_by_connector(if_stmt_node, src_node.data))) == 1 + compute_read_edge = next(scan_compute_st.in_edges_by_connector(if_stmt_node, src_node.data)) + compute_src_node = compute_read_edge.src + assert isinstance(compute_src_node, dace_nodes.AccessNode) + + # retrieve the destination node in the compute state, where the data is written + assert len(list(scan_compute_st.out_edges_by_connector(if_stmt_node, dst_node.data))) == 1 + compute_write_edge = next( + scan_compute_st.out_edges_by_connector(if_stmt_node, dst_node.data) + ) + compute_dst_node = compute_write_edge.dst + assert ( + isinstance(compute_dst_node, dace_nodes.AccessNode) + and scan_compute_st.in_degree(compute_dst_node) == 1 + ) + output_data_name = compute_dst_node.data + scan_update_last_level = scan_node.sdfg.nodes()[-1] + assert "scan_last_level" in scan_update_last_level.label + # If only the last level of the output data is updated then don't consider it as a self-copy as we only write in the end only its last level + if any( + node + for node in scan_update_last_level.nodes() + if isinstance(node, dace_nodes.AccessNode) and node.data == compute_dst_node.data + ): + continue + + # retrieve the data access inside the scan update state + update_src_nodes = [ + node for node in scan_update_st.source_nodes() if node.data == output_data_name + ] + update_dst_nodes = [ + node + for node in scan_compute_st.sink_nodes() + if node.data == output_data_name and not node.desc(scan_compute_st).transient + ] + assert (len(update_src_nodes) <= 1 and len(update_dst_nodes) == 0) or ( + len(update_src_nodes) == 0 and len(update_dst_nodes) <= 1 + ) + if not update_src_nodes and not update_dst_nodes: + continue + if update_src_nodes: + update_src_node = update_src_nodes[0] + assert scan_update_st.out_degree(update_src_node) == 1 + update_write_edge = scan_update_st.out_edges(update_src_node)[0] + update_dst_node = update_write_edge.dst + assert isinstance(update_dst_node, dace_nodes.AccessNode) + assert ( + scan_update_st.in_degree(update_dst_node) == 1 + and scan_update_st.out_degree(update_dst_node) == 0 + ) + else: + update_dst_node = update_dst_nodes[0] + assert scan_compute_st.in_degree(update_dst_node) == 1 + assert scan_compute_st.out_degree(update_dst_node) == 0 + + if compute_src_node.desc(scan_sdfg).transient: # Handles some scalar variables + _cleanup_local_self_update( + scan_sdfg=scan_sdfg, + if_stmt_node=if_stmt_node, + if_stmt_conn=dst_node.data, + compute_src_node=compute_src_node, + compute_dst_node=compute_dst_node, + update_src_node=update_src_node, + update_dst_node=update_dst_node, + scan_compute_st=scan_compute_st, + scan_update_st=scan_update_st, + ) + else_st.remove_nodes_from([src_node, dst_node]) + elif update_src_nodes: # handles `q_out_*` AccessNodes + _cleanup_global_self_update( + sdfg=sdfg, + state=state, + if_stmt_node=if_stmt_node, + if_stmt_else_state=else_st, + scan_node=scan_node, + scan_compute_st=scan_compute_st, + if_stmt_output=dst_node.data, + compute_src_node=compute_src_node, + compute_dst_node=compute_dst_node, + ) + else: # Handles `t_out`. `update_dst_nodes`` is not empty + # replace the input and propagate the changes to all the edges + # remove the copy in the false branch + if update_dst_node.desc(scan_node.sdfg).transient: + continue + self_copy_edge_src_data = self_copy_edge.src.data + map_entry_in_edge = next( + iter(state.in_edges_by_connector(scan_node, self_copy_edge_src_data)) + ) + # Means that there is no other computation before + if not isinstance(map_entry_in_edge.src, dace_nodes.MapEntry): + continue + outer_data_map_out_connector = map_entry_in_edge.src_conn + outer_access_node = next( + iter( + state.in_edges_by_connector( + map_entry_in_edge.src, "IN_" + outer_data_map_out_connector[4:] + ) + ) + ).src + map_exit_in_edge = next(iter(state.out_edges_by_connector(scan_node, output_data_name))) + # Means that there is no computation with this after the scan + if not isinstance(map_exit_in_edge.dst, dace_nodes.MapExit): + continue + outer_data_map_in_connector = map_exit_in_edge.dst_conn + outer_dst_node = next( + iter( + state.out_edges_by_connector( + map_exit_in_edge.dst, "OUT_" + outer_data_map_in_connector[3:] + ) + ) + ).dst + # We just output to the AccessNode + if not isinstance(outer_dst_node, dace_nodes.AccessNode): + continue + new_in_access_node = state.add_access(outer_dst_node.data) + map_exit_out_edge = next( + iter( + state.out_edges_by_connector( + map_exit_in_edge.dst, "OUT_" + map_exit_in_edge.dst_conn[3:] + ) + ) + ) + _replace_scan_input( + sdfg=sdfg, + state=state, + old_node=outer_access_node, + new_node=new_in_access_node, + new_node_offsets=map_exit_out_edge.data.get_dst_subset( + map_exit_out_edge, state + ).min_element(), + ) + else_st.remove_nodes_from([self_copy_edge.src, self_copy_edge.dst]) + print( + f"Removed self-copy in {if_stmt_node.label}: {outer_access_node.data} -> {src_node.data} -> {dst_node.data} -> {output_data_name} -> {outer_dst_node.data}" + ) + + if else_st.is_empty(): + if_region.remove_branch(else_br) + + +def remove_self_copy_inside_scan(sdfg: dace.SDFG) -> None: + assert len(sdfg.states()) == 1 + st = sdfg.states()[0] + assert ( + len( + list( + node + for node in st.nodes() + if isinstance(node, dace_nodes.NestedSDFG) and node.label.startswith("scan_") + ) + ) + == 1 + ) + scan_nsdfg_node = next( + node + for node in st.nodes() + if isinstance(node, dace_nodes.NestedSDFG) and node.label.startswith("scan_") + ) + scan_sdfg = scan_nsdfg_node.sdfg + assert len(scan_sdfg.nodes()) == 3 + assert isinstance(scan_sdfg.nodes()[1], dace_sdfg.state.LoopRegion) + loop_regions = [ + scan_sdfg_node + for scan_sdfg_node in scan_sdfg.nodes() + if isinstance(scan_sdfg_node, dace_sdfg.state.LoopRegion) + ] + assert len(loop_regions) == 1 + scan_loop = next(iter(loop_regions)) + assert len(scan_loop.nodes()) == 2 and all( + isinstance(node, dace.SDFGState) for node in scan_loop.nodes() + ) + if scan_loop.nodes()[0].label.startswith("scan_compute"): + assert scan_loop.nodes()[1].label.startswith("scan_update") + scan_compute_st, scan_update_st = scan_loop.nodes() + else: + assert scan_loop.nodes()[0].label.startswith("scan_update") + scan_update_st, scan_compute_st = scan_loop.nodes() + + if_stmt_nodes = [ + node + for node in scan_compute_st.nodes() + if isinstance(node, dace_nodes.NestedSDFG) and node.label.startswith("if_stmt_") + ] + for if_stmt_node in if_stmt_nodes: + _graupel_run_self_copy_removal_inside_if_stmt( + sdfg, st, scan_nsdfg_node, scan_compute_st, scan_update_st, if_stmt_node + ) + + sdfg.validate() + + for input_access_nodes in ["te", "q_in_2", "q_in_3", "q_in_4", "q_in_5"]: + all_maps_with_accessnode_input = [ + node + for node in st.nodes() + if isinstance(node, dace_nodes.MapEntry) + and f"IN_{input_access_nodes}" in node.in_connectors + ] + all_maps_with_accessnode_input_and_if_stmt = [ + map_with_if + for map_with_if in all_maps_with_accessnode_input + if any( + isinstance(map_node, dace_nodes.NestedSDFG) + and map_node.label.startswith("if_stmt_") + for map_node in st.scope_subgraph(map_with_if).nodes() + ) + ] + map_with_accessnode_input_and_if_stmt = next( + iter(all_maps_with_accessnode_input_and_if_stmt) + ) + nsdfg_if_stmt_with_accessnode = next( + iter( + [ + node + for node in st.scope_subgraph(map_with_accessnode_input_and_if_stmt).nodes() + if isinstance(node, dace_nodes.NestedSDFG) + ] + ) + ) + nsdfg_conditional_block = nsdfg_if_stmt_with_accessnode.sdfg.nodes()[0] + else_branch = nsdfg_conditional_block.branches[1][1] + else_branch_state = else_branch.nodes()[0] + output_edges = [ + edge + for edge in st.out_edges_by_connector( + nsdfg_if_stmt_with_accessnode, + next(iter(nsdfg_if_stmt_with_accessnode.out_connectors.keys())), + ) + ] + assert len(output_edges) == 1 + output_edge = next(iter(output_edges)) + intermediate_an = output_edge.dst + assert isinstance(intermediate_an, dace_nodes.AccessNode) + out_edge_of_inter_an = st.out_edges(intermediate_an)[0] + dst_out_edge_of_inter_an = out_edge_of_inter_an.dst + assert isinstance(dst_out_edge_of_inter_an, dace_nodes.MapExit) + out_edges_of_map_exit = [ + oedge_map_exit + for oedge_map_exit in st.out_edges_by_connector( + dst_out_edge_of_inter_an, "OUT_" + out_edge_of_inter_an.dst_conn[3:] + ) + ] + assert len(out_edges_of_map_exit) == 1 + out_edge_of_map_exit = next(iter(out_edges_of_map_exit)) + dst_out_edge_of_map_exit = out_edge_of_map_exit.dst + assert isinstance(dst_out_edge_of_map_exit, dace_nodes.AccessNode) + new_memlet = dace.Memlet( + data=out_edge_of_inter_an.data.data, + subset=copy.deepcopy(out_edge_of_inter_an.data.subset), + other_subset=copy.deepcopy(output_edge.data.subset), + ) + new_output_edge = dace_transformation.helpers.redirect_edge( + state=st, + edge=output_edge, + new_dst=dst_out_edge_of_inter_an, + new_dst_conn=out_edge_of_inter_an.dst_conn, + new_memlet=new_memlet, + ) + new_output_edge.data.allow_oob = True + st.remove_edge(out_edge_of_inter_an) + st.remove_node(intermediate_an) + sdfg.arrays.pop(intermediate_an.data) + if ( + "else_body" in else_branch.name + and len(else_branch_state.nodes()) == 2 + and all(isinstance(node, dace_nodes.AccessNode) for node in else_branch_state.nodes()) + ): + print( + f"Removed self-copy in {nsdfg_if_stmt_with_accessnode.label} for '{input_access_nodes}' by removing the else branch" + ) + else_branch_state.sdfg.remove_nodes_from( + [internal_node for internal_node in else_branch_state.nodes()] + ) + nsdfg_conditional_block.remove_branch(else_branch) + sdfg.validate() + + +def rename_intermediate_access_nodes(sdfg: dace.SDFG) -> None: + assert len(sdfg.states()) == 1 + st: dace.SDFGState = sdfg.states()[0] + access_node_renaming_dict = { + "q_out_2": "q_in_2", + "q_out_3": "q_in_3", + "q_out_4": "q_in_4", + "q_out_5": "q_in_5", + "t_out": "te", + } + for node in st.nodes(): + if isinstance(node, dace_nodes.MapEntry) and st.scope_dict()[node] is None: + map_entry = node + map_exit = st.exit_node(map_entry) + map_entry_input_data = [in_edge.src.data for in_edge in st.in_edges(map_entry)] + map_exit_output_data = [out_edge.dst.data for out_edge in st.out_edges(map_exit)] + if all(key in map_exit_output_data for key in access_node_renaming_dict) and all( + key in map_entry_input_data for key in access_node_renaming_dict.values() + ): + in_out_dict = {} + for out_edge in st.out_edges(map_exit): + if out_edge.dst.data in access_node_renaming_dict: + new_data_name = access_node_renaming_dict[out_edge.dst.data] + input_node = next( + in_edge.src + for in_edge in st.in_edges(map_entry) + if in_edge.src.data == new_data_name + ) + new_access_node = st.add_access(new_data_name) + old_node = out_edge.dst + old_data_name = old_node.data + # The subsets and strides of the renamed data is the same so no reason to have an offset + # Ideally we should update it with the new_subset.min_element() - old_subset.min_element() + new_offset = (0, 0) + for out_edge_of_old_node in st.out_edges(old_node): + new_consumer_edge = gtx_transformations.utils.reroute_edge( + is_producer_edge=False, + current_edge=out_edge_of_old_node, + ss_offset=new_offset, + state=st, + sdfg=sdfg, + old_node=old_node, + new_node=new_access_node, + ) + gtx_transformations.utils.reconfigure_dataflow_after_rerouting( + is_producer_edge=False, + new_edge=new_consumer_edge, + sdfg=sdfg, + state=st, + ss_offset=new_offset, + old_node=old_node, + new_node=new_access_node, + ) + new_producer_edge = gtx_transformations.utils.reroute_edge( + is_producer_edge=True, + current_edge=out_edge, + ss_offset=new_offset, + state=st, + sdfg=sdfg, + old_node=old_node, + new_node=new_access_node, + ) + gtx_transformations.utils.reconfigure_dataflow_after_rerouting( + is_producer_edge=True, + new_edge=new_producer_edge, + sdfg=sdfg, + state=st, + ss_offset=new_offset, + old_node=old_node, + new_node=new_access_node, + ) + st.remove_node(old_node) + in_out_dict[new_data_name] = (input_node, new_access_node) + print( + f"Renamed intermediate AccessNode from '{old_data_name}' to '{new_data_name}' in {map_exit.label}" + ) + gtx_local_double_buffering._add_local_double_buffering_to( + in_out_dict, + map_entry, + st, + sdfg, + ) + # Apply this only to the first map + return + sdfg.validate() diff --git a/model/atmosphere/subgrid_scale_physics/muphys/tests/muphys/integration_tests/test_graupel_only.py b/model/atmosphere/subgrid_scale_physics/muphys/tests/muphys/integration_tests/test_graupel_only.py index f5c42a07f9..a51dd56ac6 100644 --- a/model/atmosphere/subgrid_scale_physics/muphys/tests/muphys/integration_tests/test_graupel_only.py +++ b/model/atmosphere/subgrid_scale_physics/muphys/tests/muphys/integration_tests/test_graupel_only.py @@ -41,19 +41,28 @@ class Experiments: ) +_GRAUPEL_TEST_CASES = [ + (Experiments.MINI, True), + (Experiments.TINY, True), + (Experiments.R2B05, True), + (Experiments.R2B05, False), +] + + @pytest.mark.uses_concat_where @pytest.mark.datatest @pytest.mark.parametrize( - "experiment", - [ - Experiments.MINI, - Experiments.TINY, - Experiments.R2B05, + ("experiment", "enable_dace_hooks"), + _GRAUPEL_TEST_CASES, + ids=[ + f"{exp.name}-dacehooks[{enable_dace_hooks}]" + for exp, enable_dace_hooks in _GRAUPEL_TEST_CASES ], - ids=lambda exp: exp.name, ) def test_graupel_only( - backend_like: model_backends.BackendLike, experiment: utils.MuphysExperiment + backend_like: model_backends.BackendLike, + experiment: utils.MuphysExperiment, + enable_dace_hooks: bool, ) -> None: assert experiment.type == utils.ExperimentType.GRAUPEL_ONLY inp = common.GraupelInput.load( @@ -68,6 +77,7 @@ def test_graupel_only( horizontal_end=inp.ncells, vertical_start=0, vertical_end=inp.nlev, + enable_dace_hooks=enable_dace_hooks, enable_masking=True, # `False` would require different reference data (or relaxing thresholds) ) diff --git a/model/common/pyproject.toml b/model/common/pyproject.toml index 770e91c423..6ac931182b 100644 --- a/model/common/pyproject.toml +++ b/model/common/pyproject.toml @@ -60,6 +60,7 @@ io = [ "uxarray==2024.3.0", "xarray[complete]>=2024.3.0" ] +rocm7_0 = ['amd-cupy>=13.0'] # TODO(havogt): add gt4py[rocm7_0] once available [project.urls] repository = "https://github.com/C2SM/icon4py" diff --git a/pyproject.toml b/pyproject.toml index 5d7c802741..e955df9342 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,6 +103,7 @@ distributed = ["icon4py-common[distributed]"] fortran = ["icon4py-tools~=0.1.0"] io = ["icon4py-common[io]"] profiling = ['viztracer>=1.1.0'] +rocm7_0 = ["icon4py-common[rocm7_0]"] testing = ["icon4py-testing"] [project.urls] @@ -376,7 +377,13 @@ explicit = true name = 'gridtools' url = 'https://gridtools.github.io/pypi/' +[[tool.uv.index]] +explicit = true +name = 'amd' +url = 'https://pypi.amd.com/rocm-7.0.2/simple' + [tool.uv.sources] +amd-cupy = {index = "amd"} dace = {index = "gridtools"} # gt4py = {git = "https://github.com/GridTools/gt4py", branch = "main"} # gt4py = {index = "test.pypi"} diff --git a/tools/src/icon4py/tools/py2fgen/wrappers/muphys_wrapper.py b/tools/src/icon4py/tools/py2fgen/wrappers/muphys_wrapper.py index 03ef3134b6..0a72cc0391 100644 --- a/tools/src/icon4py/tools/py2fgen/wrappers/muphys_wrapper.py +++ b/tools/src/icon4py/tools/py2fgen/wrappers/muphys_wrapper.py @@ -44,6 +44,7 @@ def graupel_run( pflx: gtx.Field[gtx.Dims[dims.CellDim, dims.KDim], ta.wpfloat], pre_gsp: gtx.Field[gtx.Dims[dims.CellDim, dims.KDim], ta.wpfloat], enable_masking: bool, + enable_dace_hooks: bool, wait_for_completion: bool, ) -> None: global graupel_program # noqa: PLW0603 [global-statement] @@ -62,6 +63,7 @@ def graupel_run( vertical_start=kstart, vertical_end=ke, enable_masking=enable_masking, + enable_dace_hooks=enable_dace_hooks, ) q = graupel.Q(qv, qc, qr, qs, qi, qg) diff --git a/uv.lock b/uv.lock index 876ce4f82d..9930e57458 100644 --- a/uv.lock +++ b/uv.lock @@ -36,6 +36,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/32/34/d4e1c02d3bee589efb5dfa17f88ea08bdb3e3eac12bc475462aec52ed223/alabaster-0.7.16-py3-none-any.whl", hash = "sha256:b46733c07dce03ae4e150330b975c75737fa60f0a7c591b6c8bf4928a28e2c92", size = 13511, upload-time = "2024-01-10T00:56:08.388Z" }, ] +[[package]] +name = "amd-cupy" +version = "13.5.1" +source = { registry = "https://pypi.amd.com/rocm-7.0.2/simple" } +dependencies = [ + { name = "fastrlock" }, + { name = "numpy" }, +] +wheels = [ + { url = "https://pypi.amd.com/rocm-7.0.2/packages/amd-cupy/amd_cupy-13.5.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:eca984c7b8176eecaff0dd84504b322828bedd40c177d736753295e8a4b672de" }, + { url = "https://pypi.amd.com/rocm-7.0.2/packages/amd-cupy/amd_cupy-13.5.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:468ca95416f57d5bbf6663ad8ca69a6ac46b4a34166833f01e5535068fa1b4e8" }, + { url = "https://pypi.amd.com/rocm-7.0.2/packages/amd-cupy/amd_cupy-13.5.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:de3138281e2711e06efaf49a31310d0d4824998e18d43e13e288a0e52ca75ec0" }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -1589,6 +1603,9 @@ io = [ profiling = [ { name = "viztracer" }, ] +rocm7-0 = [ + { name = "icon4py-common", extra = ["rocm7-0"] }, +] testing = [ { name = "icon4py-testing" }, ] @@ -1690,13 +1707,14 @@ requires-dist = [ { name = "icon4py-common", extras = ["cuda12"], marker = "extra == 'cuda12'", editable = "model/common" }, { name = "icon4py-common", extras = ["distributed"], marker = "extra == 'distributed'", editable = "model/common" }, { name = "icon4py-common", extras = ["io"], marker = "extra == 'io'", editable = "model/common" }, + { name = "icon4py-common", extras = ["rocm7-0"], marker = "extra == 'rocm7-0'", editable = "model/common" }, { name = "icon4py-driver", editable = "model/driver" }, { name = "icon4py-standalone-driver", editable = "model/standalone_driver" }, { name = "icon4py-testing", marker = "extra == 'testing'", editable = "model/testing" }, { name = "icon4py-tools", marker = "extra == 'fortran'", editable = "tools" }, { name = "viztracer", marker = "extra == 'profiling'", specifier = ">=1.1.0" }, ] -provides-extras = ["all", "cuda11", "cuda12", "distributed", "fortran", "io", "profiling", "testing"] +provides-extras = ["all", "cuda11", "cuda12", "distributed", "fortran", "io", "profiling", "rocm7-0", "testing"] [package.metadata.requires-dev] build = [ @@ -1922,9 +1940,13 @@ io = [ { name = "uxarray" }, { name = "xarray", extra = ["complete"] }, ] +rocm7-0 = [ + { name = "amd-cupy" }, +] [package.metadata] requires-dist = [ + { name = "amd-cupy", marker = "extra == 'rocm7-0'", specifier = ">=13.0", index = "https://pypi.amd.com/rocm-7.0.2/simple" }, { name = "array-api-compat", specifier = ">=1.13.0" }, { name = "cartopy", marker = "extra == 'io'", specifier = ">=0.22.0" }, { name = "cftime", marker = "extra == 'io'", specifier = ">=1.6.3" }, @@ -1950,7 +1972,7 @@ requires-dist = [ { name = "uxarray", marker = "extra == 'io'", specifier = "==2024.3.0" }, { name = "xarray", extras = ["complete"], marker = "extra == 'io'", specifier = ">=2024.3.0" }, ] -provides-extras = ["all", "cuda11", "cuda12", "distributed", "io"] +provides-extras = ["all", "cuda11", "cuda12", "distributed", "io", "rocm7-0"] [[package]] name = "icon4py-driver"