diff --git a/docs/en/dev/ir/02-types.md b/docs/en/dev/ir/02-types.md index 306ba3ed1..ae6710c16 100644 --- a/docs/en/dev/ir/02-types.md +++ b/docs/en/dev/ir/02-types.md @@ -38,6 +38,41 @@ tensor_with_memref = ir.TensorType(shape, DataType.FP32, memref) `TensorType.memory_space` is always `ir.Mem.DDR`. `MemRef` carries address, size, and id; memory space is not stored on `MemRef` itself. +### DistributedTensorType + +`DistributedTensorType` is a precise-`ObjectKind` subclass of `TensorType` +used as the function-signature type for chip orchestrator / InCore parameters +that slice a CommGroup HCCL window buffer. It exists so cross-rank op +verifiers (introduced in later milestones) can reject plain `Tensor` +arguments — `As` does NOT match a `DistributedTensorType` +(precise `ObjectKind` semantics; see +[ir-kind-traits.md](../../../../.claude/rules/ir-kind-traits.md)). Use +`As` to dispatch on the distributed variant. + +The DSL surface is `pld.DistributedTensor[[shape], dtype]`: + +```python +import pypto.language.distributed as pld +import pypto.language as pl + +@pl.function(type=pl.FunctionType.InCore) +def kernel(self, data: pld.DistributedTensor[[256], pl.FP32]): ... +``` + +At the IR level: + +```python +t = ir.DistributedTensorType([64], DataType.FP32) +assert isinstance(t, ir.TensorType) # C++ inheritance preserved +# As(t) → null; As(t) → cast. +``` + +Allocation-side metadata (the buffer name, host-staging flags) lives on the +allocation op (`pld.alloc_window_buffer`, added in a later milestone) and on +the `ir.WindowBuffer` slot in `program.comm_groups`, not on the type itself. +Tile types do not have a distributed variant; cross-rank ops always operate +on `DistributedTensor`. + ### TensorType with TensorView Tensor with layout and stride information for optimized memory access. diff --git a/docs/zh-cn/dev/ir/02-types.md b/docs/zh-cn/dev/ir/02-types.md index e4fedb98b..649ba3acd 100644 --- a/docs/zh-cn/dev/ir/02-types.md +++ b/docs/zh-cn/dev/ir/02-types.md @@ -38,6 +38,39 @@ tensor_with_memref = ir.TensorType(shape, DataType.FP32, memref) `TensorType.memory_space` 始终是 `ir.Mem.DDR`。`MemRef` 只保存地址、大小和 id;内存空间不再存储在 `MemRef` 本身上。 +### DistributedTensorType + +`DistributedTensorType` 是 `TensorType` 的精确 `ObjectKind` 子类,作为 chip +orchestrator / InCore 形参的类型注解,用来切片 CommGroup HCCL window buffer。 +它的存在让跨 rank op 的 verifier(后续 milestone 引入)可以静态拒绝普通的 +`Tensor` 实参 —— `As` **不会**匹配 `DistributedTensorType` +(精确 `ObjectKind` 匹配语义,见 +[ir-kind-traits.md](../../../../.claude/rules/ir-kind-traits.md)),跨 rank op 用 +`As` 派生。 + +DSL 形式是 `pld.DistributedTensor[[shape], dtype]`: + +```python +import pypto.language.distributed as pld +import pypto.language as pl + +@pl.function(type=pl.FunctionType.InCore) +def kernel(self, data: pld.DistributedTensor[[256], pl.FP32]): ... +``` + +IR 层: + +```python +t = ir.DistributedTensorType([64], DataType.FP32) +assert isinstance(t, ir.TensorType) # C++ 继承关系保留 +# As(t) → null;As(t) → 转型成功 +``` + +分配侧的元数据(buffer 名字、host staging 标志)挂在 alloc op +(`pld.alloc_window_buffer`,后续 milestone 引入)和 `program.comm_groups` +中的 `ir.WindowBuffer` slot 上,**不**在类型本身。Tile 类型没有 distributed +变体;跨 rank op 始终作用在 `DistributedTensor` 上。 + ### 带 TensorView 的 TensorType 带有布局和步长信息的张量,用于优化内存访问。 diff --git a/include/pypto/ir/core.h b/include/pypto/ir/core.h index 3a446fa53..8d0b70d79 100644 --- a/include/pypto/ir/core.h +++ b/include/pypto/ir/core.h @@ -109,12 +109,15 @@ enum class ObjectKind { ScalarType, ShapedType, TensorType, + DistributedTensorType, TileType, TupleType, // Other IR node kinds Function, Program, + WindowBuffer, + CommGroup, // Op kinds Op, diff --git a/include/pypto/ir/kind_traits.h b/include/pypto/ir/kind_traits.h index ec8fe9f20..935212d4e 100644 --- a/include/pypto/ir/kind_traits.h +++ b/include/pypto/ir/kind_traits.h @@ -104,7 +104,11 @@ DEFINE_KIND_TRAIT(InlineStmt, ObjectKind::InlineStmt) DEFINE_KIND_TRAIT(UnknownType, ObjectKind::UnknownType) DEFINE_KIND_TRAIT(ScalarType, ObjectKind::ScalarType) // ShapedType is both a concrete type and a base class - handled separately below +// TensorType: precise-match (DistributedTensorType is a subclass with its own +// ObjectKind, so As(dt) returns nullptr by design — see +// .claude/rules/ir-kind-traits.md and the L3 distributed plan). DEFINE_KIND_TRAIT(TensorType, ObjectKind::TensorType) +DEFINE_KIND_TRAIT(DistributedTensorType, ObjectKind::DistributedTensorType) DEFINE_KIND_TRAIT(TileType, ObjectKind::TileType) DEFINE_KIND_TRAIT(TupleType, ObjectKind::TupleType) DEFINE_KIND_TRAIT(MemRefType, ObjectKind::MemRefType) @@ -113,6 +117,8 @@ DEFINE_KIND_TRAIT(PtrType, ObjectKind::PtrType) // Other IR node types DEFINE_KIND_TRAIT(Function, ObjectKind::Function) DEFINE_KIND_TRAIT(Program, ObjectKind::Program) +DEFINE_KIND_TRAIT(WindowBuffer, ObjectKind::WindowBuffer) +DEFINE_KIND_TRAIT(CommGroup, ObjectKind::CommGroup) // Op kinds DEFINE_KIND_TRAIT(Op, ObjectKind::Op) @@ -189,19 +195,22 @@ struct KindTrait { // Type base class - matches any type kind template <> struct KindTrait { - static constexpr ObjectKind kinds[] = {ObjectKind::UnknownType, ObjectKind::ScalarType, - ObjectKind::ShapedType, ObjectKind::TensorType, - ObjectKind::TileType, ObjectKind::TupleType}; + static constexpr ObjectKind kinds[] = {ObjectKind::UnknownType, + ObjectKind::ScalarType, + ObjectKind::ShapedType, + ObjectKind::TensorType, + ObjectKind::DistributedTensorType, + ObjectKind::TileType, + ObjectKind::TupleType}; static constexpr size_t count = sizeof(kinds) / sizeof(ObjectKind); }; // ShapedType can be used as both a concrete type and a base class -// It matches itself, TensorType, and TileType +// It matches itself, TensorType, DistributedTensorType, and TileType template <> struct KindTrait { - // For base class matching: includes ShapedType, TensorType, TileType static constexpr ObjectKind kinds[] = {ObjectKind::ShapedType, ObjectKind::TensorType, - ObjectKind::TileType}; + ObjectKind::DistributedTensorType, ObjectKind::TileType}; static constexpr size_t count = sizeof(kinds) / sizeof(ObjectKind); }; diff --git a/include/pypto/ir/program.h b/include/pypto/ir/program.h index 9e36b9c49..e2bafca68 100644 --- a/include/pypto/ir/program.h +++ b/include/pypto/ir/program.h @@ -19,6 +19,7 @@ #include #include +#include "pypto/core/dtype.h" #include "pypto/ir/core.h" #include "pypto/ir/expr.h" #include "pypto/ir/function.h" @@ -28,6 +29,91 @@ namespace pypto { namespace ir { +/** + * @brief Per-rank allocation spec for one named CommGroup HCCL window buffer. + * + * Maps 1:1 to ``simpler.task_interface.ChipBufferSpec`` at submit-time. Pure + * allocation metadata: does NOT describe how the buffer is used in code. + * Code-level use is expressed in the function signature via + * ``pld.DistributedTensor[[shape], dtype]``; the alloc op + * (``pld.alloc_window_buffer``, added in N2) materialises one of these slots + * into the program's :class:`CommGroup`. + * + * ``size_`` is the **element count** of one rank's slice (a single scalar; this + * struct is allocation-only and intentionally does not carry a multi-dim + * shape). ``size_`` may be a ``ConstInt`` (compile-time known) or a symbolic + * expression referring to the world size. + * + * ``load_from_host_`` / ``store_to_host_`` are simple boolean flags marking + * whether the slot participates in pre-fork H2D / post-task D2H staging. The + * specific host tensor that supplies / receives the staged data is recorded + * on the alloc op, not on this allocation spec. + */ +class WindowBuffer : public IRNode { + public: + std::string name_; ///< Buffer name (parser-extracted from alloc-op LHS) + ExprPtr size_; ///< Per-rank element count (ConstInt or symbolic Expr) + DataType dtype_; ///< Element data type + bool load_from_host_ = false; ///< Pre-fork H2D copy from a host staging tensor + bool store_to_host_ = false; ///< Post-task D2H copy back into a host staging tensor + + WindowBuffer(std::string name, ExprPtr size, DataType dtype, bool load_from_host = false, + bool store_to_host = false, Span span = Span::unknown()) + : IRNode(std::move(span)), + name_(std::move(name)), + size_(std::move(size)), + dtype_(dtype), + load_from_host_(load_from_host), + store_to_host_(store_to_host) {} + + [[nodiscard]] ObjectKind GetKind() const override { return ObjectKind::WindowBuffer; } + [[nodiscard]] std::string TypeName() const override { return "WindowBuffer"; } + + static constexpr auto GetFieldDescriptors() { + return std::tuple_cat( + IRNode::GetFieldDescriptors(), + std::make_tuple(reflection::UsualField(&WindowBuffer::name_, "name"), + reflection::UsualField(&WindowBuffer::size_, "size"), + reflection::UsualField(&WindowBuffer::dtype_, "dtype"), + reflection::UsualField(&WindowBuffer::load_from_host_, "load_from_host"), + reflection::UsualField(&WindowBuffer::store_to_host_, "store_to_host"))); + } +}; + +using WindowBufferPtr = std::shared_ptr; + +/** + * @brief A communication group inferred for a ``@pl.program``. + * + * The ``CollectCommGroups`` pass (N4) builds these from + * ``pld.alloc_window_buffer`` ops and their dispatch coverage. The runtime + * (``distributed_runner``) uses this to compose ``ChipBootstrapConfig`` before + * bringing the workers up. + * + * ``devices_`` is the ascending-sorted set of physical device ids covered by + * the group. **An empty vector means "all devices"** (every entry of + * ``DistributedConfig.device_ids``, resolved by the driver at submit-time). + */ +class CommGroup : public IRNode { + public: + std::vector devices_; ///< Covered device ids (ascending); empty = all devices + std::vector slots_; ///< Allocation slots in this group (alloc-order) + + CommGroup(std::vector devices, std::vector slots, Span span = Span::unknown()) + : IRNode(std::move(span)), devices_(std::move(devices)), slots_(std::move(slots)) {} + + [[nodiscard]] ObjectKind GetKind() const override { return ObjectKind::CommGroup; } + [[nodiscard]] std::string TypeName() const override { return "CommGroup"; } + + static constexpr auto GetFieldDescriptors() { + return std::tuple_cat(IRNode::GetFieldDescriptors(), + std::make_tuple(reflection::UsualField(&CommGroup::devices_, "devices"), + reflection::UsualField(&CommGroup::slots_, "slots"))); + } +}; + +using CommGroupPtr = std::shared_ptr; + /** * @brief Program definition * @@ -52,6 +138,16 @@ class Program : public IRNode { Program(std::map functions, std::string name, Span span) : IRNode(std::move(span)), functions_(std::move(functions)), name_(std::move(name)) {} + /** + * @brief Map-based ctor with CommGroup metadata (used by the deserializer). + */ + Program(std::map functions, + std::vector comm_groups, std::string name, Span span) + : IRNode(std::move(span)), + functions_(std::move(functions)), + name_(std::move(name)), + comm_groups_(std::move(comm_groups)) {} + /** * @brief Create a program from a list of functions * @@ -64,6 +160,17 @@ class Program : public IRNode { */ Program(const std::vector& functions, std::string name, Span span); + /** + * @brief Create a program from a list of functions and CommGroup metadata. + * + * @param functions List of functions + * @param comm_groups List of CommGroups declared on the program + * @param name Program name (optional) + * @param span Source location + */ + Program(const std::vector& functions, std::vector comm_groups, std::string name, + Span span); + [[nodiscard]] ObjectKind GetKind() const override { return ObjectKind::Program; } [[nodiscard]] std::string TypeName() const override { return "Program"; } @@ -84,19 +191,23 @@ class Program : public IRNode { [[nodiscard]] GlobalVarPtr GetGlobalVar(const std::string& name) const; /** - * @brief Get field descriptors for reflection-based visitation + * @brief Get field descriptors for reflection-based visitation. * - * @return Tuple of field descriptors (name as IGNORE field, functions as USUAL field) + * ``comm_groups_`` participates in structural equality / hashing via + * ``UsualField``: two programs declaring the same CommGroups (same names, + * sizes, dtypes, host-staging flags) are structurally equivalent. */ static constexpr auto GetFieldDescriptors() { return std::tuple_cat(IRNode::GetFieldDescriptors(), std::make_tuple(reflection::IgnoreField(&Program::name_, "name"), - reflection::UsualField(&Program::functions_, "functions"))); + reflection::UsualField(&Program::functions_, "functions"), + reflection::UsualField(&Program::comm_groups_, "comm_groups"))); } public: std::string name_; // Program name std::map functions_; // Map of GlobalVars to Functions + std::vector comm_groups_; // CommGroups (host-side metadata) }; using ProgramPtr = std::shared_ptr; diff --git a/include/pypto/ir/serialization/type_registry.h b/include/pypto/ir/serialization/type_registry.h index 475e2a778..d7b9e03c5 100644 --- a/include/pypto/ir/serialization/type_registry.h +++ b/include/pypto/ir/serialization/type_registry.h @@ -45,6 +45,14 @@ class DeserializerContext { virtual IRNodePtr DeserializeNode(const msgpack::object& obj, msgpack::zone& zone) = 0; virtual msgpack::object GetFieldObj(const msgpack::object& fields_obj, const std::string& field_name) = 0; + /** + * @brief Whether ``field_name`` is present in ``fields_obj``. + * + * Lets deserializers handle optional / forward-compatible fields without + * tripping ``GetFieldObj``'s missing-field exception. + */ + virtual bool HasField(const msgpack::object& fields_obj, const std::string& field_name) = 0; + template T GetField(const msgpack::object& fields_obj, const std::string& field_name) { msgpack::object field_obj = GetFieldObj(fields_obj, field_name); diff --git a/include/pypto/ir/type.h b/include/pypto/ir/type.h index bb1171bf8..ffdeb71ec 100644 --- a/include/pypto/ir/type.h +++ b/include/pypto/ir/type.h @@ -486,6 +486,50 @@ class TensorType : public ShapedType { using TensorTypePtr = std::shared_ptr; +/** + * @brief Distributed tensor type — a per-rank slice of a CommGroup HCCL window buffer. + * + * Subclass of :class:`TensorType` distinguished only by ``ObjectKind`` so that + * verifiers can reject plain ``TensorType`` arguments to cross-rank ops + * (``pld.tile.remote_load`` / ``pld.system.notify`` / ``pld.system.wait``). + * + * Note ``As`` does NOT match ``DistributedTensorType`` (precise + * ObjectKind match). This is intentional — the cross-rank ops use + * ``As`` to enforce that only window-bound tensors flow + * through them. + */ +class DistributedTensorType : public TensorType { + public: + DistributedTensorType(std::vector shape, DataType dtype) : TensorType(std::move(shape), dtype) {} + + DistributedTensorType(std::vector shape, DataType dtype, MemRefPtr memref) + : TensorType(std::move(shape), dtype, std::move(memref)) {} + + DistributedTensorType(std::vector shape, DataType dtype, std::optional memref) + : TensorType(std::move(shape), dtype, std::move(memref)) {} + + DistributedTensorType(std::vector shape, DataType dtype, std::optional memref, + std::optional tensor_view) + : TensorType(std::move(shape), dtype, std::move(memref), std::move(tensor_view)) {} + + DistributedTensorType(const std::vector& shape, DataType dtype) + : TensorType(shape, dtype, std::nullopt) {} + + DistributedTensorType(const std::vector& shape, DataType dtype, std::optional memref) + : TensorType(shape, dtype, std::move(memref)) {} + + DistributedTensorType(const std::vector& shape, DataType dtype, std::optional memref, + std::optional tensor_view) + : TensorType(shape, dtype, std::move(memref), std::move(tensor_view)) {} + + [[nodiscard]] ObjectKind GetKind() const override { return ObjectKind::DistributedTensorType; } + [[nodiscard]] std::string TypeName() const override { return "DistributedTensorType"; } + + static constexpr auto GetFieldDescriptors() { return TensorType::GetFieldDescriptors(); } +}; + +using DistributedTensorTypePtr = std::shared_ptr; + /** * @brief Tile type representation * diff --git a/python/bindings/modules/ir.cpp b/python/bindings/modules/ir.cpp index b00531efa..41b69615f 100644 --- a/python/bindings/modules/ir.cpp +++ b/python/bindings/modules/ir.cpp @@ -429,6 +429,33 @@ void BindIR(nb::module_& m) { "Create a tensor type with constant shape, optional memory reference and tensor view"); BindFields(tensor_type_class); + // DistributedTensorType - subclass of TensorType used as the param type for + // cross-rank ops (pld.tile.remote_load / pld.system.notify / pld.system.wait). + // Distinguished from TensorType only by ObjectKind so verifiers can reject + // plain Tensors. Otherwise mirrors TensorType's ctor overloads — memref / + // tensor_view variants are equally valid on the distributed flavour. + auto dist_tensor_type_class = nb::class_( + ir, "DistributedTensorType", "Tensor backed by a per-rank slice of a CommGroup HCCL window buffer"); + dist_tensor_type_class.def(nb::init&, DataType>(), nb::arg("shape"), + nb::arg("dtype"), "Create a distributed tensor type"); + dist_tensor_type_class.def(nb::init&, DataType>(), nb::arg("shape"), + nb::arg("dtype"), "Create a distributed tensor type with constant shape"); + dist_tensor_type_class.def(nb::init&, DataType, std::optional>(), + nb::arg("shape"), nb::arg("dtype"), nb::arg("memref") = nb::none(), + "Create a distributed tensor type with optional memref"); + dist_tensor_type_class.def(nb::init&, DataType, std::optional>(), + nb::arg("shape"), nb::arg("dtype"), nb::arg("memref") = nb::none(), + "Create a distributed tensor type with constant shape and optional memref"); + dist_tensor_type_class.def( + nb::init&, DataType, std::optional, std::optional>(), + nb::arg("shape"), nb::arg("dtype"), nb::arg("memref") = nb::none(), nb::arg("tensor_view") = nb::none(), + "Create a distributed tensor type with optional memref and tensor_view"); + dist_tensor_type_class.def( + nb::init&, DataType, std::optional, std::optional>(), + nb::arg("shape"), nb::arg("dtype"), nb::arg("memref") = nb::none(), nb::arg("tensor_view") = nb::none(), + "Create a distributed tensor type with constant shape, optional memref and tensor_view"); + BindFields(dist_tensor_type_class); + // TileType - const shared_ptr auto tile_type_class = nb::class_(ir, "TileType", "Tile type representation (multi-dimensional tensor)"); @@ -1413,6 +1440,11 @@ void BindIR(nb::module_& m) { nb::arg("functions"), nb::arg("name"), nb::arg("span"), "Create a program from a list of functions. " "GlobalVar references are created automatically from function names."); + program_class.def( + nb::init&, std::vector, const std::string&, const Span&>(), + nb::arg("functions"), nb::arg("comm_groups"), nb::arg("name"), nb::arg("span"), + "Create a program from a list of functions and CommGroup metadata. " + "GlobalVar references are created automatically from function names."); program_class.def("get_function", &Program::GetFunction, nb::arg("name"), "Get a function by name, returns None if not found"); program_class.def("get_global_var", &Program::GetGlobalVar, nb::arg("name"), @@ -1436,6 +1468,33 @@ void BindIR(nb::module_& m) { "Map of GlobalVar references to their corresponding functions, sorted by GlobalVar name"); program_class.def_ro("name", &Program::name_, "Program name"); program_class.def_ro("span", &Program::span_, "Source location"); + program_class.def_ro("comm_groups", &Program::comm_groups_, + "List of CommGroups declared on the program. CommGroups participate " + "in structural equality / hashing through reflection."); + + // CommGroup / WindowBuffer — IRNode-typed host-side metadata attached to a Program. + auto window_buffer_class = nb::class_( + ir, "WindowBuffer", + "Per-rank allocation spec for a named CommGroup HCCL window buffer. " + "Maps 1:1 to simpler.task_interface.ChipBufferSpec at submit-time. " + "size_ is the per-rank element count; load_from_host_ / store_to_host_ " + "are bool flags marking host-staging participation (the actual host " + "tensor binding is recorded on the alloc op, not here)."); + window_buffer_class.def(nb::init(), nb::arg("name"), + nb::arg("size"), nb::arg("dtype"), nb::arg("load_from_host") = false, + nb::arg("store_to_host") = false, nb::arg("span") = Span::unknown(), + "Create a WindowBuffer."); + BindFields(window_buffer_class); + + auto comm_group_class = + nb::class_(ir, "CommGroup", + "A communication group inferred from pld.alloc_window_buffer ops: a " + "device-id list (empty = all devices) and a list of WindowBuffer slots " + "shared by every rank in the group."); + comm_group_class.def(nb::init, std::vector, Span>(), + nb::arg("devices"), nb::arg("slots"), nb::arg("span") = Span::unknown(), + "Create a CommGroup. ``devices`` empty = all devices."); + BindFields(comm_group_class); // Python-style printer function - unified API for IRNode ir.def( diff --git a/python/pypto/ir/comm_manifest.py b/python/pypto/ir/comm_manifest.py new file mode 100644 index 000000000..a5652fb85 --- /dev/null +++ b/python/pypto/ir/comm_manifest.py @@ -0,0 +1,162 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- + +"""Compile-time CommGroup manifest support (v2 schema). + +Lifts ``program.comm_groups`` into a JSON-safe dict suitable for serialisation +under ``output_dir/orchestration/``. The runtime +re-enters the output directory and reads the file via +``pypto.runtime.distributed_runner._build_chip_bootstrap_configs_from_manifest`` +without needing the live ``Program`` object. + +Two halves of the AOT pipeline: + + compile-time: Program → ``lift_comm_manifest`` → dict (JSON-safe) + on disk: output_dir/orchestration/ + runtime: dict → _build_chip_bootstrap_configs_from_manifest(...) + → list[ChipBootstrapConfig] + +v2 schema (CommGroups are pass-inferred, not user-declared): + +* ``devices``: list of physical device ids covered by the group; **empty list + = all devices** (resolved by the driver against + ``DistributedConfig.device_ids``). +* ``slots``: list of allocation specs, each a 1:1 image of + ``simpler.task_interface.ChipBufferSpec``. ``load_from_host`` / + ``store_to_host`` are bool flags here — the specific host tensor binding + lives on the alloc op, not on this allocation spec. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from pypto.pypto_core import DataType +from pypto.pypto_core.ir import ConstInt, Program + +# Manifest filename under output_dir/orchestration/. Keep in sync with the +# runtime loader in pypto.runtime.distributed_runner. +COMM_MANIFEST_FILENAME = "comm_manifest.json" +COMM_MANIFEST_VERSION = 2 + + +# Maps PyPTO DataType to the dtype-string convention simpler's ChipBufferSpec +# expects (numpy/torch style, e.g. "float32"). DataType.to_string() returns +# "fp32" / "bfloat16" which simpler does not understand; this table is the +# explicit single source of truth. +_SIMPLER_DTYPE_STR: dict[DataType, str] = { + DataType.FP32: "float32", + DataType.FP16: "float16", + DataType.BF16: "bfloat16", + DataType.INT8: "int8", + DataType.INT16: "int16", + DataType.INT32: "int32", + DataType.INT64: "int64", + DataType.UINT8: "uint8", +} + + +def _simpler_dtype_str(dtype: DataType) -> str: + try: + return _SIMPLER_DTYPE_STR[dtype] + except KeyError as exc: + raise RuntimeError( + f"Unsupported WindowBuffer dtype {dtype!r} for ChipBufferSpec; " + f"add an entry to _SIMPLER_DTYPE_STR. " + f"Known dtypes: {sorted(d.to_string() for d in _SIMPLER_DTYPE_STR)}" + ) from exc + + +def lift_comm_manifest(program: Program) -> dict[str, Any] | None: + """Lift ``program.comm_groups`` into a JSON-safe dict for AOT serialization. + + Returns ``None`` when the program declares no CommGroup — callers should + skip emitting / loading the manifest entirely (preserving the comm-less + path used by multi_chip_dispatch and parallel_reduce). + + Current v2 supports a single CommGroup with literal ``int`` per-slot + ``size``. Symbolic sizes raise ``RuntimeError`` at compile time — that's + a better failure point than runtime, since the program author can fix it + without redeploying. + """ + comm_groups = list(program.comm_groups) + if not comm_groups: + return None + if len(comm_groups) > 1: + raise RuntimeError( + f"distributed_runner currently supports at most one CommGroup per program, got {len(comm_groups)}" + ) + + group = comm_groups[0] + + slots_data: list[dict[str, Any]] = [] + for slot in group.slots: + size_const = slot.size if isinstance(slot.size, ConstInt) else None + if size_const is None: + raise RuntimeError( + f"dynamic WindowBuffer size is not supported yet " + f"(slot {slot.name!r}); declare size as a literal int" + ) + slots_data.append( + { + "name": slot.name, + "dtype": _simpler_dtype_str(slot.dtype), + "size": int(size_const.value), + "bits_per_element": slot.dtype.get_bit(), + "load_from_host": bool(slot.load_from_host), + "store_to_host": bool(slot.store_to_host), + } + ) + + return { + "version": COMM_MANIFEST_VERSION, + "comm_groups": [ + { + # Empty list = all devices (resolved by the driver against + # DistributedConfig.device_ids). + "devices": [int(d) for d in group.devices], + "slots": slots_data, + } + ], + } + + +def emit_comm_manifest(program: Program, output_dir: Path | str) -> Path | None: + """Lift ``program.comm_groups`` and write the JSON manifest to disk. + + Writes ``output_dir/orchestration/``. Returns the + written path, or ``None`` when the program declares no CommGroup (no file + is created — comm-less programs are unaffected). + + This is the compile-time emission point. The runner re-enters + ``output_dir`` and reads the file via + ``pypto.runtime.distributed_runner._build_chip_bootstrap_configs_from_manifest``, + so a ``CompiledProgram`` instance can be reconstructed from the output + directory alone — no live ``Program`` object required. + """ + manifest = lift_comm_manifest(program) + if manifest is None: + return None + orch_dir = Path(output_dir) / "orchestration" + orch_dir.mkdir(parents=True, exist_ok=True) + manifest_path = orch_dir / COMM_MANIFEST_FILENAME + with manifest_path.open("w", encoding="utf-8") as fh: + json.dump(manifest, fh, indent=2, sort_keys=True) + fh.write("\n") + return manifest_path + + +__all__ = [ + "COMM_MANIFEST_FILENAME", + "COMM_MANIFEST_VERSION", + "lift_comm_manifest", + "emit_comm_manifest", +] diff --git a/python/pypto/ir/compile.py b/python/pypto/ir/compile.py index 53d699708..e0b3c6736 100644 --- a/python/pypto/ir/compile.py +++ b/python/pypto/ir/compile.py @@ -174,6 +174,13 @@ def _stage(name: str) -> AbstractContextManager[Any]: _write_files(exc.files, output_dir) raise _write_files(files, output_dir) + + # Emit the comm-group manifest alongside generated host_orch.py so the + # runner can re-enter the output directory without holding the live + # Program. No-op when the program declares no CommGroup. + from .comm_manifest import emit_comm_manifest # noqa: PLC0415 + + emit_comm_manifest(transformed_program, output_dir) finally: if owns_profiler and prof is not None: prof.__exit__(None, None, None) diff --git a/python/pypto/language/distributed/__init__.py b/python/pypto/language/distributed/__init__.py new file mode 100644 index 000000000..173f1f400 --- /dev/null +++ b/python/pypto/language/distributed/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- + +"""PyPTO distributed DSL — namespace ``pypto.language.distributed`` (alias ``pld``). + +Provides cross-rank concepts that complement the single-device DSL in +``pypto.language``. Communication-domain metadata (``ir.CommGroup`` / +``ir.WindowBuffer``) is **inferred** by the ``CollectCommGroups`` pass from +``pld.alloc_window_buffer`` calls in the host orchestrator and the +``device=`` kwarg on dispatch sites; users do not declare ``CommGroup`` +manually. + +This module is currently a namespace placeholder; concrete entry points +(``pld.DistributedTensor``, ``pld.alloc_window_buffer``, ``pld.world_size``, +``pld.tile.*``, ``pld.system.*``) are added in subsequent milestones (N1.2+). +""" + +from .distributed_tensor import DistributedTensor + +__all__ = ["DistributedTensor"] diff --git a/python/pypto/language/distributed/distributed_tensor.py b/python/pypto/language/distributed/distributed_tensor.py new file mode 100644 index 000000000..4493c2820 --- /dev/null +++ b/python/pypto/language/distributed/distributed_tensor.py @@ -0,0 +1,60 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- + +"""``pld.DistributedTensor`` — DSL annotation for window-bound tensors. + +Function-signature type annotation for chip_orch / InCore parameters that +slice a CommGroup HCCL window buffer. Behaves identically to :class:`pl.Tensor` +at the DSL surface (same ``[shape, dtype, layout|memref|view]`` subscript +forms); the only difference is the IR-level ``ObjectKind`` +(:class:`ir.DistributedTensorType`), which lets cross-rank op verifiers +(``pld.tile.remote_load`` / ``pld.system.notify`` / ``pld.system.wait``, added +in later milestones) reject plain ``Tensor`` arguments. + +Use:: + + @pl.function(type=pl.FunctionType.InCore) + def kernel( + self, + data: pld.DistributedTensor[[256], pl.FP32], + bias: pld.DistributedTensor[[256], pl.FP32, pl.NZ], + ): ... +""" + +from collections.abc import Sequence +from typing import Any + +from pypto.language.typing.tensor import Tensor, TensorMeta + + +class DistributedTensorMeta(TensorMeta): + """Metaclass enabling ``pld.DistributedTensor[...]`` syntax. + + Inherits :class:`TensorMeta`'s subscript dispatch unchanged — a + ``DistributedTensor`` is a plain Tensor with a different IR ObjectKind. + """ + + +class DistributedTensor(Tensor, metaclass=DistributedTensorMeta): + """Tensor backed by a per-rank slice of a CommGroup HCCL window buffer. + + Same DSL surface as :class:`pl.Tensor` — supports memref / layout / + tensor_view in the third (and fourth) subscript slot. The IR-level type + (:class:`ir.DistributedTensorType`) is a precise-``ObjectKind`` subclass of + ``TensorType``: cross-rank op verifiers dispatch on it to reject plain + ``Tensor`` arguments. ``pl.load`` / ``pl.store`` etc. operate + transparently on a DistributedTensor (local rank's slice). + """ + + @classmethod + def __class_getitem__(cls, item: tuple[Sequence[Any], Any]) -> "DistributedTensor": + return type(cls).__getitem__(cls, item) + + +__all__ = ["DistributedTensor"] diff --git a/python/pypto/language/parser/type_resolver.py b/python/pypto/language/parser/type_resolver.py index 081c811c3..05a7de046 100644 --- a/python/pypto/language/parser/type_resolver.py +++ b/python/pypto/language/parser/type_resolver.py @@ -260,9 +260,12 @@ def _get_direction_wrapper(self, node: ast.expr) -> str | None: return None def _get_type_name(self, node: ast.expr) -> str | None: - """Extract the type name from an AST node referencing Tensor, Tile, or Scalar. + """Extract the type name from an AST node referencing Tensor, Tile, Scalar, + Tuple, or DistributedTensor. - Handles both ``pl.Tensor`` (ast.Attribute) and bare ``Tensor`` (ast.Name). + Handles both ``pl.Tensor`` (ast.Attribute) and bare ``Tensor`` (ast.Name); + ``pld.DistributedTensor`` is recognized either as ``pld.DistributedTensor`` + or bare ``DistributedTensor``. Args: node: AST expression to check @@ -270,9 +273,10 @@ def _get_type_name(self, node: ast.expr) -> str | None: Returns: Type name string if recognized, None otherwise """ - if isinstance(node, ast.Attribute) and node.attr in ("Tensor", "Tile", "Scalar", "Tuple"): + valid = ("Tensor", "Tile", "Scalar", "Tuple", "DistributedTensor") + if isinstance(node, ast.Attribute) and node.attr in valid: return node.attr - if isinstance(node, ast.Name) and node.id in ("Tensor", "Tile", "Scalar", "Tuple"): + if isinstance(node, ast.Name) and node.id in valid: return node.id return None @@ -356,18 +360,22 @@ def _resolve_subscript_type(self, subscript_node: ast.Subscript) -> ir.Type: # return self._resolve_tuple_subscript_type(subscript_node) # Tensor: [shape, dtype], [shape, dtype, layout_or_memref], [shape, dtype, layout, memref]. + # DistributedTensor: same forms as Tensor — only the IR ObjectKind differs. # Tile: [shape, dtype] plus any ordering of TileView/MemRef/MemorySpace, # with the constraint that MemRef requires explicit MemorySpace. - valid_counts = (2, 3, 4) if type_name == "Tensor" else (2, 3, 4, 5) + is_distributed = type_name == "DistributedTensor" + is_tensor_like = type_name == "Tensor" or is_distributed + tensor_ctor = ir.DistributedTensorType if is_distributed else ir.TensorType + valid_counts = (2, 3, 4) if is_tensor_like else (2, 3, 4, 5) if not isinstance(slice_value, ast.Tuple) or len(slice_value.elts) not in valid_counts: - if type_name == "Tensor": + if is_tensor_like: message = ( f"{type_name} subscript requires [shape, dtype], [shape, dtype, layout_or_memref], " f"or [shape, dtype, layout, memref], got: {ast.unparse(slice_value)}" ) hint = ( - "Use pl.Tensor[[shape], dtype], pl.Tensor[[shape], dtype, layout], " - "or pl.Tensor[[shape], dtype, pl.MemRef(...)] format" + f"Use {type_name}[[shape], dtype], {type_name}[[shape], dtype, layout], " + f"or {type_name}[[shape], dtype, pl.MemRef(...)] format" ) else: message = ( @@ -397,25 +405,25 @@ def _resolve_subscript_type(self, subscript_node: ast.Subscript) -> ir.Type: # if n_elts == 2: if type_name == "Tile": return ir.TileType(shape, dtype) - return ir.TensorType(shape, dtype, None, None) + return tensor_ctor(shape, dtype, None, None) if type_name == "Tile": return self._resolve_tile_annotation_args(shape, dtype, list(slice_value.elts[2:])) - # 3 args: [shape, dtype, layout_or_memref] for Tensor + # 3 args: [shape, dtype, layout_or_memref_or_tensorview] for Tensor / DistributedTensor if n_elts == 3: third = slice_value.elts[2] if self._is_memref_node(third): memref = self.resolve_memref(third) - return ir.TensorType(shape, dtype, memref, None) + return tensor_ctor(shape, dtype, memref, None) if self._is_tensorview_node(third): tensor_view = self._resolve_tensorview(third) - return ir.TensorType(shape, dtype, None, tensor_view) + return tensor_ctor(shape, dtype, None, tensor_view) layout = self.resolve_layout(third) tensor_view = ir.TensorView([], layout) - return ir.TensorType(shape, dtype, None, tensor_view) + return tensor_ctor(shape, dtype, None, tensor_view) - # Tensor 4 args: [shape, dtype, layout_or_tensorview, memref] + # Tensor / DistributedTensor 4 args: [shape, dtype, layout_or_tensorview, memref] third = slice_value.elts[2] if self._is_tensorview_node(third): tensor_view = self._resolve_tensorview(third) @@ -425,12 +433,12 @@ def _resolve_subscript_type(self, subscript_node: ast.Subscript) -> ir.Type: # memref_node = slice_value.elts[3] if not self._is_memref_node(memref_node): raise ParserTypeError( - "Tensor 4th argument must be pl.MemRef(...)", + f"{type_name} 4th argument must be pl.MemRef(...)", span=self._get_span(memref_node), - hint="Use pl.Tensor[[shape], dtype, layout, pl.MemRef(...)]", + hint=f"Use {type_name}[[shape], dtype, layout, pl.MemRef(...)]", ) memref = self.resolve_memref(memref_node) - return ir.TensorType(shape, dtype, memref, tensor_view) + return tensor_ctor(shape, dtype, memref, tensor_view) def _resolve_tile_annotation_args( self, shape: "list[int] | list[ir.Expr]", dtype: DataType, extra_nodes: list[ast.expr] diff --git a/python/pypto/pypto_core/ir.pyi b/python/pypto/pypto_core/ir.pyi index aca99c992..ddfb70eb1 100644 --- a/python/pypto/pypto_core/ir.pyi +++ b/python/pypto/pypto_core/ir.pyi @@ -532,6 +532,55 @@ class TensorType(ShapedType): tensor_view: Optional tensor view information """ +class DistributedTensorType(TensorType): + """Tensor backed by a per-rank slice of a CommGroup HCCL window buffer. + + Subclass of :class:`TensorType` distinguished only by ``ObjectKind`` so + that verifiers for cross-rank ops can reject plain :class:`TensorType` + arguments. Note ``As`` does NOT match + ``DistributedTensorType`` (precise ObjectKind match) — use + ``As`` to dispatch on the distributed variant. + + Mirrors :class:`TensorType`'s constructor surface — memref / tensor_view + are equally valid on the distributed flavour. + """ + + @overload + def __init__(self, shape: Sequence[Expr], dtype: DataType) -> None: + """Create a distributed tensor type.""" + + @overload + def __init__(self, shape: Sequence[Expr], dtype: DataType, memref: MemRef | None) -> None: + """Create a distributed tensor type with optional memref.""" + + @overload + def __init__( + self, + shape: Sequence[Expr], + dtype: DataType, + memref: MemRef | None, + tensor_view: TensorView | None, + ) -> None: + """Create a distributed tensor type with optional memref and tensor_view.""" + + @overload + def __init__(self, shape: Sequence[int], dtype: DataType) -> None: + """Create a distributed tensor type with constant shape.""" + + @overload + def __init__(self, shape: Sequence[int], dtype: DataType, memref: MemRef | None) -> None: + """Create a distributed tensor type with constant shape and optional memref.""" + + @overload + def __init__( + self, + shape: Sequence[int], + dtype: DataType, + memref: MemRef | None, + tensor_view: TensorView | None, + ) -> None: + """Create a distributed tensor type with constant shape, optional memref and tensor_view.""" + class TileView: """Tile view: read-only representation of valid shape, stride, start offset, layouts, fractal, and pad. Construct with all values; fields cannot be mutated @@ -2265,6 +2314,11 @@ class Program(IRNode): functions: Final[dict[GlobalVar, Function]] """Map of GlobalVar references to their corresponding functions, sorted by GlobalVar name.""" + comm_groups: Final[list[CommGroup]] + """CommGroups declared on the program. Participates in structural + equality / hashing through reflection.""" + + @overload def __init__( self, functions: list[Function], @@ -2281,6 +2335,16 @@ class Program(IRNode): span: Source location """ + @overload + def __init__( + self, + functions: list[Function], + comm_groups: list[CommGroup], + name: str, + span: Span, + ) -> None: + """Create a program from a list of functions and CommGroup metadata.""" + def get_function(self, name: str) -> Function | None: """Get a function by name. @@ -2328,6 +2392,67 @@ class Program(IRNode): Program with type information """ +class WindowBuffer(IRNode): + """Per-rank allocation spec for a named CommGroup HCCL window buffer. + + Maps 1:1 to ``simpler.task_interface.ChipBufferSpec`` at submit-time. + Participates in structural equality / hashing via reflection. + """ + + name: Final[str] + """Buffer name (parser-extracted from the alloc-op LHS, globally unique).""" + + size: Final[Expr] + """Per-rank element count (ConstInt or symbolic Expr). Allocation-only — + does not carry a multi-dim shape.""" + + dtype: Final[DataType] + """Element data type.""" + + load_from_host: Final[bool] + """``True`` if this slot participates in pre-fork H2D staging. The specific + host tensor that supplies the staged data is recorded on the alloc op, not + on this allocation spec.""" + + store_to_host: Final[bool] + """``True`` if this slot participates in post-task D2H staging.""" + + def __init__( + self, + name: str, + size: Expr, + dtype: DataType, + load_from_host: bool = False, + store_to_host: bool = False, + span: Span = ..., + ) -> None: + """Create a WindowBuffer.""" + +class CommGroup(IRNode): + """A communication group inferred for a ``@pl.program``. + + The ``CollectCommGroups`` pass (added in N4) populates + ``program.comm_groups`` from ``pld.alloc_window_buffer`` ops and their + dispatch coverage. Participates in structural equality / hashing via + reflection. + """ + + devices: Final[list[int]] + """Ascending-sorted physical device-id list. **Empty list means "all + devices"** (every entry of ``DistributedConfig.device_ids``, resolved by + the driver at submit-time).""" + + slots: Final[list[WindowBuffer]] + """Allocation slots shared by every rank in the group (alloc-order).""" + + def __init__( + self, + devices: list[int], + slots: list[WindowBuffer], + span: Span = ..., + ) -> None: + """Create a CommGroup. Pass an empty ``devices`` list for "all devices".""" + @overload def structural_hash(node: IRNode, enable_auto_mapping: bool = False) -> int: ... @overload diff --git a/python/pypto/runtime/distributed_runner.py b/python/pypto/runtime/distributed_runner.py index b9093d1da..547a11de9 100644 --- a/python/pypto/runtime/distributed_runner.py +++ b/python/pypto/runtime/distributed_runner.py @@ -13,15 +13,20 @@ import ctypes import importlib.util +import json +import os import sys +import tempfile from pathlib import Path from typing import TYPE_CHECKING, Any import numpy as np # pyright: ignore[reportMissingImports] import torch +from pypto.ir.comm_manifest import COMM_MANIFEST_FILENAME, COMM_MANIFEST_VERSION + if TYPE_CHECKING: - from pypto.ir.distributed_compiled_program import DistributedCompiledProgram + from pypto.ir.distributed_compiled_program import DistributedCompiledProgram, DistributedConfig # --------------------------------------------------------------------------- @@ -96,6 +101,107 @@ def _load_generated_module(path: Path) -> Any: return module +# --------------------------------------------------------------------------- +# Manifest → ChipBootstrapConfig translation (runtime side). +# +# The compile-time half (Program → manifest dict, written to +# ``output_dir/orchestration/``) lives in +# ``pypto.ir.comm_manifest``. The runtime only consumes the file. +# --------------------------------------------------------------------------- + + +def _build_chip_bootstrap_configs_from_manifest( + manifest: dict[str, Any] | None, + dc: DistributedConfig, + rootinfo_path: str, +) -> list[Any] | None: + """Materialise a manifest dict into a list of ``ChipBootstrapConfig``. + + Returns ``None`` for ``manifest is None`` (no CommGroup). + + The CommGroup's ``devices`` is a list of physical device ids — **an empty + list means "all devices"** (every entry of ``dc.device_ids``). Covered + ranks receive comm + buffers; the rest stay comm-less so simpler's + per-device cfg list stays 1:1 with ``device_ids``. + """ + if manifest is None: + return None + + from simpler.task_interface import ( # noqa: PLC0415 # pyright: ignore[reportMissingImports] + ChipBootstrapConfig, + ChipBufferSpec, + ChipCommBootstrapConfig, + ) + + version = manifest.get("version") + if version != COMM_MANIFEST_VERSION: + raise RuntimeError( + f"comm manifest version mismatch: file is {version!r}, runtime expects {COMM_MANIFEST_VERSION}" + ) + + comm_groups = manifest["comm_groups"] + if len(comm_groups) != 1: + raise RuntimeError( + f"comm manifest must declare exactly one CommGroup (got {len(comm_groups)}); " + "if you need multiple groups, file a follow-up — currently unsupported." + ) + group = comm_groups[0] + + n_devices = len(dc.device_ids) + devices_field = list(group.get("devices", [])) + if not devices_field: + # Empty list = all devices. + covered_ranks = list(range(n_devices)) + else: + covered_ranks = sorted(int(d) for d in devices_field) + if any(d < 0 or d >= n_devices for d in covered_ranks): + raise RuntimeError( + f"CommGroup devices {covered_ranks!r} contains entries outside " + f"DistributedConfig.device_ids range [0, {n_devices})" + ) + nranks_value = len(covered_ranks) + + specs: list[Any] = [] + for slot in group["slots"]: + count = int(slot["size"]) + # Round up to whole bytes — works uniformly for FP4/INT4 sub-byte dtypes. + nbytes = (count * int(slot["bits_per_element"]) + 7) // 8 + specs.append( + ChipBufferSpec( + name=slot["name"], + dtype=slot["dtype"], + count=count, + nbytes=nbytes, + load_from_host=bool(slot.get("load_from_host", False)), + store_to_host=bool(slot.get("store_to_host", False)), + ) + ) + + window_size = sum(spec.nbytes for spec in specs) + + covered_set = set(covered_ranks) + cfgs: list[Any] = [] + rank_in_group = 0 + for r in range(n_devices): + if r in covered_set: + cfgs.append( + ChipBootstrapConfig( + comm=ChipCommBootstrapConfig( + rank=rank_in_group, + nranks=nranks_value, + rootinfo_path=rootinfo_path, + window_size=window_size, + ), + buffers=specs, + ) + ) + rank_in_group += 1 + else: + # Devices outside the group: skip HCCL bring-up entirely. + cfgs.append(ChipBootstrapConfig()) + return cfgs + + def execute_distributed( # noqa: PLR0912 compiled: DistributedCompiledProgram, coerced_args: list[torch.Tensor], @@ -186,7 +292,30 @@ def execute_distributed( # noqa: PLR0912 if fn is not None: sub_worker_fns[fn_name] = fn - # 5. Create and configure Worker + # 5. Build per-chip bootstrap configs from the AOT comm manifest emitted at + # compile time (output_dir/orchestration/comm_manifest.json). Comm-less + # programs (no CommGroup declared) skip the manifest entirely and the + # runner stays on the existing comm-less Worker path. + manifest_path = output_dir / "orchestration" / COMM_MANIFEST_FILENAME + manifest: dict[str, Any] | None = None + try: + with manifest_path.open("r", encoding="utf-8") as fh: + manifest = json.load(fh) + except FileNotFoundError: + pass + + chip_bootstrap_configs: list[Any] | None = None + rootinfo_path: str | None = None + if manifest is not None: + # ``mkstemp`` returns a unique, unpredictable path and atomically creates + # the file (mode 0o600) — safer than a PID-derived name, which is racy + # under PID recycling. The fd is closed immediately because HCCL only + # needs the path: comm bring-up overwrites the file on rank-0 and + # reads it on the other ranks. + rootinfo_fd, rootinfo_path = tempfile.mkstemp(prefix="pypto_distributed_rootinfo_", suffix=".bin") + os.close(rootinfo_fd) + chip_bootstrap_configs = _build_chip_bootstrap_configs_from_manifest(manifest, dc, rootinfo_path) + num_sub = max(dc.num_sub_workers, len(sub_worker_fns)) w = Worker( level=3, @@ -194,6 +323,7 @@ def execute_distributed( # noqa: PLR0912 num_sub_workers=num_sub, platform=compiled.platform, runtime=runtime_name, + chip_bootstrap_configs=chip_bootstrap_configs, ) # 6. Register SubWorker callables @@ -206,6 +336,9 @@ def execute_distributed( # noqa: PLR0912 # 7. Build the orchestration closure and execute _keep: list[Any] = [] + # Codegen always emits ``contexts`` in the entry signature; for comm-less + # programs ``w.chip_contexts`` is an empty list and the body simply ignores + # it, so the runner uses a single uniform call shape. def orch_fn(orch, _unused_args, _unused_cfg): entry_fn( orch, @@ -215,6 +348,7 @@ def orch_fn(orch, _unused_args, _unused_cfg): callables=chip_callables, sub_ids=sub_ids, _keep=_keep, + contexts=w.chip_contexts, ) call_config = CallConfig() @@ -225,3 +359,8 @@ def orch_fn(orch, _unused_args, _unused_cfg): w.run(orch_fn) finally: w.close() + if rootinfo_path is not None: + try: + os.unlink(rootinfo_path) + except FileNotFoundError: + pass diff --git a/src/codegen/distributed/distributed_codegen.cpp b/src/codegen/distributed/distributed_codegen.cpp index 63802fac3..f9275d235 100644 --- a/src/codegen/distributed/distributed_codegen.cpp +++ b/src/codegen/distributed/distributed_codegen.cpp @@ -172,15 +172,18 @@ void DistributedCodegen::EmitFunction(const ir::FunctionPtr& func) { is_worker_context_ = is_sub_worker; // Build function signature - // Orchestrators: def func(orch, _args, config, *, tensors, callables, sub_ids, _keep): + // Orchestrators: def func(orch, _args, config, *, tensors, callables, sub_ids, _keep, contexts): // SubWorkers are not emitted as Python functions (they run on device or as registered callables) if (is_sub_worker) { is_worker_context_ = false; return; } + // ``contexts`` is always present in the signature; for comm-less programs it + // is an empty list (and goes unused inside the body), which lets the runner + // dispatch with a single uniform call shape. std::ostringstream sig; - sig << "def " << func->name_ << "(orch, _args, config, *, tensors, callables, sub_ids, _keep):"; + sig << "def " << func->name_ << "(orch, _args, config, *, tensors, callables, sub_ids, _keep, contexts):"; emitter_.EmitLine(sig.str()); emitter_.IncreaseIndent(); @@ -207,7 +210,7 @@ void DistributedCodegen::EmitEntryFunction() { current_func_ = entry_func_; // Entry function signature - emitter_.EmitLine("def entry(orch, _args, config, *, tensors, callables, sub_ids, _keep):"); + emitter_.EmitLine("def entry(orch, _args, config, *, tensors, callables, sub_ids, _keep, contexts):"); emitter_.IncreaseIndent(); // Register parameter names @@ -396,9 +399,10 @@ void DistributedCodegen::VisitExpr_(const ir::CallPtr& op) { } if (callee->role_.has_value() && *callee->role_ == ir::Role::Orchestrator) { // Orchestrator-to-orchestrator calls: emit as direct function call - current_expr_value_ = callee->name_ + - "(orch, _args, config, " - "tensors=tensors, callables=callables, sub_ids=sub_ids, _keep=_keep)"; + current_expr_value_ = + callee->name_ + + "(orch, _args, config, " + "tensors=tensors, callables=callables, sub_ids=sub_ids, _keep=_keep, contexts=contexts)"; return; } // Chip-level function (Orchestration/InCore with no role) called from HOST orchestrator diff --git a/src/ir/program.cpp b/src/ir/program.cpp index 0bcd443d0..2438f1495 100644 --- a/src/ir/program.cpp +++ b/src/ir/program.cpp @@ -44,6 +44,13 @@ Program::Program(const std::vector& functions, std::string name, Sp } } +// Vector-based constructor with CommGroup metadata. +Program::Program(const std::vector& functions, std::vector comm_groups, + std::string name, Span span) + : Program(functions, std::move(name), std::move(span)) { + comm_groups_ = std::move(comm_groups); +} + FunctionPtr Program::GetFunction(const std::string& name) const { auto it = functions_.find(std::make_shared(name)); if (it != functions_.end()) { diff --git a/src/ir/serialization/deserializer.cpp b/src/ir/serialization/deserializer.cpp index 5c8617ac4..ce31dea46 100644 --- a/src/ir/serialization/deserializer.cpp +++ b/src/ir/serialization/deserializer.cpp @@ -380,6 +380,10 @@ class IRDeserializer::Impl : public detail::DeserializerContext { } return std::make_shared(shape, DataType(dtype_code), memref, tensor_view); + } else if (type_kind == "DistributedTensorType") { + // DistributedTensorType currently carries only shape + dtype (no memref / + // tensor_view at the surface; alloc-side metadata lives on the alloc op). + return std::make_shared(shape, DataType(dtype_code)); } else if (type_kind == "TileType") { std::optional memref; std::optional tile_view; @@ -448,6 +452,22 @@ class IRDeserializer::Impl : public detail::DeserializerContext { throw RuntimeError("Missing required field: " + field_name); } + bool HasField(const msgpack::object& fields_obj, const std::string& field_name) override { + if (fields_obj.type != msgpack::type::MAP) { + return false; + } + msgpack::object_kv* p = fields_obj.via.map.ptr; + msgpack::object_kv* const pend = fields_obj.via.map.ptr + fields_obj.via.map.size; + for (; p < pend; ++p) { + std::string key; + p->key.convert(key); + if (key == field_name) { + return true; + } + } + return false; + } + private: std::unordered_map id_to_ptr_; }; diff --git a/src/ir/serialization/serializer.cpp b/src/ir/serialization/serializer.cpp index c89c50e7d..34ef37f60 100644 --- a/src/ir/serialization/serializer.cpp +++ b/src/ir/serialization/serializer.cpp @@ -112,6 +112,7 @@ class FieldSerializerVisitor { result_type VisitLeafField(const Span& field); result_type VisitLeafField(const std::vector& field); result_type VisitLeafField(const std::vector>& field); + result_type VisitLeafField(const std::vector& field); // Field kind hooks template @@ -234,6 +235,8 @@ class IRSerializer::Impl { SERIALIZE_FIELDS(InlineStmt); SERIALIZE_FIELDS(Function); SERIALIZE_FIELDS(Program); + SERIALIZE_FIELDS(WindowBuffer); + SERIALIZE_FIELDS(CommGroup); #undef SERIALIZE_FIELDS #undef SERIALIZE_FIELDS_BASE @@ -388,7 +391,12 @@ class IRSerializer::Impl { if (auto scalar_type = As(type)) { type_map["dtype"] = msgpack::object(scalar_type->dtype_.Code(), zone); - } else if (auto tensor_type = As(type)) { + } else if (type->GetKind() == ObjectKind::TensorType || + type->GetKind() == ObjectKind::DistributedTensorType) { + // DistributedTensorType has identical fields to TensorType — share the + // serialization code via static_cast. The "type_kind" key already + // distinguishes the two for the deserializer. + auto tensor_type = std::static_pointer_cast(type); type_map["dtype"] = msgpack::object(tensor_type->dtype_.Code(), zone); std::vector shape_vec; @@ -622,6 +630,15 @@ msgpack::object FieldSerializerVisitor::VisitLeafField(const SplitMode& field) { return msgpack::object(static_cast(field), zone_); } +msgpack::object FieldSerializerVisitor::VisitLeafField(const std::vector& field) { + std::vector vec; + vec.reserve(field.size()); + for (auto v : field) { + vec.emplace_back(v, zone_); + } + return msgpack::object(vec, zone_); +} + msgpack::object FieldSerializerVisitor::VisitLeafField(const std::optional& field) { if (field.has_value()) { return msgpack::object(static_cast(*field), zone_); diff --git a/src/ir/serialization/type_deserializers.cpp b/src/ir/serialization/type_deserializers.cpp index 51f8fe9c4..a9d393770 100644 --- a/src/ir/serialization/type_deserializers.cpp +++ b/src/ir/serialization/type_deserializers.cpp @@ -896,7 +896,62 @@ static IRNodePtr DeserializeProgram(const msgpack::object& fields_obj, msgpack:: } } - return std::make_shared(functions, name, span); + std::vector comm_groups; + // ``comm_groups`` is optional — older serialized programs do not contain it. + if (ctx.HasField(fields_obj, "comm_groups")) { + auto comm_groups_obj = GET_FIELD_OBJ("comm_groups"); + if (comm_groups_obj.type == msgpack::type::ARRAY) { + for (uint32_t i = 0; i < comm_groups_obj.via.array.size; ++i) { + comm_groups.push_back(std::static_pointer_cast( + ctx.DeserializeNode(comm_groups_obj.via.array.ptr[i], zone))); + } + } + } + + if (comm_groups.empty()) { + return std::make_shared(std::move(functions), name, span); + } + return std::make_shared(std::move(functions), std::move(comm_groups), name, span); +} + +// Deserialize WindowBuffer +static IRNodePtr DeserializeWindowBuffer(const msgpack::object& fields_obj, msgpack::zone& zone, + DeserializerContext& ctx) { + auto span = ctx.DeserializeSpan(GET_FIELD_OBJ("span")); + std::string name = GET_FIELD(std::string, "name"); + auto size = std::static_pointer_cast(ctx.DeserializeNode(GET_FIELD_OBJ("size"), zone)); + uint8_t dtype_code = GET_FIELD(uint8_t, "dtype"); + bool load_from_host = GET_FIELD(bool, "load_from_host"); + bool store_to_host = GET_FIELD(bool, "store_to_host"); + return std::make_shared(std::move(name), size, DataType(dtype_code), load_from_host, + store_to_host, span); +} + +// Deserialize CommGroup +static IRNodePtr DeserializeCommGroup(const msgpack::object& fields_obj, msgpack::zone& zone, + DeserializerContext& ctx) { + auto span = ctx.DeserializeSpan(GET_FIELD_OBJ("span")); + + std::vector devices; + auto devices_obj = GET_FIELD_OBJ("devices"); + if (devices_obj.type == msgpack::type::ARRAY) { + devices.reserve(devices_obj.via.array.size); + for (uint32_t i = 0; i < devices_obj.via.array.size; ++i) { + int64_t v = 0; + devices_obj.via.array.ptr[i].convert(v); + devices.push_back(v); + } + } + + std::vector slots; + auto slots_obj = GET_FIELD_OBJ("slots"); + if (slots_obj.type == msgpack::type::ARRAY) { + for (uint32_t i = 0; i < slots_obj.via.array.size; ++i) { + slots.push_back(std::static_pointer_cast( + ctx.DeserializeNode(slots_obj.via.array.ptr[i], zone))); + } + } + return std::make_shared(std::move(devices), std::move(slots), span); } // Deserialize MakeTuple @@ -981,6 +1036,8 @@ static TypeRegistrar _inline_stmt_registrar("InlineStmt", DeserializeInlineStmt) static TypeRegistrar _function_registrar("Function", DeserializeFunction); static TypeRegistrar _program_registrar("Program", DeserializeProgram); +static TypeRegistrar _window_buffer_registrar("WindowBuffer", DeserializeWindowBuffer); +static TypeRegistrar _comm_group_registrar("CommGroup", DeserializeCommGroup); static TypeRegistrar _make_tuple_registrar("MakeTuple", DeserializeMakeTuple); static TypeRegistrar _tuple_get_item_expr_registrar("TupleGetItemExpr", DeserializeTupleGetItemExpr); diff --git a/src/ir/transforms/python_printer.cpp b/src/ir/transforms/python_printer.cpp index 0177ee492..d84306a3f 100644 --- a/src/ir/transforms/python_printer.cpp +++ b/src/ir/transforms/python_printer.cpp @@ -375,10 +375,23 @@ std::string IRPythonPrinter::Print(const TypePtr& type) { return prefix_ + ".Scalar[" + prefix_ + "." + DataTypeToString(scalar_type->dtype_) + "]"; } - if (auto tensor_type = As(type)) { + // Tensor / DistributedTensor share the same rendering surface — only the + // subscript head differs (``pl.Tensor`` vs ``pld.DistributedTensor``). Note + // ``As`` is precise-match and would not fire for the subclass, + // so dispatch on DistributedTensorType first and pass it through the + // TensorType base for shared field access. + TensorTypePtr tensor_type; + std::string tensor_head; + if (auto dt_tensor = As(type)) { + tensor_type = dt_tensor; + tensor_head = "pld.DistributedTensor"; + } else if (auto plain_tensor = As(type)) { + tensor_type = plain_tensor; + tensor_head = prefix_ + ".Tensor"; + } + if (tensor_type) { std::ostringstream oss; - // Subscript-style: pl.Tensor[[shape], dtype] - oss << prefix_ << ".Tensor[["; + oss << tensor_head << "[["; PrintShapeDims(oss, tensor_type->shape_); oss << "], " << prefix_ << "." << DataTypeToString(tensor_type->dtype_); diff --git a/src/ir/transforms/structural_equal.cpp b/src/ir/transforms/structural_equal.cpp index 3a2708d9a..291579330 100644 --- a/src/ir/transforms/structural_equal.cpp +++ b/src/ir/transforms/structural_equal.cpp @@ -216,6 +216,18 @@ class StructuralEqualImpl { return true; } + result_type VisitLeafField(const std::vector& lhs, const std::vector& rhs) { + if (lhs != rhs) { + if constexpr (AssertMode) { + std::ostringstream msg; + msg << "vector mismatch (size " << lhs.size() << " vs " << rhs.size() << ")"; + ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", ""); + } + return false; + } + return true; + } + // Leaf field comparisons (dual-node version) result_type VisitLeafField(const int& lhs, const int& rhs) { if (lhs != rhs) { @@ -229,6 +241,19 @@ class StructuralEqualImpl { return true; } + result_type VisitLeafField(const bool& lhs, const bool& rhs) { + if (lhs != rhs) { + if constexpr (AssertMode) { + std::ostringstream msg; + msg << "Bool value mismatch (" << (lhs ? "true" : "false") << " != " << (rhs ? "true" : "false") + << ")"; + ThrowMismatch(msg.str(), IRNodePtr(), IRNodePtr(), "", ""); + } + return false; + } + return true; + } + result_type VisitLeafField(const int64_t& lhs, const int64_t& rhs) { if (lhs != rhs) { if constexpr (AssertMode) { @@ -930,6 +955,8 @@ bool StructuralEqualImpl::Equal(const IRNodePtr& lhs, const IRNodePt EQUAL_DISPATCH(InlineStmt) EQUAL_DISPATCH(Function) EQUAL_DISPATCH_TRANSPARENT(Program) + EQUAL_DISPATCH(WindowBuffer) + EQUAL_DISPATCH(CommGroup) throw pypto::TypeError("Unknown IR node type in StructuralEqualImpl::Equal: " + lhs->TypeName()); } @@ -970,8 +997,13 @@ bool StructuralEqualImpl::EqualType(const TypePtr& lhs, const TypePt return false; } return true; - } else if (auto lhs_tensor = As(lhs)) { - auto rhs_tensor = As(rhs); + } else if (lhs->GetKind() == ObjectKind::TensorType || + lhs->GetKind() == ObjectKind::DistributedTensorType) { + // DistributedTensorType has identical fields to TensorType — share the + // comparison code via static_cast. The TypeName check above already + // guarantees both sides have matching kinds. + auto lhs_tensor = std::static_pointer_cast(lhs); + auto rhs_tensor = std::static_pointer_cast(rhs); if (!rhs_tensor) { if constexpr (AssertMode) { ThrowMismatch("Type cast failed for TensorType", IRNodePtr(), IRNodePtr(), "", ""); diff --git a/src/ir/transforms/structural_hash.cpp b/src/ir/transforms/structural_hash.cpp index 6e2f5f286..0f742592c 100644 --- a/src/ir/transforms/structural_hash.cpp +++ b/src/ir/transforms/structural_hash.cpp @@ -140,6 +140,14 @@ class StructuralHasher { } } + result_type VisitLeafField(const std::vector& field) { + result_type h = 0; + for (auto v : field) { + h = hash_combine(h, static_cast(std::hash{}(v))); + } + return h; + } + result_type VisitLeafField(const int& field) { return static_cast(std::hash{}(field)); } result_type VisitLeafField(const int64_t& field) { @@ -232,6 +240,8 @@ class StructuralHasher { return static_cast(0); } + result_type VisitLeafField(const bool& field) { return static_cast(std::hash{}(field)); } + result_type VisitLeafField(const std::optional& field) { if (field.has_value()) { return hash_combine(1, static_cast(std::hash{}(*field))); @@ -567,6 +577,8 @@ StructuralHasher::result_type StructuralHasher::HashNode(const IRNodePtr& node) HASH_DISPATCH(InlineStmt) HASH_DISPATCH(Function) HASH_DISPATCH(Program) + HASH_DISPATCH(WindowBuffer) + HASH_DISPATCH(CommGroup) // Free Var types (including MemRef and IterArg) that may be mapped to other free vars. // These have already been dispatched above for field hashing; diff --git a/tests/ut/ir/core/test_comm_group_schema.py b/tests/ut/ir/core/test_comm_group_schema.py new file mode 100644 index 000000000..8f503a4be --- /dev/null +++ b/tests/ut/ir/core/test_comm_group_schema.py @@ -0,0 +1,95 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- + +"""IR-level tests for the v2 ``WindowBuffer`` / ``CommGroup`` schema (N1.3).""" + +import pytest +from pypto.pypto_core import DataType +from pypto.pypto_core.ir import ( + CommGroup, + ConstInt, + Span, + WindowBuffer, + structural_equal, +) + + +def _const(value: int) -> ConstInt: + return ConstInt(value, DataType.INT64, Span.unknown()) + + +# --------------------------------------------------------------------------- +# WindowBuffer +# --------------------------------------------------------------------------- + + +def test_window_buffer_minimal_construction(): + wb = WindowBuffer("data", _const(256), DataType.FP32) + assert wb.name == "data" + assert wb.dtype == DataType.FP32 + assert wb.load_from_host is False + assert wb.store_to_host is False + + +def test_window_buffer_load_store_to_host_flags(): + """v2 schema: load/store_to_host are bool flags; the actual host tensor + binding (if any) is recorded on the alloc op, not here.""" + wb = WindowBuffer("lut", _const(64), DataType.FP32, load_from_host=True, store_to_host=True) + assert wb.load_from_host is True + assert wb.store_to_host is True + + +# --------------------------------------------------------------------------- +# CommGroup structural equality +# --------------------------------------------------------------------------- + + +def test_comm_group_empty_devices_means_all(): + """``devices == []`` is the convention for "covers all DistributedConfig.device_ids".""" + g = CommGroup([], [WindowBuffer("data", _const(64), DataType.FP32)]) + assert list(g.devices) == [] + + +def test_comm_group_explicit_device_subset(): + g = CommGroup([0, 1, 2], [WindowBuffer("data", _const(64), DataType.FP32)]) + assert list(g.devices) == [0, 1, 2] + + +def test_comm_group_structural_equal(): + g1 = CommGroup([], [WindowBuffer("data", _const(256), DataType.FP32)]) + g2 = CommGroup([], [WindowBuffer("data", _const(256), DataType.FP32)]) + assert structural_equal(g1, g2) + + +def test_comm_group_structural_not_equal_when_devices_differ(): + g_all = CommGroup([], [WindowBuffer("data", _const(64), DataType.FP32)]) + g_subset = CommGroup([0, 1], [WindowBuffer("data", _const(64), DataType.FP32)]) + assert not structural_equal(g_all, g_subset) + + +def test_comm_group_structural_not_equal_when_subsets_differ(): + a = CommGroup([0, 1], [WindowBuffer("data", _const(64), DataType.FP32)]) + b = CommGroup([2, 3], [WindowBuffer("data", _const(64), DataType.FP32)]) + assert not structural_equal(a, b) + + +def test_comm_group_structural_not_equal_when_slots_differ(): + a = CommGroup([], [WindowBuffer("data", _const(64), DataType.FP32)]) + b = CommGroup([], [WindowBuffer("data", _const(128), DataType.FP32)]) + assert not structural_equal(a, b) + + +def test_comm_group_structural_not_equal_when_load_flag_differs(): + a = CommGroup([], [WindowBuffer("data", _const(64), DataType.FP32, load_from_host=False)]) + b = CommGroup([], [WindowBuffer("data", _const(64), DataType.FP32, load_from_host=True)]) + assert not structural_equal(a, b) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/ut/ir/core/test_distributed_tensor_type.py b/tests/ut/ir/core/test_distributed_tensor_type.py new file mode 100644 index 000000000..06b3dcccc --- /dev/null +++ b/tests/ut/ir/core/test_distributed_tensor_type.py @@ -0,0 +1,114 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- + +"""Tests for ``ir.DistributedTensorType`` (N1.2). + +The distributed tensor subclass is distinguished from plain :class:`TensorType` +*only* by ``ObjectKind``; structurally identical fields. It exists so cross-rank +op verifiers (added in N6) can reject plain tensors via +``As``. +""" + +import pytest +from pypto.pypto_core import DataType +from pypto.pypto_core.ir import ( + ConstInt, + DistributedTensorType, + Span, + TensorType, + assert_structural_equal, + deserialize, + serialize, + structural_equal, +) + + +def _shape(*dims: int) -> list[ConstInt]: + return [ConstInt(d, DataType.INT64, Span.unknown()) for d in dims] + + +def test_construct_with_constant_shape(): + """``DistributedTensorType([64], FP32)`` constructs and exposes shape/dtype.""" + dt = DistributedTensorType([64], DataType.FP32) + assert dt.dtype == DataType.FP32 + assert len(dt.shape) == 1 + + +def test_construct_with_expr_shape(): + """Expr-shape ctor mirrors TensorType's signature.""" + dt = DistributedTensorType(_shape(64, 128), DataType.FP32) + assert len(dt.shape) == 2 + + +def test_type_name_distinct_from_tensor_type(): + """Precise ObjectKind keeps DistributedTensorType separate from TensorType.""" + dt = DistributedTensorType([64], DataType.FP32) + plain = TensorType([64], DataType.FP32) + assert type(dt).__name__ == "DistributedTensorType" + assert type(plain).__name__ == "TensorType" + + +def test_inherits_from_tensor_type(): + """C++ inheritance is preserved across the binding so DSL helpers that + accept ``TensorType`` still work on the distributed subclass.""" + dt = DistributedTensorType([64], DataType.FP32) + assert isinstance(dt, TensorType) + + +def test_structural_equal_same_shape_dtype(): + a = DistributedTensorType([64], DataType.FP32) + b = DistributedTensorType([64], DataType.FP32) + assert structural_equal(a, b) + + +def test_structural_not_equal_to_plain_tensor_type(): + """Plain TensorType and DistributedTensorType are distinct types — no + cross-class structural equality, even with identical shape/dtype.""" + dt = DistributedTensorType([64], DataType.FP32) + plain = TensorType([64], DataType.FP32) + assert not structural_equal(dt, plain) + + +def test_structural_not_equal_different_dtype(): + a = DistributedTensorType([64], DataType.FP32) + b = DistributedTensorType([64], DataType.FP16) + assert not structural_equal(a, b) + + +def test_assert_structural_equal_passes(): + """``assert_structural_equal`` accepts equivalent distributed types.""" + a = DistributedTensorType([64], DataType.FP32) + b = DistributedTensorType([64], DataType.FP32) + assert_structural_equal(a, b) + + +def test_assert_structural_equal_diagnoses_class_mismatch(): + dt = DistributedTensorType([64], DataType.FP32) + plain = TensorType([64], DataType.FP32) + with pytest.raises(Exception, match="Type name mismatch"): + assert_structural_equal(dt, plain) + + +def test_serialization_roundtrip(): + """The deserializer reconstructs the precise subclass.""" + from pypto.pypto_core.ir import Var # noqa: PLC0415 + + dt = DistributedTensorType([64, 128], DataType.FP32) + # Wrap the type in a Var (the only way to feed a Type through the IRNode + # serializer); the deserializer must restore the precise subclass. + var = Var("v", dt, Span.unknown()) + blob = serialize(var) + restored = deserialize(blob) + assert isinstance(restored, Var) + assert isinstance(restored.type, DistributedTensorType) + assert structural_equal(dt, restored.type) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/ut/ir/parser/test_distributed_tensor_annotation.py b/tests/ut/ir/parser/test_distributed_tensor_annotation.py new file mode 100644 index 000000000..2b72bd797 --- /dev/null +++ b/tests/ut/ir/parser/test_distributed_tensor_annotation.py @@ -0,0 +1,72 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +# ruff: noqa: F722, F821 + +"""Parser tests for ``pld.DistributedTensor[[shape], dtype]`` annotations (N1.2).""" + +import pypto.language as pl +import pypto.language.distributed as pld +import pytest +from pypto.pypto_core.ir import DistributedTensorType, TensorType + + +def test_distributed_tensor_param_resolves_to_distributed_type(): + @pl.program + class P: + @pl.function + def f(self, x: pld.DistributedTensor[[256], pl.FP32]) -> pld.DistributedTensor[[256], pl.FP32]: + return x + + gvar = P.get_global_var("f") + assert gvar is not None + func = P.functions[gvar] + param_type = func.params[0].type + assert isinstance(param_type, DistributedTensorType) + assert isinstance(param_type, TensorType) # subclass relationship preserved + assert len(func.return_types) == 1 + assert isinstance(func.return_types[0], DistributedTensorType) + + +def test_plain_tensor_param_is_not_distributed_type(): + """``pl.Tensor[...]`` must not promote to ``DistributedTensorType``.""" + + @pl.program + class P: + @pl.function + def f(self, x: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: + return x + + gvar = P.get_global_var("f") + assert gvar is not None + func = P.functions[gvar] + param_type = func.params[0].type + assert isinstance(param_type, TensorType) + assert not isinstance(param_type, DistributedTensorType) + + +def test_distributed_tensor_with_layout(): + """DistributedTensor mirrors Tensor's third-slot layout dispatch.""" + + @pl.program + class P: + @pl.function + def f(self, x: pld.DistributedTensor[[64], pl.FP32, pl.NZ]) -> pl.Tensor[[64], pl.FP32]: + return x + + gvar = P.get_global_var("f") + assert gvar is not None + func = P.functions[gvar] + param_type = func.params[0].type + assert isinstance(param_type, DistributedTensorType) + # Layout flows through TensorView, mirroring pl.Tensor[[shape], dtype, layout]. + assert param_type.tensor_view is not None + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/ut/runtime/test_chip_bootstrap_configs.py b/tests/ut/runtime/test_chip_bootstrap_configs.py new file mode 100644 index 000000000..0aba6bfb6 --- /dev/null +++ b/tests/ut/runtime/test_chip_bootstrap_configs.py @@ -0,0 +1,338 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +# ruff: noqa: F722, F821 +"""Tests for the AOT comm-group manifest pipeline (v2 schema, N1.4). + +Compile-time: Program → ``lift_comm_manifest`` → JSON-safe dict +On disk: ``output_dir/orchestration/comm_manifest.json`` +Runtime: dict → ``_build_chip_bootstrap_configs_from_manifest`` → list[ChipBootstrapConfig] + +The CollectCommGroups pass that *infers* CommGroups from +``pld.alloc_window_buffer`` ops is added in N4. These tests pre-stage the +CommGroups directly on a hand-built ``ir.Program`` so the manifest pipeline +can be exercised independently of the inference logic. +""" + +import json + +import pypto.language as pl +import pypto.language.distributed as pld +import pytest +from pypto.pypto_core import DataType +from pypto.pypto_core.ir import ( + CommGroup, + ConstInt, + Program, + Span, + Var, + WindowBuffer, +) + + +def _make_dc(device_ids): + from pypto.ir.distributed_compiled_program import DistributedConfig # noqa: PLC0415 + + return DistributedConfig(device_ids=list(device_ids)) + + +def _lift(program): + from pypto.ir.comm_manifest import lift_comm_manifest # noqa: PLC0415 + + return lift_comm_manifest(program) + + +def _build(manifest, device_ids, rootinfo_path="/tmp/_test_rootinfo.bin"): + # Skip in environments without simpler installed (e.g. unit-tests CI). + pytest.importorskip("simpler.task_interface") + from pypto.runtime.distributed_runner import ( # noqa: PLC0415 + _build_chip_bootstrap_configs_from_manifest, + ) + + return _build_chip_bootstrap_configs_from_manifest(manifest, _make_dc(device_ids), rootinfo_path) + + +def _const(value: int) -> ConstInt: + return ConstInt(value, DataType.INT64, Span.unknown()) + + +def _trivial_program(groups: list[CommGroup] | None = None) -> Program: + """Build a minimal Program (optionally with CommGroups) via @pl.program + + immediate Program(...) reconstruction. Until the CollectCommGroups pass + lands (N4), tests pre-stage the groups directly through the constructor. + """ + + @pl.program + class P: + @pl.function + def f(self, x: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: + return x + + if not groups: + return P + funcs = list(P.functions.values()) + return Program(funcs, list(groups), P.name, P.span) + + +# --------------------------------------------------------------------------- +# Compile-time: lift_comm_manifest +# --------------------------------------------------------------------------- + + +def test_lift_no_comm_group_returns_none(): + """A program without any CommGroup must skip the manifest entirely.""" + p = _trivial_program() + assert _lift(p) is None + + +def test_lift_const_size_emits_json_safe_manifest(): + """All-devices group with literal sizes lifts to a JSON-safe v2 manifest.""" + slots = [ + WindowBuffer("data", _const(256), DataType.FP32), + WindowBuffer("signal", _const(2), DataType.INT32), + ] + p = _trivial_program([CommGroup([], slots)]) # empty devices = all + + manifest = _lift(p) + assert manifest is not None + # Manifest must round-trip through JSON without losing fidelity. + assert json.loads(json.dumps(manifest)) == manifest + + assert manifest["version"] == 2 + assert len(manifest["comm_groups"]) == 1 + g = manifest["comm_groups"][0] + assert g["devices"] == [] # empty = all devices + assert g["slots"] == [ + { + "name": "data", + "dtype": "float32", + "size": 256, + "bits_per_element": 32, + "load_from_host": False, + "store_to_host": False, + }, + { + "name": "signal", + "dtype": "int32", + "size": 2, + "bits_per_element": 32, + "load_from_host": False, + "store_to_host": False, + }, + ] + + +def test_lift_explicit_device_list(): + """A group with explicit devices serializes the literal list.""" + p = _trivial_program([CommGroup([0, 1], [WindowBuffer("data", _const(64), DataType.FP32)])]) + + manifest = _lift(p) + assert manifest is not None + assert manifest["comm_groups"][0]["devices"] == [0, 1] + + +def test_lift_dynamic_size_unsupported_raises(): + """Symbolic ``size`` is rejected at compile time so authors get a clear error.""" + sym = Var("N", _const(0).type, Span.unknown()) + p = _trivial_program([CommGroup([], [WindowBuffer("signal", sym, DataType.INT32)])]) + with pytest.raises(RuntimeError, match="dynamic WindowBuffer size is not supported"): + _lift(p) + + +def test_lift_two_comm_groups_raises(): + """Multi-group programs are not yet supported by the runner.""" + p = _trivial_program( + [ + CommGroup([], [WindowBuffer("a", _const(64), DataType.FP32)]), + CommGroup([], [WindowBuffer("b", _const(64), DataType.FP32)]), + ] + ) + with pytest.raises(RuntimeError, match="at most one CommGroup"): + _lift(p) + + +def test_lift_load_store_to_host_flags_propagate(): + """``load_from_host`` / ``store_to_host`` bool flags pass through 1:1.""" + slots = [ + WindowBuffer("lut", _const(64), DataType.FP32, load_from_host=True), + WindowBuffer("met", _const(64), DataType.INT32, store_to_host=True), + ] + p = _trivial_program([CommGroup([], slots)]) + + manifest = _lift(p) + assert manifest is not None + slot_data = manifest["comm_groups"][0]["slots"] + assert slot_data[0]["load_from_host"] is True + assert slot_data[0]["store_to_host"] is False + assert slot_data[1]["load_from_host"] is False + assert slot_data[1]["store_to_host"] is True + + +# --------------------------------------------------------------------------- +# Runtime: _build_chip_bootstrap_configs_from_manifest +# --------------------------------------------------------------------------- + + +def test_build_none_manifest_returns_none(): + pytest.importorskip("simpler.task_interface") + assert _build(None, [0, 1]) is None + + +def _make_manifest(devices: list[int], slots: list[dict]) -> dict: + return {"version": 2, "comm_groups": [{"devices": devices, "slots": slots}]} + + +def _slot( + name: str, + size: int, + dtype: str = "float32", + bits: int = 32, + *, + load_from_host: bool = False, + store_to_host: bool = False, +) -> dict: + return { + "name": name, + "dtype": dtype, + "size": size, + "bits_per_element": bits, + "load_from_host": load_from_host, + "store_to_host": store_to_host, + } + + +def test_build_full_coverage_via_empty_devices(): + """Empty ``devices`` list ⇒ every device in the dc gets a comm config.""" + pytest.importorskip("simpler.task_interface") + manifest = _make_manifest([], [_slot("data", 256), _slot("signal", 2, "int32")]) + cfgs = _build(manifest, [0, 1]) + assert cfgs is not None + assert len(cfgs) == 2 + assert all(c.comm is not None for c in cfgs) + assert cfgs[0].comm.rank == 0 and cfgs[1].comm.rank == 1 + assert all(c.comm.nranks == 2 for c in cfgs) + assert all(c.comm.window_size == 256 * 4 + 2 * 4 for c in cfgs) + assert [b.name for b in cfgs[0].buffers] == ["data", "signal"] + + +def test_build_explicit_subset_keeps_extras_commless(): + """Devices not in the list get bare ``ChipBootstrapConfig()`` (comm=None).""" + pytest.importorskip("simpler.task_interface") + manifest = _make_manifest([0, 1], [_slot("data", 64)]) + cfgs = _build(manifest, [0, 1, 2, 3]) + assert cfgs is not None + assert len(cfgs) == 4 + assert cfgs[0].comm is not None and cfgs[0].comm.rank == 0 + assert cfgs[1].comm is not None and cfgs[1].comm.rank == 1 + assert cfgs[2].comm is None and cfgs[3].comm is None + + +def test_build_subset_out_of_range_raises(): + pytest.importorskip("simpler.task_interface") + manifest = _make_manifest([3, 4], [_slot("data", 1)]) + with pytest.raises(RuntimeError, match="outside DistributedConfig.device_ids range"): + _build(manifest, [0, 1]) + + +def test_build_unknown_version_raises(): + pytest.importorskip("simpler.task_interface") + manifest = {"version": 999, "comm_groups": [{"devices": [], "slots": []}]} + with pytest.raises(RuntimeError, match="version mismatch"): + _build(manifest, [0, 1]) + + +def test_build_rejects_two_groups(): + pytest.importorskip("simpler.task_interface") + manifest = { + "version": 2, + "comm_groups": [ + {"devices": [], "slots": []}, + {"devices": [], "slots": []}, + ], + } + with pytest.raises(RuntimeError, match="exactly one CommGroup"): + _build(manifest, [0, 1]) + + +def test_build_subbyte_dtype_byte_calculation(): + """nbytes rounds up for sub-byte dtypes (e.g. INT4).""" + pytest.importorskip("simpler.task_interface") + # INT4: 4 bits per element, 7 elements → ceil(7*4/8) = 4 bytes. + manifest = _make_manifest([], [_slot("x", 7, "int4", bits=4)]) + cfgs = _build(manifest, [0, 1]) + assert cfgs is not None + assert cfgs[0].buffers[0].nbytes == 4 + assert cfgs[0].comm.window_size == 4 + + +def test_build_load_store_host_flags(): + """Slot bool flags propagate to ChipBufferSpec without modification.""" + pytest.importorskip("simpler.task_interface") + manifest = _make_manifest( + [], + [ + _slot("lut", 4, load_from_host=True), + _slot("out", 4, store_to_host=True), + ], + ) + cfgs = _build(manifest, [0, 1]) + assert cfgs is not None + assert cfgs[0].buffers[0].load_from_host is True + assert cfgs[0].buffers[0].store_to_host is False + assert cfgs[0].buffers[1].load_from_host is False + assert cfgs[0].buffers[1].store_to_host is True + + +# --------------------------------------------------------------------------- +# AOT roundtrip: writing manifest to disk and re-reading it +# --------------------------------------------------------------------------- + + +def test_aot_roundtrip_writes_and_loads_manifest(tmp_path): + """``emit_comm_manifest`` writes the file and ``json.load`` recovers an + identical dict to the in-memory ``lift_comm_manifest`` output. + """ + from pypto.ir.comm_manifest import ( # noqa: PLC0415 + COMM_MANIFEST_FILENAME, + emit_comm_manifest, + ) + + p = _trivial_program([CommGroup([], [WindowBuffer("data", _const(64), DataType.FP32)])]) + + expected = _lift(p) + out_path = emit_comm_manifest(p, tmp_path) + assert out_path is not None + assert out_path == tmp_path / "orchestration" / COMM_MANIFEST_FILENAME + + with out_path.open("r", encoding="utf-8") as fh: + actual = json.load(fh) + assert actual == expected + + +def test_aot_roundtrip_no_group_writes_no_file(tmp_path): + from pypto.ir.comm_manifest import emit_comm_manifest # noqa: PLC0415 + + p = _trivial_program() + assert emit_comm_manifest(p, tmp_path) is None + assert not (tmp_path / "orchestration").exists() + + +# --------------------------------------------------------------------------- +# Sanity: ensure the new pld.* surface area exposes only DistributedTensor +# --------------------------------------------------------------------------- + + +def test_pld_does_not_export_legacy_dataclasses(): + """N1 removes the user-declared ``pld.CommGroup`` / ``pld.WindowBuffer``.""" + assert not hasattr(pld, "CommGroup") + assert not hasattr(pld, "WindowBuffer") + assert hasattr(pld, "DistributedTensor") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])