Skip to content
5 changes: 4 additions & 1 deletion src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def compile(
| common.OffsetProvider
| list[common.OffsetProviderType | common.OffsetProvider]
| None = None,
static_domains: Optional[dict[common.Domain, int] | None] = None,
**static_args: list[xtyping.MaybeNestedInTuple[core_defs.Scalar]],
) -> Self:
"""
Expand Down Expand Up @@ -194,7 +195,9 @@ def compile(
for op in offset_provider
)

self._compiled_programs.compile(offset_providers=offset_provider, **static_args)
self._compiled_programs.compile(
offset_providers=offset_provider, static_domains=static_domains, **static_args
)
return self


Expand Down
49 changes: 38 additions & 11 deletions src/gt4py/next/otf/compiled_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import itertools
import warnings
from collections.abc import Callable, Hashable, Sequence
from typing import Any, Generic, TypeAlias, TypeVar
from typing import Any, Generic, Optional, TypeAlias, TypeVar

from gt4py._core import definitions as core_defs
from gt4py.eve import extended_typing as xtyping, utils as eve_utils
Expand All @@ -28,7 +28,7 @@
)
from gt4py.next.instrumentation import hook_machinery, metrics
from gt4py.next.otf import arguments, stages
from gt4py.next.type_system import type_specifications as ts
from gt4py.next.type_system import type_info, type_specifications as ts
from gt4py.next.utils import tree_map


Expand Down Expand Up @@ -605,6 +605,7 @@ def _compile_variant(
def compile(
self,
offset_providers: list[common.OffsetProvider | common.OffsetProviderType],
static_domains: Optional[dict[common.Domain, int] | None] = None,
**static_args: list[ScalarOrTupleOfScalars],
) -> None:
"""
Expand All @@ -619,17 +620,43 @@ def compile(
pool.compile(static_arg0=[0], static_arg1=[2]).compile(static_arg=[1], static_arg1=[3])
will compile for (0,2), (1,3)
"""

def _build_field_domain_descriptors(program_type, static_domains):
def _create_field_descriptor(field_type):
domain_ranges = {
dim: static_domains[dim] for dim in field_type.dims
} # TODO: improve error message
return arguments.FieldDomainDescriptor(common.domain(domain_ranges))

field_domain_descriptors = {}
for arg_name, arg_type_ in program_type.definition.pos_or_kw_args.items():
for el_type_, path in type_info.primitive_constituents(
arg_type_, with_path_arg=True
):
if isinstance(el_type_, ts.FieldType):
path_as_expr = "".join(map(lambda idx: f"[{idx}]", path))
field_domain_descriptors[f"{arg_name}{path_as_expr}"] = (
_create_field_descriptor(el_type_)
)

return field_domain_descriptors

for offset_provider in offset_providers: # not included in product for better type checking
for static_values in itertools.product(*static_args.values()):
argument_descriptor_dict = {
arguments.StaticArg: dict(
zip(
static_args.keys(),
[arguments.StaticArg(value=v) for v in static_values],
strict=True,
)
),
}
if static_domains:
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static_domains is checked via truthiness (if static_domains:), so passing an empty dict (which is still a non-None value per the public API) will silently skip building FieldDomainDescriptors. This can lead to compiling a variant without any domain descriptors even though the caller explicitly provided static_domains. Consider switching this condition to if static_domains is not None: (and let _build_field_domain_descriptors raise for missing dims as needed).

Suggested change
if static_domains:
if static_domains is not None:

Copilot uses AI. Check for mistakes.
argument_descriptor_dict[arguments.FieldDomainDescriptor] = (
_build_field_domain_descriptors(self.program_type, static_domains)
)
Comment on lines +700 to +702
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a # type: ignore[assignment] on the FieldDomainDescriptor insertion because _build_field_domain_descriptors returns dict[str, FieldDomainDescriptor] while argument_descriptor_dict is typed as dict[str, ArgStaticDescriptor]. Consider widening the helper’s return type (e.g., to dict[str, arguments.ArgStaticDescriptor]) or casting at the assignment site to avoid relying on type: ignore here.

Copilot uses AI. Check for mistakes.
self._compile_variant(
argument_descriptors={
arguments.StaticArg: dict(
zip(
static_args.keys(),
[arguments.StaticArg(value=v) for v in static_values],
strict=True,
)
),
},
argument_descriptor_dict,
offset_provider=offset_provider,
)
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
skip_value_mesh,
)

from gt4py.next.otf import arguments

_raise_on_compile = mock.Mock()
_raise_on_compile.compile.side_effect = AssertionError("This function should never be called.")
Expand All @@ -49,7 +48,7 @@ class NamedTupleNamedCollection(NamedTuple):
@pytest.fixture(
params=[
pytest.param(True, id="program"),
pytest.param(False, id="field-operator"),
# pytest.param(False, id="field-operator"),
]
)
def compile_testee(request, cartesian_case):
Expand All @@ -62,6 +61,7 @@ def testee(a: cases.IField, b: cases.IField, out: cases.IField):
testee_op(a, b, out=out)

wrap_in_program = request.param
print(f"HUHU{id(testee)}", flush=True)
if wrap_in_program:
return testee
else:
Expand Down Expand Up @@ -991,3 +991,56 @@ def testee(inp: tuple[cases.IField, cases.IField, float], out: NamedTupleNamedCo
arguments.FieldDomainDescriptor(out[1].domain),
),
}


def test_compile_with_static_domains(compile_variants_field_operator, cartesian_case):
if cartesian_case.backend is None:
pytest.skip("Embedded compiled program doesn't make sense.")

captured_cargs: Optional[arguments.CompileTimeArgs] = None

class CaptureCompileTimeArgsBackend:
def __getattr__(self, name):
return getattr(cartesian_case.backend, name)

def compile(self, program, compile_time_args):
nonlocal captured_cargs
captured_cargs = compile_time_args

return cartesian_case.backend.compile(program, compile_time_args)

@gtx.field_operator
def identity_like(inp: tuple[cases.IField, cases.IField, float]):
return inp[0], inp[1]

# the float argument here is merely to test that static domains work for tuple arguments
# of inhomogeneous types
@gtx.program(backend=CaptureCompileTimeArgsBackend(), static_domains=True)
def testee(
inp: tuple[cases.IField, cases.IField, float], out: tuple[cases.IField, cases.IField]
):
identity_like(inp, out=out)

inp = cases.allocate(cartesian_case, testee, "inp")()
out = cases.allocate(cartesian_case, testee, "out")()

testee.compile(
offset_provider=cartesian_case.offset_provider,
static_domains={dim: (0, size) for dim, size in cartesian_case.default_sizes.items()},
)

assert testee._compiled_programs.argument_descriptor_mapping[
arguments.FieldDomainDescriptor
] == ["inp[0]", "inp[1]", "out[0]", "out[1]"]

assert captured_cargs.argument_descriptor_contexts[arguments.FieldDomainDescriptor] == {
"inp": (
arguments.FieldDomainDescriptor(inp[0].domain),
arguments.FieldDomainDescriptor(inp[1].domain),
None,
),
"out": (
arguments.FieldDomainDescriptor(out[0].domain),
arguments.FieldDomainDescriptor(out[1].domain),
),
}
Loading