Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions src/gt4py/next/program_processors/runners/dace/gtir_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from __future__ import annotations

import dataclasses
from typing import Optional, Sequence, TypeAlias

import dace
Expand All @@ -19,15 +20,24 @@
from gt4py.next.program_processors.runners.dace import gtir_to_sdfg_utils


FieldopDomain: TypeAlias = list[
tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]
]
"""
Domain of a field operator represented as a list of tuples with 3 elements:
- dimension definition
- symbolic expression for lower bound (inclusive)
- symbolic expression for upper bound (exclusive)
"""
@dataclasses.dataclass(frozen=True)
class FieldopDomainRange:
"""
Represents the range of a field operator domain in one dimension.

It contains 3 elements:
dim: dimension definition
start: symbolic expression for lower bound (inclusive)
stop: symbolic expression for upper bound (exclusive)
"""

dim: gtx_common.Dimension
start: dace.symbolic.SymbolicType
stop: dace.symbolic.SymbolicType


FieldopDomain: TypeAlias = list[FieldopDomainRange]
"""Domain of a field operator represented as a list of `FieldopDomainRange` for each dimension."""


def extract_domain(node: gtir.Expr) -> FieldopDomain:
Expand All @@ -49,12 +59,12 @@ def extract_domain(node: gtir.Expr) -> FieldopDomain:
gtir_to_sdfg_utils.get_symbolic(arg) for arg in named_range.args[1:3]
)
dim = gtx_common.Dimension(axis.value, axis.kind)
domain.append((dim, lower_bound, upper_bound))
domain.append(FieldopDomainRange(dim, lower_bound, upper_bound))

elif isinstance(node, domain_utils.SymbolicDomain):
for dim, drange in node.ranges.items():
domain.append(
(
FieldopDomainRange(
dim,
gtir_to_sdfg_utils.get_symbolic(drange.start),
gtir_to_sdfg_utils.get_symbolic(drange.stop),
Expand Down Expand Up @@ -119,6 +129,7 @@ def get_field_layout(
"""
if len(domain) == 0:
return [], [], []
domain_dims, domain_lbs, domain_ubs = zip(*domain)
domain_sizes = [(ub - lb) for lb, ub in zip(domain_lbs, domain_ubs)]
return list(domain_dims), list(domain_lbs), domain_sizes
domain_dims = [domain_range.dim for domain_range in domain]
domain_origin = [domain_range.start for domain_range in domain]
domain_shape = [(domain_range.stop - domain_range.start) for domain_range in domain]
return domain_dims, domain_origin, domain_shape
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def _make_access_index_for_field(
# since the access indices have to follow the order of dimensions in field domain
if isinstance(data.gt_type, ts.FieldType) and len(data.gt_type.dims) != 0:
assert data.origin is not None
domain_ranges = {dim: (lb, ub) for dim, lb, ub in domain}
domain_ranges = {
domain_range.dim: (domain_range.start, domain_range.stop) for domain_range in domain
}
return dace.subsets.Range(
(domain_ranges[dim][0] - origin, domain_ranges[dim][1] - origin - 1, 1)
for dim, origin in zip(data.gt_type.dims, data.origin, strict=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,20 +185,20 @@ def _translate_concat_where_impl(
gtir_domain.extract_domain(domain) for domain in [tb_node_domain, fb_node_domain]
)
assert len(mask_domain) == 1
concat_dim, mask_lower_bound, mask_upper_bound = mask_domain[0]
concat_domain = mask_domain[0]

# Expect unbound range in the concat domain expression on lower or upper range:
# - if the domain expression is unbound on lower side (negative infinite),
# the expression on the true branch is to be considered the input for the
# lower domain.
# - viceversa, if the domain expression is unbound on upper side (positive
# infinite), the true expression represents the input for the upper domain.
if mask_lower_bound == gtir_to_sdfg_utils.get_symbolic(gtir.InfinityLiteral.NEGATIVE):
concat_dim_bound = mask_upper_bound
if concat_domain.start == gtir_to_sdfg_utils.get_symbolic(gtir.InfinityLiteral.NEGATIVE):
concat_dim_bound = concat_domain.stop
lower, lower_desc, lower_domain = (tb_field, tb_data_desc, tb_domain)
upper, upper_desc, upper_domain = (fb_field, fb_data_desc, fb_domain)
elif mask_upper_bound == gtir_to_sdfg_utils.get_symbolic(gtir.InfinityLiteral.POSITIVE):
concat_dim_bound = mask_lower_bound
elif concat_domain.stop == gtir_to_sdfg_utils.get_symbolic(gtir.InfinityLiteral.POSITIVE):
concat_dim_bound = concat_domain.start
lower, lower_desc, lower_domain = (fb_field, fb_data_desc, fb_domain)
upper, upper_desc, upper_domain = (tb_field, tb_data_desc, tb_domain)
else:
Expand All @@ -207,9 +207,9 @@ def _translate_concat_where_impl(
# we use the concat domain, stored in the annex, as the domain of output field
output_domain = gtir_domain.extract_domain(node_domain)
output_dims, output_origin, output_shape = _get_concat_where_field_layout(
output_domain, concat_dim
output_domain, concat_domain.dim
)
concat_dim_index = output_dims.index(concat_dim)
concat_dim_index = output_dims.index(concat_domain.dim)

"""
In case one of the arguments is a scalar value, for example:
Expand All @@ -225,23 +225,27 @@ def testee(a: np.int32, b: cases.IJKField) -> cases.IJKField:
assert isinstance(upper.gt_type, ts.FieldType)
lower = gtir_to_sdfg_types.FieldopData(
lower.dc_node,
ts.FieldType(dims=[concat_dim], dtype=lower.gt_type),
ts.FieldType(dims=[concat_domain.dim], dtype=lower.gt_type),
origin=(concat_dim_bound - 1,),
)
lower_bound = output_domain[concat_dim_index][1]
lower_domain = [(concat_dim, lower_bound, concat_dim_bound)]
lower_bound = output_domain[concat_dim_index].start
lower_domain = [
gtir_domain.FieldopDomainRange(concat_domain.dim, lower_bound, concat_dim_bound)
]
elif isinstance(upper.gt_type, ts.ScalarType):
assert len(upper_domain) == 0
assert isinstance(lower.gt_type, ts.FieldType)
upper = gtir_to_sdfg_types.FieldopData(
upper.dc_node,
ts.FieldType(dims=[concat_dim], dtype=upper.gt_type),
ts.FieldType(dims=[concat_domain.dim], dtype=upper.gt_type),
origin=(concat_dim_bound,),
)
upper_bound = output_domain[concat_dim_index][2]
upper_domain = [(concat_dim, concat_dim_bound, upper_bound)]
upper_bound = output_domain[concat_dim_index].stop
upper_domain = [
gtir_domain.FieldopDomainRange(concat_domain.dim, concat_dim_bound, upper_bound)
]

if concat_dim not in lower.gt_type.dims: # type: ignore[union-attr]
if concat_domain.dim not in lower.gt_type.dims: # type: ignore[union-attr]
"""
The field on the lower domain is to be treated as a slice to add as one
level in the concat dimension, on the lower bound.
Expand All @@ -261,13 +265,22 @@ def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField:
]
)
lower, lower_desc = _make_concat_field_slice(
sdfg, state, lower, lower_desc, concat_dim, concat_dim_index, concat_dim_bound - 1
sdfg=sdfg,
state=state,
field=lower,
field_desc=lower_desc,
concat_dim=concat_domain.dim,
concat_dim_index=concat_dim_index,
concat_dim_origin=concat_dim_bound - 1,
)
lower_bound = dace.symbolic.pystr_to_symbolic(
f"max({concat_dim_bound - 1}, {output_domain[concat_dim_index][1]})"
f"max({concat_dim_bound - 1}, {output_domain[concat_dim_index].start})"
)
lower_domain.insert(
concat_dim_index,
gtir_domain.FieldopDomainRange(concat_domain.dim, lower_bound, concat_dim_bound),
)
lower_domain.insert(concat_dim_index, (concat_dim, lower_bound, concat_dim_bound))
elif concat_dim not in upper.gt_type.dims: # type: ignore[union-attr]
elif concat_domain.dim not in upper.gt_type.dims: # type: ignore[union-attr]
# Same as previous case, but the field slice is added on the upper bound.
assert (
upper.gt_type.dims # type: ignore[union-attr]
Expand All @@ -277,12 +290,21 @@ def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField:
]
)
upper, upper_desc = _make_concat_field_slice(
sdfg, state, upper, upper_desc, concat_dim, concat_dim_index, concat_dim_bound
sdfg=sdfg,
state=state,
field=upper,
field_desc=upper_desc,
concat_dim=concat_domain.dim,
concat_dim_index=concat_dim_index,
concat_dim_origin=concat_dim_bound,
)
upper_bound = dace.symbolic.pystr_to_symbolic(
f"min({concat_dim_bound + 1}, {output_domain[concat_dim_index][2]})"
f"min({concat_dim_bound + 1}, {output_domain[concat_dim_index].stop})"
)
upper_domain.insert(
concat_dim_index,
gtir_domain.FieldopDomainRange(concat_domain.dim, concat_dim_bound, upper_bound),
)
upper_domain.insert(concat_dim_index, (concat_dim, concat_dim_bound, upper_bound))
elif isinstance(lower_desc, dace.data.Scalar) or (
len(lower.gt_type.dims) == 1 and len(output_domain) > 1 # type: ignore[union-attr]
):
Expand All @@ -297,27 +319,37 @@ def testee(a: cases.KField, b: cases.IJKField) -> cases.IJKField:
return concat_where(KDim == 0, a, b)
```
"""
assert len(lower_domain) == 1 and lower_domain[0][0] == concat_dim
assert len(lower_domain) == 1 and lower_domain[0].dim == concat_domain.dim
lower_domain = [
*output_domain[:concat_dim_index],
lower_domain[0],
*output_domain[concat_dim_index + 1 :],
]
lower, lower_desc = _make_concat_scalar_broadcast(
sdfg, state, lower, lower_desc, lower_domain, concat_dim_index
sdfg=sdfg,
state=state,
inp=lower,
inp_desc=lower_desc,
domain=lower_domain,
concat_dim_index=concat_dim_index,
)
elif isinstance(upper_desc, dace.data.Scalar) or (
len(upper.gt_type.dims) == 1 and len(output_domain) > 1 # type: ignore[union-attr]
):
# Same as previous case, but the scalar value is taken from `upper` input.
assert len(upper_domain) == 1 and upper_domain[0][0] == concat_dim
assert len(upper_domain) == 1 and upper_domain[0].dim == concat_domain.dim
upper_domain = [
*output_domain[:concat_dim_index],
upper_domain[0],
*output_domain[concat_dim_index + 1 :],
]
upper, upper_desc = _make_concat_scalar_broadcast(
sdfg, state, upper, upper_desc, upper_domain, concat_dim_index
sdfg=sdfg,
state=state,
inp=upper,
inp_desc=upper_desc,
domain=upper_domain,
concat_dim_index=concat_dim_index,
)
else:
"""
Expand All @@ -341,15 +373,15 @@ def testee(a: cases.IJKField, b: cases.IJKField) -> cases.IJKField:
# ensure that the arguments have the same domain as the concat result
assert all(ftype.dims == output_dims for ftype in (lower.gt_type, upper.gt_type)) # type: ignore[union-attr]

lower_range_0 = output_domain[concat_dim_index][1]
lower_range_0 = output_domain[concat_dim_index].start
lower_range_1 = dace.symbolic.pystr_to_symbolic(
f"max({lower_range_0}, {lower_domain[concat_dim_index][2]})"
f"max({lower_range_0}, {lower_domain[concat_dim_index].stop})"
)
lower_range_size = lower_range_1 - lower_range_0

upper_range_1 = output_domain[concat_dim_index][2]
upper_range_1 = output_domain[concat_dim_index].stop
upper_range_0 = dace.symbolic.pystr_to_symbolic(
f"min({upper_range_1}, {upper_domain[concat_dim_index][1]})"
f"min({upper_range_1}, {upper_domain[concat_dim_index].start})"
)
upper_range_size = upper_range_1 - upper_range_0

Expand Down Expand Up @@ -391,15 +423,15 @@ def testee(a: cases.IJKField, b: cases.IJKField) -> cases.IJKField:
else:
lower_subset.append(
(
output_domain[dim_index][1] - lower.origin[dim_index],
output_domain[dim_index][1] - lower.origin[dim_index] + size - 1,
output_domain[dim_index].start - lower.origin[dim_index],
output_domain[dim_index].start - lower.origin[dim_index] + size - 1,
1,
)
)
upper_subset.append(
(
output_domain[dim_index][1] - upper.origin[dim_index],
output_domain[dim_index][1] - upper.origin[dim_index] + size - 1,
output_domain[dim_index].start - upper.origin[dim_index],
output_domain[dim_index].start - upper.origin[dim_index] + size - 1,
1,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,10 @@ def _create_field_operator(
else:
# create map range corresponding to the field operator domain
map_range = {
gtir_to_sdfg_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}"
for dim, lower_bound, upper_bound in domain
gtir_to_sdfg_utils.get_map_variable(
domain_range.dim
): f"{domain_range.start}:{domain_range.stop}"
for domain_range in domain
}
map_entry, map_exit = sdfg_builder.add_map("fieldop", state, map_range)

Expand Down Expand Up @@ -511,8 +513,7 @@ def translate_index(
assert "domain" in node.annex
domain = gtir_domain.extract_domain(node.annex.domain)
assert len(domain) == 1
dim, _, _ = domain[0]
dim_index = gtir_to_sdfg_utils.get_map_variable(dim)
dim_index = gtir_to_sdfg_utils.get_map_variable(domain[0].dim)

index_data, _ = sdfg_builder.add_temp_scalar(sdfg, gtir_to_sdfg_types.INDEX_DTYPE)
index_node = state.add_access(index_data)
Expand Down
36 changes: 17 additions & 19 deletions src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,11 @@ def _create_scan_field_operator(
"fieldop",
state,
ndrange={
gtir_to_sdfg_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}"
for dim, lower_bound, upper_bound in domain
if not sdfg_builder.is_column_axis(dim)
gtir_to_sdfg_utils.get_map_variable(
domain_range.dim
): f"{domain_range.start}:{domain_range.stop}"
for domain_range in domain
if not sdfg_builder.is_column_axis(domain_range.dim)
},
)

Expand Down Expand Up @@ -329,22 +331,18 @@ def _lower_lambda_to_nested_sdfg(
)

# use the vertical dimension in the domain as scan dimension
scan_domain = [
(dim, lower_bound, upper_bound)
for dim, lower_bound, upper_bound in domain
if sdfg_builder.is_column_axis(dim)
]
assert len(scan_domain) == 1
scan_dim, scan_lower_bound, scan_upper_bound = scan_domain[0]
scan_domain = next(
domain_range for domain_range in domain if sdfg_builder.is_column_axis(domain_range.dim)
)

# extract the scan loop range
scan_loop_var = gtir_to_sdfg_utils.get_map_variable(scan_dim)
scan_loop_var = gtir_to_sdfg_utils.get_map_variable(scan_domain.dim)

# in case the scan operator computes a list (not a scalar), we need to add an extra dimension
def get_scan_output_shape(
scan_init_data: gtir_to_sdfg_types.FieldopData,
) -> list[dace.symbolic.SymExpr]:
scan_column_size = scan_upper_bound - scan_lower_bound
scan_column_size = scan_domain.stop - scan_domain.start
if isinstance(scan_init_data.gt_type, ts.ScalarType):
return [scan_column_size]
assert isinstance(scan_init_data.gt_type, ts.ListType)
Expand Down Expand Up @@ -391,18 +389,18 @@ def init_scan_carry(sym: gtir.Sym) -> None:
if scan_forward:
scan_loop = dace.sdfg.state.LoopRegion(
label="scan",
condition_expr=f"{scan_loop_var} < {scan_upper_bound}",
condition_expr=f"{scan_loop_var} < {scan_domain.stop}",
loop_var=scan_loop_var,
initialize_expr=f"{scan_loop_var} = {scan_lower_bound}",
initialize_expr=f"{scan_loop_var} = {scan_domain.start}",
update_expr=f"{scan_loop_var} = {scan_loop_var} + 1",
inverted=False,
)
else:
scan_loop = dace.sdfg.state.LoopRegion(
label="scan",
condition_expr=f"{scan_loop_var} >= {scan_lower_bound}",
condition_expr=f"{scan_loop_var} >= {scan_domain.start}",
loop_var=scan_loop_var,
initialize_expr=f"{scan_loop_var} = {scan_upper_bound} - 1",
initialize_expr=f"{scan_loop_var} = {scan_domain.stop} - 1",
update_expr=f"{scan_loop_var} = {scan_loop_var} - 1",
inverted=False,
)
Expand Down Expand Up @@ -431,7 +429,7 @@ def init_scan_carry(sym: gtir.Sym) -> None:
for edge in lambda_input_edges:
edge.connect(map_entry=None)
# connect the dataflow output nodes, called 'scan_result' below, to a global field called 'output'
output_column_index = dace.symbolic.pystr_to_symbolic(scan_loop_var) - scan_lower_bound
output_column_index = dace.symbolic.pystr_to_symbolic(scan_loop_var) - scan_domain.start

def connect_scan_output(
scan_output_edge: gtir_dataflow.DataflowOutputEdge,
Expand Down Expand Up @@ -475,8 +473,8 @@ def connect_scan_output(
dace.Memlet.from_array(scan_result_data, scan_result_desc),
)

output_type = ts.FieldType(dims=[scan_dim], dtype=scan_result.gt_dtype)
return gtir_to_sdfg_types.FieldopData(output_node, output_type, origin=(scan_lower_bound,))
output_type = ts.FieldType(dims=[scan_domain.dim], dtype=scan_result.gt_dtype)
return gtir_to_sdfg_types.FieldopData(output_node, output_type, origin=(scan_domain.start,))

# write the stencil result (value on one vertical level) into a 1D field
# with full vertical shape representing one column
Expand Down
Loading