|
19 | 19 | import ast
|
20 | 20 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
|
21 | 21 |
|
| 22 | +import tvm |
| 23 | + |
22 | 24 | from . import dispatch, doc
|
23 | 25 | from .error import ParserError
|
24 | 26 |
|
@@ -173,18 +175,19 @@ def _visit(self, node: doc.AST) -> Any:
|
173 | 175 | isinstance(node, doc.Call)
|
174 | 176 | and hasattr(node.func, "attr")
|
175 | 177 | and node.func.attr not in ["reads", "writes", "match_buffer", "realize"]
|
176 |
| - ) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare, doc.BoolOp)): |
| 178 | + ) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare, doc.BoolOp, doc.IfExp)): |
177 | 179 | if isinstance(node, doc.BinOp):
|
178 | 180 | args = [node.left, node.right]
|
179 | 181 | elif isinstance(node, doc.UnaryOp):
|
180 | 182 | args = [node.operand]
|
181 | 183 | elif isinstance(node, doc.Compare):
|
182 | 184 | args = [node.left, *node.comparators]
|
183 |
| - else: |
184 |
| - if isinstance(node, doc.Call): |
185 |
| - args = node.args |
186 |
| - elif isinstance(node, doc.BoolOp): |
187 |
| - args = node.values |
| 185 | + elif isinstance(node, doc.IfExp): |
| 186 | + args = [node.test, node.body, node.orelse] |
| 187 | + elif isinstance(node, doc.Call): |
| 188 | + args = node.args |
| 189 | + elif isinstance(node, doc.BoolOp): |
| 190 | + args = node.values |
188 | 191 | for arg in args:
|
189 | 192 | if isinstance(arg, doc.Subscript) and isinstance(arg.slice, (doc.Slice, doc.Tuple)):
|
190 | 193 | if isinstance(arg.slice, doc.Slice):
|
@@ -256,6 +259,8 @@ def _visit(self, node: doc.AST) -> Any:
|
256 | 259 | value = self._eval_unary_op(fields)
|
257 | 260 | elif isinstance(node, doc.BinOp):
|
258 | 261 | value = self._eval_bin_op(fields)
|
| 262 | + elif isinstance(node, doc.IfExp): |
| 263 | + value = self._eval_if_exp(fields) |
259 | 264 | elif isinstance(node, doc.Slice):
|
260 | 265 | value = self._eval_slice(fields)
|
261 | 266 | else:
|
@@ -364,6 +369,30 @@ def _eval_bin_op(self, fields: Dict[str, Any]) -> Any:
|
364 | 369 | ],
|
365 | 370 | )
|
366 | 371 |
|
| 372 | + def _eval_if_exp(self, fields: Dict[str, Any]) -> Any: |
| 373 | + """The doc AST if-else expression node evaluating method. |
| 374 | +
|
| 375 | + Parameters |
| 376 | + ---------- |
| 377 | + fields : Dict[str, Any] |
| 378 | + The dictionary of if-else expression information, |
| 379 | + e.g., test, body, orelse. |
| 380 | +
|
| 381 | + Returns |
| 382 | + ------- |
| 383 | + res : Any |
| 384 | + The evaluation result. |
| 385 | + """ |
| 386 | + test = self._eval_expr(fields["test"]) |
| 387 | + body = self._eval_expr(fields["body"]) |
| 388 | + orelse = self._eval_expr(fields["orelse"]) |
| 389 | + if isinstance(test, bool): |
| 390 | + return body if test else orelse |
| 391 | + elif isinstance(test, tvm.tir.PrimExpr) and test.dtype == "bool": |
| 392 | + return tvm.tir.op.if_then_else(test, body, orelse) |
| 393 | + else: |
| 394 | + raise TypeError(f"Expected Python bool or TIR bool, but got {type(test)}") |
| 395 | + |
367 | 396 | def _eval_slice(self, fields: Dict[str, Any]) -> slice:
|
368 | 397 | """The doc AST slice node evaluating method.
|
369 | 398 |
|
|
0 commit comments