Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions src/gt4py/cartesian/gtc/dace/oir_to_tasklet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import operator
import warnings
from dataclasses import dataclass
from functools import reduce
from typing import Any, Final
Expand Down Expand Up @@ -101,9 +102,26 @@ def visit_FieldAccess(self, node: oir.FieldAccess, ctx: Context, is_target: bool
# Gather all parts of the variable name in this list
name_parts = [tasklet_name]

# Domain sizes of data dimensions
data_domains: list[int] = (
ctx.tree.containers[node.name].shape[-len(node.data_index) :] if node.data_index else []
)

# Data dimension subscript
data_indices: list[str] = []
for index in node.data_index:
for dimension, index in enumerate(node.data_index):
# special case: domain size of this data dimension is one
if data_domains[dimension] == 1:
if not isinstance(index, oir.Literal):
warnings.warn(
f"Data dimension {dimension} of field {node.name} has static size one. Accessing without a Literal is suspicious.",
stacklevel=2,
)

# no need for data dimension subscript because that access
# is entirely captured in the memlet.
continue

data_indices.append(self.visit(index, ctx=ctx, is_target=False))

if isinstance(node.offset, oir.AbsoluteKIndex):
Expand Down Expand Up @@ -139,9 +157,6 @@ def visit_FieldAccess(self, node: oir.FieldAccess, ctx: Context, is_target: bool
return "".join(filter(None, name_parts))

# Build Memlet and add it to inputs/outputs
data_domains: list[int] = (
ctx.tree.containers[node.name].shape[-len(node.data_index) :] if node.data_index else []
)
memlet = Memlet(
data=node.name,
subset=_memlet_subset(node, data_domains, ctx),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
import pytest

from gt4py.cartesian.frontend.exceptions import GTScriptSyntaxError
import gt4py.cartesian.gtscript as gtscript
import gt4py.storage as gt_storage
from gt4py.cartesian.gtscript import Field, K
Expand Down Expand Up @@ -404,6 +405,58 @@ def test_mismatch(self, sample_stencil):
)
)

@pytest.mark.parametrize("backend", ALL_BACKENDS)
def test_data_dimension_1d(self, backend: str):
@gtscript.stencil(backend=backend)
def data_dimension_1d(field_out: gtscript.Field[gtscript.IJ, (np.float64, (1,))]):
with computation(FORWARD), interval(...):
field_out[0, 0][0] = 42.0

ones = gt_storage.ones(
shape=(2, 3),
dimensions=["I", "J"],
dtype=(np.float64, (1,)),
backend=backend,
aligned_index=(0, 0),
)
data_dimension_1d(ones)
assert ones[0, 0][0] == 42.0

@pytest.mark.requires_dace
def test_data_dimension_1d_warning(self):
backend = "dace:cpu"

# Expect parsing to emit a warning
with pytest.warns(UserWarning, match="Accessing without a Literal is suspicious."):

@gtscript.stencil(backend=backend)
def data_dimension_1d_warning(
field_out: gtscript.Field[gtscript.IJ, (np.float64, (1,))], index: int
):
with computation(FORWARD), interval(...):
# 1. warn about accessing an array of static size 1 with a variable index
# 2. add a runtime check and error out in case index != 0
field_out[0, 0][index] = 42.0

ones = gt_storage.ones(
shape=(2, 3),
dimensions=["I", "J"],
dtype=(np.float64, (1,)),
backend=backend,
aligned_index=(0, 0),
)
data_dimension_1d_warning(ones, index=0)
assert ones[0, 0][0] == 42.0

def test_data_dimensions_1d_error(self):
# Expect out of bounds write to be caught at stencil parse time
with pytest.raises(GTScriptSyntaxError, match="Data index out of bounds."):

@gtscript.stencil(backend=self.backend)
def data_dimension_1d_error(field_out: gtscript.Field[gtscript.IJ, (np.float64, (1,))]):
with computation(FORWARD), interval(...):
field_out[0, 0][1] = 42.0


@pytest.mark.parametrize("backend", ALL_BACKENDS)
def test_origin_unchanged(backend):
Expand Down
Loading