Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions python/tvm/script/parser/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import ast
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union

import tvm

from . import dispatch, doc
from .error import ParserError

Expand Down Expand Up @@ -173,18 +175,19 @@ def _visit(self, node: doc.AST) -> Any:
isinstance(node, doc.Call)
and hasattr(node.func, "attr")
and node.func.attr not in ["reads", "writes", "match_buffer", "realize"]
) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare, doc.BoolOp)):
) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare, doc.BoolOp, doc.IfExp)):
if isinstance(node, doc.BinOp):
args = [node.left, node.right]
elif isinstance(node, doc.UnaryOp):
args = [node.operand]
elif isinstance(node, doc.Compare):
args = [node.left, *node.comparators]
else:
if isinstance(node, doc.Call):
args = node.args
elif isinstance(node, doc.BoolOp):
args = node.values
elif isinstance(node, doc.IfExp):
args = [node.test, node.body, node.orelse]
elif isinstance(node, doc.Call):
args = node.args
elif isinstance(node, doc.BoolOp):
args = node.values
for arg in args:
if isinstance(arg, doc.Subscript) and isinstance(arg.slice, (doc.Slice, doc.Tuple)):
if isinstance(arg.slice, doc.Slice):
Expand Down Expand Up @@ -256,6 +259,8 @@ def _visit(self, node: doc.AST) -> Any:
value = self._eval_unary_op(fields)
elif isinstance(node, doc.BinOp):
value = self._eval_bin_op(fields)
elif isinstance(node, doc.IfExp):
value = self._eval_if_exp(fields)
elif isinstance(node, doc.Slice):
value = self._eval_slice(fields)
else:
Expand Down Expand Up @@ -364,6 +369,30 @@ def _eval_bin_op(self, fields: Dict[str, Any]) -> Any:
],
)

def _eval_if_exp(self, fields: Dict[str, Any]) -> Any:
"""The doc AST if-else expression node evaluating method.

Parameters
----------
fields : Dict[str, Any]
The dictionary of if-else expression information,
e.g., test, body, orelse.

Returns
-------
res : Any
The evaluation result.
"""
test = self._eval_expr(fields["test"])
body = self._eval_expr(fields["body"])
orelse = self._eval_expr(fields["orelse"])
if isinstance(test, bool):
return body if test else orelse
elif isinstance(test, tvm.tir.PrimExpr) and test.dtype == "bool":
return tvm.tir.op.if_then_else(test, body, orelse)
else:
raise TypeError(f"Expected Python bool or TIR bool, but got {type(test)}")

def _eval_slice(self, fields: Dict[str, Any]) -> slice:
"""The doc AST slice node evaluating method.

Expand Down
14 changes: 14 additions & 0 deletions tests/python/tvmscript/test_tvmscript_parser_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,5 +612,19 @@ def expected() -> None:
tvm.ir.assert_structural_equal(func, expected)


def test_ifexp():
@T.prim_func(private=True)
def func(A: T.buffer((128, 128), "float32")):
for i, j in T.grid(128, 128):
A[i, j] = i if i < j else j

@T.prim_func(private=True)
def expected(A: T.buffer((128, 128), "float32")):
for i, j in T.grid(128, 128):
A[i, j] = T.if_then_else(i < j, i, j)

tvm.ir.assert_structural_equal(func, expected)


if __name__ == "__main__":
tvm.testing.main()
Loading