diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1de959acd..f581abf84 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -80,8 +80,8 @@ jobs: env: ASCEND_HOME_PATH: /usr/local/Ascend/cann-8.5.0 PTOAS_ROOT: ${{ github.workspace }}/ptoas-bin - PTOAS_VERSION: v0.36 - PTOAS_SHA256: 698f753b67ca4387e2e4ef96dfdbb35d71295e94f9f5b20a545b34647f548efc + PTOAS_VERSION: v0.37 + PTOAS_SHA256: 60ddce76c69b6aba847f96dbbdbce2b3173cb5fb4143c6c5bd1a87a7176d8514 CMAKE_BUILD_PARALLEL_LEVEL: 16 CMAKE_C_COMPILER_LAUNCHER: ccache CMAKE_CXX_COMPILER_LAUNCHER: ccache @@ -171,8 +171,8 @@ jobs: env: ASCEND_HOME_PATH: /usr/local/Ascend/cann-8.5.0 PTOAS_ROOT: ${{ github.workspace }}/ptoas-bin - PTOAS_VERSION: v0.36 - PTOAS_SHA256: 698f753b67ca4387e2e4ef96dfdbb35d71295e94f9f5b20a545b34647f548efc + PTOAS_VERSION: v0.37 + PTOAS_SHA256: 60ddce76c69b6aba847f96dbbdbce2b3173cb5fb4143c6c5bd1a87a7176d8514 CMAKE_BUILD_PARALLEL_LEVEL: 16 CMAKE_C_COMPILER_LAUNCHER: ccache CMAKE_CXX_COMPILER_LAUNCHER: ccache @@ -264,8 +264,8 @@ jobs: runs-on: ubuntu-latest env: PTOAS_ROOT: ${{ github.workspace }}/ptoas-bin - PTOAS_VERSION: v0.36 - PTOAS_SHA256: 07bfedea5a9ba70266925ead70d87d129fc143f4e9ea280651d28cb2942a1055 + PTOAS_VERSION: v0.37 + PTOAS_SHA256: a4f52a6f2088e451ebb06a783141d556001d320df0a8081f09e4d68f361a12c2 container: image: ghcr.io/hw-native-sys/pypto/github-ci:latest steps: diff --git a/docs/en/dev/ir/05-operators.md b/docs/en/dev/ir/05-operators.md index 6a467a201..b212bf58a 100644 --- a/docs/en/dev/ir/05-operators.md +++ b/docs/en/dev/ir/05-operators.md @@ -406,6 +406,27 @@ class CrossCoreExample: See [TPUSH/TPOP ISA Reference](../../reference/pto-isa/01-tpush_tpop.md) and [Buffer Management](../../reference/pto-isa/02-buffer_management.md) for hardware details. +### Cross-Rank Signal Operations + +| Operation | Args | Description | Kwargs | +| --------- | ---- | ----------- | ------ | +| `tile.comm_notify` | 2 (signal, value) | Write or atomic-add an INT32 value into a remote rank's signal slot | `op` (`"atomic_add"` or `"set"`) | +| `tile.comm_wait` | 2 (signal, cmp_value) | Block until a local INT32 signal slot satisfies the given comparison | `cmp` (`"eq"`, `"ne"`, `"gt"`, `"ge"`, `"lt"`, `"le"`) | +| `tile.comm_test` | 2 (signal, cmp_value) | Non-blocking poll: returns BOOL = (local INT32 signal slot `` cmp_value) | `cmp` (`"eq"`, `"ne"`, `"gt"`, `"ge"`, `"lt"`, `"le"`) | + +For all three ops, `signal` is a 1-element INT32 tensor that views a GM signal slot. `tile.comm_notify` targets a remote rank's slot (typically obtained via `pl.import_peer_buffer`); `tile.comm_wait` / `tile.comm_test` poll the local rank's slot. The integer operand (`value` / `cmp_value`) is a Python int, `Scalar`, or `Expr`. They lower to `pto::comm::TNOTIFY` / `pto::comm::TWAIT` / `pto::comm::TTEST` on the AIV side. `tile.comm_test` returns `pl.Scalar[pl.BOOL]` (PTO `i1`); the others have no return value. + +**Pipeline ordering note.** Cross-rank communication ops require pipe-level ordering between GM payload writes and signal writes (the cross-rank done-barrier pattern). Consistent with the rest of PyPTO, pipe synchronization is **not** inserted at the IR or codegen layer — it is the responsibility of the downstream PTOAS lowering. PyPTO users do not need to (and cannot) manually insert pipe barriers around `comm_notify` / `comm_wait` / `comm_test`. + +```python +import pypto.language as pl + +# inside an InCore function on AIV side: +pl.tile.comm_notify(remote_signal, 1, op="atomic_add") # producer side +pl.tile.comm_wait(local_signal, 1, cmp="ge") # consumer side (blocking) +ok = pl.tile.comm_test(local_signal, 1, cmp="ge") # consumer side (non-blocking, BOOL) +``` + ## File Organization | Directory/File | Contents | diff --git a/docs/zh-cn/dev/ir/05-operators.md b/docs/zh-cn/dev/ir/05-operators.md index f8e471479..384492d1b 100644 --- a/docs/zh-cn/dev/ir/05-operators.md +++ b/docs/zh-cn/dev/ir/05-operators.md @@ -400,6 +400,27 @@ class CrossCoreExample: 参阅 [TPUSH/TPOP ISA 参考](../../reference/pto-isa/01-tpush_tpop.md) 和[缓冲区管理](../../reference/pto-isa/02-buffer_management.md)了解硬件细节。 +### 跨 Rank 信号操作 + +| 操作 | 参数 | 说明 | Kwargs | +| ---- | ---- | ---- | ------ | +| `tile.comm_notify` | 2 (signal, value) | 向远端 rank 信号槽写入或原子加 INT32 值 | `op`(`"atomic_add"` 或 `"set"`) | +| `tile.comm_wait` | 2 (signal, cmp_value) | 阻塞直至本地 INT32 信号槽满足给定比较 | `cmp`(`"eq"`、`"ne"`、`"gt"`、`"ge"`、`"lt"`、`"le"`) | +| `tile.comm_test` | 2 (signal, cmp_value) | 非阻塞轮询:返回 BOOL = (本地 INT32 信号槽 `` cmp_value) | `cmp`(`"eq"`、`"ne"`、`"gt"`、`"ge"`、`"lt"`、`"le"`) | + +三个 op 的 `signal` 都是一个 1 元素 INT32 tensor,视图指向 GM 中的信号槽:`tile.comm_notify` 写远端 rank 的槽(通常通过 `pl.import_peer_buffer` 获取),`tile.comm_wait` / `tile.comm_test` 轮询本地 rank 的槽。整数操作数(`value` / `cmp_value`)可以是 Python `int`、`Scalar` 或 `Expr`。在 AIV 侧分别 lowering 为 `pto::comm::TNOTIFY` / `pto::comm::TWAIT` / `pto::comm::TTEST`。`tile.comm_test` 返回 `pl.Scalar[pl.BOOL]`(PTO `i1`),其余两者无返回值。 + +**流水排序说明。** 跨 rank 通信 op 需要在 GM payload 写与 signal 写之间保证 pipe 级别的顺序(即跨 rank done-barrier 模式)。与 PyPTO 其它部分一致,pipe 同步**不在** IR 或 codegen 层插入,而由下游的 PTOAS 在 lowering 阶段处理。PyPTO 用户无需(也无法)手工在 `comm_notify` / `comm_wait` / `comm_test` 周围插入 pipe-barrier。 + +```python +import pypto.language as pl + +# 在 AIV 侧 InCore 函数内部: +pl.tile.comm_notify(remote_signal, 1, op="atomic_add") # 生产者 +pl.tile.comm_wait(local_signal, 1, cmp="ge") # 消费者(阻塞) +ok = pl.tile.comm_test(local_signal, 1, cmp="ge") # 消费者(非阻塞,返回 BOOL) +``` + ## 文件组织 | 目录/文件 | 内容 | diff --git a/python/pypto/ir/op/tile_ops.py b/python/pypto/ir/op/tile_ops.py index d18200dbe..da3f05056 100644 --- a/python/pypto/ir/op/tile_ops.py +++ b/python/pypto/ir/op/tile_ops.py @@ -2323,6 +2323,72 @@ def tpop_from_aiv( return _ir_core.create_op_call("tile.tpop_from_aiv", [], kwargs, actual_span) +_NOTIFY_OPS = ("atomic_add", "set") + + +def comm_notify(signal: Expr, value: Expr, *, op: str, span: Span | None = None) -> Call: + """Send a flag notification to a remote rank's signal slot. + + Lowers to ``pto::comm::TNOTIFY`` via PTOAS ``pto.comm.tnotify``. The + signal is a 1-element INT32 Tensor (GM) that views the destination rank's + signal location in its HCCL window — typically obtained via + ``pl.import_peer_buffer``. + + Args: + signal: Destination signal tensor (1-element INT32) in remote rank's window + value: INT32 scalar value to write or atomic-add + op: Notify operation, ``"atomic_add"`` or ``"set"`` + span: Optional source span + """ + if op not in _NOTIFY_OPS: + raise ValueError(f"tile.comm_notify: op must be one of {_NOTIFY_OPS}, got {op!r}") + actual_span = _get_span_or_capture(span, frame_offset=1) + return _ir_core.create_op_call("tile.comm_notify", [signal, value], {"op": op}, actual_span) + + +_WAIT_CMPS = ("eq", "ne", "gt", "ge", "lt", "le") + + +def comm_wait(signal: Expr, cmp_value: Expr, *, cmp: str, span: Span | None = None) -> Call: + """Block until a local INT32 signal slot satisfies a comparison. + + Lowers to ``pto::comm::TWAIT`` via PTOAS ``pto.comm.twait``. The signal + is a 1-element INT32 Tensor (GM) in the local rank's window — the slot + peers ``tile.comm_notify`` into. + + Args: + signal: Local signal tensor (1-element INT32) to poll + cmp_value: INT32 scalar comparison value + cmp: Comparison predicate, one of ``"eq"`` | ``"ne"`` | ``"gt"`` | + ``"ge"`` | ``"lt"`` | ``"le"`` + span: Optional source span + """ + if cmp not in _WAIT_CMPS: + raise ValueError(f"tile.comm_wait: cmp must be one of {_WAIT_CMPS}, got {cmp!r}") + actual_span = _get_span_or_capture(span, frame_offset=1) + return _ir_core.create_op_call("tile.comm_wait", [signal, cmp_value], {"cmp": cmp}, actual_span) + + +def comm_test(signal: Expr, cmp_value: Expr, *, cmp: str, span: Span | None = None) -> Call: + """Non-blocking poll of a local INT32 signal slot, returning a BOOL. + + Lowers to ``pto::comm::TTEST`` via PTOAS ``pto.comm.ttest``. Same operand + shape as :func:`comm_wait`, but does not block — the result is BOOL and + equals ``signal cmp_value``. + + Args: + signal: Local signal tensor (1-element INT32) to poll + cmp_value: INT32 scalar comparison value + cmp: Comparison predicate, one of ``"eq"`` | ``"ne"`` | ``"gt"`` | + ``"ge"`` | ``"lt"`` | ``"le"`` + span: Optional source span + """ + if cmp not in _WAIT_CMPS: + raise ValueError(f"tile.comm_test: cmp must be one of {_WAIT_CMPS}, got {cmp!r}") + actual_span = _get_span_or_capture(span, frame_offset=1) + return _ir_core.create_op_call("tile.comm_test", [signal, cmp_value], {"cmp": cmp}, actual_span) + + # ============================================================================ # Sorting Operations # ============================================================================ diff --git a/python/pypto/language/op/system_ops.py b/python/pypto/language/op/system_ops.py index f7590f5b3..a067b79c2 100644 --- a/python/pypto/language/op/system_ops.py +++ b/python/pypto/language/op/system_ops.py @@ -15,6 +15,7 @@ """ from pypto.ir.op import system_ops as _ir_ops +from pypto.ir.op import tile_ops as _ir_tile_ops from pypto.ir.op.system_ops import ( AUTO, aic_initialize_pipe, @@ -26,9 +27,9 @@ sync_src, ) from pypto.pypto_core import DataType -from pypto.pypto_core.ir import Call, Span +from pypto.pypto_core.ir import Call, ConstInt, Expr, Span -from ..typing import Scalar, Tile +from ..typing import Scalar, Tensor, Tile __all__ = [ "AUTO", @@ -47,6 +48,9 @@ "import_peer_buffer", "tfree_to_aic", "tfree_to_aiv", + "comm_notify", + "comm_wait", + "comm_test", ] @@ -143,3 +147,116 @@ def import_peer_buffer(*, name: str, peer_func: str, span: Span | None = None) - """ call = _ir_ops.import_peer_buffer(name=name, peer_func=peer_func, span=span) return Scalar(DataType.INT32, call) + + +def _value_to_int32_expr(value: int | Scalar | Expr, arg_name: str) -> Expr: + """Coerce an ``int | Scalar | Expr`` argument to an INT32 ``Expr``. + + Frontend callers can pass any of the three forms; the IR binding expects a + single ``Expr`` whose ScalarType dtype is ``INT32``. The DSL parser turns + literal Python ints into ``ConstInt`` with ``DataType.INDEX`` by default, + so an integer constant arriving here is rewrapped as ``INT32`` to satisfy + the IR-level contract of ``tile.comm_notify`` / ``tile.comm_wait`` / + ``tile.comm_test``. Non-constant ``Expr`` and ``Scalar`` values are + passed through unchanged. + """ + if isinstance(value, Scalar): + return value.unwrap() + if isinstance(value, ConstInt): + return ConstInt(int(value.value), DataType.INT32, value.span or Span.unknown()) + if isinstance(value, Expr): + return value + if isinstance(value, int) and not isinstance(value, bool): + return ConstInt(value, DataType.INT32, Span.unknown()) + raise TypeError(f"Argument '{arg_name}' must be int, pl.Scalar, or pl.Expr, got {type(value).__name__}") + + +def comm_notify( + signal: Tensor, + value: int | Scalar | Expr, + *, + op: str, + span: Span | None = None, +) -> Call: + """Send a flag notification to a remote rank's signal slot. + + Lowers to ``pto::comm::TNOTIFY`` via PTOAS ``pto.comm.tnotify``. The + signal is a 1-element INT32 ``pl.Tensor`` (GM) that views the destination + rank's signal location in its HCCL window — typically obtained via + :func:`import_peer_buffer`. + + Note: + Cross-rank communication ops require pipe-level ordering between GM + payload writes and signal writes (the cross-rank done-barrier + pattern). Consistent with the rest of PyPTO, pipe synchronization is + **not** inserted at the IR or codegen layer — it is the + responsibility of the downstream PTOAS lowering. Users do not need + to (and cannot) manually insert pipe barriers around ``comm_notify``. + + Args: + signal: Destination signal tensor (1-element INT32) in remote rank's window. + value: INT32 scalar value to write or atomic-add (Python int, Scalar, or Expr). + op: Notify operation, ``"atomic_add"`` or ``"set"``. + span: Optional source span. + + Returns: + The IR ``Call`` for ``tile.comm_notify`` (used for its side effect; no return value). + """ + return _ir_tile_ops.comm_notify(signal.unwrap(), _value_to_int32_expr(value, "value"), op=op, span=span) + + +def comm_wait( + signal: Tensor, + cmp_value: int | Scalar | Expr, + *, + cmp: str, + span: Span | None = None, +) -> Call: + """Block until a local INT32 signal slot satisfies a comparison. + + Lowers to ``pto::comm::TWAIT`` via PTOAS ``pto.comm.twait``. The signal + is a 1-element INT32 ``pl.Tensor`` (GM) in the local rank's window — the + slot peers ``pl.tile.comm_notify`` into. + + Args: + signal: Local signal tensor (1-element INT32) to poll. + cmp_value: INT32 scalar comparison value (Python int, Scalar, or Expr). + cmp: Comparison predicate, one of ``"eq"`` | ``"ne"`` | ``"gt"`` | + ``"ge"`` | ``"lt"`` | ``"le"``. + span: Optional source span. + + Returns: + The IR ``Call`` for ``tile.comm_wait`` (used for its side effect; no return value). + """ + return _ir_tile_ops.comm_wait( + signal.unwrap(), _value_to_int32_expr(cmp_value, "cmp_value"), cmp=cmp, span=span + ) + + +def comm_test( + signal: Tensor, + cmp_value: int | Scalar | Expr, + *, + cmp: str, + span: Span | None = None, +) -> Scalar: + """Non-blocking poll of a local INT32 signal slot, returning a BOOL Scalar. + + Lowers to ``pto::comm::TTEST`` via PTOAS ``pto.comm.ttest``. Same operand + shape as :func:`comm_wait`, but does not block — the result is + ``pl.Scalar[pl.BOOL]`` and equals ``signal cmp_value``. + + Args: + signal: Local signal tensor (1-element INT32) to poll. + cmp_value: INT32 scalar comparison value (Python int, Scalar, or Expr). + cmp: Comparison predicate, one of ``"eq"`` | ``"ne"`` | ``"gt"`` | + ``"ge"`` | ``"lt"`` | ``"le"``. + span: Optional source span. + + Returns: + ``pl.Scalar[pl.BOOL]`` wrapping the ``tile.comm_test`` IR call (PTO ``... -> i1``). + """ + call = _ir_tile_ops.comm_test( + signal.unwrap(), _value_to_int32_expr(cmp_value, "cmp_value"), cmp=cmp, span=span + ) + return Scalar(DataType.BOOL, call) diff --git a/python/pypto/language/op/tile_ops.py b/python/pypto/language/op/tile_ops.py index fe56fd125..73d05ddba 100644 --- a/python/pypto/language/op/tile_ops.py +++ b/python/pypto/language/op/tile_ops.py @@ -122,6 +122,9 @@ "tpush_to_aic", "tpop_from_aic", "tpop_from_aiv", + "comm_notify", + "comm_wait", + "comm_test", "sort32", "gather", "gather_mask", @@ -139,6 +142,9 @@ from ..typing import IntLike, Scalar, Tensor, Tile from .system_ops import ( # noqa: F401 + comm_notify, + comm_test, + comm_wait, tpop_from_aic, tpop_from_aiv, tpush_to_aic, diff --git a/src/backend/common/pto_ops_common.cpp b/src/backend/common/pto_ops_common.cpp index 89a240b24..2eea0cf9e 100644 --- a/src/backend/common/pto_ops_common.cpp +++ b/src/backend/common/pto_ops_common.cpp @@ -1939,6 +1939,167 @@ static std::string GetPipeBufOperandI32SSA(codegen::PTOCodegen& codegen, const i return codegen.GetExprAsCode(expr); } +// Lower a comm-op signal Var to a !pto.partition_tensor_view covering the +// full ranked tensor (offsets=[0,...], sizes=[shape...]). PTOAS requires the +// signal operand of pto.comm.tnotify / pto.comm.twait to be a partition view, +// not a !pto.ptr or a raw !pto.tensor_view. The lowering chain is: +// %arg : !pto.ptr +// → pto.make_tensor_view → !pto.tensor_view (GetOrCreateTensorView) +// → pto.partition_view → !pto.partition_tensor_view +static std::string EmitCommSignalPartitionView(const ir::ExprPtr& signal_arg, + const ir::TensorType* signal_type, + codegen::PTOCodegen& codegen, const ir::Span& span, + const char* op_name) { + auto signal_var = ir::AsVarLike(signal_arg); + INTERNAL_CHECK_SPAN(signal_var, span) << op_name << " signal must be a Var or IterArg (kernel-arg tensor)"; + + std::string tensor_view = codegen.GetOrCreateTensorView(signal_var); + std::string tensor_view_type = codegen.GetTensorViewTypeString(signal_type); + std::string dtype_str = codegen.GetTypeString(signal_type->dtype_); + + // Build offsets=[%c0, ...] and sizes=[shape...] in index dtype so partition_view + // covers the entire signal region (typically rank-1, length 1 for HCCL slots). + std::vector offset_codes; + std::vector size_codes; + std::vector dim_strings; + offset_codes.reserve(signal_type->shape_.size()); + size_codes.reserve(signal_type->shape_.size()); + dim_strings.reserve(signal_type->shape_.size()); + for (const auto& dim_expr : signal_type->shape_) { + offset_codes.push_back(codegen.GetOrEmitConstant(int64_t{0}, DataType::INDEX)); + if (auto c = ir::As(dim_expr)) { + size_codes.push_back(codegen.GetOrEmitConstant(c->value_, DataType::INDEX)); + dim_strings.push_back(std::to_string(c->value_)); + } else { + size_codes.push_back(codegen.GetExprAsCode(dim_expr)); + dim_strings.emplace_back("?"); + } + } + std::string partition_type = MakePartitionTensorViewType(dim_strings, dtype_str); + return EmitPartitionViewPTO(signal_var->name_hint_, tensor_view, tensor_view_type, partition_type, + offset_codes, size_codes, codegen) + + " : " + partition_type; +} + +// Helper: split "ssa : type" into separate {ssa, type} components. EmitCommSignalPartitionView +// returns a packed "ssa : type" so we keep the partition_type string in lock-step with the +// SSA name it was emitted for. +static std::pair SplitSSAAndType(const std::string& ssa_with_type) { + auto pos = ssa_with_type.rfind(" : "); + INTERNAL_CHECK(pos != std::string::npos) << "Expected 'ssa : type' format, got '" << ssa_with_type << "'"; + return {ssa_with_type.substr(0, pos), ssa_with_type.substr(pos + 3)}; +} + +// Recover the signal TensorType for codegen. The full operand contract +// (rank >= 1, INT32, single-slot) is enforced at IR construction by the op's +// f_deduce_type (see src/ir/op/tile_ops/cross_core.cpp::CheckCommSignalArgs). +// Any violation reaching codegen is a PyPTO bug, so we only assert here that +// the type is recoverable for downstream emission. +static ir::TensorTypePtr CheckCommSignalType(const ir::ExprPtr& signal_arg, codegen::PTOCodegen& codegen, + const ir::Span& span, const char* op_name) { + auto signal_tensor_type = As(signal_arg->GetType()); + INTERNAL_CHECK_SPAN(signal_tensor_type, span) + << op_name + << " signal must be a TensorType (GM signal slot); IR validator should have rejected this earlier"; + return signal_tensor_type; +} + +// tile.comm_notify: cross-rank signal write/atomic-add → pto.comm.tnotify +// signal: 1+-dim INT32 Tensor viewing remote rank's HCCL signal slot (lowered +// to !pto.partition_tensor_view) +// value: signless integer scalar (ConstInt or i32 SSA) +// kwarg `op`: "atomic_add" | "set" → MLIR enum #pto.notify_op<...> +static std::string MakeTileNotifyCodegenPTO(const CallPtr& op, codegen::CodegenBase& codegen_base) { + auto& codegen = dynamic_cast(codegen_base); + + INTERNAL_CHECK_SPAN(op->args_.size() == 2, op->span_) + << "tile.comm_notify requires 2 arguments (signal, value), got " << op->args_.size(); + + const auto notify_op = op->GetKwarg("op"); + INTERNAL_CHECK_SPAN(notify_op == "atomic_add" || notify_op == "set", op->span_) + << "tile.comm_notify 'op' attribute must be 'atomic_add' or 'set', got '" << notify_op << "'"; + + auto signal_type = CheckCommSignalType(op->args_[0], codegen, op->span_, "tile.comm_notify"); + + auto [sig_ssa, sig_type] = SplitSSAAndType( + EmitCommSignalPartitionView(op->args_[0], signal_type.get(), codegen, op->span_, "tile.comm_notify")); + std::string val = GetPipeBufOperandI32SSA(codegen, op->args_[1]); + + // PTOAS custom assembly: pto.comm.tnotify(%sig, %v : , i32) {notifyOp = #pto} + std::ostringstream oss; + oss << "pto.comm.tnotify(" << sig_ssa << ", " << val << " : " << sig_type << ", i32)" + << " {notifyOp = #pto}"; + codegen.Emit(oss.str()); + return ""; +} + +// tile.comm_wait: cross-rank signal poll → pto.comm.twait +// signal: 1+-dim INT32 Tensor in local rank's HCCL window (lowered to +// !pto.partition_tensor_view) +// cmp_value: signless integer scalar (ConstInt or i32 SSA) +// kwarg `cmp`: "eq"|"ne"|"gt"|"ge"|"lt"|"le" → MLIR enum #pto.wait_cmp<...> +static std::string MakeTileWaitCodegenPTO(const CallPtr& op, codegen::CodegenBase& codegen_base) { + auto& codegen = dynamic_cast(codegen_base); + + INTERNAL_CHECK_SPAN(op->args_.size() == 2, op->span_) + << "tile.comm_wait requires 2 arguments (signal, cmp_value), got " << op->args_.size(); + + const auto cmp = op->GetKwarg("cmp"); + INTERNAL_CHECK_SPAN(cmp == "eq" || cmp == "ne" || cmp == "gt" || cmp == "ge" || cmp == "lt" || cmp == "le", + op->span_) + << "tile.comm_wait 'cmp' attribute must be one of eq|ne|gt|ge|lt|le, got '" << cmp << "'"; + + auto signal_type = CheckCommSignalType(op->args_[0], codegen, op->span_, "tile.comm_wait"); + + auto [sig_ssa, sig_type] = SplitSSAAndType( + EmitCommSignalPartitionView(op->args_[0], signal_type.get(), codegen, op->span_, "tile.comm_wait")); + std::string val = GetPipeBufOperandI32SSA(codegen, op->args_[1]); + + // PTOAS custom assembly: pto.comm.twait(%sig, %v : , i32) {cmp = #pto} + std::ostringstream oss; + oss << "pto.comm.twait(" << sig_ssa << ", " << val << " : " << sig_type << ", i32)" + << " {cmp = #pto}"; + codegen.Emit(oss.str()); + return ""; +} + +// tile.comm_test: non-blocking cross-rank signal poll → pto.comm.ttest +// signal: 1+-dim INT32 Tensor in local rank's HCCL window (lowered to +// !pto.partition_tensor_view) +// cmp_value: signless integer scalar (ConstInt or i32 SSA) +// kwarg `cmp`: "eq"|"ne"|"gt"|"ge"|"lt"|"le" → MLIR enum #pto +// returns: i1 (BOOL) — true iff `signal cmp_value` +static std::string MakeTileTestCodegenPTO(const CallPtr& op, codegen::CodegenBase& codegen_base) { + auto& codegen = dynamic_cast(codegen_base); + + INTERNAL_CHECK_SPAN(op->args_.size() == 2, op->span_) + << "tile.comm_test requires 2 arguments (signal, cmp_value), got " << op->args_.size(); + + const auto cmp = op->GetKwarg("cmp"); + INTERNAL_CHECK_SPAN(cmp == "eq" || cmp == "ne" || cmp == "gt" || cmp == "ge" || cmp == "lt" || cmp == "le", + op->span_) + << "tile.comm_test 'cmp' attribute must be one of eq|ne|gt|ge|lt|le, got '" << cmp << "'"; + + auto signal_type = CheckCommSignalType(op->args_[0], codegen, op->span_, "tile.comm_test"); + + auto [sig_ssa, sig_type] = SplitSSAAndType( + EmitCommSignalPartitionView(op->args_[0], signal_type.get(), codegen, op->span_, "tile.comm_test")); + std::string val = GetPipeBufOperandI32SSA(codegen, op->args_[1]); + + std::string result_ssa = codegen.GetCurrentResultTarget(); + INTERNAL_CHECK_SPAN(!result_ssa.empty(), op->span_) + << "tile.comm_test result must be bound to a Var (e.g. `ok = pl.tile.comm_test(...)`); " + << "discarding the BOOL return value is not supported"; + + // PTOAS custom assembly: + // %ok = pto.comm.ttest(%sig, %v : , i32) {cmp = #pto} -> i1 + std::ostringstream oss; + oss << result_ssa << " = pto.comm.ttest(" << sig_ssa << ", " << val << " : " << sig_type << ", i32)" + << " {cmp = #pto} -> i1"; + codegen.Emit(oss.str()); + return ""; +} + // Helper to format initialize_pipe operand list static void EmitInitializePipeOperands(std::ostringstream& oss, const std::string& gm_ssa, const std::string& c2v_ssa, const std::string& v2c_ssa) { @@ -2495,6 +2656,15 @@ void RegisterPTOOps(Backend& backend, const std::unordered_set& exc reg("tile.tpop_from_aic", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) { return MakeTpopFromAicCodegenPTO(op, codegen); }); + reg("tile.comm_notify", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) { + return MakeTileNotifyCodegenPTO(op, codegen); + }); + reg("tile.comm_wait", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) { + return MakeTileWaitCodegenPTO(op, codegen); + }); + reg("tile.comm_test", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) { + return MakeTileTestCodegenPTO(op, codegen); + }); reg("system.tfree_to_aic", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) { return MakeTfreeToAicCodegenPTO(op, codegen); }); diff --git a/src/ir/op/tile_ops/cross_core.cpp b/src/ir/op/tile_ops/cross_core.cpp index 3378a8c06..fd6f12b36 100644 --- a/src/ir/op/tile_ops/cross_core.cpp +++ b/src/ir/op/tile_ops/cross_core.cpp @@ -14,8 +14,11 @@ #include #include +#include "pypto/core/dtype.h" +#include "pypto/core/error.h" #include "pypto/ir/core_affinity_kind.h" #include "pypto/ir/expr.h" +#include "pypto/ir/kind_traits.h" #include "pypto/ir/op_registry.h" #include "pypto/ir/type.h" @@ -29,6 +32,84 @@ TypePtr DeduceUnknownType(const std::vector& args, return GetUnknownType(); } +// Shared validation for the (signal: 1-element INT32 tensor, value: INT32 scalar) +// operand contract of tile.comm_notify / tile.comm_wait / tile.comm_test. +void CheckCommSignalArgs(const std::vector& args, const char* op_name, const char* value_arg_name) { + CHECK(args.size() == 2) << op_name << " requires 2 arguments (signal, " << value_arg_name << "), got " + << args.size(); + auto sig_ty = As(args[0]->GetType()); + CHECK(sig_ty) << op_name << " signal must be a TensorType, got " << args[0]->GetType()->TypeName(); + CHECK(sig_ty->dtype_ == DataType::INT32) + << op_name << " signal must be INT32, got " << DataTypeToString(sig_ty->dtype_); + CHECK(!sig_ty->shape_.empty()) << op_name << " signal must be rank >= 1, got rank-0 tensor"; + + // Enforce the single-slot contract: when every shape dim is a ConstInt, + // their product must equal 1. Dynamic dims are allowed (could be 1 at + // runtime) but a statically-known non-singleton extent is rejected here so + // the error surfaces at IR construction instead of late during PTO lowering. + bool all_static = true; + int64_t prod = 1; + for (const auto& d : sig_ty->shape_) { + auto c = As(d); + if (!c) { + all_static = false; + continue; + } + CHECK(c->value_ >= 1) << op_name << " signal shape dim must be positive, got " << c->value_; + CHECK(c->value_ == 1) << op_name << " signal must hold exactly one INT32 slot, got static dim " + << c->value_; + prod *= c->value_; + } + if (all_static) { + CHECK(prod == 1) << op_name << " signal must hold exactly one INT32 slot, got element count " << prod; + } + + auto val_ty = As(args[1]->GetType()); + CHECK(val_ty) << op_name << " " << value_arg_name << " must be a ScalarType, got " + << args[1]->GetType()->TypeName(); + CHECK(val_ty->dtype_ == DataType::INT32) + << op_name << " " << value_arg_name << " must be INT32 scalar, got " + << DataTypeToString(val_ty->dtype_); +} + +TypePtr DeduceTileCommNotifyType(const std::vector& args, + const std::vector>& kwargs) { + CheckCommSignalArgs(args, "tile.comm_notify", "value"); + for (const auto& [k, v] : kwargs) { + if (k == "op") { + const auto& s = std::any_cast(v); + CHECK(s == "atomic_add" || s == "set") + << "tile.comm_notify 'op' attribute must be 'atomic_add' or 'set', got '" << s << "'"; + } + } + return GetUnknownType(); +} + +static void CheckCommCmpKwarg(const std::vector>& kwargs, + const char* op_name) { + for (const auto& [k, v] : kwargs) { + if (k == "cmp") { + const auto& s = std::any_cast(v); + CHECK(s == "eq" || s == "ne" || s == "gt" || s == "ge" || s == "lt" || s == "le") + << op_name << " 'cmp' attribute must be one of eq|ne|gt|ge|lt|le, got '" << s << "'"; + } + } +} + +TypePtr DeduceTileCommWaitType(const std::vector& args, + const std::vector>& kwargs) { + CheckCommSignalArgs(args, "tile.comm_wait", "cmp_value"); + CheckCommCmpKwarg(kwargs, "tile.comm_wait"); + return GetUnknownType(); +} + +TypePtr DeduceTileCommTestType(const std::vector& args, + const std::vector>& kwargs) { + CheckCommSignalArgs(args, "tile.comm_test", "cmp_value"); + CheckCommCmpKwarg(kwargs, "tile.comm_test"); + return std::static_pointer_cast(std::make_shared(DataType::BOOL)); +} + } // namespace // ============================================================================ @@ -83,5 +164,54 @@ REGISTER_OP("tile.tpop_from_aiv") .no_memory_spec() .f_deduce_type(DeduceUnknownType); +// ============================================================================ +// Cross-Rank Signal Operations (notify / wait) +// ============================================================================ + +// Notify a remote rank by writing or atomic-adding a value into its signal slot. +// The signal operand is a 1-element INT32 Tensor that views the destination +// rank's GM signal location (typically obtained via import_peer_buffer); +// codegen lowers this to `pto::comm::TNOTIFY` via PTOAS `pto.comm.tnotify`. +REGISTER_OP("tile.comm_notify") + .set_description( + "Send a flag notification to a remote rank: write or atomic-add an INT32 value " + "into the destination signal slot") + .set_op_category("CrossCoreOp") + .set_core_affinity(core_affinity::CoreAffinity::VECTOR) + .add_argument("signal", "Destination signal tensor (1-element INT32, GM, remote-rank window)") + .add_argument("value", "INT32 scalar value to write or atomic-add") + .set_attr("op") // "atomic_add" | "set" + .no_memory_spec() + .f_deduce_type(DeduceTileCommNotifyType); + +// Block until the local rank's signal slot satisfies `signal cmp_value`. +// The signal operand is a 1-element INT32 Tensor in local-rank GM (the slot +// peers atomic-add or set into via tile.comm_notify); codegen lowers this to +// `pto::comm::TWAIT` via PTOAS `pto.comm.twait`. +REGISTER_OP("tile.comm_wait") + .set_description("Block until a local INT32 signal slot satisfies the given comparison against cmp_value") + .set_op_category("CrossCoreOp") + .set_core_affinity(core_affinity::CoreAffinity::VECTOR) + .add_argument("signal", "Local signal tensor (1-element INT32, GM) to poll") + .add_argument("cmp_value", "INT32 scalar comparison value") + .set_attr("cmp") // "eq" | "ne" | "gt" | "ge" | "lt" | "le" + .no_memory_spec() + .f_deduce_type(DeduceTileCommWaitType); + +// Non-blocking poll of the local signal slot: returns a BOOL result equal to +// `signal cmp_value`. Same operand shape as tile.comm_wait, but does not +// block; codegen lowers this to `pto::comm::TTEST` via PTOAS `pto.comm.ttest`, +// which produces an MLIR `i1`. +REGISTER_OP("tile.comm_test") + .set_description( + "Non-blocking poll: return BOOL = (local INT32 signal slot cmp_value); does not block") + .set_op_category("CrossCoreOp") + .set_core_affinity(core_affinity::CoreAffinity::VECTOR) + .add_argument("signal", "Local signal tensor (1-element INT32, GM) to poll") + .add_argument("cmp_value", "INT32 scalar comparison value") + .set_attr("cmp") // "eq" | "ne" | "gt" | "ge" | "lt" | "le" + .no_memory_spec() + .f_deduce_type(DeduceTileCommTestType); + } // namespace ir } // namespace pypto diff --git a/tests/st/runtime/test_notify_wait.py b/tests/st/runtime/test_notify_wait.py new file mode 100644 index 000000000..b326bcc46 --- /dev/null +++ b/tests/st/runtime/test_notify_wait.py @@ -0,0 +1,196 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- + +"""On-board ST for ``pl.tile.comm_notify`` + ``pl.tile.comm_wait``. + +Single-rank loopback exercises three codegen paths: + +1. ``count_exchange`` — notify-only: atomic-add 5 into a slot pre-set to 3, + expect 8. +2. ``wait_only`` — wait-only: pre-set the slot to 7, ``comm_wait ge 1`` + must return immediately without touching the slot. +3. ``done_barrier`` — notify + wait combined: atomic-add 1, then wait ge 1. +""" + +import os +from typing import Any + +import pypto.language as pl +import pytest +import torch +from harness.core.harness import DataType, PTOTestCase, TensorSpec +from pypto.backend import BackendType +from pypto.ir.pass_manager import OptimizationStrategy + +# These ST cases depend on PTOAS exposing `pto.comm.tnotify` / `pto.comm.twait`, +# which is staged behind a separate PTOAS release. CI runners that pull older +# PTOAS images would otherwise show infrastructure-driven red runs unrelated to +# this PR. Opt in by exporting `PTOAS_HAS_COMM_NOTIFY_WAIT=1` once the runner's +# PTOAS build supports the comm ops. +pytestmark = pytest.mark.skipif( + os.getenv("PTOAS_HAS_COMM_NOTIFY_WAIT") != "1", + reason="Requires PTOAS build with pto.comm.tnotify / pto.comm.twait support", +) + +# --- Programs --- + + +@pl.program +class CountExchangeProgram: + """Pattern 1: atomic-add a count into a remote rank's slot. + + Slot pre-initialized to 3; kernel adds 5; expected final value 8. + """ + + @pl.function(type=pl.FunctionType.InCore) + def kernel( + self, + signal: pl.Out[pl.Tensor[[1], pl.INT32]], + ) -> pl.Tensor[[1], pl.INT32]: + pl.tile.comm_notify(signal, 5, op="atomic_add") + return signal + + @pl.function(type=pl.FunctionType.Orchestration) + def orchestrator( + self, + signal: pl.Out[pl.Tensor[[1], pl.INT32]], + ) -> pl.Tensor[[1], pl.INT32]: + signal = self.kernel(signal) + return signal + + +@pl.program +class WaitOnlyProgram: + """Wait-only: slot pre-set to 7; ``comm_wait ge 1`` returns immediately. + + Isolates the ``pto.comm.twait`` codegen path with no notify in the kernel. + """ + + @pl.function(type=pl.FunctionType.InCore) + def kernel( + self, + signal: pl.Out[pl.Tensor[[1], pl.INT32]], + ) -> pl.Tensor[[1], pl.INT32]: + pl.tile.comm_wait(signal, 1, cmp="ge") + return signal + + @pl.function(type=pl.FunctionType.Orchestration) + def orchestrator( + self, + signal: pl.Out[pl.Tensor[[1], pl.INT32]], + ) -> pl.Tensor[[1], pl.INT32]: + signal = self.kernel(signal) + return signal + + +@pl.program +class DoneBarrierProgram: + """Pattern 2: notify(atomic_add, 1) → wait(ge, 1) — the done barrier. + + Slot pre-initialized to 0; kernel adds 1 then spin-waits on ge 1. + Expected final value 1. + """ + + @pl.function(type=pl.FunctionType.InCore) + def kernel( + self, + signal: pl.Out[pl.Tensor[[1], pl.INT32]], + ) -> pl.Tensor[[1], pl.INT32]: + pl.tile.comm_notify(signal, 1, op="atomic_add") + pl.tile.comm_wait(signal, 1, cmp="ge") + return signal + + @pl.function(type=pl.FunctionType.Orchestration) + def orchestrator( + self, + signal: pl.Out[pl.Tensor[[1], pl.INT32]], + ) -> pl.Tensor[[1], pl.INT32]: + signal = self.kernel(signal) + return signal + + +# --- Test cases --- + + +class _CommSignalBase(PTOTestCase): + __test__ = False + + def get_strategy(self) -> OptimizationStrategy: + return OptimizationStrategy.Default + + def get_backend_type(self) -> BackendType: + return BackendType.Ascend910B + + +class CountExchangeTestCase(_CommSignalBase): + def get_name(self) -> str: + return "comm_signal_count_exchange" + + def define_tensors(self) -> list[TensorSpec]: + # Pre-initialize the slot to 3; atomic_add 5 → 8. + return [TensorSpec("signal", [1], DataType.INT32, init_value=3, is_output=True)] + + def get_program(self) -> Any: + return CountExchangeProgram + + def compute_expected(self, tensors, params=None): + tensors["signal"][:] = torch.tensor([8], dtype=torch.int32) + + +class DoneBarrierTestCase(_CommSignalBase): + def get_name(self) -> str: + return "comm_signal_done_barrier" + + def define_tensors(self) -> list[TensorSpec]: + # Pre-initialize the slot to 0; atomic_add 1 → 1; wait ge 1 returns immediately. + return [TensorSpec("signal", [1], DataType.INT32, init_value=0, is_output=True)] + + def get_program(self) -> Any: + return DoneBarrierProgram + + def compute_expected(self, tensors, params=None): + tensors["signal"][:] = torch.tensor([1], dtype=torch.int32) + + +class WaitOnlyTestCase(_CommSignalBase): + def get_name(self) -> str: + return "comm_signal_wait_only" + + def define_tensors(self) -> list[TensorSpec]: + # Pre-initialize the slot to 7; wait ge 1 returns immediately, slot unchanged. + return [TensorSpec("signal", [1], DataType.INT32, init_value=7, is_output=True)] + + def get_program(self) -> Any: + return WaitOnlyProgram + + def compute_expected(self, tensors, params=None): + tensors["signal"][:] = torch.tensor([7], dtype=torch.int32) + + +# --- Tests --- + + +class TestCommNotifyWait: + """On-board verification of pl.tile.comm_notify + pl.tile.comm_wait (single-rank loopback).""" + + def test_count_exchange(self, test_runner): + result = test_runner.run(CountExchangeTestCase()) + assert result.passed, f"Test failed: {result.error}" + + def test_wait_only(self, test_runner): + result = test_runner.run(WaitOnlyTestCase()) + assert result.passed, f"Test failed: {result.error}" + + def test_done_barrier(self, test_runner): + result = test_runner.run(DoneBarrierTestCase()) + assert result.passed, f"Test failed: {result.error}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/ut/codegen/test_pto_codegen_ops.py b/tests/ut/codegen/test_pto_codegen_ops.py index fa3576062..6fe13ddc6 100644 --- a/tests/ut/codegen/test_pto_codegen_ops.py +++ b/tests/ut/codegen/test_pto_codegen_ops.py @@ -1721,5 +1721,183 @@ def kernel( ) +class TestTileCommNotifyPtoCodegen: + """PTO codegen tests for tile.comm_notify (cross-rank signal write/atomic-add). + + Exercises the single notify shape simpler actually uses in real kernels + (``examples/workers/l3/ep_dispatch_combine``): ``op="atomic_add"`` — + used both for count exchange and the done-barrier pattern. + """ + + def _generate_mlir(self, program_cls) -> str: + backend.reset_for_testing() + backend.set_backend_type(BackendType.Ascend910B) + + pm = PassManager.get_strategy(OptimizationStrategy.Default) + optimized = pm.run_passes(program_cls) + codegen_instance = codegen.PTOCodegen() + funcs = list(optimized.functions.values()) + assert funcs, "Program has no functions" + single = ir.Program([funcs[0]], funcs[0].name, optimized.span) + return codegen_instance.generate(single) + + def test_tile_comm_notify_atomic_add_codegen(self): + """pl.tile.comm_notify(sig, v, op='atomic_add') emits #pto. + + Mirrors simpler's TNOTIFY(sig, v, NotifyOp::AtomicAdd) shape used both + for cross-rank count exchange and done-barrier signaling. + """ + + @pl.program + class Prog: + @pl.function(type=pl.FunctionType.InCore) + def kernel_notify_atomic( + self, + signal: pl.Tensor[[1], pl.INT32], + ) -> pl.Tensor[[1], pl.INT32]: + pl.tile.comm_notify(signal, 1, op="atomic_add") + return signal + + mlir = self._generate_mlir(Prog) + assert "pto.comm.tnotify" in mlir, f"pto.comm.tnotify not emitted:\n{mlir}" + assert "#pto" in mlir, f"notifyOp missing:\n{mlir}" + notify_line = next((line for line in mlir.splitlines() if "pto.comm.tnotify" in line), "") + assert "!pto.partition_tensor_view<1xi32>, i32)" in notify_line, ( + f"Expected partition_tensor_view + i32 operand-type list:\n{notify_line}" + ) + + def test_tile_comm_notify_rejects_non_int32_signal(self): + """tile.comm_notify rejects non-INT32 signal tensors at IR construction.""" + + with pytest.raises(Exception, match=r"tile\.comm_notify signal must be INT32"): + + @pl.program + class Prog: + @pl.function(type=pl.FunctionType.InCore) + def kernel_notify_bad_dtype( + self, + signal: pl.Tensor[[1], pl.FP32], + ) -> pl.Tensor[[1], pl.FP32]: + pl.tile.comm_notify(signal, 1, op="atomic_add") + return signal + + +class TestTileCommWaitPtoCodegen: + """PTO codegen tests for tile.comm_wait (cross-rank signal poll). + + Exercises the single wait shape simpler actually uses in real kernels: + ``cmp="ge"`` paired with ``comm_notify(..., op="atomic_add")`` — the + cross-rank done-barrier pattern. + """ + + def _generate_mlir(self, program_cls) -> str: + backend.reset_for_testing() + backend.set_backend_type(BackendType.Ascend910B) + + pm = PassManager.get_strategy(OptimizationStrategy.Default) + optimized = pm.run_passes(program_cls) + codegen_instance = codegen.PTOCodegen() + funcs = list(optimized.functions.values()) + assert funcs, "Program has no functions" + single = ir.Program([funcs[0]], funcs[0].name, optimized.span) + return codegen_instance.generate(single) + + def test_tile_comm_wait_ge_codegen(self): + """pl.tile.comm_wait(sig, 1, cmp='ge') emits pto.comm.twait #pto. + + Mirrors simpler's TWAIT(sig, 1, WaitCmp::GE) — the done-barrier shape. + """ + + @pl.program + class Prog: + @pl.function(type=pl.FunctionType.InCore) + def kernel_wait_ge( + self, + signal: pl.Tensor[[1], pl.INT32], + ) -> pl.Tensor[[1], pl.INT32]: + pl.tile.comm_wait(signal, 1, cmp="ge") + return signal + + mlir = self._generate_mlir(Prog) + assert "pto.comm.twait" in mlir, f"pto.comm.twait not emitted:\n{mlir}" + assert "#pto" in mlir, f"wait_cmp attribute missing:\n{mlir}" + wait_line = next((line for line in mlir.splitlines() if "pto.comm.twait" in line), "") + assert "!pto.partition_tensor_view<1xi32>, i32)" in wait_line, ( + f"Expected partition_tensor_view + i32 operand-type list:\n{wait_line}" + ) + + def test_tile_comm_wait_rejects_non_int32_signal(self): + """tile.comm_wait rejects non-INT32 signal tensors at IR construction.""" + + with pytest.raises(Exception, match=r"tile\.comm_wait signal must be INT32"): + + @pl.program + class Prog: + @pl.function(type=pl.FunctionType.InCore) + def kernel_wait_bad_dtype( + self, + signal: pl.Tensor[[1], pl.FP32], + ) -> pl.Tensor[[1], pl.FP32]: + pl.tile.comm_wait(signal, 1, cmp="ge") + return signal + + +class TestTileCommTestPtoCodegen: + """PTO codegen tests for tile.comm_test (non-blocking signal poll, returns BOOL). + + Lowers to ``pto::comm::TTEST`` via PTOAS ``pto.comm.ttest``. Same operand + shape as ``tile.comm_wait`` but produces an MLIR ``i1`` result. + """ + + def _generate_mlir(self, program_cls) -> str: + backend.reset_for_testing() + backend.set_backend_type(BackendType.Ascend910B) + + pm = PassManager.get_strategy(OptimizationStrategy.Default) + optimized = pm.run_passes(program_cls) + codegen_instance = codegen.PTOCodegen() + funcs = list(optimized.functions.values()) + assert funcs, "Program has no functions" + single = ir.Program([funcs[0]], funcs[0].name, optimized.span) + return codegen_instance.generate(single) + + def test_tile_comm_test_eq_codegen(self): + """pl.tile.comm_test(sig, 1, cmp='eq') emits pto.comm.ttest #pto -> i1.""" + + @pl.program + class Prog: + @pl.function(type=pl.FunctionType.InCore) + def kernel_test_eq( + self, + signal: pl.Tensor[[1], pl.INT32], + ) -> pl.Tensor[[1], pl.INT32]: + ok = pl.tile.comm_test(signal, 1, cmp="eq") # noqa: F841 + return signal + + mlir = self._generate_mlir(Prog) + assert "pto.comm.ttest" in mlir, f"pto.comm.ttest not emitted:\n{mlir}" + assert "#pto" in mlir, f"wait_cmp attribute missing:\n{mlir}" + test_line = next((line for line in mlir.splitlines() if "pto.comm.ttest" in line), "") + assert "!pto.partition_tensor_view<1xi32>, i32)" in test_line, ( + f"Expected partition_tensor_view + i32 operand-type list:\n{test_line}" + ) + assert "-> i1" in test_line, f"Expected i1 return type:\n{test_line}" + + def test_tile_comm_test_rejects_non_int32_signal(self): + """tile.comm_test rejects non-INT32 signal tensors at IR construction.""" + + with pytest.raises(Exception, match=r"tile\.comm_test signal must be INT32"): + + @pl.program + class Prog: + @pl.function(type=pl.FunctionType.InCore) + def kernel_test_bad_dtype( + self, + signal: pl.Tensor[[1], pl.FP32], + ) -> pl.Tensor[[1], pl.FP32]: + ok = pl.tile.comm_test(signal, 1, cmp="eq") # noqa: F841 + return signal + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/ut/ir/operators/test_tile_ops.py b/tests/ut/ir/operators/test_tile_ops.py index 3944f9099..50c530809 100644 --- a/tests/ut/ir/operators/test_tile_ops.py +++ b/tests/ut/ir/operators/test_tile_ops.py @@ -2819,5 +2819,177 @@ def test_tile_arange_alias_is_ci(self): assert pl.tile.arange is pl.tile.ci +class TestTileCommNotifyOp: + """Tests for tile.comm_notify (cross-rank signal write/atomic-add, pto::comm::TNOTIFY). + + Mirrors the two real usage patterns in simpler's + ``examples/workers/l3/ep_dispatch_combine``: + + 1. Count exchange — ``comm_notify(remote_count_slot, n, op="atomic_add")`` + 2. Done barrier set — ``comm_notify(remote_done_slot, 1, op="atomic_add")`` + paired with ``comm_wait(local_done_slot, 1, cmp="ge")`` + + Only ``op="atomic_add"`` is exercised in real kernels; "set" exists in the + op definition for initialization/reset use. + """ + + @staticmethod + def _make_signal_var(span): + """Build a 1-element INT32 tensor var to stand in for a Signal location.""" + dim1 = ir.ConstInt(1, DataType.INT32, span) + signal_type = ir.TensorType([dim1], DataType.INT32) + return ir.Var("signal", signal_type, span) + + def test_tile_comm_notify_atomic_add(self): + """Pattern 1/2 — atomic_add is the only op simpler uses in real kernels.""" + span = ir.Span.unknown() + signal = self._make_signal_var(span) + value = ir.ConstInt(1, DataType.INT32, span) + + call = tile.comm_notify(signal, value, op="atomic_add") + + assert isinstance(call, ir.Call) + assert call.op.name == "tile.comm_notify" + assert len(call.args) == 2 + assert call.args[0] is signal + assert "tile.comm_notify" in str(call) + assert 'op="atomic_add"' in str(call) or "op='atomic_add'" in str(call) + + def test_tile_comm_notify_rejects_invalid_op(self): + span = ir.Span.unknown() + signal = self._make_signal_var(span) + value = ir.ConstInt(1, DataType.INT32, span) + + with pytest.raises(ValueError, match=r"atomic_add.*set"): + tile.comm_notify(signal, value, op="bogus") + + def test_tile_comm_notify_in_program(self): + """End-to-end: tile.comm_notify appears in printed IR of a @pl.program.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + signal_buf: pl.Tensor[[1], pl.INT32], + ) -> pl.Tensor[[1], pl.INT32]: + pl.tile.comm_notify(signal_buf, 1, op="atomic_add") + return signal_buf + + ir_str = str(Program) + assert "tile.comm_notify" in ir_str + assert "atomic_add" in ir_str + + +class TestTileCommWaitOp: + """Tests for tile.comm_wait (cross-rank signal poll, pto::comm::TWAIT). + + Simpler's ``dispatch.cpp`` / ``combine.cpp`` use only ``cmp="ge"`` (paired + with ``comm_notify(..., op="atomic_add")``) — this is the cross-rank + done-barrier pattern. The other comparators (eq/ne/gt/lt/le) exist in the + op definition for completeness but are not exercised by real kernels. + """ + + @staticmethod + def _make_signal_var(span): + """Build a 1-element INT32 tensor var to stand in for a Signal location.""" + dim1 = ir.ConstInt(1, DataType.INT32, span) + signal_type = ir.TensorType([dim1], DataType.INT32) + return ir.Var("signal", signal_type, span) + + def test_tile_comm_wait_ge_done_barrier(self): + """Pattern 2 — wait ge N is the done-barrier shape simpler actually uses.""" + span = ir.Span.unknown() + signal = self._make_signal_var(span) + cmp_value = ir.ConstInt(1, DataType.INT32, span) + + call = tile.comm_wait(signal, cmp_value, cmp="ge") + + assert isinstance(call, ir.Call) + assert call.op.name == "tile.comm_wait" + assert len(call.args) == 2 + assert call.args[0] is signal + assert "tile.comm_wait" in str(call) + assert 'cmp="ge"' in str(call) or "cmp='ge'" in str(call) + + def test_tile_comm_wait_rejects_invalid_cmp(self): + span = ir.Span.unknown() + signal = self._make_signal_var(span) + cmp_value = ir.ConstInt(1, DataType.INT32, span) + + with pytest.raises(ValueError, match=r"eq.*ne.*gt.*ge.*lt.*le"): + tile.comm_wait(signal, cmp_value, cmp="bogus") + + def test_tile_comm_wait_in_program(self): + """End-to-end: tile.comm_wait appears in printed IR of a @pl.program.""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + signal_buf: pl.Tensor[[1], pl.INT32], + ) -> pl.Tensor[[1], pl.INT32]: + pl.tile.comm_wait(signal_buf, 1, cmp="ge") + return signal_buf + + ir_str = str(Program) + assert "tile.comm_wait" in ir_str + assert "ge" in ir_str + + +class TestTileCommTestOp: + """Tests for tile.comm_test (non-blocking cross-rank signal poll, pto::comm::TTEST). + + Same operand shape as tile.comm_wait but does not block — returns a BOOL + Scalar equal to ``signal cmp_value``. + """ + + @staticmethod + def _make_signal_var(span): + dim1 = ir.ConstInt(1, DataType.INT32, span) + signal_type = ir.TensorType([dim1], DataType.INT32) + return ir.Var("signal", signal_type, span) + + def test_tile_comm_test_eq(self): + span = ir.Span.unknown() + signal = self._make_signal_var(span) + cmp_value = ir.ConstInt(1, DataType.INT32, span) + + call = tile.comm_test(signal, cmp_value, cmp="eq") + + assert isinstance(call, ir.Call) + assert call.op.name == "tile.comm_test" + assert len(call.args) == 2 + assert call.args[0] is signal + assert "tile.comm_test" in str(call) + assert 'cmp="eq"' in str(call) or "cmp='eq'" in str(call) + + def test_tile_comm_test_returns_bool_scalar(self): + """DSL wrapper returns pl.Scalar[pl.BOOL].""" + + @pl.program + class Program: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + signal_buf: pl.Tensor[[1], pl.INT32], + ) -> pl.Tensor[[1], pl.INT32]: + ok = pl.tile.comm_test(signal_buf, 1, cmp="eq") # noqa: F841 + return signal_buf + + ir_str = str(Program) + assert "tile.comm_test" in ir_str + assert "eq" in ir_str + + def test_tile_comm_test_rejects_invalid_cmp(self): + span = ir.Span.unknown() + signal = self._make_signal_var(span) + cmp_value = ir.ConstInt(1, DataType.INT32, span) + + with pytest.raises(ValueError, match=r"eq.*ne.*gt.*ge.*lt.*le"): + tile.comm_test(signal, cmp_value, cmp="bogus") + + if __name__ == "__main__": pytest.main([__file__, "-v"])