diff --git a/CMakeLists.txt b/CMakeLists.txt index 4b2d3dfff..509214f82 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -109,6 +109,7 @@ set(PYPTO_SOURCES src/ir/op/sync_ops/sync.cpp src/ir/op/sync_ops/cross_core.cpp src/ir/op/sync_ops/task.cpp + src/ir/op/distributed/memory.cpp src/ir/op/tensor_ops/broadcast.cpp src/ir/op/tensor_ops/elementwise.cpp src/ir/op/tensor_ops/matmul.cpp diff --git a/docs/en/dev/ir/02-types.md b/docs/en/dev/ir/02-types.md index d32276c36..0c4f4e2b1 100644 --- a/docs/en/dev/ir/02-types.md +++ b/docs/en/dev/ir/02-types.md @@ -69,9 +69,13 @@ assert isinstance(t, ir.TensorType) # C++ inheritance preserved # As(t) → null; As(t) → cast. ``` -Allocation-side metadata (the buffer name, host-staging flags) lives on the -allocation op (`pld.alloc_window_buffer`, added in a later milestone) and on -the `ir.WindowBuffer` slot in `program.comm_groups`, not on the type itself. +Allocation-side metadata (per-rank size, host-staging flags) lives on the +`ir.WindowBuffer` `Var` subclass that the `pld.alloc_window_buffer` op binds. +Slices materialised through `pld.window(buf, [shape], dtype=...)` carry an +optional back-reference (`DistributedTensorType.window_buffer`) to the source +`WindowBuffer`, so two same-shape / same-dtype slices of different +allocations stay structurally distinct. User-declared parameter annotations +like `pld.DistributedTensor[[shape], dtype]` leave this field as `None`. Tile types do not have a distributed variant; cross-rank ops always operate on `DistributedTensor`. diff --git a/docs/zh-cn/dev/ir/02-types.md b/docs/zh-cn/dev/ir/02-types.md index 7f557c863..262a2b7b0 100644 --- a/docs/zh-cn/dev/ir/02-types.md +++ b/docs/zh-cn/dev/ir/02-types.md @@ -68,10 +68,14 @@ assert isinstance(t, ir.TensorType) # C++ 继承关系保留 # As(t) → null;As(t) → 转型成功 ``` -分配侧的元数据(buffer 名字、host staging 标志)挂在 alloc op -(`pld.alloc_window_buffer`,后续 milestone 引入)和 `program.comm_groups` -中的 `ir.WindowBuffer` slot 上,**不**在类型本身。Tile 类型没有 distributed -变体;跨 rank op 始终作用在 `DistributedTensor` 上。 +分配侧的元数据(每 rank 大小、host staging 标志)挂在 `pld.alloc_window_buffer` +op 所绑定的 `ir.WindowBuffer`(`Var` 子类)上。通过 +`pld.window(buf, [shape], dtype=...)` 物化的切片在 +`DistributedTensorType.window_buffer` 上保留指向源 `WindowBuffer` 的可选反向 +引用,从而让两个 shape/dtype 相同但分配来源不同的切片在结构上保持不同。 +用户在签名中写的 `pld.DistributedTensor[[shape], dtype]` 不填该字段(为 +`None`)。Tile 类型没有 distributed 变体;跨 rank op 始终作用在 +`DistributedTensor` 上。 ### 带 TensorView 的 TensorType diff --git a/include/pypto/ir/core.h b/include/pypto/ir/core.h index 70acfdf8c..44c0797dd 100644 --- a/include/pypto/ir/core.h +++ b/include/pypto/ir/core.h @@ -113,6 +113,7 @@ enum class ObjectKind { TileType, ArrayType, TupleType, + WindowBufferType, // Other IR node kinds Function, diff --git a/include/pypto/ir/kind_traits.h b/include/pypto/ir/kind_traits.h index 078f1093c..9aae5fcdf 100644 --- a/include/pypto/ir/kind_traits.h +++ b/include/pypto/ir/kind_traits.h @@ -114,6 +114,7 @@ DEFINE_KIND_TRAIT(ArrayType, ObjectKind::ArrayType) DEFINE_KIND_TRAIT(TupleType, ObjectKind::TupleType) DEFINE_KIND_TRAIT(MemRefType, ObjectKind::MemRefType) DEFINE_KIND_TRAIT(PtrType, ObjectKind::PtrType) +DEFINE_KIND_TRAIT(WindowBufferType, ObjectKind::WindowBufferType) // Other IR node types DEFINE_KIND_TRAIT(Function, ObjectKind::Function) @@ -203,7 +204,8 @@ struct KindTrait { ObjectKind::DistributedTensorType, ObjectKind::TileType, ObjectKind::ArrayType, - ObjectKind::TupleType}; + ObjectKind::TupleType, + ObjectKind::WindowBufferType}; static constexpr size_t count = sizeof(kinds) / sizeof(ObjectKind); }; @@ -278,7 +280,10 @@ std::shared_ptr As(const std::shared_ptr& base) { * * As() uses exact ObjectKind matching and won't match IterArg. * This utility matches both Var and IterArg (which inherits from Var). - * MemRef is intentionally excluded — use As() for that. + * MemRef and WindowBuffer are intentionally excluded — they are Var + * subclasses that carry allocation-source / window-slot semantics rather + * than the plain bound-name model AsVarLike's callers assume. Use + * As() / As() when you specifically want them. */ inline VarPtr AsVarLike(const ExprPtr& expr) { if (!expr) return nullptr; diff --git a/include/pypto/ir/program.h b/include/pypto/ir/program.h index e2bafca68..9141be860 100644 --- a/include/pypto/ir/program.h +++ b/include/pypto/ir/program.h @@ -19,7 +19,6 @@ #include #include -#include "pypto/core/dtype.h" #include "pypto/ir/core.h" #include "pypto/ir/expr.h" #include "pypto/ir/function.h" @@ -30,39 +29,34 @@ namespace pypto { namespace ir { /** - * @brief Per-rank allocation spec for one named CommGroup HCCL window buffer. + * @brief Per-rank CommGroup HCCL window-buffer allocation, modelled as a Var. * - * Maps 1:1 to ``simpler.task_interface.ChipBufferSpec`` at submit-time. Pure - * allocation metadata: does NOT describe how the buffer is used in code. - * Code-level use is expressed in the function signature via - * ``pld.DistributedTensor[[shape], dtype]``; the alloc op - * (``pld.alloc_window_buffer``, added in N2) materialises one of these slots - * into the program's :class:`CommGroup`. + * A specialised :class:`Var` subclass whose SSA-edge type is the singleton + * :class:`WindowBufferType`. The buffer's runtime-unique identifier flows + * through the inherited ``Var::name_hint_``; there is no separate ``name_`` + * field so structural equality does not depend on the chosen variable name. * - * ``size_`` is the **element count** of one rank's slice (a single scalar; this - * struct is allocation-only and intentionally does not carry a multi-dim - * shape). ``size_`` may be a ``ConstInt`` (compile-time known) or a symbolic - * expression referring to the world size. - * - * ``load_from_host_`` / ``store_to_host_`` are simple boolean flags marking - * whether the slot participates in pre-fork H2D / post-task D2H staging. The - * specific host tensor that supplies / receives the staged data is recorded - * on the alloc op, not on this allocation spec. + * Fields: + * * ``base_`` — :class:`Var` holding the underlying ``Ptr`` allocation + * identity. Multiple ``WindowBuffer`` instances built from the same alloc + * Var share allocation identity through this field. + * * ``size_`` — per-rank allocation size in **bytes**; ``ConstInt`` or + * symbolic :class:`ExprPtr`. + * * ``load_from_host_`` / ``store_to_host_`` — pre-fork H2D / post-task + * D2H staging flags. */ -class WindowBuffer : public IRNode { +class WindowBuffer : public Var { public: - std::string name_; ///< Buffer name (parser-extracted from alloc-op LHS) - ExprPtr size_; ///< Per-rank element count (ConstInt or symbolic Expr) - DataType dtype_; ///< Element data type - bool load_from_host_ = false; ///< Pre-fork H2D copy from a host staging tensor - bool store_to_host_ = false; ///< Post-task D2H copy back into a host staging tensor - - WindowBuffer(std::string name, ExprPtr size, DataType dtype, bool load_from_host = false, - bool store_to_host = false, Span span = Span::unknown()) - : IRNode(std::move(span)), - name_(std::move(name)), + VarPtr base_; ///< Ptr Var from the alloc op (allocation identity) + ExprPtr size_; ///< Per-rank allocation size in bytes + bool load_from_host_ = false; ///< Pre-fork H2D staging flag + bool store_to_host_ = false; ///< Post-task D2H staging flag + + WindowBuffer(VarPtr base, ExprPtr size, bool load_from_host = false, bool store_to_host = false, + Span span = Span::unknown()) + : Var(base->name_hint_, GetWindowBufferType(), std::move(span)), + base_(std::move(base)), size_(std::move(size)), - dtype_(dtype), load_from_host_(load_from_host), store_to_host_(store_to_host) {} @@ -71,15 +65,17 @@ class WindowBuffer : public IRNode { static constexpr auto GetFieldDescriptors() { return std::tuple_cat( - IRNode::GetFieldDescriptors(), - std::make_tuple(reflection::UsualField(&WindowBuffer::name_, "name"), + Var::GetFieldDescriptors(), + std::make_tuple(reflection::UsualField(&WindowBuffer::base_, "base"), reflection::UsualField(&WindowBuffer::size_, "size"), - reflection::UsualField(&WindowBuffer::dtype_, "dtype"), reflection::UsualField(&WindowBuffer::load_from_host_, "load_from_host"), reflection::UsualField(&WindowBuffer::store_to_host_, "store_to_host"))); } }; +// WindowBufferPtr is forward-declared in include/pypto/ir/type.h so that +// DistributedTensorType::window_buffer_ can hold it without a circular +// include. using WindowBufferPtr = std::shared_ptr; /** diff --git a/include/pypto/ir/type.h b/include/pypto/ir/type.h index 248bff2a5..ed80fdb58 100644 --- a/include/pypto/ir/type.h +++ b/include/pypto/ir/type.h @@ -36,6 +36,9 @@ using ExprPtr = std::shared_ptr; class MemRef; using MemRefPtr = std::shared_ptr; +class WindowBuffer; +using WindowBufferPtr = std::shared_ptr; + /** * @brief Base class for type representations in the IR * @@ -500,32 +503,60 @@ using TensorTypePtr = std::shared_ptr; */ class DistributedTensorType : public TensorType { public: - DistributedTensorType(std::vector shape, DataType dtype) : TensorType(std::move(shape), dtype) {} + /// Optional back-reference to the :class:`WindowBuffer` whose allocation this + /// tensor is a view of. Populated by ``pld.window``'s type deducer; + /// ``std::nullopt`` for user-declared parameter annotations like + /// ``pld.DistributedTensor[[shape], dtype]``. Two DistributedTensorTypes with + /// the same shape / dtype but different ``window_buffer_`` values are + /// structurally distinct, so passes can tell apart slices of different + /// CommGroup window buffers. + std::optional window_buffer_; + + DistributedTensorType(std::vector shape, DataType dtype) + : TensorType(std::move(shape), dtype), window_buffer_(std::nullopt) {} DistributedTensorType(std::vector shape, DataType dtype, MemRefPtr memref) - : TensorType(std::move(shape), dtype, std::move(memref)) {} + : TensorType(std::move(shape), dtype, std::move(memref)), window_buffer_(std::nullopt) {} DistributedTensorType(std::vector shape, DataType dtype, std::optional memref) - : TensorType(std::move(shape), dtype, std::move(memref)) {} + : TensorType(std::move(shape), dtype, std::move(memref)), window_buffer_(std::nullopt) {} DistributedTensorType(std::vector shape, DataType dtype, std::optional memref, std::optional tensor_view) - : TensorType(std::move(shape), dtype, std::move(memref), std::move(tensor_view)) {} + : TensorType(std::move(shape), dtype, std::move(memref), std::move(tensor_view)), + window_buffer_(std::nullopt) {} DistributedTensorType(const std::vector& shape, DataType dtype) - : TensorType(shape, dtype, std::nullopt) {} + : TensorType(shape, dtype, std::nullopt), window_buffer_(std::nullopt) {} DistributedTensorType(const std::vector& shape, DataType dtype, std::optional memref) - : TensorType(shape, dtype, std::move(memref)) {} + : TensorType(shape, dtype, std::move(memref)), window_buffer_(std::nullopt) {} DistributedTensorType(const std::vector& shape, DataType dtype, std::optional memref, std::optional tensor_view) - : TensorType(shape, dtype, std::move(memref), std::move(tensor_view)) {} + : TensorType(shape, dtype, std::move(memref), std::move(tensor_view)), window_buffer_(std::nullopt) {} + + /// Construct a DistributedTensorType produced by ``pld.window``: the result + /// is paired with the originating :class:`WindowBuffer` so passes can recover + /// the comm-group / slot identity later. + DistributedTensorType(std::vector shape, DataType dtype, WindowBufferPtr window_buffer) + : TensorType(std::move(shape), dtype), window_buffer_(std::move(window_buffer)) {} + + /// Full-fields constructor used by deserialization to faithfully restore + /// every optional field (memref, tensor_view, window_buffer) in one shot. + DistributedTensorType(std::vector shape, DataType dtype, std::optional memref, + std::optional tensor_view, std::optional window_buffer) + : TensorType(std::move(shape), dtype, std::move(memref), std::move(tensor_view)), + window_buffer_(std::move(window_buffer)) {} [[nodiscard]] ObjectKind GetKind() const override { return ObjectKind::DistributedTensorType; } [[nodiscard]] std::string TypeName() const override { return "DistributedTensorType"; } - static constexpr auto GetFieldDescriptors() { return TensorType::GetFieldDescriptors(); } + static constexpr auto GetFieldDescriptors() { + return std::tuple_cat( + TensorType::GetFieldDescriptors(), + std::make_tuple(reflection::UsualField(&DistributedTensorType::window_buffer_, "window_buffer"))); + } }; using DistributedTensorTypePtr = std::shared_ptr; @@ -711,6 +742,32 @@ inline PtrTypePtr GetPtrType() { return ptr_type; } +/** + * @brief Singleton marker type for ``pld.alloc_window_buffer`` results. + * + * Carries no per-instance fields; all allocation metadata (size, host-staging + * flags, etc.) lives on the :class:`WindowBuffer` Var subclass that the alloc + * op binds. Cross-rank op verifiers dispatch on this marker + * (``As``) to reject non-window arguments. + */ +class WindowBufferType : public Type { + public: + WindowBufferType() = default; + + [[nodiscard]] ObjectKind GetKind() const override { return ObjectKind::WindowBufferType; } + [[nodiscard]] std::string TypeName() const override { return "WindowBufferType"; } + + static constexpr auto GetFieldDescriptors() { return Type::GetFieldDescriptors(); } +}; + +using WindowBufferTypePtr = std::shared_ptr; + +/// Get the shared singleton WindowBufferType instance. +inline WindowBufferTypePtr GetWindowBufferType() { + static const auto window_buffer_type = std::make_shared(); + return window_buffer_type; +} + } // namespace ir } // namespace pypto diff --git a/python/bindings/modules/ir.cpp b/python/bindings/modules/ir.cpp index 80bbd58c0..1dbd623f6 100644 --- a/python/bindings/modules/ir.cpp +++ b/python/bindings/modules/ir.cpp @@ -454,6 +454,11 @@ void BindIR(nb::module_& m) { nb::init&, DataType, std::optional, std::optional>(), nb::arg("shape"), nb::arg("dtype"), nb::arg("memref") = nb::none(), nb::arg("tensor_view") = nb::none(), "Create a distributed tensor type with constant shape, optional memref and tensor_view"); + dist_tensor_type_class.def( + nb::init, DataType, WindowBufferPtr>(), nb::arg("shape"), nb::arg("dtype"), + nb::arg("window_buffer"), + "Create a distributed tensor type produced by pld.window; window_buffer is the back-" + "reference to the source WindowBuffer allocation."); BindFields(dist_tensor_type_class); // TileType - const shared_ptr @@ -507,6 +512,18 @@ void BindIR(nb::module_& m) { ptr_type_class.def_static("get", &GetPtrType, "Get the singleton PtrType instance"); BindFields(ptr_type_class); + // WindowBufferType - singleton marker type for pld.alloc_window_buffer outputs. + // Mirrors MemRefType: no per-instance fields. The WindowBuffer Var subclass + // (bindings below) carries the allocation metadata. + auto window_buffer_type_class = nb::class_( + ir, "WindowBufferType", + "Singleton marker type for pld.alloc_window_buffer outputs. The companion WindowBuffer " + "Var subclass carries the allocation metadata (name, size, dtype, host flags)."); + window_buffer_type_class.def(nb::init<>(), "Create the singleton WindowBufferType instance."); + window_buffer_type_class.def_static("get", &GetWindowBufferType, + "Get the shared singleton WindowBufferType instance."); + BindFields(window_buffer_type_class); + // MemorySpace enum nb::enum_(ir, "MemorySpace", "Memory space enumeration") .value("DDR", MemorySpace::DDR, "DDR memory (off-chip)") @@ -1487,18 +1504,23 @@ void BindIR(nb::module_& m) { "List of CommGroups declared on the program. CommGroups participate " "in structural equality / hashing through reflection."); - // CommGroup / WindowBuffer — IRNode-typed host-side metadata attached to a Program. - auto window_buffer_class = nb::class_( + // WindowBuffer — a specialised Var subclass that carries CommGroup window-buffer + // allocation metadata. Mirrors MemRef's Var-subclass shape; the inherited + // ``name_hint`` is mirrored from ``name_`` (UsualField, unique-id role). + auto window_buffer_class = nb::class_( ir, "WindowBuffer", - "Per-rank allocation spec for a named CommGroup HCCL window buffer. " - "Maps 1:1 to simpler.task_interface.ChipBufferSpec at submit-time. " - "size_ is the per-rank element count; load_from_host_ / store_to_host_ " - "are bool flags marking host-staging participation (the actual host " - "tensor binding is recorded on the alloc op, not here)."); - window_buffer_class.def(nb::init(), nb::arg("name"), - nb::arg("size"), nb::arg("dtype"), nb::arg("load_from_host") = false, - nb::arg("store_to_host") = false, nb::arg("span") = Span::unknown(), - "Create a WindowBuffer."); + "Per-rank CommGroup window-buffer allocation, modelled as a specialised Var. " + "Its SSA-edge type is the singleton WindowBufferType; the allocation metadata " + "(name, size, dtype, host-staging flags) lives on the Var subclass directly — " + "the exact mirror of how MemRef carries (base, byte_offset, size) under MemRefType. " + "Constructed by the comm-collection pass; the alloc op's LHS at parse time is a " + "plain Var(PtrType)."); + window_buffer_class.def(nb::init(), nb::arg("base"), nb::arg("size"), + nb::arg("load_from_host") = false, nb::arg("store_to_host") = false, + nb::arg("span") = Span::unknown(), + "Create a WindowBuffer wrapping the given Ptr Var. The buffer's " + "runtime-unique identifier flows through the inherited " + "Var.name_hint (taken from base.name_hint)."); BindFields(window_buffer_class); auto comm_group_class = diff --git a/python/pypto/ir/comm_manifest.py b/python/pypto/ir/comm_manifest.py index a5652fb85..4e9547d05 100644 --- a/python/pypto/ir/comm_manifest.py +++ b/python/pypto/ir/comm_manifest.py @@ -27,10 +27,13 @@ * ``devices``: list of physical device ids covered by the group; **empty list = all devices** (resolved by the driver against ``DistributedConfig.device_ids``). -* ``slots``: list of allocation specs, each a 1:1 image of - ``simpler.task_interface.ChipBufferSpec``. ``load_from_host`` / - ``store_to_host`` are bool flags here — the specific host tensor binding - lives on the alloc op, not on this allocation spec. +* ``slots``: list of allocation specs. Each slot carries only ``name`` (the + buffer's runtime-unique identifier, taken from the underlying Var's + ``name_hint``), ``nbytes`` (per-rank allocation size in bytes — exact mirror + of ``ChipBufferSpec.nbytes``), and the host-staging flags. + ``WindowBuffer`` is intentionally dtype-agnostic (matching how ``MemRef`` + carries no dtype): the runtime's ``ChipBufferSpec.dtype`` field is unused + and we pass an opaque placeholder. """ from __future__ import annotations @@ -39,7 +42,6 @@ from pathlib import Path from typing import Any -from pypto.pypto_core import DataType from pypto.pypto_core.ir import ConstInt, Program # Manifest filename under output_dir/orchestration/. Keep in sync with the @@ -48,33 +50,6 @@ COMM_MANIFEST_VERSION = 2 -# Maps PyPTO DataType to the dtype-string convention simpler's ChipBufferSpec -# expects (numpy/torch style, e.g. "float32"). DataType.to_string() returns -# "fp32" / "bfloat16" which simpler does not understand; this table is the -# explicit single source of truth. -_SIMPLER_DTYPE_STR: dict[DataType, str] = { - DataType.FP32: "float32", - DataType.FP16: "float16", - DataType.BF16: "bfloat16", - DataType.INT8: "int8", - DataType.INT16: "int16", - DataType.INT32: "int32", - DataType.INT64: "int64", - DataType.UINT8: "uint8", -} - - -def _simpler_dtype_str(dtype: DataType) -> str: - try: - return _SIMPLER_DTYPE_STR[dtype] - except KeyError as exc: - raise RuntimeError( - f"Unsupported WindowBuffer dtype {dtype!r} for ChipBufferSpec; " - f"add an entry to _SIMPLER_DTYPE_STR. " - f"Known dtypes: {sorted(d.to_string() for d in _SIMPLER_DTYPE_STR)}" - ) from exc - - def lift_comm_manifest(program: Program) -> dict[str, Any] | None: """Lift ``program.comm_groups`` into a JSON-safe dict for AOT serialization. @@ -83,9 +58,9 @@ def lift_comm_manifest(program: Program) -> dict[str, Any] | None: path used by multi_chip_dispatch and parallel_reduce). Current v2 supports a single CommGroup with literal ``int`` per-slot - ``size``. Symbolic sizes raise ``RuntimeError`` at compile time — that's - a better failure point than runtime, since the program author can fix it - without redeploying. + ``size`` (in bytes). Symbolic sizes raise ``RuntimeError`` at compile + time — that's a better failure point than runtime, since the program + author can fix it without redeploying. """ comm_groups = list(program.comm_groups) if not comm_groups: @@ -103,14 +78,15 @@ def lift_comm_manifest(program: Program) -> dict[str, Any] | None: if size_const is None: raise RuntimeError( f"dynamic WindowBuffer size is not supported yet " - f"(slot {slot.name!r}); declare size as a literal int" + f"(slot {slot.name_hint!r}); declare size as a literal int (bytes)" ) slots_data.append( { - "name": slot.name, - "dtype": _simpler_dtype_str(slot.dtype), - "size": int(size_const.value), - "bits_per_element": slot.dtype.get_bit(), + # The buffer's runtime-unique identifier flows through the + # inherited Var.name_hint (no separate name field on + # WindowBuffer — mirrors MemRef). + "name": slot.name_hint, + "nbytes": int(size_const.value), "load_from_host": bool(slot.load_from_host), "store_to_host": bool(slot.store_to_host), } diff --git a/python/pypto/language/distributed/__init__.py b/python/pypto/language/distributed/__init__.py index 173f1f400..6dc526661 100644 --- a/python/pypto/language/distributed/__init__.py +++ b/python/pypto/language/distributed/__init__.py @@ -21,6 +21,7 @@ ``pld.tile.*``, ``pld.system.*``) are added in subsequent milestones (N1.2+). """ +from .alloc import alloc_window_buffer, window from .distributed_tensor import DistributedTensor -__all__ = ["DistributedTensor"] +__all__ = ["DistributedTensor", "alloc_window_buffer", "window"] diff --git a/python/pypto/language/distributed/alloc.py b/python/pypto/language/distributed/alloc.py new file mode 100644 index 000000000..7eb62744f --- /dev/null +++ b/python/pypto/language/distributed/alloc.py @@ -0,0 +1,109 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- + +"""``pld.alloc_window_buffer`` / ``pld.window`` — DSL sentinels for CommGroup windows. + +These functions are parser sentinels — calling them at Python runtime always +raises. They exist so that source code like:: + + @pl.function(level=pl.Level.HOST, role=pl.Role.Orchestrator) + def host_orch(self): + buf = pld.alloc_window_buffer(256 * 4) # 256 FP32 elements + data = pld.window(buf, [256], dtype=pl.FP32) + ... + +is syntactically valid Python that the AST parser can intercept and lift into +``ir.OpExpr(pld.alloc_window_buffer)`` / ``ir.OpExpr(pld.window)`` IR nodes. + +The layout mirrors the ``tile.alloc`` / ``MemRef`` / ``TileType`` triple: + +* ``alloc_window_buffer`` is **pure address-space allocation** — it takes a + per-rank ``size`` in **bytes** (matching ``tile.alloc(memspace, size)``) + and returns the singleton :class:`ir.PtrType` (allocation-identity token). + At parse time the LHS is a plain ``Var(PtrType)``; the comm-collection + pass later wraps the Ptr in an :class:`ir.WindowBuffer` Var subclass and + registers it on the program's CommGroup metadata. +* ``window`` lifts that Ptr handle into a :class:`ir.DistributedTensorType` + view by specifying the per-rank ``shape`` and ``dtype`` at materialisation + time. The result type's ``window_buffer`` back-reference is filled in by + the same comm-collection pass. +""" + +from collections.abc import Sequence + +from pypto.language.typing import IntLike +from pypto.pypto_core import DataType +from pypto.pypto_core.ir import Var + +from .distributed_tensor import DistributedTensor + + +def alloc_window_buffer(size: IntLike) -> Var: # noqa: ARG001 + """Declare a per-rank CommGroup window-buffer slot of ``size`` bytes. + + Mirrors ``tile.alloc(memory_space, size)``: pure allocation semantics, + no shape / dtype concept on the buffer itself. The result is the + allocation-identity token that ``pld.window`` consumes. + + Args: + size: Per-rank allocation size in **bytes**. Accepts an ``int`` + literal, a DSL ``Scalar``, or a raw ``ir.Expr`` (e.g. a + symbolic expression containing ``pld.world_size()``). + + Returns: + A plain :class:`ir.Var` of type :class:`ir.PtrType` bound at the + assignment LHS (the buffer's runtime-unique name comes from the + LHS variable identifier, captured automatically by the parser). + Pass the result through :func:`window` to materialise a + :class:`DistributedTensor` view. + + Raises: + RuntimeError: Always — this function is a parser sentinel. The parser + intercepts the call before Python ever invokes the body. + """ + raise RuntimeError( + "pld.alloc_window_buffer must be called inside a @pl.function " + "(level=Level.HOST, role=Role.Orchestrator)" + ) + + +def window( + buf: Var, # noqa: ARG001 + shape: Sequence[IntLike], # noqa: ARG001 + *, + dtype: DataType, # noqa: ARG001 +) -> DistributedTensor: + """Materialise a window-buffer Ptr handle as a DistributedTensor view. + + Shape and dtype enter the type system here; the result type + (:class:`ir.DistributedTensorType`) carries an optional back-reference + to the source :class:`ir.WindowBuffer` that the comm-collection pass + fills in later. + + Args: + buf: A :class:`ir.Var` of type :class:`ir.PtrType` produced by + :func:`alloc_window_buffer`. + shape: Per-rank shape (list / tuple of ints, DSL ``Scalar``s, or raw + ``ir.Expr``s — anything :data:`IntLike` accepts). + dtype: Element data type. Kwarg-only. + + Returns: + A :class:`DistributedTensor` view (IR-level + :class:`ir.DistributedTensorType`) of the given shape and dtype that + represents the local rank's slice of the window. + + Raises: + RuntimeError: Always — this function is a parser sentinel. + """ + raise RuntimeError( + "pld.window must be called inside a @pl.function (level=Level.HOST, role=Role.Orchestrator)" + ) + + +__all__ = ["alloc_window_buffer", "window"] diff --git a/python/pypto/language/parser/ast_parser.py b/python/pypto/language/parser/ast_parser.py index 76764decf..844d91503 100644 --- a/python/pypto/language/parser/ast_parser.py +++ b/python/pypto/language/parser/ast_parser.py @@ -15,7 +15,7 @@ from collections.abc import Iterator, Sequence from contextlib import contextmanager from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, TypeGuard, cast from pypto.ir import IRBuilder from pypto.ir import op as ir_op @@ -51,6 +51,25 @@ from .decorator import InlineFunction +def _is_pld_call(node: object, attr_name: str) -> TypeGuard[ast.Call]: + """Return True when ``node`` is the AST for a ``pld.(...)`` call. + + Recognises only the dotted ``pld.`` form. Aliasing under another + name (``from pypto.language.distributed import alloc_window_buffer``) is + intentionally not matched — the parser anchors distributed-DSL detection + on the ``pld.`` prefix. + """ + if not isinstance(node, ast.Call): + return False + func = node.func + return ( + isinstance(func, ast.Attribute) + and func.attr == attr_name + and isinstance(func.value, ast.Name) + and func.value.id == "pld" + ) + + def _is_const_int(value: object) -> bool: """Check if a value is a compile-time constant integer. @@ -252,6 +271,7 @@ def __init__( # noqa: PLR0913 buffer_name_meta: dict[tuple[str, str], dict[str, Any]] | None = None, dyn_var_cache: dict[str, ir.Var] | None = None, pending_comments: dict[int, list[tuple[int, str]]] | None = None, + alloc_window_buffer_names: set[str] | None = None, ): """Initialize AST parser. @@ -273,6 +293,10 @@ def __init__( # noqa: PLR0913 pending_comments: Map from 1-based line number to ``#``-stripped comment lines (produced by :func:`extract_line_comments`). Drained in source order and attached as ``leading_comments`` metadata to the stmt that follows. + alloc_window_buffer_names: Optional shared set of buffer names that have already + been declared by ``pld.alloc_window_buffer`` within this program. Multiple + functions in a ``@pl.program`` share this set so the name-uniqueness check + spans the whole program rather than a single function. """ self.span_tracker = SpanTracker(source_file, source_lines, line_offset, col_offset) self.scope_manager = ScopeManager(strict_ssa=strict_ssa) @@ -330,6 +354,13 @@ def __init__( # noqa: PLR0913 # tail-of-block comments (inside body indent) from outer-scope comments. self._pending_comments: dict[int, list[tuple[int, str]]] = pending_comments or {} + # Names declared by pld.alloc_window_buffer in this program. Globally unique + # across all functions in a single @pl.program (the decorator passes a shared + # set when constructing per-function parsers). + self._alloc_window_buffer_names: set[str] = ( + alloc_window_buffer_names if alloc_window_buffer_names is not None else set() + ) + # Cached arithmetic analyzer used to simplify symbolic slice extents at # construction time. One instance per parser amortises the sub-analyzer # setup across the many subscripts found in a typical function body. @@ -979,6 +1010,15 @@ def parse_assignment(self, stmt: ast.Assign) -> None: Args: stmt: Assign AST node """ + # Intercept ``buf = pld.alloc_window_buffer(...)`` before the generic + # path: the alloc op needs the LHS name as a kwarg, and we enforce a + # single-Name target + program-global name uniqueness here so the + # error surfaces at the original assignment site rather than from + # deep inside type deduction. + if len(stmt.targets) == 1 and _is_pld_call(stmt.value, "alloc_window_buffer"): + self._parse_alloc_window_buffer_assignment(stmt.targets[0], stmt.value) + return + # Handle tuple unpacking for yields if len(stmt.targets) == 1: target = stmt.targets[0] @@ -1063,6 +1103,74 @@ def parse_assignment(self, stmt: ast.Assign) -> None: hint="Use simple variable assignments or tuple unpacking with pl.yield_()", ) + def _parse_alloc_window_buffer_assignment(self, target: ast.expr, value: ast.Call) -> None: + """Parse ``buf = pld.alloc_window_buffer(size)``. + + Takes one positional ``size`` (per-rank byte count) and binds the LHS + to a plain ``Var(PtrType)``. The LHS must be a single ``ast.Name``, + the chosen name must be program-globally unique, and no user kwargs + are accepted (``name`` is parser-injected from the LHS). + """ + span = self.span_tracker.get_span(value) + + if not isinstance(target, ast.Name): + raise ParserSyntaxError( + "pld.alloc_window_buffer must be assigned to a single variable name " + f"(got '{ast.unparse(target)}')", + span=span, + hint="Use 'buf = pld.alloc_window_buffer(...)' (no tuple unpacking, no subscripts)", + ) + + name = target.id + + if name in self._alloc_window_buffer_names: + raise ParserSyntaxError( + f"pld.alloc_window_buffer name '{name}' is already declared in this program", + span=span, + hint="Each window buffer must have a globally unique name across all functions", + ) + + if len(value.args) != 1: + raise ParserSyntaxError( + "pld.alloc_window_buffer takes exactly 1 positional argument (size in bytes); " + f"got {len(value.args)}", + span=span, + hint="Use 'pld.alloc_window_buffer(N)' where N is the per-rank size in bytes", + ) + + if value.keywords: + kw_names = [kw.arg for kw in value.keywords if kw.arg is not None] + raise ParserSyntaxError( + f"pld.alloc_window_buffer does not accept user-supplied kwargs; got {kw_names}", + span=span, + hint="Pass only the size as a positional argument; the buffer's name is " + "auto-derived from the assignment LHS", + ) + + size_expr = self._parse_op_positional_arg(value.args[0]) + if isinstance(size_expr, int): + size_expr = ir.ConstInt(int(size_expr), DataType.INT64, span) + elif isinstance(size_expr, (list, tuple, ir.MakeTuple)): + raise ParserSyntaxError( + "pld.alloc_window_buffer size must be a scalar (int / Expr in bytes), not a list/tuple", + span=span, + hint="The redesigned alloc is pure-allocation — pass a single byte count " + "(e.g. 256 * 4). Shape lives on pld.window(buf, [shape], dtype=...) instead.", + ) + elif not isinstance(size_expr, ir.Expr): + raise ParserSyntaxError( + "pld.alloc_window_buffer size must be an int / Expr in bytes " + f"(got {type(size_expr).__name__})", + span=span, + hint="Use a literal like '1024' or a symbolic expression (pld.world_size() * N, ...)", + ) + + alloc_call = ir.create_op_call("pld.alloc_window_buffer", [size_expr], {"name": name}, span) + + self._alloc_window_buffer_names.add(name) + var = self._assign_or_let(name, alloc_call, span) + self.scope_manager.define_var(name, var, span=span) + def _parse_subscript_assignment(self, target: ast.Subscript, value_node: ast.expr) -> None: """Desugar ``dst[] = src`` to ``dst = pl.assemble(dst, src, offsets)``. @@ -3589,6 +3697,10 @@ def parse_op_call(self, call: ast.Call) -> ir.Expr: if isinstance(node, ast.Name): attrs.insert(0, node.id) + # pld.{operation} (2-segment) — distributed DSL ops + if len(attrs) == 2 and attrs[0] == "pld": + return self._parse_pld_op(attrs[1], call) + # pl.tensor.{operation} (3-segment) if len(attrs) >= 3 and attrs[0] == "pl" and attrs[1] == "tensor": op_name = attrs[2] @@ -4322,6 +4434,81 @@ def _parse_array_op(self, op_name: str, call: ast.Call) -> ir.Expr: """Parse array operation (create / get_element / update_element).""" return self._dispatch_op(_dsl_array, "pl.array", op_name, call) + def _parse_pld_op(self, op_name: str, call: ast.Call) -> ir.Expr: + """Parse a ``pld.(...)`` distributed-DSL operation. + + Only ``pld.window`` is dispatched as a normal expression here. + ``pld.alloc_window_buffer`` is intercepted in + :meth:`parse_assignment` (it needs the LHS name) and reaches this + method only when it appears outside an assignment — in which case + we emit a targeted error rather than silently constructing a + nameless alloc op. + """ + span = self.span_tracker.get_span(call) + + if op_name == "alloc_window_buffer": + raise ParserSyntaxError( + "pld.alloc_window_buffer must appear as the RHS of a simple assignment " + "(its result must be bound to a named variable)", + span=span, + hint="Write 'buf = pld.alloc_window_buffer(N)'", + ) + + if op_name == "window": + if len(call.args) != 2: + raise ParserSyntaxError( + f"pld.window takes exactly 2 positional arguments (buf, shape); got {len(call.args)}", + span=span, + hint="Use 'pld.window(buf, [shape], dtype=pl.
)'", + ) + + allowed_kwargs = {"dtype"} + for kw in call.keywords: + if kw.arg not in allowed_kwargs: + raise ParserSyntaxError( + f"pld.window does not accept kwarg '{kw.arg}'", + span=span, + hint="Only 'dtype' is accepted as a kwarg", + ) + + buf_expr = self.parse_expression(call.args[0]) + if not isinstance(buf_expr, ir.Expr): + raise ParserSyntaxError( + "pld.window first argument must be an IR expression", + span=span, + ) + if not isinstance(buf_expr.type, ir.PtrType): + raise ParserTypeError( + "pld.window expects a Ptr handle (output of pld.alloc_window_buffer); " + f"got {ir.python_print_type(buf_expr.type)}", + span=span, + hint="Pass a variable assigned from 'pld.alloc_window_buffer(...)'", + ) + + shape_raw = self._parse_op_positional_arg(call.args[1]) + if not isinstance(shape_raw, ir.MakeTuple): + raise ParserSyntaxError( + f"pld.window shape must be a list/tuple literal (got {type(shape_raw).__name__})", + span=span, + hint="Use a list literal like '[N]' or '[pld.world_size(), N]'", + ) + shape_arg: ir.Expr = shape_raw + + kwargs = self._parse_op_kwargs(call) + if "dtype" not in kwargs: + raise ParserSyntaxError( + "pld.window requires a 'dtype' kwarg", + span=span, + hint="Use 'pld.window(buf, [shape], dtype=pl.FP32)'", + ) + return ir.create_op_call("pld.window", [buf_expr, shape_arg], kwargs, span) + + raise InvalidOperationError( + f"Unknown distributed operation 'pld.{op_name}'", + span=span, + hint="Available: pld.alloc_window_buffer, pld.window", + ) + # Maps iterator type name to ForKind enum value. _ITERATOR_TO_KIND = { "range": ir.ForKind.Sequential, diff --git a/python/pypto/language/parser/decorator.py b/python/pypto/language/parser/decorator.py index e99ac91e0..69aaab44f 100644 --- a/python/pypto/language/parser/decorator.py +++ b/python/pypto/language/parser/decorator.py @@ -895,6 +895,10 @@ def _parse_program_body( # noqa: PLR0912 # ir.Var objects for dynamic dimension variables (issue #618). dyn_var_cache: dict[str, ir.Var] = {} + # Shared set of pld.alloc_window_buffer names so name-uniqueness + # checks span every function in the program. + alloc_window_buffer_names: set[str] = set() + # Compute per-method line-range boundaries. Each method owns comments from # its first-line up to the line just before the next method (or end of # class for the last one). This captures tail-of-block comments inside @@ -948,6 +952,7 @@ def _parse_program_body( # noqa: PLR0912 buffer_name_meta=buffer_name_meta, dyn_var_cache=dyn_var_cache, pending_comments=method_comments, + alloc_window_buffer_names=alloc_window_buffer_names, ) try: diff --git a/python/pypto/pypto_core/ir.pyi b/python/pypto/pypto_core/ir.pyi index 2ab59598f..cf714d29e 100644 --- a/python/pypto/pypto_core/ir.pyi +++ b/python/pypto/pypto_core/ir.pyi @@ -542,9 +542,17 @@ class DistributedTensorType(TensorType): ``As`` to dispatch on the distributed variant. Mirrors :class:`TensorType`'s constructor surface — memref / tensor_view - are equally valid on the distributed flavour. + are equally valid on the distributed flavour. Additionally carries an + optional ``window_buffer`` back-reference (populated by ``pld.window``) + so that two same-shape / same-dtype slices of different + :class:`WindowBuffer` allocations stay structurally distinct. """ + window_buffer: Final[WindowBuffer | None] + """Source :class:`WindowBuffer` for slices materialised via ``pld.window``; + ``None`` for user-declared parameter annotations such as + ``pld.DistributedTensor[[N], pl.FP32]``.""" + @overload def __init__(self, shape: Sequence[Expr], dtype: DataType) -> None: """Create a distributed tensor type.""" @@ -581,6 +589,10 @@ class DistributedTensorType(TensorType): ) -> None: """Create a distributed tensor type with constant shape, optional memref and tensor_view.""" + @overload + def __init__(self, shape: Sequence[Expr], dtype: DataType, window_buffer: WindowBuffer) -> None: + """Create a distributed tensor type bound to a specific WindowBuffer (produced by ``pld.window``).""" + class TileView: """Tile view: read-only representation of valid shape, stride, start offset, layouts, fractal, and pad. Construct with all values; fields cannot be mutated @@ -1054,6 +1066,20 @@ class PtrType(Type): """Get the singleton PtrType instance.""" ... +class WindowBufferType(Type): + """Singleton marker type for ``pld.alloc_window_buffer`` outputs. + + Mirrors :class:`MemRefType`: the type carries no per-instance fields. All + allocation metadata (name, size, dtype, host-staging flags) lives on the + companion :class:`WindowBuffer` Var subclass. + """ + + def __init__(self) -> None: ... + @staticmethod + def get() -> WindowBufferType: + """Get the shared singleton WindowBufferType instance.""" + ... + class MemRef(Var): """Memory reference variable for shaped types (inherits from Var).""" @@ -2414,41 +2440,46 @@ class Program(IRNode): Program with type information """ -class WindowBuffer(IRNode): - """Per-rank allocation spec for a named CommGroup HCCL window buffer. +class WindowBuffer(Var): + """Per-rank CommGroup HCCL window-buffer allocation, modelled as a specialised Var. + + Mirrors :class:`MemRef`: its SSA-edge type is the singleton + :class:`WindowBufferType`; allocation metadata (the underlying Ptr Var, + per-rank size in bytes, host-staging flags) lives on the Var subclass + itself. The buffer's runtime-unique identifier flows through the + inherited :attr:`Var.name_hint` (taken from ``base.name_hint``) — there + is no separate ``name`` field, because two structurally-identical + allocations should compare equal regardless of the chosen variable name. - Maps 1:1 to ``simpler.task_interface.ChipBufferSpec`` at submit-time. - Participates in structural equality / hashing via reflection. + ``WindowBuffer`` instances are constructed by the comm-collection pass; + at parse time the :func:`alloc_window_buffer` op's LHS is a plain + :class:`Var` of type :class:`PtrType`. """ - name: Final[str] - """Buffer name (parser-extracted from the alloc-op LHS, globally unique).""" + base: Final[Var] + """The underlying ``Ptr`` Var produced by ``pld.alloc_window_buffer``. + Same role as :attr:`MemRef.base`.""" size: Final[Expr] - """Per-rank element count (ConstInt or symbolic Expr). Allocation-only — - does not carry a multi-dim shape.""" - - dtype: Final[DataType] - """Element data type.""" + """Per-rank allocation size in **bytes** (matches + :attr:`tile.alloc`'s size argument and the runtime's + ``ChipBufferSpec.nbytes`` field).""" load_from_host: Final[bool] - """``True`` if this slot participates in pre-fork H2D staging. The specific - host tensor that supplies the staged data is recorded on the alloc op, not - on this allocation spec.""" + """``True`` if this slot participates in pre-fork H2D staging.""" store_to_host: Final[bool] """``True`` if this slot participates in post-task D2H staging.""" def __init__( self, - name: str, + base: Var, size: Expr, - dtype: DataType, load_from_host: bool = False, store_to_host: bool = False, span: Span = ..., ) -> None: - """Create a WindowBuffer.""" + """Create a WindowBuffer wrapping the given Ptr ``base`` Var.""" class CommGroup(IRNode): """A communication group inferred for a ``@pl.program``. diff --git a/python/pypto/runtime/distributed_runner.py b/python/pypto/runtime/distributed_runner.py index 547a11de9..3b5f06d36 100644 --- a/python/pypto/runtime/distributed_runner.py +++ b/python/pypto/runtime/distributed_runner.py @@ -29,6 +29,11 @@ from pypto.ir.distributed_compiled_program import DistributedCompiledProgram, DistributedConfig +# Placeholder dtype for ChipBufferSpec: WindowBuffer carries no dtype (matching +# MemRef), and simpler does not consume this field at runtime. +_OPAQUE_DTYPE = "opaque" + + # --------------------------------------------------------------------------- # ContinuousTensor → torch.Tensor conversion # --------------------------------------------------------------------------- @@ -163,14 +168,12 @@ def _build_chip_bootstrap_configs_from_manifest( specs: list[Any] = [] for slot in group["slots"]: - count = int(slot["size"]) - # Round up to whole bytes — works uniformly for FP4/INT4 sub-byte dtypes. - nbytes = (count * int(slot["bits_per_element"]) + 7) // 8 + nbytes = int(slot["nbytes"]) specs.append( ChipBufferSpec( name=slot["name"], - dtype=slot["dtype"], - count=count, + dtype=_OPAQUE_DTYPE, + count=nbytes, nbytes=nbytes, load_from_host=bool(slot.get("load_from_host", False)), store_to_host=bool(slot.get("store_to_host", False)), diff --git a/src/ir/op/distributed/memory.cpp b/src/ir/op/distributed/memory.cpp new file mode 100644 index 000000000..daeaba01c --- /dev/null +++ b/src/ir/op/distributed/memory.cpp @@ -0,0 +1,139 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ + +/** + * @file memory.cpp + * @brief Distributed-memory ops — CommGroup window-buffer allocation & materialization. + * + * Mirrors :file:`src/ir/op/tile_ops/memory.cpp` in structure: this translation + * unit owns every op that touches CommGroup window-buffer address-space life + * cycle. Two ops are registered: + * + * * ``pld.alloc_window_buffer(size, *, name)`` — pure address-space allocation. + * Takes a per-rank scalar ``size`` (in **bytes**, matching ``tile.alloc``) + * plus a ``name`` kwarg, and returns the singleton :class:`PtrType` (the + * same allocation-identity token ``tile.alloc`` produces). The parser binds + * the result to a plain ``Var(PtrType, name_hint=name)``; the comm-collection + * pass later wraps the Ptr in an :class:`ir.WindowBuffer` Var subclass and + * threads it through ``DistributedTensorType.window_buffer_``. + * + * * ``pld.window(buf, shape, *, dtype)`` — materialises a Ptr handle as a + * :class:`DistributedTensorType` view with the supplied shape and dtype. + * The result type's ``window_buffer_`` field is left ``nullopt`` at parse + * time; the comm-collection pass populates it from the def-use chain back + * to the alloc. Shape & dtype enter the type system only here. + */ + +#include +#include +#include +#include +#include + +#include "pypto/core/any_cast.h" +#include "pypto/core/dtype.h" +#include "pypto/core/error.h" +#include "pypto/ir/expr.h" +#include "pypto/ir/kind_traits.h" +#include "pypto/ir/op_registry.h" +#include "pypto/ir/type.h" + +namespace pypto { +namespace ir { + +namespace { + +template +T GetKwarg(const std::vector>& kwargs, const std::string& key, + const std::string& op_name) { + for (const auto& [k, v] : kwargs) { + if (k == key) { + return AnyCast(v, "kwarg key: " + key); + } + } + throw ValueError("Missing kwarg '" + key + "' on " + op_name); +} + +TypePtr DeduceAllocWindowBufferType(const std::vector& args, + const std::vector>& kwargs) { + CHECK(args.size() == 1) << "pld.alloc_window_buffer requires exactly 1 positional argument " + "(size: ScalarType expression in bytes), but got " + << args.size(); + CHECK(args[0]) << "pld.alloc_window_buffer size argument must not be null"; + + auto name = GetKwarg(kwargs, "name", "pld.alloc_window_buffer"); + CHECK(!name.empty()) << "pld.alloc_window_buffer requires a non-empty 'name' kwarg"; + + // The op produces a Ptr — exact mirror of tile.alloc / tensor.alloc. + return GetPtrType(); +} + +TypePtr DeduceWindowType(const std::vector& args, + const std::vector>& kwargs) { + CHECK(args.size() == 2) << "pld.window requires 2 positional args (buf, shape), but got " << args.size(); + CHECK(args[0]) << "pld.window 'buf' argument must not be null"; + + // First arg is the allocation-identity token from pld.alloc_window_buffer + // (or, in principle, any Ptr-typed Var the parser routes here). The + // back-reference to the actual ``WindowBuffer`` is filled in later by the + // comm-collection pass; until then ``window_buffer_`` is ``std::nullopt``. + CHECK(IsA(args[0]->GetType())) + << "pld.window 'buf' must have type Ptr (output of pld.alloc_window_buffer), got " + << args[0]->GetType()->TypeName(); + + auto shape_tuple = As(args[1]); + CHECK(shape_tuple) << "pld.window second argument must be a shape tuple (MakeTuple of ints / Exprs), got " + << args[1]->TypeName(); + + auto dtype = GetKwarg(kwargs, "dtype", "pld.window"); + + // Shape & dtype enter the type system here. window_buffer_ stays nullopt + // until CollectCommGroups (N4) wires it to the originating WindowBuffer. + return std::make_shared(shape_tuple->elements_, dtype); +} + +} // namespace + +// ============================================================================ +// pld.alloc_window_buffer — per-rank CommGroup window-buffer allocation +// ============================================================================ + +REGISTER_OP("pld.alloc_window_buffer") + .set_description( + "Declare a per-rank CommGroup window-buffer slot of `size` bytes. Returns a Ptr " + "(allocation-identity token, exactly like tile.alloc / tensor.alloc). The " + "comm-collection pass later wraps the Ptr into an ir.WindowBuffer Var subclass and " + "registers it on the program's CommGroup metadata.") + .set_op_category("DistributedOp") + .add_argument("size", "Per-rank allocation size in bytes (ScalarType expression)") + .set_attr("name") + .no_memory_spec() + .f_deduce_type(DeduceAllocWindowBufferType); + +// ============================================================================ +// pld.window — materialise a window-buffer Ptr as a DistributedTensor view +// ============================================================================ + +REGISTER_OP("pld.window") + .set_description( + "Materialise a window-buffer Ptr as a DistributedTensor view with the given shape " + "and dtype. The result type's window_buffer back-reference is left unset at parse " + "time; the comm-collection pass populates it from the def-use chain back to the " + "alloc op.") + .set_op_category("DistributedOp") + .add_argument("buf", "Ptr handle (output of pld.alloc_window_buffer)") + .add_argument("shape", "Per-rank shape (MakeTuple of ExprPtr / ConstInt)") + .set_attr("dtype") + .no_memory_spec() + .f_deduce_type(DeduceWindowType); + +} // namespace ir +} // namespace pypto diff --git a/src/ir/serialization/deserializer.cpp b/src/ir/serialization/deserializer.cpp index de730fdf4..6e6b7c320 100644 --- a/src/ir/serialization/deserializer.cpp +++ b/src/ir/serialization/deserializer.cpp @@ -31,6 +31,7 @@ #include "pypto/ir/core.h" #include "pypto/ir/expr.h" #include "pypto/ir/memref.h" +#include "pypto/ir/program.h" #include "pypto/ir/scalar_expr.h" #include "pypto/ir/serialization/type_registry.h" #include "pypto/ir/span.h" @@ -323,10 +324,12 @@ class IRDeserializer::Impl : public detail::DeserializerContext { msgpack::object memref_obj; msgpack::object tile_view_obj; msgpack::object tensor_view_obj; + msgpack::object window_buffer_obj; uint8_t memory_space_code = 0; bool has_memref = false; bool has_tile_view = false; bool has_tensor_view = false; + bool has_window_buffer = false; bool has_memory_space = false; msgpack::object_kv* p = obj.via.map.ptr; @@ -360,6 +363,9 @@ class IRDeserializer::Impl : public detail::DeserializerContext { } else if (key == "tensor_view") { tensor_view_obj = p->val; has_tensor_view = true; + } else if (key == "window_buffer") { + window_buffer_obj = p->val; + has_window_buffer = true; } else if (key == "memory_space") { p->val.convert(memory_space_code); has_memory_space = true; @@ -381,9 +387,21 @@ class IRDeserializer::Impl : public detail::DeserializerContext { return std::make_shared(shape, DataType(dtype_code), memref, tensor_view); } else if (type_kind == "DistributedTensorType") { - // DistributedTensorType currently carries only shape + dtype (no memref / - // tensor_view at the surface; alloc-side metadata lives on the alloc op). - return std::make_shared(shape, DataType(dtype_code)); + std::optional memref; + std::optional tensor_view; + std::optional window_buffer; + if (has_memref) { + memref = DeserializeMemRef(memref_obj, zone); + } + if (has_tensor_view) { + tensor_view = DeserializeTensorView(tensor_view_obj, zone); + } + if (has_window_buffer) { + window_buffer = + std::static_pointer_cast(DeserializeNode(window_buffer_obj, zone)); + } + return std::make_shared(shape, DataType(dtype_code), memref, tensor_view, + window_buffer); } else if (type_kind == "TileType") { std::optional memref; std::optional tile_view; @@ -405,6 +423,8 @@ class IRDeserializer::Impl : public detail::DeserializerContext { return std::make_shared(DataType(dtype_code), shape[0]); } else if (type_kind == "TupleType") { return std::make_shared(types); + } else if (type_kind == "WindowBufferType") { + return GetWindowBufferType(); } else if (type_kind == "MemRefType") { return GetMemRefType(); } else if (type_kind == "Ptr") { diff --git a/src/ir/serialization/serializer.cpp b/src/ir/serialization/serializer.cpp index 0033dd1c5..e878804bf 100644 --- a/src/ir/serialization/serializer.cpp +++ b/src/ir/serialization/serializer.cpp @@ -414,6 +414,16 @@ class IRSerializer::Impl { if (tensor_type->tensor_view_.has_value()) { type_map["tensor_view"] = SerializeTensorView(tensor_type->tensor_view_, zone); } + + // DistributedTensorType-only: serialise the optional WindowBuffer + // back-reference as a node reference (WindowBuffer is an IRNode and + // shares the ref-table with the rest of the IR). + if (type->GetKind() == ObjectKind::DistributedTensorType) { + auto dt = std::static_pointer_cast(type); + if (dt->window_buffer_.has_value()) { + type_map["window_buffer"] = SerializeNode(*dt->window_buffer_, zone); + } + } } else if (auto tile_type = As(type)) { type_map["dtype"] = msgpack::object(tile_type->dtype_.Code(), zone); @@ -447,8 +457,9 @@ class IRSerializer::Impl { types_vec.push_back(SerializeType(t, zone)); } type_map["types"] = msgpack::object(types_vec, zone); - } else if (IsA(type) || IsA(type) || IsA(type)) { - // MemRefType, PtrType, and UnknownType have no additional fields + } else if (IsA(type) || IsA(type) || IsA(type) || + IsA(type)) { + // Singleton marker types (no extra fields beyond the type_kind key). } else { INTERNAL_UNREACHABLE << "Unknown Type subclass: " << type->TypeName(); } diff --git a/src/ir/serialization/type_deserializers.cpp b/src/ir/serialization/type_deserializers.cpp index a9d393770..4979eee7c 100644 --- a/src/ir/serialization/type_deserializers.cpp +++ b/src/ir/serialization/type_deserializers.cpp @@ -918,13 +918,14 @@ static IRNodePtr DeserializeProgram(const msgpack::object& fields_obj, msgpack:: static IRNodePtr DeserializeWindowBuffer(const msgpack::object& fields_obj, msgpack::zone& zone, DeserializerContext& ctx) { auto span = ctx.DeserializeSpan(GET_FIELD_OBJ("span")); - std::string name = GET_FIELD(std::string, "name"); + // ``name_hint`` is the inherited Var field carrying the runtime-unique + // identifier (mirrors MemRef). The base Ptr Var that this WindowBuffer + // wraps is reconstructed from the ``base`` field below. + auto base = std::static_pointer_cast(ctx.DeserializeNode(GET_FIELD_OBJ("base"), zone)); auto size = std::static_pointer_cast(ctx.DeserializeNode(GET_FIELD_OBJ("size"), zone)); - uint8_t dtype_code = GET_FIELD(uint8_t, "dtype"); bool load_from_host = GET_FIELD(bool, "load_from_host"); bool store_to_host = GET_FIELD(bool, "store_to_host"); - return std::make_shared(std::move(name), size, DataType(dtype_code), load_from_host, - store_to_host, span); + return std::make_shared(base, size, load_from_host, store_to_host, span); } // Deserialize CommGroup diff --git a/src/ir/transforms/python_printer.cpp b/src/ir/transforms/python_printer.cpp index 3d39a5378..f158059c6 100644 --- a/src/ir/transforms/python_printer.cpp +++ b/src/ir/transforms/python_printer.cpp @@ -496,6 +496,13 @@ std::string IRPythonPrinter::Print(const TypePtr& type) { return prefix_ + ".Ptr"; } + if (As(type)) { + // Singleton marker — no per-instance fields. Render as a bare attribute + // so it round-trips through the parser via the same path as ``pld.`` + // namespace lookups. + return "pld.WindowBufferType"; + } + return prefix_ + ".UnknownType"; } diff --git a/src/ir/transforms/structural_equal.cpp b/src/ir/transforms/structural_equal.cpp index 0a4a97f65..d0be3e6b8 100644 --- a/src/ir/transforms/structural_equal.cpp +++ b/src/ir/transforms/structural_equal.cpp @@ -1082,6 +1082,24 @@ bool StructuralEqualImpl::EqualType(const TypePtr& lhs, const TypePt return false; } } + // DistributedTensorType-only field: window_buffer_ back-reference. Both + // sides must agree on presence, and the underlying WindowBuffer Vars must + // match through the regular Var-equality path (which falls into the Var + // identity map / auto-mapping logic — same as MemRef inside ShapedType). + if (lhs->GetKind() == ObjectKind::DistributedTensorType) { + auto lhs_dt = std::static_pointer_cast(lhs); + auto rhs_dt = std::static_pointer_cast(rhs); + if (lhs_dt->window_buffer_.has_value() != rhs_dt->window_buffer_.has_value()) { + if constexpr (AssertMode) { + ThrowMismatch("DistributedTensorType window_buffer presence mismatch", IRNodePtr(), IRNodePtr(), "", + ""); + } + return false; + } + if (lhs_dt->window_buffer_.has_value() && !EqualVar(*lhs_dt->window_buffer_, *rhs_dt->window_buffer_)) { + return false; + } + } return true; } else if (auto lhs_tile = As(lhs)) { auto rhs_tile = As(rhs); @@ -1244,7 +1262,8 @@ bool StructuralEqualImpl::EqualType(const TypePtr& lhs, const TypePt return false; } return true; - } else if (IsA(lhs) || IsA(lhs) || IsA(lhs)) { + } else if (IsA(lhs) || IsA(lhs) || IsA(lhs) || + IsA(lhs)) { return true; // Singleton type, both being same type kind is sufficient } diff --git a/src/ir/transforms/structural_hash.cpp b/src/ir/transforms/structural_hash.cpp index 0f742592c..18e9587cd 100644 --- a/src/ir/transforms/structural_hash.cpp +++ b/src/ir/transforms/structural_hash.cpp @@ -451,6 +451,21 @@ StructuralHasher::result_type StructuralHasher::HashType(const TypePtr& type) { } else { h = hash_combine(h, static_cast(0)); // indicate absence } + // DistributedTensorType-only back-reference to its source WindowBuffer. + // Mix in presence + Var identity (HashNode dispatches the WindowBuffer Var + // path) so two same-shape / same-dtype DistributedTensorTypes built from + // different WindowBuffers hash apart. + if (type->GetKind() == ObjectKind::DistributedTensorType) { + auto dt = std::static_pointer_cast(type); + if (dt->window_buffer_.has_value()) { + h = hash_combine(h, static_cast(1)); + INTERNAL_CHECK(*dt->window_buffer_) + << "structural_hash encountered null window_buffer in DistributedTensorType"; + h = hash_combine(h, HashNode(*dt->window_buffer_)); + } else { + h = hash_combine(h, static_cast(0)); + } + } } else if (auto tile_type = As(type)) { // Hash dtype h = hash_combine(h, static_cast(std::hash{}(tile_type->dtype_.Code()))); @@ -505,8 +520,9 @@ StructuralHasher::result_type StructuralHasher::HashType(const TypePtr& type) { INTERNAL_CHECK(t) << "structural_hash encountered null type in TupleType"; h = hash_combine(h, HashType(t)); } - } else if (IsA(type) || IsA(type) || IsA(type)) { - // MemRefType, PtrType, and UnknownType have no fields, only hash type name (already done above) + } else if (IsA(type) || IsA(type) || IsA(type) || + IsA(type)) { + // Singleton marker types (no fields beyond the type name hashed above). } else { INTERNAL_CHECK(false) << "HashType encountered unhandled Type: " << type->TypeName(); } @@ -592,7 +608,8 @@ StructuralHasher::result_type StructuralHasher::HashNode(const IRNodePtr& node) }; auto kind = node->GetKind(); - if (kind == ObjectKind::MemRef || kind == ObjectKind::IterArg || kind == ObjectKind::Var) { + if (kind == ObjectKind::MemRef || kind == ObjectKind::IterArg || kind == ObjectKind::Var || + kind == ObjectKind::WindowBuffer) { hash_var_identity(static_cast(node.get())->UniqueId()); } diff --git a/tests/ut/ir/core/test_comm_group_schema.py b/tests/ut/ir/core/test_comm_group_schema.py index 8f503a4be..61ae21855 100644 --- a/tests/ut/ir/core/test_comm_group_schema.py +++ b/tests/ut/ir/core/test_comm_group_schema.py @@ -7,14 +7,26 @@ # See LICENSE in the root of the software repository for the full text of the License. # ----------------------------------------------------------------------------------------------------------- -"""IR-level tests for the v2 ``WindowBuffer`` / ``CommGroup`` schema (N1.3).""" +"""IR-level tests for the v2 ``WindowBuffer`` / ``CommGroup`` schema. + +Post-redesign, ``WindowBuffer`` is a :class:`Var` subclass that mirrors +:class:`MemRef` exactly: + +* Wraps a :class:`Var` of type :class:`PtrType` (``base``) — the + allocation-identity token from ``pld.alloc_window_buffer``. +* Carries ``size`` (per-rank bytes), ``load_from_host`` / ``store_to_host`` + flags. **No** dtype, **no** name field — the runtime-unique identifier + comes from the inherited ``Var.name_hint`` (taken from ``base.name_hint``). +""" import pytest from pypto.pypto_core import DataType from pypto.pypto_core.ir import ( CommGroup, ConstInt, + PtrType, Span, + Var, WindowBuffer, structural_equal, ) @@ -24,15 +36,23 @@ def _const(value: int) -> ConstInt: return ConstInt(value, DataType.INT64, Span.unknown()) +def _ptr(name: str) -> Var: + return Var(name, PtrType(), Span.unknown()) + + # --------------------------------------------------------------------------- # WindowBuffer # --------------------------------------------------------------------------- def test_window_buffer_minimal_construction(): - wb = WindowBuffer("data", _const(256), DataType.FP32) - assert wb.name == "data" - assert wb.dtype == DataType.FP32 + base = _ptr("data") + wb = WindowBuffer(base, _const(256)) + # name_hint flows from the base Ptr Var (mirrors MemRef). + assert wb.name_hint == "data" + assert wb.base is base + assert isinstance(wb.size, ConstInt) + assert wb.size.value == 256 assert wb.load_from_host is False assert wb.store_to_host is False @@ -40,11 +60,20 @@ def test_window_buffer_minimal_construction(): def test_window_buffer_load_store_to_host_flags(): """v2 schema: load/store_to_host are bool flags; the actual host tensor binding (if any) is recorded on the alloc op, not here.""" - wb = WindowBuffer("lut", _const(64), DataType.FP32, load_from_host=True, store_to_host=True) + base = _ptr("lut") + wb = WindowBuffer(base, _const(64), load_from_host=True, store_to_host=True) assert wb.load_from_host is True assert wb.store_to_host is True +def test_window_buffer_is_var_subclass(): + """Mirror MemRef: WindowBuffer inherits from Var so visitor / mutator + machinery treats it the same as any other Var.""" + base = _ptr("data") + wb = WindowBuffer(base, _const(64)) + assert isinstance(wb, Var) + + # --------------------------------------------------------------------------- # CommGroup structural equality # --------------------------------------------------------------------------- @@ -52,42 +81,77 @@ def test_window_buffer_load_store_to_host_flags(): def test_comm_group_empty_devices_means_all(): """``devices == []`` is the convention for "covers all DistributedConfig.device_ids".""" - g = CommGroup([], [WindowBuffer("data", _const(64), DataType.FP32)]) + g = CommGroup([], [WindowBuffer(_ptr("data"), _const(64))]) assert list(g.devices) == [] def test_comm_group_explicit_device_subset(): - g = CommGroup([0, 1, 2], [WindowBuffer("data", _const(64), DataType.FP32)]) + g = CommGroup([0, 1, 2], [WindowBuffer(_ptr("data"), _const(64))]) assert list(g.devices) == [0, 1, 2] -def test_comm_group_structural_equal(): - g1 = CommGroup([], [WindowBuffer("data", _const(256), DataType.FP32)]) - g2 = CommGroup([], [WindowBuffer("data", _const(256), DataType.FP32)]) +def test_comm_group_structural_equal_when_slot_var_is_shared(): + """Two CommGroups whose ``slots`` contain the **same** WindowBuffer Var + instances compare structurally equal. + + Mirrors :class:`MemRef` semantics: independently-constructed Var + instances each have a unique identity, so structural equality requires + sharing the underlying Var (same ``shared_ptr``). This is the form the + comm-collection pass produces — slots in ``Program.comm_groups`` and + the ``DistributedTensorType.window_buffer`` references both alias the + same WindowBuffer. + """ + wb = WindowBuffer(_ptr("data"), _const(256)) + g1 = CommGroup([], [wb]) + g2 = CommGroup([], [wb]) assert structural_equal(g1, g2) +def test_comm_group_structural_equal_for_independent_slot_vars_under_auto_mapping(): + """Two CommGroups whose ``slots`` are *separately constructed* compare + structurally equal under ``enable_auto_mapping=True`` — corresponding + Vars are matched by their position in the structure, so distinct + UniqueIds don't break equality when the surrounding shape, base + ``name_hint``, ``size``, and host-staging flags all match. + + Mirrors :class:`MemRef`'s auto-mapping semantics: the default identity + path requires shared Var instances (see + :func:`test_comm_group_structural_equal_when_slot_var_is_shared`); the + auto-mapping path is the structural-isomorphism check used when + comparing IR produced by independent runs. + """ + g1 = CommGroup([], [WindowBuffer(_ptr("data"), _const(256))]) + g2 = CommGroup([], [WindowBuffer(_ptr("data"), _const(256))]) + assert structural_equal(g1, g2, enable_auto_mapping=True) + + def test_comm_group_structural_not_equal_when_devices_differ(): - g_all = CommGroup([], [WindowBuffer("data", _const(64), DataType.FP32)]) - g_subset = CommGroup([0, 1], [WindowBuffer("data", _const(64), DataType.FP32)]) + """Sharing the slot Var isolates the ``devices``-only difference.""" + wb = WindowBuffer(_ptr("data"), _const(64)) + g_all = CommGroup([], [wb]) + g_subset = CommGroup([0, 1], [wb]) assert not structural_equal(g_all, g_subset) def test_comm_group_structural_not_equal_when_subsets_differ(): - a = CommGroup([0, 1], [WindowBuffer("data", _const(64), DataType.FP32)]) - b = CommGroup([2, 3], [WindowBuffer("data", _const(64), DataType.FP32)]) + wb = WindowBuffer(_ptr("data"), _const(64)) + a = CommGroup([0, 1], [wb]) + b = CommGroup([2, 3], [wb]) assert not structural_equal(a, b) def test_comm_group_structural_not_equal_when_slots_differ(): - a = CommGroup([], [WindowBuffer("data", _const(64), DataType.FP32)]) - b = CommGroup([], [WindowBuffer("data", _const(128), DataType.FP32)]) + """Same base Var but different ``size`` → distinct slot Vars → distinct groups.""" + base = _ptr("data") + a = CommGroup([], [WindowBuffer(base, _const(64))]) + b = CommGroup([], [WindowBuffer(base, _const(128))]) assert not structural_equal(a, b) def test_comm_group_structural_not_equal_when_load_flag_differs(): - a = CommGroup([], [WindowBuffer("data", _const(64), DataType.FP32, load_from_host=False)]) - b = CommGroup([], [WindowBuffer("data", _const(64), DataType.FP32, load_from_host=True)]) + base = _ptr("data") + a = CommGroup([], [WindowBuffer(base, _const(64), load_from_host=False)]) + b = CommGroup([], [WindowBuffer(base, _const(64), load_from_host=True)]) assert not structural_equal(a, b) diff --git a/tests/ut/ir/parser/test_alloc_window_buffer.py b/tests/ut/ir/parser/test_alloc_window_buffer.py new file mode 100644 index 000000000..d30a3ea54 --- /dev/null +++ b/tests/ut/ir/parser/test_alloc_window_buffer.py @@ -0,0 +1,192 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +# ruff: noqa: F722, F821 + +"""Parser tests for ``pld.alloc_window_buffer``. + +After the MemRef-mirror redesign, the alloc op is a pure-allocation primitive +(parallel to ``tile.alloc(memspace, size)``): + +* ``size`` is a scalar **byte** count; no shape, no dtype on the alloc. +* The op returns the singleton :class:`PtrType`; the parser binds the LHS as + a plain :class:`ir.Var` of type :class:`ir.PtrType`. The + comm-collection pass later wraps the Ptr in an :class:`ir.WindowBuffer` Var + subclass and registers it on ``Program.comm_groups``. +* The LHS variable name flows through ``Var.name_hint`` (and is also injected + as the op's ``name`` kwarg so the comm-collection pass can find it). +""" + +import pypto.language as pl +import pypto.language.distributed as pld +import pytest +from pypto.pypto_core import ir + + +def _get_host_orch(program: ir.Program, name: str = "host_orch") -> ir.Function: + gvar = program.get_global_var(name) + assert gvar is not None, f"Function '{name}' not found in program" + return program.functions[gvar] + + +def _find_alloc_assignment(func: ir.Function) -> ir.AssignStmt: + """Return the first AssignStmt whose RHS is a ``pld.alloc_window_buffer`` Call.""" + + def walk(stmt: ir.Stmt) -> ir.AssignStmt | None: + if isinstance(stmt, ir.AssignStmt): + if isinstance(stmt.value, ir.Call) and stmt.value.op.name == "pld.alloc_window_buffer": + return stmt + if isinstance(stmt, ir.SeqStmts): + for s in stmt.stmts: + hit = walk(s) + if hit is not None: + return hit + return None + + hit = walk(func.body) + assert hit is not None, "no pld.alloc_window_buffer assignment found in function body" + return hit + + +def test_alloc_window_buffer_lhs_is_plain_ptr_var(): + """The LHS variable is a plain ``ir.Var`` of type ``ir.PtrType`` — + no specialised ``WindowBuffer`` Var subclass at parse time.""" + + @pl.program + class P: + @pl.function(level=pl.Level.HOST, role=pl.Role.Orchestrator) + def host_orch(self): + buf = pld.alloc_window_buffer(1024) + return buf # noqa: RET504 + + func = _get_host_orch(P) + stmt = _find_alloc_assignment(func) + var = stmt.var + # Plain Var with PtrType — exact mirror of `mem_vec_7: Ptr = tile.alloc(...)`. + assert isinstance(var, ir.Var) + assert not isinstance(var, ir.WindowBuffer) + assert isinstance(var.type, ir.PtrType) + # The buffer's runtime-unique identifier comes from the LHS variable name + # via Var.name_hint. + assert var.name_hint == "buf" + + +def test_alloc_window_buffer_call_carries_name_kwarg(): + """The op call's kwargs carry the LHS-injected name. No dtype kwarg.""" + + @pl.program + class P: + @pl.function(level=pl.Level.HOST, role=pl.Role.Orchestrator) + def host_orch(self): + data_buf = pld.alloc_window_buffer(512) + return data_buf + + func = _get_host_orch(P) + stmt = _find_alloc_assignment(func) + assert isinstance(stmt.value, ir.Call) + call = stmt.value + assert call.kwargs["name"] == "data_buf" + assert "dtype" not in call.kwargs + assert len(call.args) == 1 + assert isinstance(call.args[0], ir.ConstInt) + assert call.args[0].value == 512 + + +def test_alloc_window_buffer_returns_singleton_ptr_type(): + """Different alloc sites all return the SAME singleton PtrType.""" + + @pl.program + class P: + @pl.function(level=pl.Level.HOST, role=pl.Role.Orchestrator) + def host_orch(self): + buf_a = pld.alloc_window_buffer(8) + buf_b = pld.alloc_window_buffer(16) # noqa: F841 + return buf_a + + func = _get_host_orch(P) + allocs: list[ir.AssignStmt] = [] + + def walk(stmt: ir.Stmt) -> None: + if isinstance(stmt, ir.AssignStmt) and isinstance(stmt.value, ir.Call): + if stmt.value.op.name == "pld.alloc_window_buffer": + allocs.append(stmt) + if isinstance(stmt, ir.SeqStmts): + for s in stmt.stmts: + walk(s) + + walk(func.body) + assert len(allocs) == 2 + type_a = allocs[0].value.type + type_b = allocs[1].value.type + assert isinstance(type_a, ir.PtrType) + assert isinstance(type_b, ir.PtrType) + assert ir.structural_equal(type_a, type_b) + + +def test_alloc_window_buffer_rejects_non_name_lhs(): + """Tuple-unpacking / subscript / attribute LHS is rejected — name must be a bare identifier.""" + with pytest.raises(Exception, match="must appear as the RHS of a simple assignment"): + + @pl.program + class P: # noqa: F841 + @pl.function(level=pl.Level.HOST, role=pl.Role.Orchestrator) + def host_orch(self): + a, b = pld.alloc_window_buffer(8), 0 # noqa: F841 + return a + + +def test_alloc_window_buffer_rejects_duplicate_names(): + with pytest.raises(Exception, match="already declared"): + + @pl.program + class P: # noqa: F841 + @pl.function(level=pl.Level.HOST, role=pl.Role.Orchestrator) + def host_orch(self): + buf = pld.alloc_window_buffer(8) + buf = pld.alloc_window_buffer(8) # noqa: F841 + return buf + + +def test_alloc_window_buffer_rejects_user_kwargs(): + """The user-facing alloc takes only a positional size — no kwargs are allowed.""" + with pytest.raises(Exception, match="does not accept user-supplied kwargs"): + + @pl.program + class P: # noqa: F841 + @pl.function(level=pl.Level.HOST, role=pl.Role.Orchestrator) + def host_orch(self): + buf = pld.alloc_window_buffer(8, dtype=pl.FP32) # noqa: F841 + return buf + + +def test_alloc_window_buffer_rejects_bare_call_outside_assignment(): + """Without an assignment LHS there is no globally-unique name to bind to.""" + with pytest.raises(Exception, match="must appear as the RHS of a simple assignment"): + + @pl.program + class P: # noqa: F841 + @pl.function(level=pl.Level.HOST, role=pl.Role.Orchestrator) + def host_orch(self): + pld.alloc_window_buffer(8) + return 0 + + +def test_alloc_window_buffer_rejects_list_for_size(): + """The redesigned signature takes a scalar byte size, not a shape list.""" + with pytest.raises(Exception, match="size must be a scalar"): + + @pl.program + class P: # noqa: F841 + @pl.function(level=pl.Level.HOST, role=pl.Role.Orchestrator) + def host_orch(self): + buf = pld.alloc_window_buffer([256]) # pyright: ignore[reportArgumentType] # noqa: F841 + return buf + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/ut/ir/parser/test_window_op.py b/tests/ut/ir/parser/test_window_op.py new file mode 100644 index 000000000..46060b1df --- /dev/null +++ b/tests/ut/ir/parser/test_window_op.py @@ -0,0 +1,225 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +# ruff: noqa: F722, F821 + +"""Parser tests for ``pld.window``. + +After the MemRef-mirror redesign: + +* ``pld.window(buf, shape, dtype=...)`` consumes a ``Ptr``-typed Var + (the LHS of ``pld.alloc_window_buffer``) plus an explicit ``shape`` list + and a ``dtype`` kwarg. +* Returns a :class:`ir.DistributedTensorType` carrying shape and dtype. + ``window_buffer`` back-reference is **None** at parse time — the + comm-collection pass populates it later. +""" + +from typing import cast + +import pypto.language as pl +import pypto.language.distributed as pld +import pytest +from pypto.pypto_core import ir + + +def _get_host_orch(program: ir.Program) -> ir.Function: + gvar = program.get_global_var("host_orch") + assert gvar is not None + return program.functions[gvar] + + +def _find_call(func: ir.Function, op_name: str) -> ir.Call: + found: list[ir.Call] = [] + + def walk(stmt: ir.Stmt) -> None: + if isinstance(stmt, ir.AssignStmt): + if isinstance(stmt.value, ir.Call) and stmt.value.op.name == op_name: + found.append(stmt.value) + if isinstance(stmt, ir.SeqStmts): + for s in stmt.stmts: + walk(s) + if isinstance(stmt, ir.ForStmt): + walk(stmt.body) + + walk(func.body) + assert found, f"no {op_name} call found in function body" + return found[0] + + +def _find_alloc_var(func: ir.Function) -> ir.Var: + """Return the LHS Var of the first pld.alloc_window_buffer assignment.""" + + def walk(stmt: ir.Stmt) -> ir.Var | None: + if ( + isinstance(stmt, ir.AssignStmt) + and isinstance(stmt.value, ir.Call) + and stmt.value.op.name == "pld.alloc_window_buffer" + ): + return stmt.var + if isinstance(stmt, ir.SeqStmts): + for s in stmt.stmts: + hit = walk(s) + if hit is not None: + return hit + return None + + hit = walk(func.body) + assert hit is not None + return hit + + +def test_window_returns_distributed_tensor_type_no_buffer_yet(): + """Parse-time: result type carries shape + dtype; window_buffer is None.""" + + @pl.program + class P: + @pl.function(level=pl.Level.HOST, role=pl.Role.Orchestrator) + def host_orch(self): + buf = pld.alloc_window_buffer(1024) + data = pld.window(buf, [256], dtype=pl.FP32) + return data + + func = _get_host_orch(P) + win_call = _find_call(func, "pld.window") + assert isinstance(win_call.type, ir.DistributedTensorType) + assert win_call.type.dtype == pl.FP32 + shape = win_call.type.shape + assert len(shape) == 1 + assert isinstance(shape[0], ir.ConstInt) + assert shape[0].value == 256 + # window_buffer is filled in by the comm-collection pass; not yet at parse + # time. Mirrors how TensorType.memref starts as None until InitMemRef runs. + assert win_call.type.window_buffer is None + + +def test_window_input_is_alloc_ptr_var(): + """The window op's first input is the plain Ptr Var bound by alloc.""" + + @pl.program + class P: + @pl.function(level=pl.Level.HOST, role=pl.Role.Orchestrator) + def host_orch(self): + buf = pld.alloc_window_buffer(64) + data = pld.window(buf, [16], dtype=pl.FP32) + return data + + func = _get_host_orch(P) + win_call = _find_call(func, "pld.window") + buf_var = _find_alloc_var(func) + assert len(win_call.args) == 2 + buf_arg = win_call.args[0] + # The window op receives the same Var instance bound by alloc — there's + # no second-binding indirection (this is the disconnect the redesign + # fixed). + assert buf_arg is buf_var + assert isinstance(buf_arg, ir.Var) + assert isinstance(buf_arg.type, ir.PtrType) + assert buf_arg.name_hint == "buf" + + +def test_window_propagates_multi_dim_shape(): + @pl.program + class P: + @pl.function(level=pl.Level.HOST, role=pl.Role.Orchestrator) + def host_orch(self): + buf = pld.alloc_window_buffer(2048) + data = pld.window(buf, [8, 64], dtype=pl.FP16) + return data + + func = _get_host_orch(P) + call = _find_call(func, "pld.window") + dt = call.type + assert isinstance(dt, ir.DistributedTensorType) + assert dt.dtype == pl.FP16 + assert all(isinstance(d, ir.ConstInt) for d in dt.shape) + assert [int(cast(ir.ConstInt, d).value) for d in dt.shape] == [8, 64] + + +def test_window_rejects_non_ptr_arg(): + """A non-Ptr-typed Var cannot stand in for a Ptr handle.""" + with pytest.raises(Exception, match="Ptr handle"): + + @pl.program + class P: # noqa: F841 + @pl.function(level=pl.Level.HOST, role=pl.Role.Orchestrator) + def host_orch(self, x: pl.Tensor[[64], pl.FP32]): + data = pld.window(x, [16], dtype=pl.FP32) # type: ignore[arg-type] # noqa: F841 + return 0 + + +def test_window_rejects_unknown_kwarg(): + with pytest.raises(Exception, match="does not accept kwarg"): + + @pl.program + class P: # noqa: F841 + @pl.function(level=pl.Level.HOST, role=pl.Role.Orchestrator) + def host_orch(self): + buf = pld.alloc_window_buffer(8) + data = pld.window(buf, [8], dtype=pl.FP32, target_memory=pl.Mem.DDR) # noqa: F841 + return 0 + + +def test_window_rejects_missing_shape_arg(): + with pytest.raises(Exception, match="2 positional"): + + @pl.program + class P: # noqa: F841 + @pl.function(level=pl.Level.HOST, role=pl.Role.Orchestrator) + def host_orch(self): + buf = pld.alloc_window_buffer(8) + data = pld.window(buf, dtype=pl.FP32) # noqa: F841 + return 0 + + +def test_window_rejects_missing_dtype_kwarg(): + with pytest.raises(Exception, match="dtype"): + + @pl.program + class P: # noqa: F841 + @pl.function(level=pl.Level.HOST, role=pl.Role.Orchestrator) + def host_orch(self): + buf = pld.alloc_window_buffer(8) + data = pld.window(buf, [8]) # noqa: F841 + return 0 + + +def test_window_can_be_called_inside_for_loop(): + @pl.program + class P: + @pl.function(level=pl.Level.HOST, role=pl.Role.Orchestrator) + def host_orch(self): + buf = pld.alloc_window_buffer(64) + for _ in pl.range(0, 4): + data = pld.window(buf, [16], dtype=pl.FP32) # noqa: F841 + return 0 + + func = _get_host_orch(P) + win_call = _find_call(func, "pld.window") + assert isinstance(win_call.type, ir.DistributedTensorType) + + +def test_alloc_names_globally_unique_across_functions(): + """A second function in the same @pl.program cannot reuse a buffer name.""" + with pytest.raises(Exception, match="already declared"): + + @pl.program + class P: # noqa: F841 + @pl.function(level=pl.Level.HOST, role=pl.Role.Orchestrator) + def host_orch(self): + buf = pld.alloc_window_buffer(8) + return buf + + @pl.function(level=pl.Level.HOST, role=pl.Role.Orchestrator) + def host_orch_2(self): + buf = pld.alloc_window_buffer(8) + return buf + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/ut/ir/test_distributed_ops.py b/tests/ut/ir/test_distributed_ops.py new file mode 100644 index 000000000..e2c329819 --- /dev/null +++ b/tests/ut/ir/test_distributed_ops.py @@ -0,0 +1,179 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- + +"""Tests for distributed ops registered via OpRegistry. + +After the MemRef-mirror redesign: + +* ``WindowBufferType`` is a singleton (no per-instance fields). +* ``WindowBuffer`` is a :class:`Var` subclass with no ``name``/``dtype`` + fields; it wraps a base ``Var(PtrType)`` plus a per-rank byte size and + host-staging flags. Constructed by the comm-collection pass. +* ``pld.alloc_window_buffer(size, name=...)`` is pure-allocation and returns + the singleton :class:`PtrType` (same as ``tile.alloc``). +* ``pld.window(buf, shape, dtype=...)`` consumes a ``Ptr`` and returns + :class:`DistributedTensorType`; ``window_buffer`` back-reference is + ``None`` at parse time and filled in by the comm-collection pass later. +""" + +import pytest +from pypto import DataType, ir + + +def _make_shape_tuple(values: list[int], span: ir.Span) -> ir.MakeTuple: + return ir.MakeTuple([ir.ConstInt(v, DataType.INT64, span) for v in values], span) + + +# --------------------------------------------------------------------------- +# WindowBufferType singleton +# --------------------------------------------------------------------------- + + +def test_window_buffer_type_is_singleton(): + """``WindowBufferType.get()`` returns a structurally-equal instance every call.""" + a = ir.WindowBufferType.get() + b = ir.WindowBufferType.get() + assert a is b + assert ir.structural_equal(a, ir.WindowBufferType()) + + +# --------------------------------------------------------------------------- +# pld.alloc_window_buffer op +# --------------------------------------------------------------------------- + + +def test_alloc_window_buffer_returns_ptr_type(): + """Pure-allocation: alloc returns the singleton PtrType (mirrors tile.alloc).""" + span = ir.Span.unknown() + size = ir.ConstInt(1024, DataType.INT64, span) + call = ir.create_op_call( + "pld.alloc_window_buffer", + [size], + {"name": "buf"}, + span, + ) + assert isinstance(call.type, ir.PtrType) + # Op preserves the parser-injected name kwarg for downstream consumers. + assert call.kwargs["name"] == "buf" + # No dtype kwarg on the op surface — alloc is dtype-agnostic. + assert "dtype" not in call.kwargs + + +def test_alloc_window_buffer_requires_non_empty_name(): + span = ir.Span.unknown() + size = ir.ConstInt(4, DataType.INT64, span) + with pytest.raises(Exception, match="non-empty 'name'"): + ir.create_op_call( + "pld.alloc_window_buffer", + [size], + {"name": ""}, + span, + ) + + +# --------------------------------------------------------------------------- +# WindowBuffer Var subclass +# --------------------------------------------------------------------------- + + +def test_window_buffer_is_var_subclass_wrapping_ptr(): + """WindowBuffer is a Var whose type is the singleton WindowBufferType, + wrapping a base Ptr Var (mirrors MemRef wrapping a base Ptr).""" + span = ir.Span.unknown() + base = ir.Var("buf", ir.PtrType(), span) + size = ir.ConstInt(64, DataType.INT64, span) + wb = ir.WindowBuffer(base, size, span=span) + assert isinstance(wb, ir.Var) + assert isinstance(wb.type, ir.WindowBufferType) + # name_hint flows from base.name_hint — no separate name field on + # WindowBuffer (mirrors MemRef). + assert wb.name_hint == "buf" + assert wb.base is base + assert isinstance(wb.size, ir.ConstInt) + assert wb.size.value == 64 + assert wb.load_from_host is False + assert wb.store_to_host is False + + +# --------------------------------------------------------------------------- +# pld.window op +# --------------------------------------------------------------------------- + + +def test_window_returns_distributed_tensor_with_no_buffer_at_parse_time(): + """``pld.window(ptr, shape, dtype=...)`` returns DistributedTensorType + with shape + dtype set; ``window_buffer`` is None until the + comm-collection pass populates it.""" + span = ir.Span.unknown() + base = ir.Var("buf", ir.PtrType(), span) + shape = _make_shape_tuple([64], span) + call = ir.create_op_call("pld.window", [base, shape], {"dtype": DataType.FP16}, span) + assert isinstance(call.type, ir.DistributedTensorType) + assert call.type.dtype == DataType.FP16 + assert len(call.type.shape) == 1 + assert isinstance(call.type.shape[0], ir.ConstInt) + assert call.type.shape[0].value == 64 + # window_buffer back-reference is filled in by the comm-collection pass, + # not by the op deducer — at parse time it is None. + assert call.type.window_buffer is None + + +def test_window_rejects_non_ptr_arg(): + """A Var with a non-PtrType type cannot be passed to ``pld.window``.""" + span = ir.Span.unknown() + tensor_type = ir.TensorType([ir.ConstInt(64, DataType.INT64, span)], DataType.FP32) + bad = ir.Var("x", tensor_type, span) + shape = _make_shape_tuple([64], span) + with pytest.raises(Exception, match="Ptr"): + ir.create_op_call("pld.window", [bad, shape], {"dtype": DataType.FP32}, span) + + +def test_window_rejects_non_make_tuple_shape(): + span = ir.Span.unknown() + base = ir.Var("buf", ir.PtrType(), span) + bad_shape = ir.ConstInt(8, DataType.INT64, span) + with pytest.raises(Exception, match="shape tuple"): + ir.create_op_call("pld.window", [base, bad_shape], {"dtype": DataType.FP32}, span) + + +# --------------------------------------------------------------------------- +# DistributedTensorType.window_buffer back-reference +# --------------------------------------------------------------------------- + + +def test_distributed_tensor_type_distinguishes_distinct_window_buffers(): + """Same shape + dtype but different window_buffer ⇒ structurally distinct.""" + span = ir.Span.unknown() + base_a = ir.Var("buf_a", ir.PtrType(), span) + base_b = ir.Var("buf_b", ir.PtrType(), span) + wb_a = ir.WindowBuffer(base_a, ir.ConstInt(32, DataType.INT64, span), span=span) + wb_b = ir.WindowBuffer(base_b, ir.ConstInt(32, DataType.INT64, span), span=span) + shape = [ir.ConstInt(32, DataType.INT64, span)] + dt_a = ir.DistributedTensorType(shape, DataType.FP32, wb_a) + dt_b = ir.DistributedTensorType(shape, DataType.FP32, wb_b) + assert dt_a.window_buffer is wb_a + assert dt_b.window_buffer is wb_b + assert not ir.structural_equal(dt_a, dt_b) + + +def test_distributed_tensor_type_with_and_without_window_buffer_differ(): + """Param-annotation form (no buffer) and bound form (with buffer) differ.""" + span = ir.Span.unknown() + base = ir.Var("buf", ir.PtrType(), span) + wb = ir.WindowBuffer(base, ir.ConstInt(32, DataType.INT64, span), span=span) + shape = [ir.ConstInt(32, DataType.INT64, span)] + dt_param = ir.DistributedTensorType(shape, DataType.FP32) + dt_bound = ir.DistributedTensorType(shape, DataType.FP32, wb) + assert dt_param.window_buffer is None + assert dt_bound.window_buffer is wb + assert not ir.structural_equal(dt_param, dt_bound) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/ut/runtime/test_chip_bootstrap_configs.py b/tests/ut/runtime/test_chip_bootstrap_configs.py index 0aba6bfb6..70209b785 100644 --- a/tests/ut/runtime/test_chip_bootstrap_configs.py +++ b/tests/ut/runtime/test_chip_bootstrap_configs.py @@ -7,12 +7,18 @@ # See LICENSE in the root of the software repository for the full text of the License. # ----------------------------------------------------------------------------------------------------------- # ruff: noqa: F722, F821 -"""Tests for the AOT comm-group manifest pipeline (v2 schema, N1.4). +"""Tests for the AOT comm-group manifest pipeline (v2 schema). Compile-time: Program → ``lift_comm_manifest`` → JSON-safe dict On disk: ``output_dir/orchestration/comm_manifest.json`` Runtime: dict → ``_build_chip_bootstrap_configs_from_manifest`` → list[ChipBootstrapConfig] +Post-redesign, ``WindowBuffer`` is dtype-agnostic (mirrors ``MemRef``); the +manifest carries ``name`` (from ``Var.name_hint``) plus ``nbytes`` (per-rank +allocation size in bytes). The runtime passes the placeholder string +``"opaque"`` for ``ChipBufferSpec.dtype`` since simpler does not consume +that field. + The CollectCommGroups pass that *infers* CommGroups from ``pld.alloc_window_buffer`` ops is added in N4. These tests pre-stage the CommGroups directly on a hand-built ``ir.Program`` so the manifest pipeline @@ -29,6 +35,7 @@ CommGroup, ConstInt, Program, + PtrType, Span, Var, WindowBuffer, @@ -61,6 +68,10 @@ def _const(value: int) -> ConstInt: return ConstInt(value, DataType.INT64, Span.unknown()) +def _ptr(name: str) -> Var: + return Var(name, PtrType(), Span.unknown()) + + def _trivial_program(groups: list[CommGroup] | None = None) -> Program: """Build a minimal Program (optionally with CommGroups) via @pl.program + immediate Program(...) reconstruction. Until the CollectCommGroups pass @@ -91,10 +102,11 @@ def test_lift_no_comm_group_returns_none(): def test_lift_const_size_emits_json_safe_manifest(): - """All-devices group with literal sizes lifts to a JSON-safe v2 manifest.""" + """All-devices group with literal byte sizes lifts to a JSON-safe v2 manifest.""" slots = [ - WindowBuffer("data", _const(256), DataType.FP32), - WindowBuffer("signal", _const(2), DataType.INT32), + # 256 FP32 elements = 1024 bytes; 2 INT32 elements = 8 bytes. + WindowBuffer(_ptr("data"), _const(1024)), + WindowBuffer(_ptr("signal"), _const(8)), ] p = _trivial_program([CommGroup([], slots)]) # empty devices = all @@ -110,17 +122,13 @@ def test_lift_const_size_emits_json_safe_manifest(): assert g["slots"] == [ { "name": "data", - "dtype": "float32", - "size": 256, - "bits_per_element": 32, + "nbytes": 1024, "load_from_host": False, "store_to_host": False, }, { "name": "signal", - "dtype": "int32", - "size": 2, - "bits_per_element": 32, + "nbytes": 8, "load_from_host": False, "store_to_host": False, }, @@ -129,7 +137,7 @@ def test_lift_const_size_emits_json_safe_manifest(): def test_lift_explicit_device_list(): """A group with explicit devices serializes the literal list.""" - p = _trivial_program([CommGroup([0, 1], [WindowBuffer("data", _const(64), DataType.FP32)])]) + p = _trivial_program([CommGroup([0, 1], [WindowBuffer(_ptr("data"), _const(64))])]) manifest = _lift(p) assert manifest is not None @@ -139,7 +147,7 @@ def test_lift_explicit_device_list(): def test_lift_dynamic_size_unsupported_raises(): """Symbolic ``size`` is rejected at compile time so authors get a clear error.""" sym = Var("N", _const(0).type, Span.unknown()) - p = _trivial_program([CommGroup([], [WindowBuffer("signal", sym, DataType.INT32)])]) + p = _trivial_program([CommGroup([], [WindowBuffer(_ptr("signal"), sym)])]) with pytest.raises(RuntimeError, match="dynamic WindowBuffer size is not supported"): _lift(p) @@ -148,8 +156,8 @@ def test_lift_two_comm_groups_raises(): """Multi-group programs are not yet supported by the runner.""" p = _trivial_program( [ - CommGroup([], [WindowBuffer("a", _const(64), DataType.FP32)]), - CommGroup([], [WindowBuffer("b", _const(64), DataType.FP32)]), + CommGroup([], [WindowBuffer(_ptr("a"), _const(64))]), + CommGroup([], [WindowBuffer(_ptr("b"), _const(64))]), ] ) with pytest.raises(RuntimeError, match="at most one CommGroup"): @@ -159,8 +167,8 @@ def test_lift_two_comm_groups_raises(): def test_lift_load_store_to_host_flags_propagate(): """``load_from_host`` / ``store_to_host`` bool flags pass through 1:1.""" slots = [ - WindowBuffer("lut", _const(64), DataType.FP32, load_from_host=True), - WindowBuffer("met", _const(64), DataType.INT32, store_to_host=True), + WindowBuffer(_ptr("lut"), _const(64), load_from_host=True), + WindowBuffer(_ptr("met"), _const(64), store_to_host=True), ] p = _trivial_program([CommGroup([], slots)]) @@ -189,18 +197,14 @@ def _make_manifest(devices: list[int], slots: list[dict]) -> dict: def _slot( name: str, - size: int, - dtype: str = "float32", - bits: int = 32, + nbytes: int, *, load_from_host: bool = False, store_to_host: bool = False, ) -> dict: return { "name": name, - "dtype": dtype, - "size": size, - "bits_per_element": bits, + "nbytes": nbytes, "load_from_host": load_from_host, "store_to_host": store_to_host, } @@ -209,21 +213,21 @@ def _slot( def test_build_full_coverage_via_empty_devices(): """Empty ``devices`` list ⇒ every device in the dc gets a comm config.""" pytest.importorskip("simpler.task_interface") - manifest = _make_manifest([], [_slot("data", 256), _slot("signal", 2, "int32")]) + manifest = _make_manifest([], [_slot("data", 1024), _slot("signal", 8)]) cfgs = _build(manifest, [0, 1]) assert cfgs is not None assert len(cfgs) == 2 assert all(c.comm is not None for c in cfgs) assert cfgs[0].comm.rank == 0 and cfgs[1].comm.rank == 1 assert all(c.comm.nranks == 2 for c in cfgs) - assert all(c.comm.window_size == 256 * 4 + 2 * 4 for c in cfgs) + assert all(c.comm.window_size == 1024 + 8 for c in cfgs) assert [b.name for b in cfgs[0].buffers] == ["data", "signal"] def test_build_explicit_subset_keeps_extras_commless(): """Devices not in the list get bare ``ChipBootstrapConfig()`` (comm=None).""" pytest.importorskip("simpler.task_interface") - manifest = _make_manifest([0, 1], [_slot("data", 64)]) + manifest = _make_manifest([0, 1], [_slot("data", 256)]) cfgs = _build(manifest, [0, 1, 2, 3]) assert cfgs is not None assert len(cfgs) == 4 @@ -234,7 +238,7 @@ def test_build_explicit_subset_keeps_extras_commless(): def test_build_subset_out_of_range_raises(): pytest.importorskip("simpler.task_interface") - manifest = _make_manifest([3, 4], [_slot("data", 1)]) + manifest = _make_manifest([3, 4], [_slot("data", 4)]) with pytest.raises(RuntimeError, match="outside DistributedConfig.device_ids range"): _build(manifest, [0, 1]) @@ -259,15 +263,14 @@ def test_build_rejects_two_groups(): _build(manifest, [0, 1]) -def test_build_subbyte_dtype_byte_calculation(): - """nbytes rounds up for sub-byte dtypes (e.g. INT4).""" +def test_build_carries_nbytes_directly(): + """Manifest now carries ``nbytes`` directly — no dtype-based byte arithmetic.""" pytest.importorskip("simpler.task_interface") - # INT4: 4 bits per element, 7 elements → ceil(7*4/8) = 4 bytes. - manifest = _make_manifest([], [_slot("x", 7, "int4", bits=4)]) + manifest = _make_manifest([], [_slot("x", 7)]) cfgs = _build(manifest, [0, 1]) assert cfgs is not None - assert cfgs[0].buffers[0].nbytes == 4 - assert cfgs[0].comm.window_size == 4 + assert cfgs[0].buffers[0].nbytes == 7 + assert cfgs[0].comm.window_size == 7 def test_build_load_store_host_flags(): @@ -302,7 +305,7 @@ def test_aot_roundtrip_writes_and_loads_manifest(tmp_path): emit_comm_manifest, ) - p = _trivial_program([CommGroup([], [WindowBuffer("data", _const(64), DataType.FP32)])]) + p = _trivial_program([CommGroup([], [WindowBuffer(_ptr("data"), _const(64))])]) expected = _lift(p) out_path = emit_comm_manifest(p, tmp_path)