diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py index e07c3d9492..dc208abe1f 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py @@ -587,16 +587,16 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: ) deref_node = self._add_tasklet( "deref", - {"field"} | set(index_connectors), - {"val"}, - code=f"val = field[{index_internals}]", + {"__field"} | set(index_connectors), + {"__val"}, + code=f"__val = __field[{index_internals}]", ) # add new termination point for the field parameter self._add_input_data_edge( arg_expr.field, dace_subsets.Range.from_array(field_desc), deref_node, - "field", + "__field", src_offset=[offset for (_, offset) in arg_expr.field_domain], ) @@ -622,7 +622,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: else: assert isinstance(index_expr, SymbolExpr) - return self._construct_tasklet_result(field_desc.dtype, deref_node, "val") + return self._construct_tasklet_result(field_desc.dtype, deref_node, "__val") def _visit_if_branch_arg( self, @@ -1090,24 +1090,27 @@ def _visit_list_get(self, node: gtir.FunCall) -> ValueExpr: ) elif isinstance(index_arg, ValueExpr): tasklet_node = self._add_tasklet( - "list_get", inputs={"index", "list"}, outputs={"value"}, code="value = list[index]" + "list_get", + inputs={"__index", "__data"}, + outputs={"__val"}, + code="__val = __data[__index]", ) self._add_edge( index_arg.dc_node, None, tasklet_node, - "index", + "__index", dace.Memlet(data=index_arg.dc_node.data, subset="0"), ) self._add_edge( list_arg.dc_node, None, tasklet_node, - "list", + "__data", self.sdfg.make_array_memlet(list_arg.dc_node.data), ) self._add_edge( - tasklet_node, "value", result_node, None, dace.Memlet(data=result, subset="0") + tasklet_node, "__val", result_node, None, dace.Memlet(data=result, subset="0") ) else: raise TypeError(f"Unexpected value {index_arg} as index argument.") diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py index 199783d893..1219262f51 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py @@ -11,6 +11,7 @@ from typing import Any, Callable import numpy as np +import sympy from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt @@ -143,6 +144,11 @@ def visit_FunCall(self, node: gtir.FunCall, args_map: dict[str, gtir.Node]) -> s return format_builtin(builtin_name, *args) raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") + def visit_InfinityLiteral( + self, node: gtir.InfinityLiteral, args_map: dict[str, gtir.Node] + ) -> str: + return str(sympy.oo) if node == gtir.InfinityLiteral.POSITIVE else str(-sympy.oo) + def visit_SymRef(self, node: gtir.SymRef, args_map: dict[str, gtir.Node]) -> str: symbol = str(node.id) if symbol in args_map: diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py index d1433764e2..26234b04c5 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py @@ -42,6 +42,7 @@ from gt4py.next.iterator.type_system import inference as gtir_type_inference from gt4py.next.program_processors.runners.dace import ( gtir_domain, + gtir_to_sdfg_concat_where, gtir_to_sdfg_primitives, gtir_to_sdfg_types, gtir_to_sdfg_utils, @@ -745,7 +746,9 @@ def visit_FunCall( head_state: dace.SDFGState, ) -> gtir_to_sdfg_types.FieldopResult: # use specialized dataflow builder classes for each builtin function - if cpm.is_call_to(node, "if_"): + if cpm.is_call_to(node, "concat_where"): + return gtir_to_sdfg_concat_where.translate_concat_where(node, sdfg, head_state, self) + elif cpm.is_call_to(node, "if_"): return gtir_to_sdfg_primitives.translate_if(node, sdfg, head_state, self) elif cpm.is_call_to(node, "index"): return gtir_to_sdfg_primitives.translate_index(node, sdfg, head_state, self) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py new file mode 100644 index 0000000000..3053165003 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py @@ -0,0 +1,495 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements the lowering of concat_where operator. + +This builtin translator implements the `PrimitiveTranslator` protocol as other +translators in `gtir_to_sdfg_primitives` module. +""" + +from __future__ import annotations + +import dace +from dace import subsets as dace_subsets + +from gt4py.next import common as gtx_common, utils as gtx_utils +from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.program_processors.runners.dace import ( + gtir_domain, + gtir_to_sdfg, + gtir_to_sdfg_types, + gtir_to_sdfg_utils, +) +from gt4py.next.type_system import type_specifications as ts + + +def _make_concat_field_slice( + sdfg: dace.SDFG, + state: dace.SDFGState, + field: gtir_to_sdfg_types.FieldopData, + field_desc: dace.data.Array, + concat_dim: gtx_common.Dimension, + concat_dim_index: int, + concat_dim_origin: dace.symbolic.SymbolicType, +) -> tuple[gtir_to_sdfg_types.FieldopData, dace.data.Array]: + """ + Helper function called by `_translate_concat_where_impl` to create a slice along + the concat dimension, that is a new array with an extra dimension and a single + level. This allows to concatanate the input fields along the concat dimension. + """ + assert isinstance(field.gt_type, ts.FieldType) + assert concat_dim not in field.gt_type.dims + dims = [ + *field.gt_type.dims[:concat_dim_index], + concat_dim, + *field.gt_type.dims[concat_dim_index:], + ] + origin = tuple( + [*field.origin[:concat_dim_index], concat_dim_origin, *field.origin[concat_dim_index:]] + ) + shape = tuple([*field_desc.shape[:concat_dim_index], 1, *field_desc.shape[concat_dim_index:]]) + extended_field_data, extended_field_desc = sdfg.add_temp_transient(shape, field_desc.dtype) + extended_field_node = state.add_access(extended_field_data) + state.add_nedge( + field.dc_node, + extended_field_node, + dace.Memlet( + data=field.dc_node.data, + subset=dace_subsets.Range.from_array(field_desc), + other_subset=dace_subsets.Range.from_array(extended_field_desc), + ), + ) + extended_field = gtir_to_sdfg_types.FieldopData( + extended_field_node, ts.FieldType(dims=dims, dtype=field.gt_type.dtype), origin + ) + return extended_field, extended_field_desc + + +def _make_concat_scalar_broadcast( + sdfg: dace.SDFG, + state: dace.SDFGState, + inp: gtir_to_sdfg_types.FieldopData, + inp_desc: dace.data.Array, + domain: gtir_domain.FieldopDomain, + concat_dim_index: int, +) -> tuple[gtir_to_sdfg_types.FieldopData, dace.data.Array]: + """ + Helper function called by `_translate_concat_where_impl` to create a mapped + tasklet that broadcasts one scalar value on the given domain. + + The scalar value can come from either a scalar node or from a 1D-array (assuming + the array represents a field in the concat dimension). + """ + assert isinstance(inp.gt_type, ts.FieldType) + assert len(inp.gt_type.dims) == 1 + out_dims, out_origin, out_shape = _get_concat_where_field_layout(domain, concat_dim_index) + out_type = ts.FieldType(dims=out_dims, dtype=inp.gt_type.dtype) + + out_name, out_desc = sdfg.add_temp_transient(out_shape, inp_desc.dtype) + out_node = state.add_access(out_name) + + map_variables = [gtir_to_sdfg_utils.get_map_variable(dim) for dim in out_dims] + inp_index = ( + "0" + if isinstance(inp.dc_node.desc(sdfg), dace.data.Scalar) + else ( + f"({map_variables[concat_dim_index]} + {out_origin[concat_dim_index] - inp.origin[0]})" + ) + ) + state.add_mapped_tasklet( + "broadcast", + map_ranges=dict(zip(map_variables, dace_subsets.Range.from_array(out_desc), strict=True)), + code="__out = __inp", + inputs={"__inp": dace.Memlet(data=inp.dc_node.data, subset=inp_index)}, + outputs={"__out": dace.Memlet(data=out_name, subset=",".join(map_variables))}, + input_nodes={inp.dc_node}, + output_nodes={out_node}, + external_edges=True, + ) + + out_field = gtir_to_sdfg_types.FieldopData(out_node, out_type, tuple(out_origin)) + return out_field, out_desc + + +def _get_concat_where_field_layout( + domain: gtir_domain.FieldopDomain, concat_dim: gtx_common.Dimension | int +) -> tuple[list[gtx_common.Dimension], list[dace.symbolic.SymExpr], list[dace.symbolic.SymExpr]]: + """ + Helper function that wraps `gtir_domain.get_field_layout()` and adds a check + on the array shape. + + The concat_where domain expressions require special handling, because the lower + bound is not necessarily smaller than the upper. When the upper bound is smaller, + it indicates an empty range to copy. This can only be resolved at runtime, which + is why we use dynamic memlets. + """ + out_dims, out_origin, out_shape = gtir_domain.get_field_layout(domain) + concat_dim_index = ( + out_dims.index(concat_dim) if isinstance(concat_dim, gtx_common.Dimension) else concat_dim + ) + out_shape[concat_dim_index] = dace.symbolic.pystr_to_symbolic( + f"max(0, {out_shape[concat_dim_index]})" + ) + return out_dims, out_origin, out_shape + + +def _translate_concat_where_impl( + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_to_sdfg.SDFGBuilder, + mask_domain: gtir_domain.FieldopDomain, + node_domain: gtir.Expr, + tb_node_domain: gtir.Expr, + fb_node_domain: gtir.Expr, + tb_field: gtir_to_sdfg_types.FieldopData, + fb_field: gtir_to_sdfg_types.FieldopData, +) -> gtir_to_sdfg_types.FieldopData: + """ + Helper function called by `translate_concat_where()` to lower 'concat_where' + on a single output field. + + In case of tuples, this function is called on all fields by means of `tree_map`. + + It builds the output field by concatanating the two input fields on the lower + and upper domain. These two domains are computed from the intersection of the + mask and the input domains. + + Note that 'tb' and 'fb' stand for true/false branch and refer to the two branches + of the concat_where expression, passed as second and third argument, respectively. + + Args: + sdfg: The SDFG where the primitive subgraph should be instantiated + state: The SDFG state where the result of the primitive function should be made available + sdfg_builder: The object responsible for visiting child nodes of the primitive node. + mask_domain: Domain (only for concat dimension) of the true branch, infinite + on lower or upper boundary. + node_domain: Domain (all dimensions) of output field. + tb_node_domain: Domain of the field passed on the true branch. + fb_node_domain: Domain of the field passed on the false branch. + tb_field: Input field on the true branch. + fb_field: Input field on the false branch. + + Returns: + The field resulted from concatanating the input fields on the lower and upper domain. + """ + tb_data_desc, fb_data_desc = (inp.dc_node.desc(sdfg) for inp in [tb_field, fb_field]) + assert tb_data_desc.dtype == fb_data_desc.dtype + + tb_domain, fb_domain = ( + gtir_domain.extract_domain(domain) for domain in [tb_node_domain, fb_node_domain] + ) + assert len(mask_domain) == 1 + concat_dim, mask_lower_bound, mask_upper_bound = mask_domain[0] + + # Expect unbound range in the concat domain expression on lower or upper range: + # - if the domain expression is unbound on lower side (negative infinite), + # the expression on the true branch is to be considered the input for the + # lower domain. + # - viceversa, if the domain expression is unbound on upper side (positive + # infinite), the true expression represents the input for the upper domain. + if mask_lower_bound == gtir_to_sdfg_utils.get_symbolic(gtir.InfinityLiteral.NEGATIVE): + concat_dim_bound = mask_upper_bound + lower, lower_desc, lower_domain = (tb_field, tb_data_desc, tb_domain) + upper, upper_desc, upper_domain = (fb_field, fb_data_desc, fb_domain) + elif mask_upper_bound == gtir_to_sdfg_utils.get_symbolic(gtir.InfinityLiteral.POSITIVE): + concat_dim_bound = mask_lower_bound + lower, lower_desc, lower_domain = (fb_field, fb_data_desc, fb_domain) + upper, upper_desc, upper_domain = (tb_field, tb_data_desc, tb_domain) + else: + raise ValueError(f"Unexpected concat mask {mask_domain[0]}.") + + # we use the concat domain, stored in the annex, as the domain of output field + output_domain = gtir_domain.extract_domain(node_domain) + output_dims, output_origin, output_shape = _get_concat_where_field_layout( + output_domain, concat_dim + ) + concat_dim_index = output_dims.index(concat_dim) + + """ + In case one of the arguments is a scalar value, for example: + ```python + @gtx.field_operator + def testee(a: np.int32, b: cases.IJKField) -> cases.IJKField: + return concat_where(KDim < 1, a, b) + ``` + we convert it to a single-element 1D field with the dimension of the concat expression. + """ + if isinstance(lower.gt_type, ts.ScalarType): + assert len(lower_domain) == 0 + assert isinstance(upper.gt_type, ts.FieldType) + lower = gtir_to_sdfg_types.FieldopData( + lower.dc_node, + ts.FieldType(dims=[concat_dim], dtype=lower.gt_type), + origin=(concat_dim_bound - 1,), + ) + lower_bound = output_domain[concat_dim_index][1] + lower_domain = [(concat_dim, lower_bound, concat_dim_bound)] + elif isinstance(upper.gt_type, ts.ScalarType): + assert len(upper_domain) == 0 + assert isinstance(lower.gt_type, ts.FieldType) + upper = gtir_to_sdfg_types.FieldopData( + upper.dc_node, + ts.FieldType(dims=[concat_dim], dtype=upper.gt_type), + origin=(concat_dim_bound,), + ) + upper_bound = output_domain[concat_dim_index][2] + upper_domain = [(concat_dim, concat_dim_bound, upper_bound)] + + if concat_dim not in lower.gt_type.dims: # type: ignore[union-attr] + """ + The field on the lower domain is to be treated as a slice to add as one + level in the concat dimension, on the lower bound. + Consider for example the following IR, where a horizontal field is added + as level zero in K-dimension: + ```python + @gtx.field_operator + def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField: + return concat_where(KDim == 0, boundary, interior) + ``` + """ + assert ( + lower.gt_type.dims # type: ignore[union-attr] + == [ + *upper.gt_type.dims[0:concat_dim_index], # type: ignore[union-attr] + *upper.gt_type.dims[concat_dim_index + 1 :], # type: ignore[union-attr] + ] + ) + lower, lower_desc = _make_concat_field_slice( + sdfg, state, lower, lower_desc, concat_dim, concat_dim_index, concat_dim_bound - 1 + ) + lower_bound = dace.symbolic.pystr_to_symbolic( + f"max({concat_dim_bound - 1}, {output_domain[concat_dim_index][1]})" + ) + lower_domain.insert(concat_dim_index, (concat_dim, lower_bound, concat_dim_bound)) + elif concat_dim not in upper.gt_type.dims: # type: ignore[union-attr] + # Same as previous case, but the field slice is added on the upper bound. + assert ( + upper.gt_type.dims # type: ignore[union-attr] + == [ + *lower.gt_type.dims[0:concat_dim_index], # type: ignore[union-attr] + *lower.gt_type.dims[concat_dim_index + 1 :], # type: ignore[union-attr] + ] + ) + upper, upper_desc = _make_concat_field_slice( + sdfg, state, upper, upper_desc, concat_dim, concat_dim_index, concat_dim_bound + ) + upper_bound = dace.symbolic.pystr_to_symbolic( + f"min({concat_dim_bound + 1}, {output_domain[concat_dim_index][2]})" + ) + upper_domain.insert(concat_dim_index, (concat_dim, concat_dim_bound, upper_bound)) + elif isinstance(lower_desc, dace.data.Scalar) or ( + len(lower.gt_type.dims) == 1 and len(output_domain) > 1 # type: ignore[union-attr] + ): + """ + The input on the lower domain is either a scalar or a 1d field, representing + the value(s) to be added as one level in the concat dimension below the upper domain. + Consider for example the following IR, where the scalar value is one level + (`KDim == 0`) taken from lower input 'a': + ```python + @gtx.field_operator + def testee(a: cases.KField, b: cases.IJKField) -> cases.IJKField: + return concat_where(KDim == 0, a, b) + ``` + """ + assert len(lower_domain) == 1 and lower_domain[0][0] == concat_dim + lower_domain = [ + *output_domain[:concat_dim_index], + lower_domain[0], + *output_domain[concat_dim_index + 1 :], + ] + lower, lower_desc = _make_concat_scalar_broadcast( + sdfg, state, lower, lower_desc, lower_domain, concat_dim_index + ) + elif isinstance(upper_desc, dace.data.Scalar) or ( + len(upper.gt_type.dims) == 1 and len(output_domain) > 1 # type: ignore[union-attr] + ): + # Same as previous case, but the scalar value is taken from `upper` input. + assert len(upper_domain) == 1 and upper_domain[0][0] == concat_dim + upper_domain = [ + *output_domain[:concat_dim_index], + upper_domain[0], + *output_domain[concat_dim_index + 1 :], + ] + upper, upper_desc = _make_concat_scalar_broadcast( + sdfg, state, upper, upper_desc, upper_domain, concat_dim_index + ) + else: + """ + Handle here the _regular_ case, that is concat_where applied to two fields + with same domain: + ```python + @gtx.field_operator + def testee(a: cases.IJKField, b: cases.IJKField) -> cases.IJKField: + return concat_where(KDim <=10 , a, b) + ``` + """ + assert isinstance(lower.gt_type, ts.FieldType) + assert isinstance(lower_desc, dace.data.Array) + assert isinstance(upper.gt_type, ts.FieldType) + assert isinstance(upper_desc, dace.data.Array) + if lower.gt_type.dims != upper.gt_type.dims: + raise NotImplementedError( + "Lowering concat_where on fields with different domain is not supported." + ) + + # ensure that the arguments have the same domain as the concat result + assert all(ftype.dims == output_dims for ftype in (lower.gt_type, upper.gt_type)) # type: ignore[union-attr] + + lower_range_0 = output_domain[concat_dim_index][1] + lower_range_1 = dace.symbolic.pystr_to_symbolic( + f"max({lower_range_0}, {lower_domain[concat_dim_index][2]})" + ) + lower_range_size = lower_range_1 - lower_range_0 + + upper_range_1 = output_domain[concat_dim_index][2] + upper_range_0 = dace.symbolic.pystr_to_symbolic( + f"min({upper_range_1}, {upper_domain[concat_dim_index][1]})" + ) + upper_range_size = upper_range_1 - upper_range_0 + + output, output_desc = sdfg_builder.add_temp_array(sdfg, output_shape, lower_desc.dtype) + output_node = state.add_access(output) + + lower_subset = [] + lower_output_subset = [] + upper_subset = [] + upper_output_subset = [] + for dim_index, size in enumerate(output_desc.shape): + if dim_index == concat_dim_index: + lower_subset.append( + ( + lower_range_0 - lower.origin[dim_index], + lower_range_1 - lower.origin[dim_index] - 1, + 1, + ) + ) + upper_subset.append( + ( + upper_range_0 - upper.origin[dim_index], + upper_range_1 - upper.origin[dim_index] - 1, + 1, + ) + ) + # we write the data of the lower range into the output array starting + # from the index zero + lower_output_subset.append((0, lower_range_size - 1, 1)) + # the upper range should be written next to the lower range, so the + # destination subset does not start from index zero + upper_output_subset.append( + ( + lower_range_size, + lower_range_size + upper_range_size - 1, + 1, + ) + ) + else: + lower_subset.append( + ( + output_domain[dim_index][1] - lower.origin[dim_index], + output_domain[dim_index][1] - lower.origin[dim_index] + size - 1, + 1, + ) + ) + upper_subset.append( + ( + output_domain[dim_index][1] - upper.origin[dim_index], + output_domain[dim_index][1] - upper.origin[dim_index] + size - 1, + 1, + ) + ) + + lower_output_subset.append((0, size - 1, 1)) + upper_output_subset.append((0, size - 1, 1)) + + # use dynamic memlets because the subset range could be empty, but this is known only at runtime + state.add_nedge( + lower.dc_node, + output_node, + dace.Memlet( + data=lower.dc_node.data, + subset=dace_subsets.Range(lower_subset), + other_subset=dace_subsets.Range(lower_output_subset), + dynamic=True, + ), + ) + state.add_nedge( + upper.dc_node, + output_node, + dace.Memlet( + data=upper.dc_node.data, + subset=dace_subsets.Range(upper_subset), + other_subset=dace_subsets.Range(upper_output_subset), + dynamic=True, + ), + ) + + return gtir_to_sdfg_types.FieldopData(output_node, lower.gt_type, origin=tuple(output_origin)) + + +def translate_concat_where( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_to_sdfg.SDFGBuilder, +) -> gtir_to_sdfg_types.FieldopResult: + """ + Lowers a `concat_where` expression to a dataflow where two memlets write + disjoint subsets, for the lower and upper domain, on one data access node. + + Implements the `PrimitiveTranslator` protocol. + """ + assert cpm.is_call_to(node, "concat_where") + assert len(node.args) == 3 + + # First argument is a domain expression that defines the mask of the true branch: + # we extract the dimension along which we need to concatenate the field arguments, + # and determine whether the true branch argument should be on the lower or upper + # range with respect to the boundary value. + mask_domain = gtir_domain.extract_domain(node.args[0]) + if len(mask_domain) != 1: + raise NotImplementedError("Expected `concat_where` along single axis.") + + # we visit the field arguments for the true and false branch + tb, fb = (sdfg_builder.visit(node.args[i], sdfg=sdfg, head_state=state) for i in [1, 2]) + + return ( + _translate_concat_where_impl( + sdfg, + state, + sdfg_builder, + mask_domain, + node.annex.domain, + node.args[1].annex.domain, + node.args[2].annex.domain, + tb, + fb, + ) + if isinstance(node.type, ts.FieldType) + else gtx_utils.tree_map( + lambda _node_domain, + _tb_node_domain, + _fb_node_domain, + _tb_field, + _fb_field, + _sdfg_builder=sdfg_builder, + _sdfg=sdfg, + _state=state, + _mask_domain=mask_domain: _translate_concat_where_impl( + _sdfg, + _state, + _sdfg_builder, + _mask_domain, + _node_domain, + _tb_node_domain, + _fb_node_domain, + _tb_field, + _fb_field, + ) + )(node.annex.domain, node.args[1].annex.domain, node.args[2].annex.domain, tb, fb) + ) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py index 7f9e13561e..fe75b3d0ad 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py @@ -26,6 +26,9 @@ gtir_to_sdfg_utils, utils as gtx_dace_utils, ) +from gt4py.next.program_processors.runners.dace.gtir_to_sdfg_concat_where import ( + translate_concat_where, +) from gt4py.next.program_processors.runners.dace.gtir_to_sdfg_scan import translate_scan from gt4py.next.type_system import type_info as ti, type_specifications as ts @@ -766,6 +769,7 @@ def translate_symbol_ref( # Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol __primitive_translators: list[PrimitiveTranslator] = [ translate_as_fieldop, + translate_concat_where, translate_if, translate_index, translate_literal, diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index ad3ff4bbfc..0bff0b0aa7 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -160,7 +160,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_SCAN_IN_STENCIL, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), (USES_TUPLE_ITERATOR, XFAIL, UNSUPPORTED_MESSAGE), - (USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), ] ) EMBEDDED_SKIP_LIST = [ diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 8e56967a3a..7a46296ce7 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -28,6 +28,7 @@ Cell, Edge, IDim, + JDim, MeshDescriptor, V2EDim, Vertex, @@ -2092,3 +2093,135 @@ def test_gtir_index(): sdfg(v, **FSYMBOLS) np.allclose(v, ref) + + +def test_gtir_concat_where(): + SUBSET_SIZE = 5 + assert SUBSET_SIZE < N + domain = im.domain(gtx_common.GridType.CARTESIAN, {IDim: (0, N)}) + domain_cond_lhs = im.domain( + gtx_common.GridType.CARTESIAN, {IDim: (gtir.InfinityLiteral.NEGATIVE, N - SUBSET_SIZE)} + ) + domain_cond_rhs = im.domain( + gtx_common.GridType.CARTESIAN, {IDim: (SUBSET_SIZE, gtir.InfinityLiteral.POSITIVE)} + ) + domain_lhs = im.domain(gtx_common.GridType.CARTESIAN, {IDim: (0, N - SUBSET_SIZE)}) + domain_rhs = im.domain(gtx_common.GridType.CARTESIAN, {IDim: (N - SUBSET_SIZE, N)}) + + concat_expr_lhs = im.concat_where( + domain_cond_lhs, + im.as_fieldop("deref", domain_lhs)("x"), + im.as_fieldop("deref", domain_rhs)("y"), + ) + concat_expr_rhs = im.concat_where( + domain_cond_rhs, + im.as_fieldop("deref", domain_rhs)("y"), + im.as_fieldop("deref", domain_lhs)("x"), + ) + + a = np.random.rand(N) + b = np.random.rand(N) + ref = np.concatenate((a[:SUBSET_SIZE], b[SUBSET_SIZE:])) + + for concat_expr, suffix in [(concat_expr_lhs, "lhs"), (concat_expr_rhs, "rhs")]: + testee = gtir.Program( + id=f"gtir_concat_where_{suffix}", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=ts.FieldType(dims=[IDim], dtype=SIZE_TYPE)), + gtir.Sym(id="y", type=ts.FieldType(dims=[IDim], dtype=SIZE_TYPE)), + gtir.Sym(id="z", type=ts.FieldType(dims=[IDim], dtype=SIZE_TYPE)), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=concat_expr, + domain=domain, + target=gtir.SymRef(id="z"), + ) + ], + ) + + # run domain inference in order to add the domain annex information to the concat_where node. + testee = infer_domain.infer_program(testee, offset_provider=CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) + c = np.empty_like(a) + + sdfg(a, b, c, **FSYMBOLS) + np.allclose(c, ref) + + +def test_gtir_concat_where_two_dimensions(): + M, N = (30, 20) + domain = im.domain(gtx_common.GridType.CARTESIAN, {IDim: (0, 30), JDim: (0, 20)}) + domain_cond1 = im.domain( + gtx_common.GridType.CARTESIAN, {JDim: (10, gtir.InfinityLiteral.POSITIVE)} + ) + domain_cond2 = im.domain( + gtx_common.GridType.CARTESIAN, {IDim: (gtir.InfinityLiteral.NEGATIVE, 20)} + ) + domain1 = im.domain(gtx_common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 20)}) + domain2 = im.domain(gtx_common.GridType.CARTESIAN, {IDim: (20, 30), JDim: (10, 20)}) + domain3 = im.domain(gtx_common.GridType.CARTESIAN, {IDim: (0, 30), JDim: (0, 10)}) + + testee = gtir.Program( + id=f"gtir_concat_where_two_dimensions", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=ts.FieldType(dims=[IDim, JDim], dtype=FLOAT_TYPE)), + gtir.Sym(id="y", type=ts.FieldType(dims=[IDim, JDim], dtype=FLOAT_TYPE)), + gtir.Sym(id="w", type=ts.FieldType(dims=[IDim, JDim], dtype=FLOAT_TYPE)), + gtir.Sym(id="z", type=ts.FieldType(dims=[IDim, JDim], dtype=FLOAT_TYPE)), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.concat_where( + domain_cond1, # 0, 30; 10,20 + im.concat_where( + domain_cond2, + im.as_fieldop("deref", domain1)("x"), + im.as_fieldop("deref", domain2)("y"), + ), + im.as_fieldop("deref", domain3)("w"), + ), + domain=domain, + target=gtir.SymRef(id="z"), + ) + ], + ) + + a = np.random.rand(M, N) + b = np.random.rand(M, N) + c = np.random.rand(M, N) + d = np.empty_like(a) + ref = np.concatenate( + (c[:, :10], np.concatenate((a[:20, :], b[20:, :]), axis=0)[:, 10:]), axis=1 + ) + + field_symbols = { + "__x_0_range_1": a.shape[0], + "__x_1_range_1": a.shape[1], + "__x_stride_0": a.strides[0] // a.itemsize, + "__x_stride_1": a.strides[1] // a.itemsize, + "__y_0_range_1": b.shape[0], + "__y_1_range_1": b.shape[1], + "__y_stride_0": b.strides[0] // b.itemsize, + "__y_stride_1": b.strides[1] // b.itemsize, + "__w_0_range_1": c.shape[0], + "__w_1_range_1": c.shape[1], + "__w_stride_0": c.strides[0] // c.itemsize, + "__w_stride_1": c.strides[1] // c.itemsize, + "__z_0_range_1": d.shape[0], + "__z_1_range_1": d.shape[1], + "__z_stride_0": d.strides[0] // d.itemsize, + "__z_stride_1": d.strides[1] // d.itemsize, + } + + # run domain inference in order to add the domain annex information to the concat_where node. + testee = infer_domain.infer_program(testee, offset_provider=CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) + + sdfg(a, b, c, d, **field_symbols) + + np.allclose(d, ref)