Skip to content

Commit f94a037

Browse files
committed
edit
1 parent f8180e2 commit f94a037

3 files changed

Lines changed: 25 additions & 50 deletions

File tree

src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -315,25 +315,18 @@ def translate_broadcast(
315315
# Retrieve the scalar argument, which could be either a literal value or the
316316
# result of a scalar expression.
317317
assert len(node.args) == 1
318+
arg = _parse_fieldop_arg(node.args[0], ctx, sdfg_builder, field_domain)
319+
assert isinstance(arg, gtir_dataflow.MemletExpr)
320+
assert arg.subset.num_elements() == 1
318321

319322
# Use a 'Fill' library node to write the scalar value to the result field.
320-
if isinstance(node.args[0], gtir.Literal):
321-
assert node.args[0].type == node.type.dtype
322-
value = field_dtype(node.args[0].value)
323-
fill_node = sdfg_library_nodes.Fill("fill", value)
324-
ctx.state.add_node(fill_node)
325-
else:
326-
arg = _parse_fieldop_arg(node.args[0], ctx, sdfg_builder, field_domain)
327-
assert isinstance(arg, gtir_dataflow.MemletExpr)
328-
assert arg.subset.num_elements() == 1
329-
330-
fill_node = sdfg_library_nodes.Fill("fill")
331-
ctx.state.add_node(fill_node)
332-
ctx.state.add_nedge(
333-
arg.dc_node,
334-
fill_node,
335-
dace.Memlet(data=arg.dc_node.data, subset=arg.subset),
336-
)
323+
fill_node = sdfg_library_nodes.Fill("fill")
324+
ctx.state.add_node(fill_node)
325+
ctx.state.add_nedge(
326+
arg.dc_node,
327+
fill_node,
328+
dace.Memlet(data=arg.dc_node.data, subset=arg.subset),
329+
)
337330

338331
ctx.state.add_nedge(
339332
fill_node,

src/gt4py/next/program_processors/runners/dace/sdfg_library_nodes.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,16 @@ def expansion(node: Fill, parent_state: dace.SDFGState, parent_sdfg: dace.SDFG)
4444
out_mem = dace.Memlet(expr=f"{out}[{','.join(map_params)}]")
4545
outputs = {"_out": out_mem}
4646

47-
if node._value is None:
48-
assert len(parent_state.in_edges(node)) == 1
49-
inedge = parent_state.in_edges(node)[0]
50-
inp_desc = parent_sdfg.arrays[inedge.data.data]
51-
inner_inp_desc = inp_desc.clone()
52-
inner_inp_desc.transient = False
53-
inp = sdfg.add_datadesc(_INPUT_NAME, inner_inp_desc)
54-
inedge._dst_conn = _INPUT_NAME
55-
node.add_in_connector(_INPUT_NAME)
56-
inputs = {"_in": dace.Memlet(data=inp, subset="0")}
57-
code = "_out = _in"
58-
else:
59-
inputs = {}
60-
code = f"_out = {node._value}"
47+
assert len(parent_state.in_edges(node)) == 1
48+
inedge = parent_state.in_edges(node)[0]
49+
inp_desc = parent_sdfg.arrays[inedge.data.data]
50+
inner_inp_desc = inp_desc.clone()
51+
inner_inp_desc.transient = False
52+
inp = sdfg.add_datadesc(_INPUT_NAME, inner_inp_desc)
53+
inedge._dst_conn = _INPUT_NAME
54+
node.add_in_connector(_INPUT_NAME)
55+
inputs = {"_in": dace.Memlet(data=inp, subset="0")}
56+
code = "_out = _in"
6157

6258
state.add_mapped_tasklet(
6359
f"{node.label}_tasklet", map_rng, inputs, code, outputs, external_edges=True
@@ -72,8 +68,6 @@ class Fill(dace_nodes.LibraryNode):
7268

7369
implementations: Final[dict[str, dace_transform.ExpandTransformation]] = {"pure": ExpandPure}
7470
default_implementation: Final[str] = "pure"
75-
_value: dace.typeclass | None
7671

77-
def __init__(self, name: str, value: dace.typeclass | None = None):
72+
def __init__(self, name: str):
7873
super().__init__(name)
79-
self._value = value

tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2110,16 +2110,8 @@ def test_gtir_concat_where():
21102110
gtx_common.GridType.CARTESIAN, {IDim: (SUBSET_SIZE, gtir.InfinityLiteral.POSITIVE)}
21112111
)
21122112

2113-
concat_expr_lhs = im.concat_where(
2114-
domain_cond_lhs,
2115-
im.as_fieldop("deref")("x"),
2116-
im.as_fieldop("deref")("y"),
2117-
)
2118-
concat_expr_rhs = im.concat_where(
2119-
domain_cond_rhs,
2120-
im.as_fieldop("deref")("y"),
2121-
im.as_fieldop("deref")("x"),
2122-
)
2113+
concat_expr_lhs = im.concat_where(domain_cond_lhs, "x", "y")
2114+
concat_expr_rhs = im.concat_where(domain_cond_rhs, "y", "x")
21232115

21242116
a = np.random.rand(N)
21252117
b = np.random.rand(N)
@@ -2177,12 +2169,8 @@ def test_gtir_concat_where_two_dimensions():
21772169
gtir.SetAt(
21782170
expr=im.concat_where(
21792171
domain_cond1, # 0, 30; 10,20
2180-
im.concat_where(
2181-
domain_cond2,
2182-
im.as_fieldop("deref")("x"),
2183-
im.as_fieldop("deref")("y"),
2184-
),
2185-
im.as_fieldop("deref")("w"),
2172+
im.concat_where(domain_cond2, "x", "y"),
2173+
"w",
21862174
),
21872175
domain=domain,
21882176
target=gtir.SymRef(id="z"),

0 commit comments

Comments
 (0)