feat(distributed): DistributedTensorType + CommGroup IRNode + AOT comm-manifest sidecar#1297
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds IR nodes WindowBuffer and CommGroup, extends TensorType with a window_name field, exposes DSL declarations and parser lifting/validation, implements manifest emission (orchestration/comm_manifest.json) during compile, and enables runtime manifest-driven bootstrap/config construction. ChangesDistributed Communication Groups with Window Buffers
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces support for communication groups (CommGroup) and window buffers (WindowBuffer) across the IR, DSL, and runtime. It enables TensorType to carry a window_name for window-bound views, facilitating cross-rank operations. Key updates include C++ IR extensions, Python binding enhancements, AST parser logic for CommGroup declarations, and an AOT manifest system for runtime configuration. Feedback highlights a security risk in distributed_runner.py where predictable temporary filenames could lead to race conditions, suggesting the use of tempfile.mkstemp() for better robustness.
There was a problem hiding this comment.
Actionable comments posted: 6
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@python/pypto/language/distributed/comm_group.py`:
- Around line 58-65: The current validation allows non-string truthy names and
permits bool for nranks; update CommGroup input checks so self.name must be an
instance of str and not empty, and self.nranks must be either an int (but not a
bool) or a DynVar. Concretely, replace the truthy name check with an explicit
isinstance(self.name, str) and non-empty check, and change the nranks type check
to reject bool by ensuring nranks is an int and not a bool or a DynVar; keep the
existing positive-value validation for integer nranks (self.nranks > 0). This
touches the CommGroup initializer validations referencing self.name,
self.nranks, and DynVar.
In `@python/pypto/language/parser/decorator.py`:
- Around line 134-141: The loop currently re-evaluates class assignment RHS via
_eval_class_body_assign when collecting CommGroup instances, which can duplicate
side effects; instead, after assignments have been applied to the class
namespace (class_local_ns), obtain the bound value from the created class
object's dict (use c.__dict__[tgt_name] or the actual class object variable used
in this scope) for each target name and test isinstance(..., CommGroup) on that
retrieved value, removing the call to _eval_class_body_assign for the CommGroup
collection step.
- Around line 95-99: The cached DynVar IR vars are created with DataType.INT32
causing dtype mismatch with ints normalized as DataType.INDEX; change the
creation of cached in the dyn_var_cache path (the ir.Var instantiation named
cached) to use ir.ScalarType(DataType.INDEX) instead of INT32 so it matches the
behavior of _normalize_expr(value, ..., int_dtype=DataType.INDEX) and prevents
mixed-type sources for fields like nranks/size.
In `@python/pypto/pypto_core/ir.pyi`:
- Around line 492-510: The TensorType __init__ overloads in the .pyi currently
only include window_name forms with shape as Sequence[Expr] or Sequence[int];
add overloads that match the bound C++ constructors which accept memref and
tensor_view variants and also accept window_name as a keyword-only argument.
Specifically, add overload signatures for __init__(self, memref: "MemRefType",
*, window_name: str) -> None and __init__(self, tensor_view: "TensorViewType",
*, window_name: str) -> None (or equivalent names used in the bindings) so the
pyi matches the bound constructors (the C++-bound symbols referencing
memref/tensor_view and window_name), ensuring type-checkers see the same ctor
shapes as the runtime bindings.
In `@python/pypto/runtime/distributed_runner.py`:
- Around line 327-339: The entry wrapper orch_fn should validate that the entry
signature's inject_contexts agrees with the worker's bootstrap manifest before
accessing w.chip_contexts: add a pre-flight guard (in the scope where
inject_contexts is computed/used, e.g., around orch_fn) that checks for
mismatches using hasattr/getattr (e.g., hasattr(w, "chip_bootstrap_configs") or
getattr(w, "chip_bootstrap_configs", False)) and if inject_contexts !=
bool(getattr(w, "chip_bootstrap_configs", False)) raise a clear ValueError
indicating the mismatch between the entry function requiring "contexts" and the
missing/incorrect chip bootstrap comm config in the manifest; ensure you only
access w.chip_contexts after the check to avoid AttributeError.
In `@src/ir/transforms/python_printer.cpp`:
- Around line 401-403: The printer emits tensor_type->window_name_ directly into
a quoted Python literal which breaks if the string contains " or \; update the
Python printer (in python_printer.cpp, inside the block handling
tensor_type->window_name_) to escape backslashes and double quotes (and other
Python-meaningful characters like newlines) before writing the value to oss —
either call an existing helper (e.g., a string-escape or quote helper) or add a
small escape routine (e.g., replace '\' with '\\' and '"' with '\"', handle
'\n', '\t' etc.) and then emit the escaped result wrapped in quotes so the
produced Python string is syntactically correct and round-trippable.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: d3ce4901-760f-41a1-859c-2933d040908f
📒 Files selected for processing (33)
docs/en/dev/ir/02-types.mddocs/zh-cn/dev/ir/02-types.mdinclude/pypto/ir/core.hinclude/pypto/ir/kind_traits.hinclude/pypto/ir/program.hinclude/pypto/ir/serialization/type_registry.hinclude/pypto/ir/type.hpython/bindings/modules/ir.cpppython/pypto/ir/comm_manifest.pypython/pypto/ir/compile.pypython/pypto/ir/type.pypython/pypto/language/distributed/__init__.pypython/pypto/language/distributed/comm_group.pypython/pypto/language/distributed/window_buffer.pypython/pypto/language/parser/ast_parser.pypython/pypto/language/parser/decorator.pypython/pypto/language/parser/type_resolver.pypython/pypto/language/typing/tensor.pypython/pypto/pypto_core/ir.pyipython/pypto/runtime/distributed_runner.pysrc/ir/program.cppsrc/ir/serialization/deserializer.cppsrc/ir/serialization/serializer.cppsrc/ir/serialization/type_deserializers.cppsrc/ir/transforms/python_printer.cppsrc/ir/transforms/structural_equal.cppsrc/ir/transforms/structural_hash.cpptests/ut/ir/core/test_tensor_type_window_name.pytests/ut/ir/parser/test_comm_group_program.pytests/ut/ir/parser/test_window_name_annotation.pytests/ut/language/test_comm_group_dsl.pytests/ut/language/test_window_buffer.pytests/ut/runtime/test_chip_bootstrap_configs.py
- CI: skip ChipBootstrapConfig build tests when simpler is unavailable (importorskip, matching test_worker_reuse.py convention). - distributed_runner: use tempfile.mkstemp for rootinfo_path so the file is unique and atomically created (no PID-collision race). - distributed_runner: pre-flight guard when entry_fn expects ``contexts=`` but the comm manifest is missing — surface a clear error instead of letting w.chip_contexts AttributeError leak. - CommGroup/WindowBuffer: reject bool values for nranks/size (bool is an int subclass) and require name to be str, not just truthy. - decorator: read CommGroup attributes from the already-constructed class via vars(c) instead of re-eval'ing class-body AST. Avoids duplicating side effects and accepts CommGroup under any attribute name. - decorator: lift DynVar to ir.Var with INDEX dtype so it matches the ConstInt path (parser already uses INDEX for literal sizes). - ir.pyi: add TensorType overloads with memref/tensor_view + window_name to mirror the bound C++ ctors so type checkers don't drift from runtime.
f5451a8 to
fa0fbc5
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
src/ir/transforms/python_printer.cpp (1)
401-403:⚠️ Potential issue | 🟡 Minor | ⚡ Quick win
window_name_still not escaped when emitting the Python string literal.The fix from the previous review round was not applied.
window_name_is still written raw into a double-quoted literal; a name containing"or\will produce invalid or semantically different output on reparse.std::quoted(already included via<iomanip>) is the correct fix, consistent with its use at lines 1034 and 1046.🐛 Proposed fix
- if (!tensor_type->window_name_.empty()) { - oss << ", \"" << tensor_type->window_name_ << "\""; - } + if (!tensor_type->window_name_.empty()) { + oss << ", " << std::quoted(tensor_type->window_name_); + }🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/ir/transforms/python_printer.cpp` around lines 401 - 403, The code emits tensor_type->window_name_ raw into a double-quoted Python literal which breaks on quotes/backslashes; update the emission site (where tensor_type->window_name_ is written) to use std::quoted to escape the string (consistent with other uses at lines where std::quoted is used), i.e. replace the raw insertion of tensor_type->window_name_ with insertion using std::quoted(tensor_type->window_name_) so the generated Python string is properly escaped.
🧹 Nitpick comments (1)
tests/ut/ir/parser/test_comm_group_program.py (1)
126-127: ⚡ Quick winUse specific exception types in negative tests
pytest.raises(Exception, ...)is too broad and can let unrelated failures pass as expected behavior. Please assert the concrete parser/runtime exception type for each case.Also applies to: 174-175, 187-188
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/ut/ir/parser/test_comm_group_program.py` around lines 126 - 127, Replace the broad pytest.raises(Exception, ...) in test_program_window_name_validation_rejects_unknown_name with the concrete parser/runtime exception class raised by the parser (for example ParserError or ValidationError): import that exception at the top of the test file and use pytest.raises(TheConcreteException, match="not declared in any CommGroup") so the test only passes for the intended failure; make the same replacement for the two other negative tests in this file that currently use pytest.raises(Exception, ...).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@python/pypto/language/distributed/comm_group.py`:
- Around line 57-59: The __post_init__ in CommGroup currently raises TypeError
for both a non-str and an empty str; change the validation to first check
isinstance(self.name, str) and raise TypeError if not, then separately check if
self.name is empty and raise ValueError for an empty string so callers catching
ValueError can handle invalid values correctly.
---
Duplicate comments:
In `@src/ir/transforms/python_printer.cpp`:
- Around line 401-403: The code emits tensor_type->window_name_ raw into a
double-quoted Python literal which breaks on quotes/backslashes; update the
emission site (where tensor_type->window_name_ is written) to use std::quoted to
escape the string (consistent with other uses at lines where std::quoted is
used), i.e. replace the raw insertion of tensor_type->window_name_ with
insertion using std::quoted(tensor_type->window_name_) so the generated Python
string is properly escaped.
---
Nitpick comments:
In `@tests/ut/ir/parser/test_comm_group_program.py`:
- Around line 126-127: Replace the broad pytest.raises(Exception, ...) in
test_program_window_name_validation_rejects_unknown_name with the concrete
parser/runtime exception class raised by the parser (for example ParserError or
ValidationError): import that exception at the top of the test file and use
pytest.raises(TheConcreteException, match="not declared in any CommGroup") so
the test only passes for the intended failure; make the same replacement for the
two other negative tests in this file that currently use
pytest.raises(Exception, ...).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 244cbbfa-4b74-48b2-aaf0-a53fba38d226
📒 Files selected for processing (33)
docs/en/dev/ir/02-types.mddocs/zh-cn/dev/ir/02-types.mdinclude/pypto/ir/core.hinclude/pypto/ir/kind_traits.hinclude/pypto/ir/program.hinclude/pypto/ir/serialization/type_registry.hinclude/pypto/ir/type.hpython/bindings/modules/ir.cpppython/pypto/ir/comm_manifest.pypython/pypto/ir/compile.pypython/pypto/ir/type.pypython/pypto/language/distributed/__init__.pypython/pypto/language/distributed/comm_group.pypython/pypto/language/distributed/window_buffer.pypython/pypto/language/parser/ast_parser.pypython/pypto/language/parser/decorator.pypython/pypto/language/parser/type_resolver.pypython/pypto/language/typing/tensor.pypython/pypto/pypto_core/ir.pyipython/pypto/runtime/distributed_runner.pysrc/ir/program.cppsrc/ir/serialization/deserializer.cppsrc/ir/serialization/serializer.cppsrc/ir/serialization/type_deserializers.cppsrc/ir/transforms/python_printer.cppsrc/ir/transforms/structural_equal.cppsrc/ir/transforms/structural_hash.cpptests/ut/ir/core/test_tensor_type_window_name.pytests/ut/ir/parser/test_comm_group_program.pytests/ut/ir/parser/test_window_name_annotation.pytests/ut/language/test_comm_group_dsl.pytests/ut/language/test_window_buffer.pytests/ut/runtime/test_chip_bootstrap_configs.py
✅ Files skipped from review due to trivial changes (4)
- include/pypto/ir/serialization/type_registry.h
- docs/zh-cn/dev/ir/02-types.md
- python/bindings/modules/ir.cpp
- python/pypto/pypto_core/ir.pyi
🚧 Files skipped from review as they are similar to previous changes (20)
- python/pypto/ir/compile.py
- src/ir/transforms/structural_hash.cpp
- src/ir/serialization/serializer.cpp
- include/pypto/ir/kind_traits.h
- python/pypto/language/distributed/init.py
- docs/en/dev/ir/02-types.md
- src/ir/serialization/deserializer.cpp
- src/ir/transforms/structural_equal.cpp
- tests/ut/language/test_window_buffer.py
- python/pypto/language/parser/type_resolver.py
- python/pypto/language/parser/decorator.py
- src/ir/program.cpp
- python/pypto/ir/type.py
- src/ir/serialization/type_deserializers.cpp
- tests/ut/ir/core/test_tensor_type_window_name.py
- python/pypto/language/distributed/window_buffer.py
- python/pypto/runtime/distributed_runner.py
- python/pypto/language/typing/tensor.py
- tests/ut/language/test_comm_group_dsl.py
- include/pypto/ir/program.h
Empty-string name is a value problem, not a type problem — raise ValueError as before (and as the existing tests expected). Non-str name still raises TypeError. Per CodeRabbit feedback on hw-native-sys#1297.
be26c5f to
e48bf10
Compare
…CommGroup metadata (N1)
Per docs/_build/distributed/l3_distributed_implementation_plan.md milestone N1.
Self-contained baseline for the new L3 distributed DSL: cleans up the previous
window_name + class-attribute CommGroup approach and lays the IR + manifest
foundation for the alloc-op-driven design (N2+).
Surface changes
---------------
- Drop ``TensorType.window_name`` field (C++/bindings/stub/serializer/printer/
structural-eq) and the matching DSL third-slot string syntax.
- Drop user-declared ``pld.CommGroup`` / ``pld.WindowBuffer`` Python dataclasses
and the ``_lift_comm_group`` / ``_collect_comm_groups`` parser path. CommGroup
metadata is now pass-inferred (``CollectCommGroups``, N4).
- Add ``ir.DistributedTensorType``: a precise-ObjectKind subclass of
``TensorType`` for cross-rank op verifiers (N6) to dispatch on. Inherits
TensorType fields; serializer/structural-eq/python-printer share the
TensorType comparison code via static_cast.
- Add ``pld.DistributedTensor[[shape], dtype]`` DSL annotation; parser routes
the annotation to ``ir.DistributedTensorType``.
IR schema reshape
-----------------
- ``ir.WindowBuffer``: keeps ``size: ExprPtr`` (allocation = address space, not
shape) and ``load_from_host`` / ``store_to_host`` as bool flags. The actual
host tensor binding lives on the alloc op (N2), not on this allocation spec.
- ``ir.CommGroup``: replaced ``name`` / ``nranks`` with
``devices: vector<int64_t>`` (empty list = all devices) and ``buffers`` →
``slots``.
Manifest v2
-----------
- ``COMM_MANIFEST_VERSION = 2``. Schema:
``{"devices": [...], "slots": [{name, dtype, size, bits_per_element,
load_from_host, store_to_host}, ...]}``.
- ``_build_chip_bootstrap_configs_from_manifest`` consumes the new schema:
empty ``devices`` ⇒ all entries of ``DistributedConfig.device_ids`` get a
comm config; explicit list ⇒ those device-ids covered, the rest stay
comm-less ``ChipBootstrapConfig()``.
Tests
-----
- New: ``tests/ut/ir/core/test_distributed_tensor_type.py``,
``tests/ut/ir/core/test_comm_group_schema.py``,
``tests/ut/ir/parser/test_distributed_tensor_annotation.py``.
- Rewritten: ``tests/ut/runtime/test_chip_bootstrap_configs.py`` for v2 schema.
- Removed: legacy window_name / pld.CommGroup / pld.WindowBuffer test files.
- 4432 unit tests pass; comm-less ``test_l3_distributed.py`` /
``test_l3_parallel_reduce.py`` system tests unchanged.
e48bf10 to
e5f6490
Compare
Summary
L3 distributed DSL milestone N1 — IR + DSL + runtime infrastructure that later milestones (alloc op / dispatch
device=/ cross-rank ops /CollectCommGroupspass) build on. Single squashed commit; review-friendly diff againstupstream/main.Three layers (one squashed commit, the layers map cleanly to file groups):
ir.DistributedTensorTypeIR subclass +pld.DistributedTensorDSL. Precise-ObjectKindsubclass ofTensorTypeso cross-rank op verifiers (later milestone) can reject plainTensorarguments viaAs<DistributedTensorType>. DSL surface mirrorspl.Tensor(same 2/3/4-slot subscript: shape, dtype, [layout|memref|view], [memref]); only the IR ObjectKind differs. Reflection / serializer / structural-eq / printer share TensorType comparison viastatic_pointer_cast; printer head differs (pld.DistributedTensorvspl.Tensor).ir.WindowBuffer/ir.CommGroupIRNodes +Program.comm_groups_field.WindowBuffer{name, size: ExprPtr, dtype, load_from_host: bool, store_to_host: bool}— allocation-spec only, 1:1 tosimpler.task_interface.ChipBufferSpec.CommGroup{devices: vector<int64_t>, slots: vector<WindowBufferPtr>}— emptydevices⇒ "all devices" (resolved by driver againstDistributedConfig.device_ids); non-empty ⇒ explicit subset.Programgains a(functions, comm_groups, name, span)ctor overload, registered asUsualField(participates in structural eq / hash). Reflection visitors get a newVisitLeafField(const std::vector<int64_t>&)overload to handle the device list.AOT comm-manifest sidecar.
python/pypto/ir/comm_manifest.pyliftsprogram.comm_groupsintooutput_dir/orchestration/comm_manifest.json(v2 schema:{"devices": [...], "slots": [...]}). The runner reads the file at submit time and constructsChipBootstrapConfigdirectly — noProgramobject required. Comm-less programs skip the manifest entirely somulti_chip_dispatch/parallel_reduceare unaffected.Scope (NOT in this PR)
pld.alloc_window_bufferop + parser (next milestone)pld.world_size()host-only op + dispatchdevice=rkwargCollectCommGroupspass (the producer ofprogram.comm_groupsfrom alloc ops)pld.tile.remote_load/pld.system.notify/pld.system.waitops + verifierpld.CommCtx/dist_t.comm.rankdesugarAllReduceProgramsystem testTest Plan
tests/ut/ir/core/test_distributed_tensor_type.py,tests/ut/ir/core/test_comm_group_schema.py,tests/ut/ir/parser/test_distributed_tensor_annotation.py,tests/ut/runtime/test_chip_bootstrap_configs.py—covers
As<>precise-match semantics, structural-eq, parser dispatch(
pl.Tensorvspld.DistributedTensorannotations including 3-slot layout/memref/view),manifest lift/build/AOT-roundtrip + sub-byte dtype byte calculation.
test_l3_distributed.py/test_l3_parallel_reduce.py(comm-less paths)unchanged.