Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def test_typing_exports(session: nox.Session) -> None:
"-sv",
"--mypy-testing-base",
"typing_tests",
"--mypy-only-local-stub",
"--mypy-no-silence-site-packages",
"typing_tests",
*session.posargs,
)
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ typing = [
typing_exports = [
# to test typing with gt4py in downstream code
{include-group = "typing"},
'types-six', # can not let mypy auto-install types as that leads to unexpected stderr output (which means test failure)
'pytest-mypy-plugins', # pytest plugin for running mypy on code snippets
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see change in noxfile

"xarray" # one of the regression tests requires xarray
'types-six>=1.17.0.20251009', # can not let mypy auto-install types as that leads to unexpected stderr output (which means test failure)
'pytest-mypy-plugins>=4.0.0', # pytest plugin for running mypy on code snippets
"xarray>=2024.1.0" # one of the regression tests requires xarray
]

# -- Standard project description options (PEP 621) --
Expand Down
20 changes: 11 additions & 9 deletions src/gt4py/next/iterator/transforms/infer_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,21 +486,23 @@ def infer_expr(

if cpm.is_applied_as_fieldop(expr) and cpm.is_call_to(expr.fun.args[0], "scan"):
additional_dims = gtx_utils.tree_map(
lambda d: _extract_vertical_dims(d)
if isinstance(d, domain_utils.SymbolicDomain)
else {}
lambda d: (
_extract_vertical_dims(d) if isinstance(d, domain_utils.SymbolicDomain) else {}
)
)(domain)
else:
additional_dims = gtx_utils.tree_map(lambda d: {})(domain)

domain = gtx_utils.tree_map(
lambda d, t, a: _filter_domain_dimensions(
d,
type_info.extract_dims(t),
additional_dims=a,
lambda d, t, a: (
_filter_domain_dimensions(
d,
type_info.extract_dims(t),
additional_dims=a,
)
if not isinstance(t, ts.DeferredType) and isinstance(d, domain_utils.SymbolicDomain)
else d
)
if not isinstance(t, ts.DeferredType) and isinstance(d, domain_utils.SymbolicDomain)
else d
)(domain, el_types, additional_dims)

expr, accessed_domains = _infer_expr(
Expand Down
10 changes: 6 additions & 4 deletions src/gt4py/next/iterator/type_system/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,10 +444,12 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None:
# the target can have fewer elements than the expr in which case the output from the
# expression is simply discarded.
expr_type = functools.reduce(
lambda tuple_type, i: tuple_type.types[i] # type: ignore[attr-defined] # format ensured by primitive_constituents
# `ts.DeferredType` only occurs for scans returning a tuple
if not isinstance(tuple_type, ts.DeferredType)
else ts.DeferredType(constraint=None),
lambda tuple_type, i: (
tuple_type.types[i] # type: ignore[attr-defined] # format ensured by primitive_constituents
# `ts.DeferredType` only occurs for scans returning a tuple
if not isinstance(tuple_type, ts.DeferredType)
else ts.DeferredType(constraint=None)
),
path,
node.expr.type,
)
Expand Down
18 changes: 11 additions & 7 deletions src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,10 @@ def _values_validator(
if not all(
isinstance(el, (SidFromScalar, SidComposite))
or _is_tuple_expr_of(
lambda expr: isinstance(expr, (SymRef, Literal))
or (isinstance(expr, FunCall) and expr.fun == SymRef(id="index")),
lambda expr: (
isinstance(expr, (SymRef, Literal))
or (isinstance(expr, FunCall) and expr.fun == SymRef(id="index"))
),
el,
)
for el in value
Expand Down Expand Up @@ -184,11 +186,13 @@ def _arg_validator(
) -> None:
for inp in inputs:
if not _is_tuple_expr_of(
lambda expr: isinstance(expr, (SymRef, SidComposite, SidFromScalar))
or (
isinstance(expr, FunCall)
and isinstance(expr.fun, SymRef)
and expr.fun.id == "index"
lambda expr: (
isinstance(expr, (SymRef, SidComposite, SidFromScalar))
or (
isinstance(expr, FunCall)
and isinstance(expr.fun, SymRef)
and expr.fun.id == "index"
)
),
inp,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -769,15 +769,15 @@ def _visit_if_branch(
psymbol_tree = gtir_to_sdfg_utils.make_symbol_tree(pname, ptype)
deref_on_input_memlet = pname in direct_deref_iterators
inner_arg = gtx_utils.tree_map(
lambda tsym,
targ,
deref_on_input_memlet=deref_on_input_memlet: self._visit_if_branch_arg(
if_sdfg,
if_branch_state,
str(tsym.id),
targ,
deref_on_input_memlet,
if_sdfg_input_memlets,
lambda tsym, targ, deref_on_input_memlet=deref_on_input_memlet: (
self._visit_if_branch_arg(
if_sdfg,
if_branch_state,
str(tsym.id),
targ,
deref_on_input_memlet,
if_sdfg_input_memlets,
)
)
)(psymbol_tree, arg)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -626,10 +626,12 @@ def setup_nested_context(
eve.walk_values(expr)
.filter(lambda node: cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")))
.map(
lambda domain: eve.walk_values(domain)
.if_isinstance(gtir.SymRef)
.filter(lambda sym: str(sym.id) in lambda_symbols)
.to_set()
lambda domain: (
eve.walk_values(domain)
.if_isinstance(gtir.SymRef)
.filter(lambda sym: str(sym.id) in lambda_symbols)
.to_set()
)
)
.reduce(lambda x, y: x | y, init=set())
)
Expand Down Expand Up @@ -922,9 +924,11 @@ def _visit_expression(

if use_temp: # copy the full shape of global data to temporary storage
return gtx_utils.tree_map(
lambda x: x
if x is None or x.dc_node.desc(ctx.sdfg).transient
else ctx.copy_data(self, x, domain=None)
lambda x: (
x
if x is None or x.dc_node.desc(ctx.sdfg).transient
else ctx.copy_data(self, x, domain=None)
)
)(result)
else:
return result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,15 +265,13 @@ def _create_scan_field_operator(
)

return gtx_utils.tree_map(
lambda edge, domain, sym: (
_create_scan_field_operator_impl(
ctx,
sdfg_builder,
edge,
domain,
sym.type,
map_exit,
)
lambda edge, domain, sym: _create_scan_field_operator_impl(
ctx,
sdfg_builder,
edge,
domain,
sym.type,
map_exit,
)
)(output, output_domain, dummy_output_symbol)

Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,9 @@ def tree_map(

>>> tree_map(
... collection_type=(list, tuple),
... result_collection_constructor=lambda value, elts: tuple(elts)
... if isinstance(value, list)
... else list(elts),
... result_collection_constructor=lambda value, elts: (
... tuple(elts) if isinstance(value, list) else list(elts)
... ),
... )(lambda x: x + 1)([(1, 2), 3])
([2, 3], 4)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ def contains_cast_node(cast_node, expr):
expr.walk_values()
.if_isinstance(Cast)
.filter(
lambda node: node.dtype == cast_node.dtype
and (isinstance(cast_node.expr, Placeholder) or node.expr == cast_node.expr)
lambda node: (
node.dtype == cast_node.dtype
and (isinstance(cast_node.expr, Placeholder) or node.expr == cast_node.expr)
)
)
.to_list()
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,12 @@ def testee(a: tuple[int32, tuple[int32, cases.IField, int32]]) -> cases.IField:
cases.verify_with_default_data(
cartesian_case,
testee,
ref=lambda a: np.full(
[cartesian_case.default_sizes[IDim]], a[0] + 2 * a[1][0] + 5 * a[1][2], dtype=int32
)
+ 3 * a[1][1],
ref=lambda a: (
np.full(
[cartesian_case.default_sizes[IDim]], a[0] + 2 * a[1][0] + 5 * a[1][2], dtype=int32
)
+ 3 * a[1][1]
),
)


Expand Down Expand Up @@ -942,8 +944,8 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField:
cases.verify_with_default_data(
unstructured_case,
testee,
ref=lambda a, b: (
np.sum(b[v2e_table], axis=1, initial=0, where=v2e_table != common._DEFAULT_SKIP_VALUE)
ref=lambda a, b: np.sum(
b[v2e_table], axis=1, initial=0, where=v2e_table != common._DEFAULT_SKIP_VALUE
),
)

Expand Down Expand Up @@ -992,9 +994,11 @@ def testee(out: tuple[cases.KField, tuple[cases.KField, cases.KField]]):
cartesian_case,
testee,
ref=lambda: (expected + 1.0, (expected + 2.0, expected + 3.0)),
comparison=lambda ref, out: np.all(out[0] == ref[0])
and np.all(out[1][0] == ref[1][0])
and np.all(out[1][1] == ref[1][1]),
comparison=lambda ref, out: (
np.all(out[0] == ref[0])
and np.all(out[1][0] == ref[1][0])
and np.all(out[1][1] == ref[1][1])
),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,14 @@ def fencil(edge_f: cases.EField, out: cases.VField):
cases.verify_with_default_data(
unstructured_case,
fencil,
ref=lambda edge_f: 3
* np.sum(
-(edge_f[v2e_table] ** 2) * 2,
axis=1,
initial=0,
where=v2e_table != common._DEFAULT_SKIP_VALUE,
ref=lambda edge_f: (
3
* np.sum(
-(edge_f[v2e_table] ** 2) * 2,
axis=1,
initial=0,
where=v2e_table != common._DEFAULT_SKIP_VALUE,
)
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,9 @@ def basic_trig_fieldop(inp1: cases.IFloatField, inp2: cases.IFloatField) -> case
cases.verify_with_default_data(
cartesian_case,
basic_trig_fieldop,
ref=lambda inp1, inp2: np.sin(np.cos(inp1))
- np.sinh(np.cosh(inp2))
+ np.tan(inp1)
- np.tanh(inp2),
ref=lambda inp1, inp2: (
np.sin(np.cos(inp1)) - np.sinh(np.cosh(inp2)) + np.tan(inp1) - np.tanh(inp2)
),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@


def assert_close(expected, actual):
assert math.isclose(expected, actual), "expected={}, actual={}".format(expected, actual)
assert np.allclose(expected, actual), "expected={}, actual={}".format(expected, actual)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

manual change

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the new numpy version returns a 1d array with 1 element in the reduction that is passed to this function.



class nabla_setup:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@
gtx.float64("1.0"),
),
],
ids=lambda param: f"Literal[{param.value}, {param.type}]"
if isinstance(param, itir.Literal)
else str(param),
ids=lambda param: (
f"Literal[{param.value}, {param.type}]" if isinstance(param, itir.Literal) else str(param)
),
)
def test_value_from_literal(value, expected):
result = misc.value_from_literal(value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _get_chained_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], n
dace_propagation.propagate_states(sdfg)
sdfg.validate()

return sdfg, lambda a, b: (a + (2 * b.reshape((-1, 1)) + 3))
return sdfg, lambda a, b: a + (2 * b.reshape((-1, 1)) + 3)


def _get_sdfg_with_empty_memlet(
Expand Down
4 changes: 3 additions & 1 deletion tests/next_tests/unit_tests/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def _constructor_test_cases():

@pytest.fixture(
params=_constructor_test_cases(),
ids=lambda x: f"{type(x.allocator).__name__ if x.allocator is not None else None}-device={x.device.device_type if x.device is not None and x.device.device_type is not None else None}-{x.expected_xp.__name__ if x.expected_xp is not None else None}",
ids=lambda x: (
f"{type(x.allocator).__name__ if x.allocator is not None else None}-device={x.device.device_type if x.device is not None and x.device.device_type is not None else None}-{x.expected_xp.__name__ if x.expected_xp is not None else None}"
),
)
def constructor_test_cases(request):
yield request.param
Expand Down
6 changes: 3 additions & 3 deletions tests/next_tests/unit_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ def testee(x):
def test_tree_map_multiple_input_types():
@utils.tree_map(
collection_type=(list, tuple),
result_collection_constructor=lambda value, elts: tuple(elts)
if isinstance(value, list)
else list(elts),
result_collection_constructor=lambda value, elts: (
tuple(elts) if isinstance(value, list) else list(elts)
),
)
def testee(x):
return x + 1
Expand Down
Loading
Loading