Skip to content

Commit 52c7b7f

Browse files
committed
feature: force annotated temporaries
1 parent ed71f11 commit 52c7b7f

3 files changed

Lines changed: 66 additions & 0 deletions

File tree

src/gt4py/cartesian/definitions.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import functools
1111
import os
1212
import platform
13+
import warnings
1314
from dataclasses import dataclass
1415
from typing import Literal, Tuple, Union
1516

@@ -42,6 +43,28 @@
4243
"""Default literal precision used for unspecific `float` types and casts."""
4344

4445

46+
def _check_boolean_env_var(name: str, default: bool) -> bool:
47+
envvar = os.environ.get(name, default=default)
48+
if type(envvar) is bool:
49+
return envvar
50+
51+
if type(envvar) is str:
52+
if envvar.lower() in ["true", "1"]:
53+
return True
54+
if envvar in ["false", "0"]:
55+
return False
56+
57+
warnings.warn(
58+
f"Could not match `{name}={envvar}` into a boolean value. Falling back to the default `{default}`.",
59+
stacklevel=2,
60+
)
61+
return default
62+
63+
64+
FORCE_ANNOTATED_TEMPORARIES = _check_boolean_env_var("GT4PY_FORCE_ANNOTATED_TEMPORARIES", False)
65+
"""If True, forces all temporaries in stencils to have type annotations."""
66+
67+
4568
@enum.unique
4669
class AccessKind(enum.IntFlag):
4770
NONE = 0
@@ -123,6 +146,8 @@ class BuildOptions(AttributeClassLike):
123146
"Literal precision for `int` types and casts. Defaults to architecture precision unless overwritten by the environment variable `GT4PY_LITERAL_INT_PRECISION`."
124147
literal_float_precision = attribute(of=int, default=LITERAL_FLOAT_PRECISION)
125148
"Literal precision for `float` types and casts. Defaults to architecture precision unless overwritten by the environment variable `GT4PY_LITERAL_FLOAT_PRECISION`."
149+
force_annotated_temporaries = attribute(of=bool, default=FORCE_ANNOTATED_TEMPORARIES)
150+
"If True, enforce all temporaries to have type annotations. Defaults to False unless overwritten by the environment variable `GT4PY_FORCE_ANNOTATED_TEMPORARIES`."
126151

127152
@property
128153
def qualified_name(self):

src/gt4py/cartesian/frontend/gtscript_frontend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,7 @@ def __init__(
799799
self.domain = domain or nodes.Domain.LatLonGrid()
800800
self.literal_int_precision = options.literal_int_precision
801801
self.literal_float_precision = options.literal_float_precision
802+
self.force_annotated_temporaries = options.force_annotated_temporaries
802803
self.temp_decls = temp_decls or {}
803804
self.parsing_context = None
804805
self.iteration_order = None
@@ -1705,6 +1706,11 @@ def _resolve_assign(
17051706
message="Temporaries with data dimensions need to be declared explicitly.",
17061707
loc=nodes.Location.from_ast_node(t, scope=self.stencil_name),
17071708
)
1709+
if self.force_annotated_temporaries and target_annotation is None:
1710+
raise GTScriptSyntaxError(
1711+
message=f"Missing type hint for '{name}' in stencil '{self.stencil_name}'.",
1712+
loc=nodes.Location.from_ast_node(t, scope=self.stencil_name),
1713+
)
17081714
dtype = nodes.DataType.AUTO
17091715
axes = nodes.Domain.LatLonGrid().axes_names
17101716
if target_annotation is not None:

tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def parse_definition(
5757
dtypes: Dict[Type, Type] = None,
5858
literal_int_precision: int | None = None,
5959
literal_float_precision: int | None = None,
60+
force_annotated_temporaries: bool | None = None,
6061
rebuild=False,
6162
**kwargs,
6263
) -> nodes.StencilDefinition:
@@ -73,6 +74,8 @@ def parse_definition(
7374
build_args["literal_int_precision"] = literal_int_precision
7475
if literal_float_precision is not None:
7576
build_args["literal_float_precision"] = literal_float_precision
77+
if force_annotated_temporaries is not None:
78+
build_args["force_annotated_temporaries"] = force_annotated_temporaries
7679

7780
build_options = gt_definitions.BuildOptions(**build_args)
7881

@@ -2222,3 +2225,35 @@ def stencil(in_field: gtscript.Field[float], out_field: gtscript.Field[float]):
22222225
name=inspect.stack()[0][3],
22232226
module=self.__class__.__name__,
22242227
)
2228+
2229+
2230+
class TestForceAnnotatedTemporaries:
2231+
def test_missing_annotation(self):
2232+
def good_case(in_field: gtscript.Field[float], out_field: gtscript.Field[float]):
2233+
with computation(PARALLEL), interval(...):
2234+
tmp: float = 2 * in_field
2235+
out_field = tmp
2236+
2237+
parsed = parse_definition(
2238+
good_case, name=inspect.stack()[0][3], module=self.__class__.__name__
2239+
)
2240+
2241+
declaration = parsed.computations[0].body.stmts[0]
2242+
assert isinstance(declaration, nodes.FieldDecl)
2243+
assert declaration.data_type == nodes.DataType.FLOAT64
2244+
2245+
def bad_case(in_field: gtscript.Field[float], out_field: gtscript.Field[float]):
2246+
with computation(PARALLEL), interval(...):
2247+
tmp = 2 * in_field
2248+
out_field = tmp
2249+
2250+
with pytest.raises(
2251+
gt_frontend.GTScriptSyntaxError,
2252+
match="Missing type hint for 'tmp' in stencil 'bad_case'.",
2253+
):
2254+
parse_definition(
2255+
bad_case,
2256+
name=inspect.stack()[0][3],
2257+
module=self.__class__.__name__,
2258+
force_annotated_temporaries=True,
2259+
)

0 commit comments

Comments
 (0)