Skip to content

Commit 7ec2d35

Browse files
authored
[TIR] Add support for conditional expressions in TVMScript (#18323)
Add support for conditional expressions in TVMScript This PR adds support for conditional expressions in TVMScript parser, which allows developers to use Python-style conditional expressions ```python @T.prim_func def func(A: T.buffer((128, 128), "float32")): for i, j in T.grid(128, 128): A[i, j] = i if i < j else j @T.prim_func def expected(A: T.buffer((128, 128), "float32")): for i, j in T.grid(128, 128): A[i, j] = T.if_then_else(i < j, i, j) ```
1 parent af82187 commit 7ec2d35

File tree

2 files changed

+49
-6
lines changed

2 files changed

+49
-6
lines changed

python/tvm/script/parser/core/evaluator.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import ast
2020
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
2121

22+
import tvm
23+
2224
from . import dispatch, doc
2325
from .error import ParserError
2426

@@ -173,18 +175,19 @@ def _visit(self, node: doc.AST) -> Any:
173175
isinstance(node, doc.Call)
174176
and hasattr(node.func, "attr")
175177
and node.func.attr not in ["reads", "writes", "match_buffer", "realize"]
176-
) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare, doc.BoolOp)):
178+
) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare, doc.BoolOp, doc.IfExp)):
177179
if isinstance(node, doc.BinOp):
178180
args = [node.left, node.right]
179181
elif isinstance(node, doc.UnaryOp):
180182
args = [node.operand]
181183
elif isinstance(node, doc.Compare):
182184
args = [node.left, *node.comparators]
183-
else:
184-
if isinstance(node, doc.Call):
185-
args = node.args
186-
elif isinstance(node, doc.BoolOp):
187-
args = node.values
185+
elif isinstance(node, doc.IfExp):
186+
args = [node.test, node.body, node.orelse]
187+
elif isinstance(node, doc.Call):
188+
args = node.args
189+
elif isinstance(node, doc.BoolOp):
190+
args = node.values
188191
for arg in args:
189192
if isinstance(arg, doc.Subscript) and isinstance(arg.slice, (doc.Slice, doc.Tuple)):
190193
if isinstance(arg.slice, doc.Slice):
@@ -256,6 +259,8 @@ def _visit(self, node: doc.AST) -> Any:
256259
value = self._eval_unary_op(fields)
257260
elif isinstance(node, doc.BinOp):
258261
value = self._eval_bin_op(fields)
262+
elif isinstance(node, doc.IfExp):
263+
value = self._eval_if_exp(fields)
259264
elif isinstance(node, doc.Slice):
260265
value = self._eval_slice(fields)
261266
else:
@@ -364,6 +369,30 @@ def _eval_bin_op(self, fields: Dict[str, Any]) -> Any:
364369
],
365370
)
366371

372+
def _eval_if_exp(self, fields: Dict[str, Any]) -> Any:
373+
"""The doc AST if-else expression node evaluating method.
374+
375+
Parameters
376+
----------
377+
fields : Dict[str, Any]
378+
The dictionary of if-else expression information,
379+
e.g., test, body, orelse.
380+
381+
Returns
382+
-------
383+
res : Any
384+
The evaluation result.
385+
"""
386+
test = self._eval_expr(fields["test"])
387+
body = self._eval_expr(fields["body"])
388+
orelse = self._eval_expr(fields["orelse"])
389+
if isinstance(test, bool):
390+
return body if test else orelse
391+
elif isinstance(test, tvm.tir.PrimExpr) and test.dtype == "bool":
392+
return tvm.tir.op.if_then_else(test, body, orelse)
393+
else:
394+
raise TypeError(f"Expected Python bool or TIR bool, but got {type(test)}")
395+
367396
def _eval_slice(self, fields: Dict[str, Any]) -> slice:
368397
"""The doc AST slice node evaluating method.
369398

tests/python/tvmscript/test_tvmscript_parser_tir.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,5 +612,19 @@ def expected() -> None:
612612
tvm.ir.assert_structural_equal(func, expected)
613613

614614

615+
def test_ifexp():
616+
@T.prim_func(private=True)
617+
def func(A: T.buffer((128, 128), "float32")):
618+
for i, j in T.grid(128, 128):
619+
A[i, j] = i if i < j else j
620+
621+
@T.prim_func(private=True)
622+
def expected(A: T.buffer((128, 128), "float32")):
623+
for i, j in T.grid(128, 128):
624+
A[i, j] = T.if_then_else(i < j, i, j)
625+
626+
tvm.ir.assert_structural_equal(func, expected)
627+
628+
615629
if __name__ == "__main__":
616630
tvm.testing.main()

0 commit comments

Comments
 (0)