Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,6 @@ def test_typing_exports(session: nox.Session) -> None:
"-sv",
"--mypy-testing-base",
"typing_tests",
"--mypy-only-local-stub",
"typing_tests",
*session.posargs,
)
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ typing = [
typing_exports = [
# to test typing with gt4py in downstream code
{include-group = "typing"},
'types-six', # can not let mypy auto-install types as that leads to unexpected stderr output (which means test failure)
'pytest-mypy-plugins', # pytest plugin for running mypy on code snippets
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see change in noxfile

"xarray" # one of the regression tests requires xarray
'types-six>=1.17.0.20251009', # can not let mypy auto-install types as that leads to unexpected stderr output (which means test failure)
'pytest-mypy-plugins>=4.0.0', # pytest plugin for running mypy on code snippets
"xarray>=2024.1.0" # one of the regression tests requires xarray
]

# -- Standard project description options (PEP 621) --
Expand Down
20 changes: 11 additions & 9 deletions src/gt4py/next/iterator/transforms/infer_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,21 +486,23 @@ def infer_expr(

if cpm.is_applied_as_fieldop(expr) and cpm.is_call_to(expr.fun.args[0], "scan"):
additional_dims = gtx_utils.tree_map(
lambda d: _extract_vertical_dims(d)
if isinstance(d, domain_utils.SymbolicDomain)
else {}
lambda d: (
_extract_vertical_dims(d) if isinstance(d, domain_utils.SymbolicDomain) else {}
)
)(domain)
else:
additional_dims = gtx_utils.tree_map(lambda d: {})(domain)

domain = gtx_utils.tree_map(
lambda d, t, a: _filter_domain_dimensions(
d,
type_info.extract_dims(t),
additional_dims=a,
lambda d, t, a: (
_filter_domain_dimensions(
d,
type_info.extract_dims(t),
additional_dims=a,
)
if not isinstance(t, ts.DeferredType) and isinstance(d, domain_utils.SymbolicDomain)
else d
)
if not isinstance(t, ts.DeferredType) and isinstance(d, domain_utils.SymbolicDomain)
else d
)(domain, el_types, additional_dims)

expr, accessed_domains = _infer_expr(
Expand Down
10 changes: 6 additions & 4 deletions src/gt4py/next/iterator/type_system/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,10 +444,12 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None:
# the target can have fewer elements than the expr in which case the output from the
# expression is simply discarded.
expr_type = functools.reduce(
lambda tuple_type, i: tuple_type.types[i] # type: ignore[attr-defined] # format ensured by primitive_constituents
# `ts.DeferredType` only occurs for scans returning a tuple
if not isinstance(tuple_type, ts.DeferredType)
else ts.DeferredType(constraint=None),
lambda tuple_type, i: (
tuple_type.types[i] # type: ignore[attr-defined] # format ensured by primitive_constituents
# `ts.DeferredType` only occurs for scans returning a tuple
if not isinstance(tuple_type, ts.DeferredType)
else ts.DeferredType(constraint=None)
),
path,
node.expr.type,
)
Expand Down
18 changes: 11 additions & 7 deletions src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,10 @@ def _values_validator(
if not all(
isinstance(el, (SidFromScalar, SidComposite))
or _is_tuple_expr_of(
lambda expr: isinstance(expr, (SymRef, Literal))
or (isinstance(expr, FunCall) and expr.fun == SymRef(id="index")),
lambda expr: (
isinstance(expr, (SymRef, Literal))
or (isinstance(expr, FunCall) and expr.fun == SymRef(id="index"))
),
el,
)
for el in value
Expand Down Expand Up @@ -184,11 +186,13 @@ def _arg_validator(
) -> None:
for inp in inputs:
if not _is_tuple_expr_of(
lambda expr: isinstance(expr, (SymRef, SidComposite, SidFromScalar))
or (
isinstance(expr, FunCall)
and isinstance(expr.fun, SymRef)
and expr.fun.id == "index"
lambda expr: (
isinstance(expr, (SymRef, SidComposite, SidFromScalar))
or (
isinstance(expr, FunCall)
and isinstance(expr.fun, SymRef)
and expr.fun.id == "index"
)
),
inp,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -769,15 +769,15 @@ def _visit_if_branch(
psymbol_tree = gtir_to_sdfg_utils.make_symbol_tree(pname, ptype)
deref_on_input_memlet = pname in direct_deref_iterators
inner_arg = gtx_utils.tree_map(
lambda tsym,
targ,
deref_on_input_memlet=deref_on_input_memlet: self._visit_if_branch_arg(
if_sdfg,
if_branch_state,
str(tsym.id),
targ,
deref_on_input_memlet,
if_sdfg_input_memlets,
lambda tsym, targ, deref_on_input_memlet=deref_on_input_memlet: (
self._visit_if_branch_arg(
if_sdfg,
if_branch_state,
str(tsym.id),
targ,
deref_on_input_memlet,
if_sdfg_input_memlets,
)
)
)(psymbol_tree, arg)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -626,10 +626,12 @@ def setup_nested_context(
eve.walk_values(expr)
.filter(lambda node: cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")))
.map(
lambda domain: eve.walk_values(domain)
.if_isinstance(gtir.SymRef)
.filter(lambda sym: str(sym.id) in lambda_symbols)
.to_set()
lambda domain: (
eve.walk_values(domain)
.if_isinstance(gtir.SymRef)
.filter(lambda sym: str(sym.id) in lambda_symbols)
.to_set()
)
)
.reduce(lambda x, y: x | y, init=set())
)
Expand Down Expand Up @@ -922,9 +924,11 @@ def _visit_expression(

if use_temp: # copy the full shape of global data to temporary storage
return gtx_utils.tree_map(
lambda x: x
if x is None or x.dc_node.desc(ctx.sdfg).transient
else ctx.copy_data(self, x, domain=None)
lambda x: (
x
if x is None or x.dc_node.desc(ctx.sdfg).transient
else ctx.copy_data(self, x, domain=None)
)
)(result)
else:
return result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,15 +265,13 @@ def _create_scan_field_operator(
)

return gtx_utils.tree_map(
lambda edge, domain, sym: (
_create_scan_field_operator_impl(
ctx,
sdfg_builder,
edge,
domain,
sym.type,
map_exit,
)
lambda edge, domain, sym: _create_scan_field_operator_impl(
ctx,
sdfg_builder,
edge,
domain,
sym.type,
map_exit,
)
)(output, output_domain, dummy_output_symbol)

Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,9 @@ def tree_map(

>>> tree_map(
... collection_type=(list, tuple),
... result_collection_constructor=lambda value, elts: tuple(elts)
... if isinstance(value, list)
... else list(elts),
... result_collection_constructor=lambda value, elts: (
... tuple(elts) if isinstance(value, list) else list(elts)
... ),
... )(lambda x: x + 1)([(1, 2), 3])
([2, 3], 4)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ def contains_cast_node(cast_node, expr):
expr.walk_values()
.if_isinstance(Cast)
.filter(
lambda node: node.dtype == cast_node.dtype
and (isinstance(cast_node.expr, Placeholder) or node.expr == cast_node.expr)
lambda node: (
node.dtype == cast_node.dtype
and (isinstance(cast_node.expr, Placeholder) or node.expr == cast_node.expr)
)
)
.to_list()
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,12 @@ def testee(a: tuple[int32, tuple[int32, cases.IField, int32]]) -> cases.IField:
cases.verify_with_default_data(
cartesian_case,
testee,
ref=lambda a: np.full(
[cartesian_case.default_sizes[IDim]], a[0] + 2 * a[1][0] + 5 * a[1][2], dtype=int32
)
+ 3 * a[1][1],
ref=lambda a: (
np.full(
[cartesian_case.default_sizes[IDim]], a[0] + 2 * a[1][0] + 5 * a[1][2], dtype=int32
)
+ 3 * a[1][1]
),
)


Expand Down Expand Up @@ -942,8 +944,8 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField:
cases.verify_with_default_data(
unstructured_case,
testee,
ref=lambda a, b: (
np.sum(b[v2e_table], axis=1, initial=0, where=v2e_table != common._DEFAULT_SKIP_VALUE)
ref=lambda a, b: np.sum(
b[v2e_table], axis=1, initial=0, where=v2e_table != common._DEFAULT_SKIP_VALUE
),
)

Expand Down Expand Up @@ -992,9 +994,11 @@ def testee(out: tuple[cases.KField, tuple[cases.KField, cases.KField]]):
cartesian_case,
testee,
ref=lambda: (expected + 1.0, (expected + 2.0, expected + 3.0)),
comparison=lambda ref, out: np.all(out[0] == ref[0])
and np.all(out[1][0] == ref[1][0])
and np.all(out[1][1] == ref[1][1]),
comparison=lambda ref, out: (
np.all(out[0] == ref[0])
and np.all(out[1][0] == ref[1][0])
and np.all(out[1][1] == ref[1][1])
),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,14 @@ def fencil(edge_f: cases.EField, out: cases.VField):
cases.verify_with_default_data(
unstructured_case,
fencil,
ref=lambda edge_f: 3
* np.sum(
-(edge_f[v2e_table] ** 2) * 2,
axis=1,
initial=0,
where=v2e_table != common._DEFAULT_SKIP_VALUE,
ref=lambda edge_f: (
3
* np.sum(
-(edge_f[v2e_table] ** 2) * 2,
axis=1,
initial=0,
where=v2e_table != common._DEFAULT_SKIP_VALUE,
)
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,9 @@ def basic_trig_fieldop(inp1: cases.IFloatField, inp2: cases.IFloatField) -> case
cases.verify_with_default_data(
cartesian_case,
basic_trig_fieldop,
ref=lambda inp1, inp2: np.sin(np.cos(inp1))
- np.sinh(np.cosh(inp2))
+ np.tan(inp1)
- np.tanh(inp2),
ref=lambda inp1, inp2: (
np.sin(np.cos(inp1)) - np.sinh(np.cosh(inp2)) + np.tan(inp1) - np.tanh(inp2)
),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@


def assert_close(expected, actual):
assert math.isclose(expected, actual), "expected={}, actual={}".format(expected, actual)
assert np.allclose(expected, actual), "expected={}, actual={}".format(expected, actual)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

manual change

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the new numpy version returns a 1d array with 1 element in the reduction that is passed to this function.



class nabla_setup:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@
gtx.float64("1.0"),
),
],
ids=lambda param: f"Literal[{param.value}, {param.type}]"
if isinstance(param, itir.Literal)
else str(param),
ids=lambda param: (
f"Literal[{param.value}, {param.type}]" if isinstance(param, itir.Literal) else str(param)
),
)
def test_value_from_literal(value, expected):
result = misc.value_from_literal(value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _get_chained_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], n
dace_propagation.propagate_states(sdfg)
sdfg.validate()

return sdfg, lambda a, b: (a + (2 * b.reshape((-1, 1)) + 3))
return sdfg, lambda a, b: a + (2 * b.reshape((-1, 1)) + 3)


def _get_sdfg_with_empty_memlet(
Expand Down
4 changes: 3 additions & 1 deletion tests/next_tests/unit_tests/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def _constructor_test_cases():

@pytest.fixture(
params=_constructor_test_cases(),
ids=lambda x: f"{type(x.allocator).__name__ if x.allocator is not None else None}-device={x.device.device_type if x.device is not None and x.device.device_type is not None else None}-{x.expected_xp.__name__ if x.expected_xp is not None else None}",
ids=lambda x: (
f"{type(x.allocator).__name__ if x.allocator is not None else None}-device={x.device.device_type if x.device is not None and x.device.device_type is not None else None}-{x.expected_xp.__name__ if x.expected_xp is not None else None}"
),
)
def constructor_test_cases(request):
yield request.param
Expand Down
6 changes: 3 additions & 3 deletions tests/next_tests/unit_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ def testee(x):
def test_tree_map_multiple_input_types():
@utils.tree_map(
collection_type=(list, tuple),
result_collection_constructor=lambda value, elts: tuple(elts)
if isinstance(value, list)
else list(elts),
result_collection_constructor=lambda value, elts: (
tuple(elts) if isinstance(value, list) else list(elts)
),
)
def testee(x):
return x + 1
Expand Down
Loading
Loading