Skip to content

Commit 642bd77

Browse files
committed
Add support for conditional expressions in TVMScript
This PR adds support for conditional expressions in TVMScript parser, which allows developers to use Python-style conditional expressions ```python @T.prim_func def func(A: T.buffer((128, 128), "float32")): for i, j in T.grid(128, 128): A[i, j] = i if i < j else j @T.prim_func def expected(A: T.buffer((128, 128), "float32")): for i, j in T.grid(128, 128): A[i, j] = T.if_then_else(i < j, i, j) ```
1 parent af82187 commit 642bd77

File tree

2 files changed

+49
-6
lines changed

2 files changed

+49
-6
lines changed

python/tvm/script/parser/core/evaluator.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import ast
2020
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
2121

22+
import tvm
23+
2224
from . import dispatch, doc
2325
from .error import ParserError
2426

@@ -173,18 +175,19 @@ def _visit(self, node: doc.AST) -> Any:
173175
isinstance(node, doc.Call)
174176
and hasattr(node.func, "attr")
175177
and node.func.attr not in ["reads", "writes", "match_buffer", "realize"]
176-
) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare, doc.BoolOp)):
178+
) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare, doc.BoolOp, doc.IfExp)):
177179
if isinstance(node, doc.BinOp):
178180
args = [node.left, node.right]
179181
elif isinstance(node, doc.UnaryOp):
180182
args = [node.operand]
181183
elif isinstance(node, doc.Compare):
182184
args = [node.left, *node.comparators]
183-
else:
184-
if isinstance(node, doc.Call):
185-
args = node.args
186-
elif isinstance(node, doc.BoolOp):
187-
args = node.values
185+
elif isinstance(node, doc.IfExp):
186+
args = [node.test, node.body, node.orelse]
187+
elif isinstance(node, doc.Call):
188+
args = node.args
189+
elif isinstance(node, doc.BoolOp):
190+
args = node.values
188191
for arg in args:
189192
if isinstance(arg, doc.Subscript) and isinstance(arg.slice, (doc.Slice, doc.Tuple)):
190193
if isinstance(arg.slice, doc.Slice):
@@ -256,6 +259,8 @@ def _visit(self, node: doc.AST) -> Any:
256259
value = self._eval_unary_op(fields)
257260
elif isinstance(node, doc.BinOp):
258261
value = self._eval_bin_op(fields)
262+
elif isinstance(node, doc.IfExp):
263+
value = self._eval_if_exp(fields)
259264
elif isinstance(node, doc.Slice):
260265
value = self._eval_slice(fields)
261266
else:
@@ -364,6 +369,30 @@ def _eval_bin_op(self, fields: Dict[str, Any]) -> Any:
364369
],
365370
)
366371

372+
def _eval_if_exp(self, fields: Dict[str, Any]) -> Any:
373+
"""The doc AST if-else expression node evaluating method.
374+
375+
Parameters
376+
----------
377+
fields : Dict[str, Any]
378+
The dictionary of if-else expression information,
379+
e.g., test, body, orelse.
380+
381+
Returns
382+
-------
383+
res : Any
384+
The evaluation result.
385+
"""
386+
test = self._eval_expr(fields["test"])
387+
body = self._eval_expr(fields["body"])
388+
orelse = self._eval_expr(fields["orelse"])
389+
if isinstance(test, bool):
390+
return body if test else orelse
391+
elif isinstance(test, tvm.tir.PrimExpr) and test.dtype == "bool":
392+
return tvm.tir.op.if_then_else(test, body, orelse)
393+
else:
394+
raise TypeError(f"Expected Python bool or TIR bool, but got {type(test)}")
395+
367396
def _eval_slice(self, fields: Dict[str, Any]) -> slice:
368397
"""The doc AST slice node evaluating method.
369398

tests/python/tvmscript/test_tvmscript_parser_tir.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,5 +612,19 @@ def expected() -> None:
612612
tvm.ir.assert_structural_equal(func, expected)
613613

614614

615+
def test_ifexp():
616+
@T.prim_func(private=True)
617+
def func(A: T.buffer((128, 128), "float32")):
618+
for i, j in T.grid(128, 128):
619+
A[i, j] = i if i < j else j
620+
621+
@T.prim_func(private=True)
622+
def expected(A: T.buffer((128, 128), "float32")):
623+
for i, j in T.grid(128, 128):
624+
A[i, j] = T.if_then_else(i < j, i, j)
625+
626+
tvm.ir.assert_structural_equal(func, expected)
627+
628+
615629
if __name__ == "__main__":
616630
tvm.testing.main()

0 commit comments

Comments
 (0)