Skip to content

feat(ir): Add tile.notify and tile.wait cross-rank signal ops#1301

Open
Little-oil wants to merge 2 commits into
hw-native-sys:mainfrom
Little-oil:add_notify
Open

feat(ir): Add tile.notify and tile.wait cross-rank signal ops#1301
Little-oil wants to merge 2 commits into
hw-native-sys:mainfrom
Little-oil:add_notify

Conversation

@Little-oil
Copy link
Copy Markdown
Contributor

Summary

Introduces a pair of cross-rank signaling operations on AIV:

  • tile.notify(signal, value, op) — write or atomic-add an INT32 value to a remote rank's signal slot (1-element INT32 GM tensor). Lowers to pto.comm.tnotify with notifyOp = #pto.notify_op<set|atomic_add>.
  • tile.wait(signal, cmp_value, cmp) — block until a local INT32 signal slot satisfies a comparison. Lowers to pto.comm.twait with cmp = #pto.wait_cmp<eq|ne|gt|ge|lt|le>.

Both ops take a 1-element INT32 pl.Tensor viewing a slot in the rank's HCCL window and an INT32 scalar (Python int, Scalar, or Expr).

Layers updated

  • C++ op registrations (src/ir/op/tile_ops/cross_core.cpp)
  • PTO codegen (src/backend/common/pto_ops_common.cpp)
  • Python IR wrappers (python/pypto/ir/op/tile_ops.py)
  • DSL wrappers (python/pypto/language/op/system_ops.py) re-exported via python/pypto/language/op/tile_ops.py
  • UT for IR + PTO codegen (tests/ut/ir/operators/test_tile_ops.py, tests/ut/codegen/test_pto_codegen_ops.py)
  • ST loopback covering all six cmp variants and both notify ops (tests/st/runtime/test_notify_wait.py)
  • English + Chinese docs (docs/en/dev/ir/05-operators.md, docs/zh-cn/dev/ir/05-operators.md)

Testing

  • UT pass: TestTileNotifyOp, TestTileWaitOp, TestTileNotifyPtoCodegen, TestTileWaitPtoCodegen (22/22).
  • On-board ST in tests/st/runtime/test_notify_wait.py requires a PTOAS build that exposes the pto.comm.tnotify / pto.comm.twait custom ops; the currently deployed PTOAS does not have them yet, so on-board ST is staged to enable later.

Notes

  • Lint/format/pyright pre-commit hooks all pass.
  • No new global state, no IR design changes — purely additive op registrations following the existing CrossCoreOp pattern (mirror of tpush_* / tpop_* / tfree_*).

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 7, 2026

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds tile.comm_notify and tile.comm_wait: IR ops, IR-level Python bindings, language DSL wrappers, PTO backend lowering to pto.comm.tnotify/twait, English/Chinese docs, and unit + codegen + runtime tests validating INT32 signal semantics and attributes.

Changes

Tile Signal Operations (tile.comm_notify / tile.comm_wait)

Layer / File(s) Summary
IR Operation Registration
src/ir/op/tile_ops/cross_core.cpp
Registers tile.comm_notify and tile.comm_wait as CrossCoreOp with INT32 signal operands and string attributes (op for notify, cmp for wait).
Python IR Bindings
python/pypto/ir/op/tile_ops.py
Adds _NOTIFY_OPS/_WAIT_CMPS and implements comm_notify(signal, value, *, op, span) and comm_wait(signal, cmp_value, *, cmp, span) with validation, span capture, and IR Call emission.
Language DSL (system_ops)
python/pypto/language/op/system_ops.py
Adds public comm_notify / comm_wait that normalize Python int/Scalar/Expr to INT32 Expr (using ConstInt when needed) and forward to IR-layer functions.
Module Exports (tile_ops)
python/pypto/language/op/tile_ops.py
Imports/forwards new IR-level symbols and adds comm_notify and comm_wait to __all__ to expose pl.tile.comm_notify/pl.tile.comm_wait.
PTO Backend Codegen
src/backend/common/pto_ops_common.cpp
Adds partition-view helper, SSA/type parsing helper, INT32 signal validation, MakeTileNotifyCodegenPTO and MakeTileWaitCodegenPTO to emit pto.comm.tnotify/pto.comm.twait with enum attributes; registers handlers.
Documentation
docs/en/dev/ir/05-operators.md, docs/zh-cn/dev/ir/05-operators.md
Adds "Cross-Rank Signal Operations" sections documenting signatures, INT32 signal-slot semantics, lowering targets (TNOTIFY/TWAIT), pipeline ordering note, and examples.
Unit & Codegen Tests
tests/ut/ir/operators/test_tile_ops.py, tests/ut/codegen/test_pto_codegen_ops.py
IR unit tests for call construction and kwarg printing; PTO codegen tests asserting pto.comm.tnotify/pto.comm.twait emission, enum attributes, and INT32-signal rejection tests.
Runtime Tests
tests/st/runtime/test_notify_wait.py
On-device loopback tests (single-rank) exercising notify (atomic_add/set) and wait (eq/ne/gt/ge/lt/le) patterns and asserting final INT32 signal slot contents.

Sequence Diagram(s)

sequenceDiagram
  participant Program as User Program
  participant Lang as language.system_ops
  participant IR as ir.tile_ops
  participant PTO as PTO Codegen
  participant GM as Device GM
  Program->>Lang: comm_notify(signal, value, op)
  Lang->>IR: normalized signal/value -> tile.comm_notify Call
  IR->>PTO: tile.comm_notify
  PTO->>GM: emit pto.comm.tnotify (partition_view, i32 value, notifyOp)
  Program->>Lang: comm_wait(signal, cmp_value, cmp)
  Lang->>IR: normalized cmp_value -> tile.comm_wait Call
  IR->>PTO: tile.comm_wait
  PTO->>GM: emit pto.comm.twait (partition_view, i32 cmp, cmp attr)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • hw-native-sys/pypto#267: Related PTO codegen/partition_view changes and partition_tensor_view naming; likely intersects backend lowering infrastructure.
  • hw-native-sys/pypto#1312: Also modifies PTO ops registration and codegen handlers; may overlap in RegisterPTOOps edits.

Suggested reviewers

  • Hzfengsy
  • lyfne123

Poem

🐰
I nudge the slot, I add, I set,
Across the ranks my whispers get,
A wait that watches, eyes so bright,
Until the signal says "alright".
— hops, notifies, sleeps with delight

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 48.08% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main feature being added: two new cross-rank signal operations (tile.notify and tile.wait) at the IR level.
Description check ✅ Passed The description comprehensively covers the functionality, layers updated, testing status, and implementation details, all directly related to the changeset.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@Little-oil
Copy link
Copy Markdown
Contributor Author

wait for PTOAS'new version

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces cross-rank signal operations, specifically tile.notify and tile.wait, to support synchronization between different ranks. The changes include documentation in both English and Chinese, Python IR and language-level API definitions, C++ backend codegen for PTO operations, and comprehensive unit and system tests. The feedback suggests refactoring the argument conversion logic in python/pypto/language/op/system_ops.py into a shared helper function to improve maintainability and ensure consistent validation of IntLike arguments across both operations.

Comment thread python/pypto/language/op/system_ops.py Outdated
Comment on lines +147 to +208
def notify(
signal: Tensor,
value: int | Scalar | Expr,
*,
op: str,
span: Span | None = None,
) -> Call:
"""Send a flag notification to a remote rank's signal slot.

Lowers to ``pto::comm::TNOTIFY`` via PTOAS ``pto.comm.tnotify``. The
signal is a 1-element INT32 ``pl.Tensor`` (GM) that views the destination
rank's signal location in its HCCL window — typically obtained via
:func:`import_peer_buffer`.

Args:
signal: Destination signal tensor (1-element INT32) in remote rank's window.
value: INT32 scalar value to write or atomic-add (Python int, Scalar, or Expr).
op: Notify operation, ``"atomic_add"`` or ``"set"``.
span: Optional source span.

Returns:
The IR ``Call`` for ``tile.notify`` (used for its side effect; no return value).
"""
if isinstance(value, Scalar):
value_expr: Expr = value.unwrap()
elif isinstance(value, Expr):
value_expr = value
else:
value_expr = ConstInt(int(value), DataType.INT32, Span.unknown())
return _ir_tile_ops.notify(signal.unwrap(), value_expr, op=op, span=span)


def wait(
signal: Tensor,
cmp_value: int | Scalar | Expr,
*,
cmp: str,
span: Span | None = None,
) -> Call:
"""Block until a local INT32 signal slot satisfies a comparison.

Lowers to ``pto::comm::TWAIT`` via PTOAS ``pto.comm.twait``. The signal
is a 1-element INT32 ``pl.Tensor`` (GM) in the local rank's window — the
slot peers ``pl.tile.notify`` into.

Args:
signal: Local signal tensor (1-element INT32) to poll.
cmp_value: INT32 scalar comparison value (Python int, Scalar, or Expr).
cmp: Comparison predicate, one of ``"eq"`` | ``"ne"`` | ``"gt"`` |
``"ge"`` | ``"lt"`` | ``"le"``.
span: Optional source span.

Returns:
The IR ``Call`` for ``tile.wait`` (used for its side effect; no return value).
"""
if isinstance(cmp_value, Scalar):
cmp_expr: Expr = cmp_value.unwrap()
elif isinstance(cmp_value, Expr):
cmp_expr = cmp_value
else:
cmp_expr = ConstInt(int(cmp_value), DataType.INT32, Span.unknown())
return _ir_tile_ops.wait(signal.unwrap(), cmp_expr, cmp=cmp, span=span)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When passing IntLike arguments (int, Scalar, or Expr) to C++ bindings, ensure that Scalar objects are explicitly unwrapped to Expr objects to match the expected C++ types. Additionally, validate user-provided arguments at the parser level to provide early and clear error messages. While logic duplication is sometimes preferred for specific error messages (Rule 1), extracting this conversion logic into a helper ensures that Rule 3 is consistently applied across both 'notify' and 'wait' functions.

def _value_to_int32_expr(value: int | Scalar | Expr, arg_name: str) -> Expr:
    if isinstance(value, Scalar):
        return value.unwrap()
    if isinstance(value, Expr):
        return value
    try:
        return ConstInt(int(value), DataType.INT32, Span.unknown())
    except (ValueError, TypeError):
        raise TypeError(f"Argument '{arg_name}' must be an int, Scalar, or Expr, got {type(value)}")


def notify(
    signal: Tensor,
    value: int | Scalar | Expr,
    *,
    op: str,
    span: Span | None = None,
) -> Call:
    """Send a flag notification to a remote rank's signal slot.

    Lowers to ``pto::comm::TNOTIFY`` via PTOAS ``pto.comm.tnotify``. The
    signal is a 1-element INT32 ``pl.Tensor`` (GM) that views the destination
    rank's signal location in its HCCL window — typically obtained via
    :func:`import_peer_buffer`.

    Args:
        signal: Destination signal tensor (1-element INT32) in remote rank's window.
        value: INT32 scalar value to write or atomic-add (Python int, Scalar, or Expr).
        op: Notify operation, ``"atomic_add"`` or ``"set"``.
        span: Optional source span.

    Returns:
        The IR ``Call`` for ``tile.notify`` (used for its side effect; no return value).
    """
    value_expr = _value_to_int32_expr(value, "value")
    return _ir_tile_ops.notify(signal.unwrap(), value_expr, op=op, span=span)


def wait(
    signal: Tensor,
    cmp_value: int | Scalar | Expr,
    *,
    cmp: str,
    span: Span | None = None,
) -> Call:
    """Block until a local INT32 signal slot satisfies a comparison.

    Lowers to ``pto::comm::TWAIT`` via PTOAS ``pto.comm.twait``. The signal
    is a 1-element INT32 ``pl.Tensor`` (GM) in the local rank's window — the
    slot peers ``pl.tile.notify`` into.

    Args:
        signal: Local signal tensor (1-element INT32) to poll.
        cmp_value: INT32 scalar comparison value (Python int, Scalar, or Expr).
        cmp: Comparison predicate, one of ``"eq"`` | ``"ne"`` | ``"gt"`` |
            ``"ge"`` | ``"lt"`` | ``"le"``.
        span: Optional source span.

    Returns:
        The IR ``Call`` for ``tile.wait`` (used for its side effect; no return value).
    """
    cmp_expr = _value_to_int32_expr(cmp_value, "cmp_value")
    return _ir_tile_ops.wait(signal.unwrap(), cmp_expr, cmp=cmp, span=span)
References
  1. When passing a sequence of IntLike (int | Scalar | Expr) to C++ bindings, ensure any Scalar objects are unwrapped to Expr objects to match the expected C++ types.
  2. Validate user-provided arguments for DSL functions at the parser level to provide early and clear error messages, rather than relying solely on backend C++ validation.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
tests/ut/codegen/test_pto_codegen_ops.py (1)

1608-1673: 💤 Low value

TestTileNotifyPtoCodegen — LGTM with one optional improvement

The three tests cover the key paths (set, atomic_add, bad dtype). One gap worth noting: there is no rejection test for an unsupported op string (e.g. op="invalid"). If input validation is enforced at the IR construction layer rather than codegen, add the equivalent test to tests/ut/ir/operators/test_tile_ops.py; if it's enforced in codegen, a small pytest.raises case here would complete the contract coverage.

Optionally, test_tile_notify_set_codegen and test_tile_notify_atomic_add_codegen can be collapsed into a single @pytest.mark.parametrize("op,attr", [("set", "set"), ("atomic_add", "atomic_add")]) test to reduce duplication.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/ut/codegen/test_pto_codegen_ops.py` around lines 1608 - 1673, Add a
rejection test for unsupported op strings so tile.notify validates op values:
add a new test (e.g. in TestTileNotifyPtoCodegen or in
tests/ut/ir/operators/test_tile_ops.py depending on where validation lives) that
constructs a program using pl.tile.notify(signal, 1, op="invalid") and asserts
it raises (pytest.raises) with an appropriate message; reference the existing
helper _generate_mlir and the test names
test_tile_notify_set_codegen/test_tile_notify_atomic_add_codegen to locate
similar test patterns and mirror their structure (or convert the two positive
tests into a single parametric `@pytest.mark.parametrize` if you prefer to reduce
duplication).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/ir/op/tile_ops/cross_core.cpp`:
- Around line 90-114: Add IR-level operand validation to the REGISTER_OP
declarations for "tile.notify" and "tile.wait": implement .f_validate handlers
that check the "signal" operand is an INT32 tensor with exactly one element
(shape == 1) and that the secondary operand ("value" for tile.notify,
"cmp_value" for tile.wait) is an INT32 scalar; emit a clear validation error
when these conditions fail so invalid uses fail during IR construction rather
than backend lowering. Ensure the validators reference the op names
("tile.notify", "tile.wait") and the operand names ("signal", "value",
"cmp_value") so reviewers can locate the checks.

In `@tests/st/runtime/test_notify_wait.py`:
- Around line 268-297: The test suite unconditionally exercises PTOAS-only APIs
pto.comm.tnotify / pto.comm.twait causing infra-driven failures when PTOAS is
not present; modify the TestNotifyWait tests to be skipped when the capability
is absent by checking the PTOAS capability at import/runtime (e.g., a helper
like has_ptoas_capability() or checking pto.comm for tnotify/twait) and applying
pytest.skip or pytest.mark.skipif to the whole TestNotifyWait class or
individual test methods (referencing TestNotifyWait, test_notify_* methods, and
pto.comm.tnotify/twait) so the suite only runs when those APIs are available.

---

Nitpick comments:
In `@tests/ut/codegen/test_pto_codegen_ops.py`:
- Around line 1608-1673: Add a rejection test for unsupported op strings so
tile.notify validates op values: add a new test (e.g. in
TestTileNotifyPtoCodegen or in tests/ut/ir/operators/test_tile_ops.py depending
on where validation lives) that constructs a program using
pl.tile.notify(signal, 1, op="invalid") and asserts it raises (pytest.raises)
with an appropriate message; reference the existing helper _generate_mlir and
the test names test_tile_notify_set_codegen/test_tile_notify_atomic_add_codegen
to locate similar test patterns and mirror their structure (or convert the two
positive tests into a single parametric `@pytest.mark.parametrize` if you prefer
to reduce duplication).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 2a67f8d0-4681-494b-8c02-d82da2702789

📥 Commits

Reviewing files that changed from the base of the PR and between 774775d and de39e6c.

📒 Files selected for processing (10)
  • docs/en/dev/ir/05-operators.md
  • docs/zh-cn/dev/ir/05-operators.md
  • python/pypto/ir/op/tile_ops.py
  • python/pypto/language/op/system_ops.py
  • python/pypto/language/op/tile_ops.py
  • src/backend/common/pto_ops_common.cpp
  • src/ir/op/tile_ops/cross_core.cpp
  • tests/st/runtime/test_notify_wait.py
  • tests/ut/codegen/test_pto_codegen_ops.py
  • tests/ut/ir/operators/test_tile_ops.py

Comment thread src/ir/op/tile_ops/cross_core.cpp Outdated
Comment on lines +90 to +114
REGISTER_OP("tile.notify")
.set_description(
"Send a flag notification to a remote rank: write or atomic-add an INT32 value "
"into the destination signal slot")
.set_op_category("CrossCoreOp")
.set_core_affinity(core_affinity::CoreAffinity::VECTOR)
.add_argument("signal", "Destination signal tensor (1-element INT32, GM, remote-rank window)")
.add_argument("value", "INT32 scalar value to write or atomic-add")
.set_attr<std::string>("op") // "atomic_add" | "set"
.no_memory_spec()
.f_deduce_type(DeduceUnknownType);

// Block until the local rank's signal slot satisfies `signal <cmp> cmp_value`.
// The signal operand is a 1-element INT32 Tensor in local-rank GM (the slot
// peers atomic-add or set into via tile.notify); codegen lowers this to
// `pto::comm::TWAIT` via PTOAS `pto.comm.twait`.
REGISTER_OP("tile.wait")
.set_description("Block until a local INT32 signal slot satisfies the given comparison against cmp_value")
.set_op_category("CrossCoreOp")
.set_core_affinity(core_affinity::CoreAffinity::VECTOR)
.add_argument("signal", "Local signal tensor (1-element INT32, GM) to poll")
.add_argument("cmp_value", "INT32 scalar comparison value")
.set_attr<std::string>("cmp") // "eq" | "ne" | "gt" | "ge" | "lt" | "le"
.no_memory_spec()
.f_deduce_type(DeduceUnknownType);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Add IR-level operand validation for tile.notify / tile.wait.

Right now, Line 90–114 only declares signatures; it doesn’t enforce the documented contract (1-element INT32 signal, INT32 scalar arg). Invalid calls can pass registration and fail later during backend lowering instead of failing early at IR construction.

Suggested fix
+TypePtr DeduceTileNotifyType(const std::vector<ExprPtr>& args,
+                             const std::vector<std::pair<std::string, std::any>>& kwargs) {
+  CHECK(args.size() == 2) << "tile.notify requires 2 arguments";
+  auto sig_ty = As<TensorType>(args[0]->GetType());
+  CHECK(sig_ty) << "tile.notify signal must be TensorType";
+  CHECK(sig_ty->dtype_ == DataType::INT32) << "tile.notify signal must be INT32";
+  CHECK(sig_ty->shape_.size() == 1);
+  auto n = As<ConstInt>(sig_ty->shape_[0]);
+  CHECK(n && n->value_ == 1) << "tile.notify signal must be shape [1]";
+  auto v_ty = As<ScalarType>(args[1]->GetType());
+  CHECK(v_ty && v_ty->dtype_ == DataType::INT32) << "tile.notify value must be INT32 scalar";
+  return GetUnknownType();
+}
+
+TypePtr DeduceTileWaitType(const std::vector<ExprPtr>& args,
+                           const std::vector<std::pair<std::string, std::any>>& kwargs) {
+  CHECK(args.size() == 2) << "tile.wait requires 2 arguments";
+  auto sig_ty = As<TensorType>(args[0]->GetType());
+  CHECK(sig_ty) << "tile.wait signal must be TensorType";
+  CHECK(sig_ty->dtype_ == DataType::INT32) << "tile.wait signal must be INT32";
+  CHECK(sig_ty->shape_.size() == 1);
+  auto n = As<ConstInt>(sig_ty->shape_[0]);
+  CHECK(n && n->value_ == 1) << "tile.wait signal must be shape [1]";
+  auto v_ty = As<ScalarType>(args[1]->GetType());
+  CHECK(v_ty && v_ty->dtype_ == DataType::INT32) << "tile.wait cmp_value must be INT32 scalar";
+  return GetUnknownType();
+}
...
 REGISTER_OP("tile.notify")
 ...
-    .f_deduce_type(DeduceUnknownType);
+    .f_deduce_type(DeduceTileNotifyType);
...
 REGISTER_OP("tile.wait")
 ...
-    .f_deduce_type(DeduceUnknownType);
+    .f_deduce_type(DeduceTileWaitType);
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/ir/op/tile_ops/cross_core.cpp` around lines 90 - 114, Add IR-level
operand validation to the REGISTER_OP declarations for "tile.notify" and
"tile.wait": implement .f_validate handlers that check the "signal" operand is
an INT32 tensor with exactly one element (shape == 1) and that the secondary
operand ("value" for tile.notify, "cmp_value" for tile.wait) is an INT32 scalar;
emit a clear validation error when these conditions fail so invalid uses fail
during IR construction rather than backend lowering. Ensure the validators
reference the op names ("tile.notify", "tile.wait") and the operand names
("signal", "value", "cmp_value") so reviewers can locate the checks.

Comment thread tests/st/runtime/test_notify_wait.py Outdated
Comment on lines +268 to +297
class TestNotifyWait:
"""On-board verification of pl.tile.notify + pl.tile.wait (single-rank loopback)."""

def test_notify_set_wait_eq(self, test_runner):
result = test_runner.run(NotifySetWaitEqTestCase())
assert result.passed, f"Test failed: {result.error}"

def test_notify_add_wait_ge(self, test_runner):
result = test_runner.run(NotifyAddWaitGeTestCase())
assert result.passed, f"Test failed: {result.error}"

def test_notify_set_wait_gt(self, test_runner):
result = test_runner.run(NotifySetWaitGtTestCase())
assert result.passed, f"Test failed: {result.error}"

def test_notify_set_wait_lt(self, test_runner):
result = test_runner.run(NotifySetWaitLtTestCase())
assert result.passed, f"Test failed: {result.error}"

def test_notify_set_wait_le(self, test_runner):
result = test_runner.run(NotifySetWaitLeTestCase())
assert result.passed, f"Test failed: {result.error}"

def test_notify_set_wait_ne(self, test_runner):
result = test_runner.run(NotifySetWaitNeTestCase())
assert result.passed, f"Test failed: {result.error}"


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Gate this ST suite on PTOAS capability availability.

Given the staged dependency on a PTOAS build exposing pto.comm.tnotify / pto.comm.twait, these tests should be conditionally skipped when that capability is absent; otherwise you’ll get infrastructure-driven red failures unrelated to this PR’s code.

Suggested fix
+import os
 import pytest
 ...
+pytestmark = pytest.mark.skipif(
+    os.getenv("PTOAS_HAS_COMM_NOTIFY_WAIT") != "1",
+    reason="Requires PTOAS build with pto.comm.tnotify / pto.comm.twait support",
+)
+
 class TestNotifyWait:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
class TestNotifyWait:
"""On-board verification of pl.tile.notify + pl.tile.wait (single-rank loopback)."""
def test_notify_set_wait_eq(self, test_runner):
result = test_runner.run(NotifySetWaitEqTestCase())
assert result.passed, f"Test failed: {result.error}"
def test_notify_add_wait_ge(self, test_runner):
result = test_runner.run(NotifyAddWaitGeTestCase())
assert result.passed, f"Test failed: {result.error}"
def test_notify_set_wait_gt(self, test_runner):
result = test_runner.run(NotifySetWaitGtTestCase())
assert result.passed, f"Test failed: {result.error}"
def test_notify_set_wait_lt(self, test_runner):
result = test_runner.run(NotifySetWaitLtTestCase())
assert result.passed, f"Test failed: {result.error}"
def test_notify_set_wait_le(self, test_runner):
result = test_runner.run(NotifySetWaitLeTestCase())
assert result.passed, f"Test failed: {result.error}"
def test_notify_set_wait_ne(self, test_runner):
result = test_runner.run(NotifySetWaitNeTestCase())
assert result.passed, f"Test failed: {result.error}"
if __name__ == "__main__":
pytest.main([__file__, "-v"])
import os
import pytest
pytestmark = pytest.mark.skipif(
os.getenv("PTOAS_HAS_COMM_NOTIFY_WAIT") != "1",
reason="Requires PTOAS build with pto.comm.tnotify / pto.comm.twait support",
)
class TestNotifyWait:
"""On-board verification of pl.tile.notify + pl.tile.wait (single-rank loopback)."""
def test_notify_set_wait_eq(self, test_runner):
result = test_runner.run(NotifySetWaitEqTestCase())
assert result.passed, f"Test failed: {result.error}"
def test_notify_add_wait_ge(self, test_runner):
result = test_runner.run(NotifyAddWaitGeTestCase())
assert result.passed, f"Test failed: {result.error}"
def test_notify_set_wait_gt(self, test_runner):
result = test_runner.run(NotifySetWaitGtTestCase())
assert result.passed, f"Test failed: {result.error}"
def test_notify_set_wait_lt(self, test_runner):
result = test_runner.run(NotifySetWaitLtTestCase())
assert result.passed, f"Test failed: {result.error}"
def test_notify_set_wait_le(self, test_runner):
result = test_runner.run(NotifySetWaitLeTestCase())
assert result.passed, f"Test failed: {result.error}"
def test_notify_set_wait_ne(self, test_runner):
result = test_runner.run(NotifySetWaitNeTestCase())
assert result.passed, f"Test failed: {result.error}"
if __name__ == "__main__":
pytest.main([__file__, "-v"])
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/st/runtime/test_notify_wait.py` around lines 268 - 297, The test suite
unconditionally exercises PTOAS-only APIs pto.comm.tnotify / pto.comm.twait
causing infra-driven failures when PTOAS is not present; modify the
TestNotifyWait tests to be skipped when the capability is absent by checking the
PTOAS capability at import/runtime (e.g., a helper like has_ptoas_capability()
or checking pto.comm for tnotify/twait) and applying pytest.skip or
pytest.mark.skipif to the whole TestNotifyWait class or individual test methods
(referencing TestNotifyWait, test_notify_* methods, and pto.comm.tnotify/twait)
so the suite only runs when those APIs are available.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/backend/common/pto_ops_common.cpp`:
- Around line 1924-1937: CheckCommSignalType currently only rejects rank-0
tensors but must enforce the single-slot contract: verify the tensor contains
exactly one element or reject statically-known non-singleton shapes before
lowering. In CheckCommSignalType (and using span/op_name for diagnostics) keep
the rank>=1 check, then inspect signal_tensor_type->shape_: if all extents are
statically-known, compute the product and REQUIRE it equals 1 (emit a clear
CHECK/INTERNAL_CHECK_SPAN failure referencing op_name and the shape); if any
extent is dynamic/unknown, allow it (since it could be singleton at runtime) but
still reject any statically-known extent >1 early. Return the same
signal_tensor_type on success.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 9b35e237-2d6f-4661-af36-5784836ec7cb

📥 Commits

Reviewing files that changed from the base of the PR and between de39e6c and 96f2d6d.

📒 Files selected for processing (10)
  • docs/en/dev/ir/05-operators.md
  • docs/zh-cn/dev/ir/05-operators.md
  • python/pypto/ir/op/tile_ops.py
  • python/pypto/language/op/system_ops.py
  • python/pypto/language/op/tile_ops.py
  • src/backend/common/pto_ops_common.cpp
  • src/ir/op/tile_ops/cross_core.cpp
  • tests/st/runtime/test_notify_wait.py
  • tests/ut/codegen/test_pto_codegen_ops.py
  • tests/ut/ir/operators/test_tile_ops.py
✅ Files skipped from review due to trivial changes (2)
  • docs/en/dev/ir/05-operators.md
  • docs/zh-cn/dev/ir/05-operators.md

Comment on lines +1924 to +1937
// Validate signal tensor type for comm.tnotify / comm.twait. PTOAS spec
// (PTO_IR_manual.md §pto.comm.tnotify) requires GM-shaped INT32, rank >= 1.
static ir::TensorTypePtr CheckCommSignalType(const ir::ExprPtr& signal_arg, codegen::PTOCodegen& codegen,
const ir::Span& span, const char* op_name) {
auto signal_tensor_type = As<ir::TensorType>(signal_arg->GetType());
INTERNAL_CHECK_SPAN(signal_tensor_type, span)
<< op_name
<< " signal must be a TensorType (GM signal slot) — PTOAS requires !pto.partition_tensor_view<Nxi32>";
CHECK(!signal_tensor_type->shape_.empty())
<< op_name << " signal must be a ranked tensor (rank >= 1), got rank-0 (scalar)";
CHECK(signal_tensor_type->dtype_ == DataType::INT32)
<< op_name << " signal must be INT32, got element type "
<< codegen.GetTypeString(signal_tensor_type->dtype_);
return signal_tensor_type;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Enforce the single-slot signal contract.

Line 1932 only rejects rank-0 tensors, so shapes like [2] or [1, 8] still pass even though this API is documented as operating on a single INT32 signal slot. Please validate that the signal tensor contains exactly one element, or at least reject statically-known non-singleton shapes here, before lowering to pto.comm.tnotify / pto.comm.twait. Otherwise malformed inputs survive until PTO emission and fail late or target the wrong region.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/backend/common/pto_ops_common.cpp` around lines 1924 - 1937,
CheckCommSignalType currently only rejects rank-0 tensors but must enforce the
single-slot contract: verify the tensor contains exactly one element or reject
statically-known non-singleton shapes before lowering. In CheckCommSignalType
(and using span/op_name for diagnostics) keep the rank>=1 check, then inspect
signal_tensor_type->shape_: if all extents are statically-known, compute the
product and REQUIRE it equals 1 (emit a clear CHECK/INTERNAL_CHECK_SPAN failure
referencing op_name and the shape); if any extent is dynamic/unknown, allow it
(since it could be singleton at runtime) but still reject any statically-known
extent >1 early. Return the same signal_tensor_type on success.

Youhezhen added 2 commits May 12, 2026 16:03
Introduces a pair of cross-rank signaling operations on AIV:

- tile.notify(signal, value, op): write or atomic-add an INT32 value to
  a remote rank's signal slot (1-element INT32 GM tensor). Lowers to
  pto.comm.tnotify with notifyOp = #pto.notify_op<set|atomic_add>.
- tile.wait(signal, cmp_value, cmp): block until a local INT32 signal
  slot satisfies a comparison. Lowers to pto.comm.twait with
  cmp = #pto.wait_cmp<eq|ne|gt|ge|lt|le>.

All five layers updated:
- C++ op registrations in src/ir/op/tile_ops/cross_core.cpp
- PTO codegen in src/backend/common/pto_ops_common.cpp
- Python IR wrappers in python/pypto/ir/op/tile_ops.py
- DSL wrappers in python/pypto/language/op/system_ops.py with
  re-export through python/pypto/language/op/tile_ops.py
- Tests: UT for IR + PTO codegen, ST loopback covering all six cmp
  variants and both notify ops
- Docs: Cross-Rank Signal Operations sections in
  docs/en/dev/ir/05-operators.md and docs/zh-cn/dev/ir/05-operators.md

Note: pto.comm.tnotify / pto.comm.twait require a PTOAS build that
exposes those custom ops; the on-board ST will only run on a PTOAS
that has the comm dialect enabled.
…ps to tile.comm_{notify,wait}

The previous codegen for tile.notify/tile.wait was broken — PTOAS rejected
the emitted MLIR. Two bugs:

1. Wrong operand type. Codegen emitted the signal as !pto.ptr<i32> (or a
   raw tensor_view), but pto.comm.tnotify / pto.comm.twait require
   !pto.partition_tensor_view<Nxi32>. Fix: lower the signal Var through
   make_tensor_view → partition_view to build a partition view covering
   the full signal shape.

2. Wrong assembly syntax. Codegen used the custom format
   "pto.comm.tnotify %sig, %v {...} : <type>, i32", but PTOAS's TNotifyOp /
   TWaitOp have no custom assemblyFormat — only generic MLIR op syntax is
   accepted. Fix: emit "pto.comm.tnotify"(%sig, %v) {...} : (<type>, i32) -> ().

Also rename the ops from tile.notify/tile.wait to tile.comm_notify/tile.comm_wait
for namespace consistency with the pto.comm.* MLIR ops and to keep cross-rank
signaling ops grouped under a comm_* prefix.

ST tests reshaped to mirror the two real usage patterns from simpler's
ep_dispatch_combine kernels (count exchange via atomic_add, done barrier
via atomic_add + wait ge), instead of exhaustively covering every cmp op.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant