feat(ir): Add tile.notify and tile.wait cross-rank signal ops#1301
feat(ir): Add tile.notify and tile.wait cross-rank signal ops#1301Little-oil wants to merge 2 commits into
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds tile.comm_notify and tile.comm_wait: IR ops, IR-level Python bindings, language DSL wrappers, PTO backend lowering to pto.comm.tnotify/twait, English/Chinese docs, and unit + codegen + runtime tests validating INT32 signal semantics and attributes. ChangesTile Signal Operations (tile.comm_notify / tile.comm_wait)
Sequence Diagram(s)sequenceDiagram
participant Program as User Program
participant Lang as language.system_ops
participant IR as ir.tile_ops
participant PTO as PTO Codegen
participant GM as Device GM
Program->>Lang: comm_notify(signal, value, op)
Lang->>IR: normalized signal/value -> tile.comm_notify Call
IR->>PTO: tile.comm_notify
PTO->>GM: emit pto.comm.tnotify (partition_view, i32 value, notifyOp)
Program->>Lang: comm_wait(signal, cmp_value, cmp)
Lang->>IR: normalized cmp_value -> tile.comm_wait Call
IR->>PTO: tile.comm_wait
PTO->>GM: emit pto.comm.twait (partition_view, i32 cmp, cmp attr)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
wait for PTOAS'new version |
There was a problem hiding this comment.
Code Review
This pull request introduces cross-rank signal operations, specifically tile.notify and tile.wait, to support synchronization between different ranks. The changes include documentation in both English and Chinese, Python IR and language-level API definitions, C++ backend codegen for PTO operations, and comprehensive unit and system tests. The feedback suggests refactoring the argument conversion logic in python/pypto/language/op/system_ops.py into a shared helper function to improve maintainability and ensure consistent validation of IntLike arguments across both operations.
| def notify( | ||
| signal: Tensor, | ||
| value: int | Scalar | Expr, | ||
| *, | ||
| op: str, | ||
| span: Span | None = None, | ||
| ) -> Call: | ||
| """Send a flag notification to a remote rank's signal slot. | ||
|
|
||
| Lowers to ``pto::comm::TNOTIFY`` via PTOAS ``pto.comm.tnotify``. The | ||
| signal is a 1-element INT32 ``pl.Tensor`` (GM) that views the destination | ||
| rank's signal location in its HCCL window — typically obtained via | ||
| :func:`import_peer_buffer`. | ||
|
|
||
| Args: | ||
| signal: Destination signal tensor (1-element INT32) in remote rank's window. | ||
| value: INT32 scalar value to write or atomic-add (Python int, Scalar, or Expr). | ||
| op: Notify operation, ``"atomic_add"`` or ``"set"``. | ||
| span: Optional source span. | ||
|
|
||
| Returns: | ||
| The IR ``Call`` for ``tile.notify`` (used for its side effect; no return value). | ||
| """ | ||
| if isinstance(value, Scalar): | ||
| value_expr: Expr = value.unwrap() | ||
| elif isinstance(value, Expr): | ||
| value_expr = value | ||
| else: | ||
| value_expr = ConstInt(int(value), DataType.INT32, Span.unknown()) | ||
| return _ir_tile_ops.notify(signal.unwrap(), value_expr, op=op, span=span) | ||
|
|
||
|
|
||
| def wait( | ||
| signal: Tensor, | ||
| cmp_value: int | Scalar | Expr, | ||
| *, | ||
| cmp: str, | ||
| span: Span | None = None, | ||
| ) -> Call: | ||
| """Block until a local INT32 signal slot satisfies a comparison. | ||
|
|
||
| Lowers to ``pto::comm::TWAIT`` via PTOAS ``pto.comm.twait``. The signal | ||
| is a 1-element INT32 ``pl.Tensor`` (GM) in the local rank's window — the | ||
| slot peers ``pl.tile.notify`` into. | ||
|
|
||
| Args: | ||
| signal: Local signal tensor (1-element INT32) to poll. | ||
| cmp_value: INT32 scalar comparison value (Python int, Scalar, or Expr). | ||
| cmp: Comparison predicate, one of ``"eq"`` | ``"ne"`` | ``"gt"`` | | ||
| ``"ge"`` | ``"lt"`` | ``"le"``. | ||
| span: Optional source span. | ||
|
|
||
| Returns: | ||
| The IR ``Call`` for ``tile.wait`` (used for its side effect; no return value). | ||
| """ | ||
| if isinstance(cmp_value, Scalar): | ||
| cmp_expr: Expr = cmp_value.unwrap() | ||
| elif isinstance(cmp_value, Expr): | ||
| cmp_expr = cmp_value | ||
| else: | ||
| cmp_expr = ConstInt(int(cmp_value), DataType.INT32, Span.unknown()) | ||
| return _ir_tile_ops.wait(signal.unwrap(), cmp_expr, cmp=cmp, span=span) |
There was a problem hiding this comment.
When passing IntLike arguments (int, Scalar, or Expr) to C++ bindings, ensure that Scalar objects are explicitly unwrapped to Expr objects to match the expected C++ types. Additionally, validate user-provided arguments at the parser level to provide early and clear error messages. While logic duplication is sometimes preferred for specific error messages (Rule 1), extracting this conversion logic into a helper ensures that Rule 3 is consistently applied across both 'notify' and 'wait' functions.
def _value_to_int32_expr(value: int | Scalar | Expr, arg_name: str) -> Expr:
if isinstance(value, Scalar):
return value.unwrap()
if isinstance(value, Expr):
return value
try:
return ConstInt(int(value), DataType.INT32, Span.unknown())
except (ValueError, TypeError):
raise TypeError(f"Argument '{arg_name}' must be an int, Scalar, or Expr, got {type(value)}")
def notify(
signal: Tensor,
value: int | Scalar | Expr,
*,
op: str,
span: Span | None = None,
) -> Call:
"""Send a flag notification to a remote rank's signal slot.
Lowers to ``pto::comm::TNOTIFY`` via PTOAS ``pto.comm.tnotify``. The
signal is a 1-element INT32 ``pl.Tensor`` (GM) that views the destination
rank's signal location in its HCCL window — typically obtained via
:func:`import_peer_buffer`.
Args:
signal: Destination signal tensor (1-element INT32) in remote rank's window.
value: INT32 scalar value to write or atomic-add (Python int, Scalar, or Expr).
op: Notify operation, ``"atomic_add"`` or ``"set"``.
span: Optional source span.
Returns:
The IR ``Call`` for ``tile.notify`` (used for its side effect; no return value).
"""
value_expr = _value_to_int32_expr(value, "value")
return _ir_tile_ops.notify(signal.unwrap(), value_expr, op=op, span=span)
def wait(
signal: Tensor,
cmp_value: int | Scalar | Expr,
*,
cmp: str,
span: Span | None = None,
) -> Call:
"""Block until a local INT32 signal slot satisfies a comparison.
Lowers to ``pto::comm::TWAIT`` via PTOAS ``pto.comm.twait``. The signal
is a 1-element INT32 ``pl.Tensor`` (GM) in the local rank's window — the
slot peers ``pl.tile.notify`` into.
Args:
signal: Local signal tensor (1-element INT32) to poll.
cmp_value: INT32 scalar comparison value (Python int, Scalar, or Expr).
cmp: Comparison predicate, one of ``"eq"`` | ``"ne"`` | ``"gt"`` |
``"ge"`` | ``"lt"`` | ``"le"``.
span: Optional source span.
Returns:
The IR ``Call`` for ``tile.wait`` (used for its side effect; no return value).
"""
cmp_expr = _value_to_int32_expr(cmp_value, "cmp_value")
return _ir_tile_ops.wait(signal.unwrap(), cmp_expr, cmp=cmp, span=span)References
- When passing a sequence of IntLike (int | Scalar | Expr) to C++ bindings, ensure any Scalar objects are unwrapped to Expr objects to match the expected C++ types.
- Validate user-provided arguments for DSL functions at the parser level to provide early and clear error messages, rather than relying solely on backend C++ validation.
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tests/ut/codegen/test_pto_codegen_ops.py (1)
1608-1673: 💤 Low value
TestTileNotifyPtoCodegen— LGTM with one optional improvementThe three tests cover the key paths (set, atomic_add, bad dtype). One gap worth noting: there is no rejection test for an unsupported
opstring (e.g.op="invalid"). If input validation is enforced at the IR construction layer rather than codegen, add the equivalent test totests/ut/ir/operators/test_tile_ops.py; if it's enforced in codegen, a smallpytest.raisescase here would complete the contract coverage.Optionally,
test_tile_notify_set_codegenandtest_tile_notify_atomic_add_codegencan be collapsed into a single@pytest.mark.parametrize("op,attr", [("set", "set"), ("atomic_add", "atomic_add")])test to reduce duplication.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/ut/codegen/test_pto_codegen_ops.py` around lines 1608 - 1673, Add a rejection test for unsupported op strings so tile.notify validates op values: add a new test (e.g. in TestTileNotifyPtoCodegen or in tests/ut/ir/operators/test_tile_ops.py depending on where validation lives) that constructs a program using pl.tile.notify(signal, 1, op="invalid") and asserts it raises (pytest.raises) with an appropriate message; reference the existing helper _generate_mlir and the test names test_tile_notify_set_codegen/test_tile_notify_atomic_add_codegen to locate similar test patterns and mirror their structure (or convert the two positive tests into a single parametric `@pytest.mark.parametrize` if you prefer to reduce duplication).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/ir/op/tile_ops/cross_core.cpp`:
- Around line 90-114: Add IR-level operand validation to the REGISTER_OP
declarations for "tile.notify" and "tile.wait": implement .f_validate handlers
that check the "signal" operand is an INT32 tensor with exactly one element
(shape == 1) and that the secondary operand ("value" for tile.notify,
"cmp_value" for tile.wait) is an INT32 scalar; emit a clear validation error
when these conditions fail so invalid uses fail during IR construction rather
than backend lowering. Ensure the validators reference the op names
("tile.notify", "tile.wait") and the operand names ("signal", "value",
"cmp_value") so reviewers can locate the checks.
In `@tests/st/runtime/test_notify_wait.py`:
- Around line 268-297: The test suite unconditionally exercises PTOAS-only APIs
pto.comm.tnotify / pto.comm.twait causing infra-driven failures when PTOAS is
not present; modify the TestNotifyWait tests to be skipped when the capability
is absent by checking the PTOAS capability at import/runtime (e.g., a helper
like has_ptoas_capability() or checking pto.comm for tnotify/twait) and applying
pytest.skip or pytest.mark.skipif to the whole TestNotifyWait class or
individual test methods (referencing TestNotifyWait, test_notify_* methods, and
pto.comm.tnotify/twait) so the suite only runs when those APIs are available.
---
Nitpick comments:
In `@tests/ut/codegen/test_pto_codegen_ops.py`:
- Around line 1608-1673: Add a rejection test for unsupported op strings so
tile.notify validates op values: add a new test (e.g. in
TestTileNotifyPtoCodegen or in tests/ut/ir/operators/test_tile_ops.py depending
on where validation lives) that constructs a program using
pl.tile.notify(signal, 1, op="invalid") and asserts it raises (pytest.raises)
with an appropriate message; reference the existing helper _generate_mlir and
the test names test_tile_notify_set_codegen/test_tile_notify_atomic_add_codegen
to locate similar test patterns and mirror their structure (or convert the two
positive tests into a single parametric `@pytest.mark.parametrize` if you prefer
to reduce duplication).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 2a67f8d0-4681-494b-8c02-d82da2702789
📒 Files selected for processing (10)
docs/en/dev/ir/05-operators.mddocs/zh-cn/dev/ir/05-operators.mdpython/pypto/ir/op/tile_ops.pypython/pypto/language/op/system_ops.pypython/pypto/language/op/tile_ops.pysrc/backend/common/pto_ops_common.cppsrc/ir/op/tile_ops/cross_core.cpptests/st/runtime/test_notify_wait.pytests/ut/codegen/test_pto_codegen_ops.pytests/ut/ir/operators/test_tile_ops.py
| REGISTER_OP("tile.notify") | ||
| .set_description( | ||
| "Send a flag notification to a remote rank: write or atomic-add an INT32 value " | ||
| "into the destination signal slot") | ||
| .set_op_category("CrossCoreOp") | ||
| .set_core_affinity(core_affinity::CoreAffinity::VECTOR) | ||
| .add_argument("signal", "Destination signal tensor (1-element INT32, GM, remote-rank window)") | ||
| .add_argument("value", "INT32 scalar value to write or atomic-add") | ||
| .set_attr<std::string>("op") // "atomic_add" | "set" | ||
| .no_memory_spec() | ||
| .f_deduce_type(DeduceUnknownType); | ||
|
|
||
| // Block until the local rank's signal slot satisfies `signal <cmp> cmp_value`. | ||
| // The signal operand is a 1-element INT32 Tensor in local-rank GM (the slot | ||
| // peers atomic-add or set into via tile.notify); codegen lowers this to | ||
| // `pto::comm::TWAIT` via PTOAS `pto.comm.twait`. | ||
| REGISTER_OP("tile.wait") | ||
| .set_description("Block until a local INT32 signal slot satisfies the given comparison against cmp_value") | ||
| .set_op_category("CrossCoreOp") | ||
| .set_core_affinity(core_affinity::CoreAffinity::VECTOR) | ||
| .add_argument("signal", "Local signal tensor (1-element INT32, GM) to poll") | ||
| .add_argument("cmp_value", "INT32 scalar comparison value") | ||
| .set_attr<std::string>("cmp") // "eq" | "ne" | "gt" | "ge" | "lt" | "le" | ||
| .no_memory_spec() | ||
| .f_deduce_type(DeduceUnknownType); |
There was a problem hiding this comment.
Add IR-level operand validation for tile.notify / tile.wait.
Right now, Line 90–114 only declares signatures; it doesn’t enforce the documented contract (1-element INT32 signal, INT32 scalar arg). Invalid calls can pass registration and fail later during backend lowering instead of failing early at IR construction.
Suggested fix
+TypePtr DeduceTileNotifyType(const std::vector<ExprPtr>& args,
+ const std::vector<std::pair<std::string, std::any>>& kwargs) {
+ CHECK(args.size() == 2) << "tile.notify requires 2 arguments";
+ auto sig_ty = As<TensorType>(args[0]->GetType());
+ CHECK(sig_ty) << "tile.notify signal must be TensorType";
+ CHECK(sig_ty->dtype_ == DataType::INT32) << "tile.notify signal must be INT32";
+ CHECK(sig_ty->shape_.size() == 1);
+ auto n = As<ConstInt>(sig_ty->shape_[0]);
+ CHECK(n && n->value_ == 1) << "tile.notify signal must be shape [1]";
+ auto v_ty = As<ScalarType>(args[1]->GetType());
+ CHECK(v_ty && v_ty->dtype_ == DataType::INT32) << "tile.notify value must be INT32 scalar";
+ return GetUnknownType();
+}
+
+TypePtr DeduceTileWaitType(const std::vector<ExprPtr>& args,
+ const std::vector<std::pair<std::string, std::any>>& kwargs) {
+ CHECK(args.size() == 2) << "tile.wait requires 2 arguments";
+ auto sig_ty = As<TensorType>(args[0]->GetType());
+ CHECK(sig_ty) << "tile.wait signal must be TensorType";
+ CHECK(sig_ty->dtype_ == DataType::INT32) << "tile.wait signal must be INT32";
+ CHECK(sig_ty->shape_.size() == 1);
+ auto n = As<ConstInt>(sig_ty->shape_[0]);
+ CHECK(n && n->value_ == 1) << "tile.wait signal must be shape [1]";
+ auto v_ty = As<ScalarType>(args[1]->GetType());
+ CHECK(v_ty && v_ty->dtype_ == DataType::INT32) << "tile.wait cmp_value must be INT32 scalar";
+ return GetUnknownType();
+}
...
REGISTER_OP("tile.notify")
...
- .f_deduce_type(DeduceUnknownType);
+ .f_deduce_type(DeduceTileNotifyType);
...
REGISTER_OP("tile.wait")
...
- .f_deduce_type(DeduceUnknownType);
+ .f_deduce_type(DeduceTileWaitType);🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@src/ir/op/tile_ops/cross_core.cpp` around lines 90 - 114, Add IR-level
operand validation to the REGISTER_OP declarations for "tile.notify" and
"tile.wait": implement .f_validate handlers that check the "signal" operand is
an INT32 tensor with exactly one element (shape == 1) and that the secondary
operand ("value" for tile.notify, "cmp_value" for tile.wait) is an INT32 scalar;
emit a clear validation error when these conditions fail so invalid uses fail
during IR construction rather than backend lowering. Ensure the validators
reference the op names ("tile.notify", "tile.wait") and the operand names
("signal", "value", "cmp_value") so reviewers can locate the checks.
| class TestNotifyWait: | ||
| """On-board verification of pl.tile.notify + pl.tile.wait (single-rank loopback).""" | ||
|
|
||
| def test_notify_set_wait_eq(self, test_runner): | ||
| result = test_runner.run(NotifySetWaitEqTestCase()) | ||
| assert result.passed, f"Test failed: {result.error}" | ||
|
|
||
| def test_notify_add_wait_ge(self, test_runner): | ||
| result = test_runner.run(NotifyAddWaitGeTestCase()) | ||
| assert result.passed, f"Test failed: {result.error}" | ||
|
|
||
| def test_notify_set_wait_gt(self, test_runner): | ||
| result = test_runner.run(NotifySetWaitGtTestCase()) | ||
| assert result.passed, f"Test failed: {result.error}" | ||
|
|
||
| def test_notify_set_wait_lt(self, test_runner): | ||
| result = test_runner.run(NotifySetWaitLtTestCase()) | ||
| assert result.passed, f"Test failed: {result.error}" | ||
|
|
||
| def test_notify_set_wait_le(self, test_runner): | ||
| result = test_runner.run(NotifySetWaitLeTestCase()) | ||
| assert result.passed, f"Test failed: {result.error}" | ||
|
|
||
| def test_notify_set_wait_ne(self, test_runner): | ||
| result = test_runner.run(NotifySetWaitNeTestCase()) | ||
| assert result.passed, f"Test failed: {result.error}" | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__, "-v"]) |
There was a problem hiding this comment.
Gate this ST suite on PTOAS capability availability.
Given the staged dependency on a PTOAS build exposing pto.comm.tnotify / pto.comm.twait, these tests should be conditionally skipped when that capability is absent; otherwise you’ll get infrastructure-driven red failures unrelated to this PR’s code.
Suggested fix
+import os
import pytest
...
+pytestmark = pytest.mark.skipif(
+ os.getenv("PTOAS_HAS_COMM_NOTIFY_WAIT") != "1",
+ reason="Requires PTOAS build with pto.comm.tnotify / pto.comm.twait support",
+)
+
class TestNotifyWait:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| class TestNotifyWait: | |
| """On-board verification of pl.tile.notify + pl.tile.wait (single-rank loopback).""" | |
| def test_notify_set_wait_eq(self, test_runner): | |
| result = test_runner.run(NotifySetWaitEqTestCase()) | |
| assert result.passed, f"Test failed: {result.error}" | |
| def test_notify_add_wait_ge(self, test_runner): | |
| result = test_runner.run(NotifyAddWaitGeTestCase()) | |
| assert result.passed, f"Test failed: {result.error}" | |
| def test_notify_set_wait_gt(self, test_runner): | |
| result = test_runner.run(NotifySetWaitGtTestCase()) | |
| assert result.passed, f"Test failed: {result.error}" | |
| def test_notify_set_wait_lt(self, test_runner): | |
| result = test_runner.run(NotifySetWaitLtTestCase()) | |
| assert result.passed, f"Test failed: {result.error}" | |
| def test_notify_set_wait_le(self, test_runner): | |
| result = test_runner.run(NotifySetWaitLeTestCase()) | |
| assert result.passed, f"Test failed: {result.error}" | |
| def test_notify_set_wait_ne(self, test_runner): | |
| result = test_runner.run(NotifySetWaitNeTestCase()) | |
| assert result.passed, f"Test failed: {result.error}" | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v"]) | |
| import os | |
| import pytest | |
| pytestmark = pytest.mark.skipif( | |
| os.getenv("PTOAS_HAS_COMM_NOTIFY_WAIT") != "1", | |
| reason="Requires PTOAS build with pto.comm.tnotify / pto.comm.twait support", | |
| ) | |
| class TestNotifyWait: | |
| """On-board verification of pl.tile.notify + pl.tile.wait (single-rank loopback).""" | |
| def test_notify_set_wait_eq(self, test_runner): | |
| result = test_runner.run(NotifySetWaitEqTestCase()) | |
| assert result.passed, f"Test failed: {result.error}" | |
| def test_notify_add_wait_ge(self, test_runner): | |
| result = test_runner.run(NotifyAddWaitGeTestCase()) | |
| assert result.passed, f"Test failed: {result.error}" | |
| def test_notify_set_wait_gt(self, test_runner): | |
| result = test_runner.run(NotifySetWaitGtTestCase()) | |
| assert result.passed, f"Test failed: {result.error}" | |
| def test_notify_set_wait_lt(self, test_runner): | |
| result = test_runner.run(NotifySetWaitLtTestCase()) | |
| assert result.passed, f"Test failed: {result.error}" | |
| def test_notify_set_wait_le(self, test_runner): | |
| result = test_runner.run(NotifySetWaitLeTestCase()) | |
| assert result.passed, f"Test failed: {result.error}" | |
| def test_notify_set_wait_ne(self, test_runner): | |
| result = test_runner.run(NotifySetWaitNeTestCase()) | |
| assert result.passed, f"Test failed: {result.error}" | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v"]) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/st/runtime/test_notify_wait.py` around lines 268 - 297, The test suite
unconditionally exercises PTOAS-only APIs pto.comm.tnotify / pto.comm.twait
causing infra-driven failures when PTOAS is not present; modify the
TestNotifyWait tests to be skipped when the capability is absent by checking the
PTOAS capability at import/runtime (e.g., a helper like has_ptoas_capability()
or checking pto.comm for tnotify/twait) and applying pytest.skip or
pytest.mark.skipif to the whole TestNotifyWait class or individual test methods
(referencing TestNotifyWait, test_notify_* methods, and pto.comm.tnotify/twait)
so the suite only runs when those APIs are available.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/backend/common/pto_ops_common.cpp`:
- Around line 1924-1937: CheckCommSignalType currently only rejects rank-0
tensors but must enforce the single-slot contract: verify the tensor contains
exactly one element or reject statically-known non-singleton shapes before
lowering. In CheckCommSignalType (and using span/op_name for diagnostics) keep
the rank>=1 check, then inspect signal_tensor_type->shape_: if all extents are
statically-known, compute the product and REQUIRE it equals 1 (emit a clear
CHECK/INTERNAL_CHECK_SPAN failure referencing op_name and the shape); if any
extent is dynamic/unknown, allow it (since it could be singleton at runtime) but
still reject any statically-known extent >1 early. Return the same
signal_tensor_type on success.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 9b35e237-2d6f-4661-af36-5784836ec7cb
📒 Files selected for processing (10)
docs/en/dev/ir/05-operators.mddocs/zh-cn/dev/ir/05-operators.mdpython/pypto/ir/op/tile_ops.pypython/pypto/language/op/system_ops.pypython/pypto/language/op/tile_ops.pysrc/backend/common/pto_ops_common.cppsrc/ir/op/tile_ops/cross_core.cpptests/st/runtime/test_notify_wait.pytests/ut/codegen/test_pto_codegen_ops.pytests/ut/ir/operators/test_tile_ops.py
✅ Files skipped from review due to trivial changes (2)
- docs/en/dev/ir/05-operators.md
- docs/zh-cn/dev/ir/05-operators.md
| // Validate signal tensor type for comm.tnotify / comm.twait. PTOAS spec | ||
| // (PTO_IR_manual.md §pto.comm.tnotify) requires GM-shaped INT32, rank >= 1. | ||
| static ir::TensorTypePtr CheckCommSignalType(const ir::ExprPtr& signal_arg, codegen::PTOCodegen& codegen, | ||
| const ir::Span& span, const char* op_name) { | ||
| auto signal_tensor_type = As<ir::TensorType>(signal_arg->GetType()); | ||
| INTERNAL_CHECK_SPAN(signal_tensor_type, span) | ||
| << op_name | ||
| << " signal must be a TensorType (GM signal slot) — PTOAS requires !pto.partition_tensor_view<Nxi32>"; | ||
| CHECK(!signal_tensor_type->shape_.empty()) | ||
| << op_name << " signal must be a ranked tensor (rank >= 1), got rank-0 (scalar)"; | ||
| CHECK(signal_tensor_type->dtype_ == DataType::INT32) | ||
| << op_name << " signal must be INT32, got element type " | ||
| << codegen.GetTypeString(signal_tensor_type->dtype_); | ||
| return signal_tensor_type; |
There was a problem hiding this comment.
Enforce the single-slot signal contract.
Line 1932 only rejects rank-0 tensors, so shapes like [2] or [1, 8] still pass even though this API is documented as operating on a single INT32 signal slot. Please validate that the signal tensor contains exactly one element, or at least reject statically-known non-singleton shapes here, before lowering to pto.comm.tnotify / pto.comm.twait. Otherwise malformed inputs survive until PTO emission and fail late or target the wrong region.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@src/backend/common/pto_ops_common.cpp` around lines 1924 - 1937,
CheckCommSignalType currently only rejects rank-0 tensors but must enforce the
single-slot contract: verify the tensor contains exactly one element or reject
statically-known non-singleton shapes before lowering. In CheckCommSignalType
(and using span/op_name for diagnostics) keep the rank>=1 check, then inspect
signal_tensor_type->shape_: if all extents are statically-known, compute the
product and REQUIRE it equals 1 (emit a clear CHECK/INTERNAL_CHECK_SPAN failure
referencing op_name and the shape); if any extent is dynamic/unknown, allow it
(since it could be singleton at runtime) but still reject any statically-known
extent >1 early. Return the same signal_tensor_type on success.
Introduces a pair of cross-rank signaling operations on AIV: - tile.notify(signal, value, op): write or atomic-add an INT32 value to a remote rank's signal slot (1-element INT32 GM tensor). Lowers to pto.comm.tnotify with notifyOp = #pto.notify_op<set|atomic_add>. - tile.wait(signal, cmp_value, cmp): block until a local INT32 signal slot satisfies a comparison. Lowers to pto.comm.twait with cmp = #pto.wait_cmp<eq|ne|gt|ge|lt|le>. All five layers updated: - C++ op registrations in src/ir/op/tile_ops/cross_core.cpp - PTO codegen in src/backend/common/pto_ops_common.cpp - Python IR wrappers in python/pypto/ir/op/tile_ops.py - DSL wrappers in python/pypto/language/op/system_ops.py with re-export through python/pypto/language/op/tile_ops.py - Tests: UT for IR + PTO codegen, ST loopback covering all six cmp variants and both notify ops - Docs: Cross-Rank Signal Operations sections in docs/en/dev/ir/05-operators.md and docs/zh-cn/dev/ir/05-operators.md Note: pto.comm.tnotify / pto.comm.twait require a PTOAS build that exposes those custom ops; the on-board ST will only run on a PTOAS that has the comm dialect enabled.
…ps to tile.comm_{notify,wait}
The previous codegen for tile.notify/tile.wait was broken — PTOAS rejected
the emitted MLIR. Two bugs:
1. Wrong operand type. Codegen emitted the signal as !pto.ptr<i32> (or a
raw tensor_view), but pto.comm.tnotify / pto.comm.twait require
!pto.partition_tensor_view<Nxi32>. Fix: lower the signal Var through
make_tensor_view → partition_view to build a partition view covering
the full signal shape.
2. Wrong assembly syntax. Codegen used the custom format
"pto.comm.tnotify %sig, %v {...} : <type>, i32", but PTOAS's TNotifyOp /
TWaitOp have no custom assemblyFormat — only generic MLIR op syntax is
accepted. Fix: emit "pto.comm.tnotify"(%sig, %v) {...} : (<type>, i32) -> ().
Also rename the ops from tile.notify/tile.wait to tile.comm_notify/tile.comm_wait
for namespace consistency with the pto.comm.* MLIR ops and to keep cross-rank
signaling ops grouped under a comm_* prefix.
ST tests reshaped to mirror the two real usage patterns from simpler's
ep_dispatch_combine kernels (count exchange via atomic_add, done barrier
via atomic_add + wait ge), instead of exhaustively covering every cmp op.
Summary
Introduces a pair of cross-rank signaling operations on AIV:
tile.notify(signal, value, op)— write or atomic-add an INT32 value to a remote rank's signal slot (1-element INT32 GM tensor). Lowers topto.comm.tnotifywithnotifyOp = #pto.notify_op<set|atomic_add>.tile.wait(signal, cmp_value, cmp)— block until a local INT32 signal slot satisfies a comparison. Lowers topto.comm.twaitwithcmp = #pto.wait_cmp<eq|ne|gt|ge|lt|le>.Both ops take a 1-element INT32
pl.Tensorviewing a slot in the rank's HCCL window and an INT32 scalar (Pythonint,Scalar, orExpr).Layers updated
src/ir/op/tile_ops/cross_core.cpp)src/backend/common/pto_ops_common.cpp)python/pypto/ir/op/tile_ops.py)python/pypto/language/op/system_ops.py) re-exported viapython/pypto/language/op/tile_ops.pytests/ut/ir/operators/test_tile_ops.py,tests/ut/codegen/test_pto_codegen_ops.py)tests/st/runtime/test_notify_wait.py)docs/en/dev/ir/05-operators.md,docs/zh-cn/dev/ir/05-operators.md)Testing
TestTileNotifyOp,TestTileWaitOp,TestTileNotifyPtoCodegen,TestTileWaitPtoCodegen(22/22).tests/st/runtime/test_notify_wait.pyrequires a PTOAS build that exposes thepto.comm.tnotify/pto.comm.twaitcustom ops; the currently deployed PTOAS does not have them yet, so on-board ST is staged to enable later.Notes
tpush_*/tpop_*/tfree_*).