Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ set(PYPTO_SOURCES
src/ir/op/sync_ops/sync.cpp
src/ir/op/sync_ops/cross_core.cpp
src/ir/op/sync_ops/task.cpp
src/ir/op/distributed/memory.cpp
src/ir/op/tensor_ops/broadcast.cpp
src/ir/op/tensor_ops/elementwise.cpp
src/ir/op/tensor_ops/matmul.cpp
Expand Down
10 changes: 7 additions & 3 deletions docs/en/dev/ir/02-types.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,13 @@ assert isinstance(t, ir.TensorType) # C++ inheritance preserved
# As<TensorType>(t) → null; As<DistributedTensorType>(t) → cast.
```

Allocation-side metadata (the buffer name, host-staging flags) lives on the
allocation op (`pld.alloc_window_buffer`, added in a later milestone) and on
the `ir.WindowBuffer` slot in `program.comm_groups`, not on the type itself.
Allocation-side metadata (per-rank size, host-staging flags) lives on the
`ir.WindowBuffer` `Var` subclass that the `pld.alloc_window_buffer` op binds.
Slices materialised through `pld.window(buf, [shape], dtype=...)` carry an
optional back-reference (`DistributedTensorType.window_buffer`) to the source
`WindowBuffer`, so two same-shape / same-dtype slices of different
allocations stay structurally distinct. User-declared parameter annotations
like `pld.DistributedTensor[[shape], dtype]` leave this field as `None`.
Tile types do not have a distributed variant; cross-rank ops always operate
on `DistributedTensor`.

Expand Down
12 changes: 8 additions & 4 deletions docs/zh-cn/dev/ir/02-types.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,14 @@ assert isinstance(t, ir.TensorType) # C++ 继承关系保留
# As<TensorType>(t) → null;As<DistributedTensorType>(t) → 转型成功
```

分配侧的元数据(buffer 名字、host staging 标志)挂在 alloc op
(`pld.alloc_window_buffer`,后续 milestone 引入)和 `program.comm_groups`
中的 `ir.WindowBuffer` slot 上,**不**在类型本身。Tile 类型没有 distributed
变体;跨 rank op 始终作用在 `DistributedTensor` 上。
分配侧的元数据(每 rank 大小、host staging 标志)挂在 `pld.alloc_window_buffer`
op 所绑定的 `ir.WindowBuffer`(`Var` 子类)上。通过
`pld.window(buf, [shape], dtype=...)` 物化的切片在
`DistributedTensorType.window_buffer` 上保留指向源 `WindowBuffer` 的可选反向
引用,从而让两个 shape/dtype 相同但分配来源不同的切片在结构上保持不同。
用户在签名中写的 `pld.DistributedTensor[[shape], dtype]` 不填该字段(为
`None`)。Tile 类型没有 distributed 变体;跨 rank op 始终作用在
`DistributedTensor` 上。

### 带 TensorView 的 TensorType

Expand Down
1 change: 1 addition & 0 deletions include/pypto/ir/core.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ enum class ObjectKind {
TileType,
ArrayType,
TupleType,
WindowBufferType,

// Other IR node kinds
Function,
Expand Down
9 changes: 7 additions & 2 deletions include/pypto/ir/kind_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ DEFINE_KIND_TRAIT(ArrayType, ObjectKind::ArrayType)
DEFINE_KIND_TRAIT(TupleType, ObjectKind::TupleType)
DEFINE_KIND_TRAIT(MemRefType, ObjectKind::MemRefType)
DEFINE_KIND_TRAIT(PtrType, ObjectKind::PtrType)
DEFINE_KIND_TRAIT(WindowBufferType, ObjectKind::WindowBufferType)

// Other IR node types
DEFINE_KIND_TRAIT(Function, ObjectKind::Function)
Expand Down Expand Up @@ -203,7 +204,8 @@ struct KindTrait<Type> {
ObjectKind::DistributedTensorType,
ObjectKind::TileType,
ObjectKind::ArrayType,
ObjectKind::TupleType};
ObjectKind::TupleType,
ObjectKind::WindowBufferType};
static constexpr size_t count = sizeof(kinds) / sizeof(ObjectKind);
};

Expand Down Expand Up @@ -278,7 +280,10 @@ std::shared_ptr<const T> As(const std::shared_ptr<const Base>& base) {
*
* As<Var>() uses exact ObjectKind matching and won't match IterArg.
* This utility matches both Var and IterArg (which inherits from Var).
* MemRef is intentionally excluded — use As<MemRef>() for that.
* MemRef and WindowBuffer are intentionally excluded — they are Var
* subclasses that carry allocation-source / window-slot semantics rather
* than the plain bound-name model AsVarLike's callers assume. Use
* As<MemRef>() / As<WindowBuffer>() when you specifically want them.
*/
inline VarPtr AsVarLike(const ExprPtr& expr) {
if (!expr) return nullptr;
Expand Down
60 changes: 28 additions & 32 deletions include/pypto/ir/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include <utility>
#include <vector>

#include "pypto/core/dtype.h"
#include "pypto/ir/core.h"
#include "pypto/ir/expr.h"
#include "pypto/ir/function.h"
Expand All @@ -30,39 +29,34 @@ namespace pypto {
namespace ir {

/**
* @brief Per-rank allocation spec for one named CommGroup HCCL window buffer.
* @brief Per-rank CommGroup HCCL window-buffer allocation, modelled as a Var.
*
* Maps 1:1 to ``simpler.task_interface.ChipBufferSpec`` at submit-time. Pure
* allocation metadata: does NOT describe how the buffer is used in code.
* Code-level use is expressed in the function signature via
* ``pld.DistributedTensor[[shape], dtype]``; the alloc op
* (``pld.alloc_window_buffer``, added in N2) materialises one of these slots
* into the program's :class:`CommGroup`.
* A specialised :class:`Var` subclass whose SSA-edge type is the singleton
* :class:`WindowBufferType`. The buffer's runtime-unique identifier flows
* through the inherited ``Var::name_hint_``; there is no separate ``name_``
* field so structural equality does not depend on the chosen variable name.
*
* ``size_`` is the **element count** of one rank's slice (a single scalar; this
* struct is allocation-only and intentionally does not carry a multi-dim
* shape). ``size_`` may be a ``ConstInt`` (compile-time known) or a symbolic
* expression referring to the world size.
*
* ``load_from_host_`` / ``store_to_host_`` are simple boolean flags marking
* whether the slot participates in pre-fork H2D / post-task D2H staging. The
* specific host tensor that supplies / receives the staged data is recorded
* on the alloc op, not on this allocation spec.
* Fields:
* * ``base_`` — :class:`Var` holding the underlying ``Ptr`` allocation
* identity. Multiple ``WindowBuffer`` instances built from the same alloc
* Var share allocation identity through this field.
* * ``size_`` — per-rank allocation size in **bytes**; ``ConstInt`` or
* symbolic :class:`ExprPtr`.
* * ``load_from_host_`` / ``store_to_host_`` — pre-fork H2D / post-task
* D2H staging flags.
*/
class WindowBuffer : public IRNode {
class WindowBuffer : public Var {
public:
std::string name_; ///< Buffer name (parser-extracted from alloc-op LHS)
ExprPtr size_; ///< Per-rank element count (ConstInt or symbolic Expr)
DataType dtype_; ///< Element data type
bool load_from_host_ = false; ///< Pre-fork H2D copy from a host staging tensor
bool store_to_host_ = false; ///< Post-task D2H copy back into a host staging tensor

WindowBuffer(std::string name, ExprPtr size, DataType dtype, bool load_from_host = false,
bool store_to_host = false, Span span = Span::unknown())
: IRNode(std::move(span)),
name_(std::move(name)),
VarPtr base_; ///< Ptr Var from the alloc op (allocation identity)
ExprPtr size_; ///< Per-rank allocation size in bytes
bool load_from_host_ = false; ///< Pre-fork H2D staging flag
bool store_to_host_ = false; ///< Post-task D2H staging flag

WindowBuffer(VarPtr base, ExprPtr size, bool load_from_host = false, bool store_to_host = false,
Span span = Span::unknown())
: Var(base->name_hint_, GetWindowBufferType(), std::move(span)),
base_(std::move(base)),
size_(std::move(size)),
dtype_(dtype),
load_from_host_(load_from_host),
store_to_host_(store_to_host) {}

Expand All @@ -71,15 +65,17 @@ class WindowBuffer : public IRNode {

static constexpr auto GetFieldDescriptors() {
return std::tuple_cat(
IRNode::GetFieldDescriptors(),
std::make_tuple(reflection::UsualField(&WindowBuffer::name_, "name"),
Var::GetFieldDescriptors(),
std::make_tuple(reflection::UsualField(&WindowBuffer::base_, "base"),
reflection::UsualField(&WindowBuffer::size_, "size"),
reflection::UsualField(&WindowBuffer::dtype_, "dtype"),
reflection::UsualField(&WindowBuffer::load_from_host_, "load_from_host"),
reflection::UsualField(&WindowBuffer::store_to_host_, "store_to_host")));
}
};

// WindowBufferPtr is forward-declared in include/pypto/ir/type.h so that
// DistributedTensorType::window_buffer_ can hold it without a circular
// include.
using WindowBufferPtr = std::shared_ptr<const WindowBuffer>;

/**
Expand Down
66 changes: 58 additions & 8 deletions include/pypto/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ using ExprPtr = std::shared_ptr<const Expr>;
class MemRef;
using MemRefPtr = std::shared_ptr<const MemRef>;

class WindowBuffer;
using WindowBufferPtr = std::shared_ptr<const WindowBuffer>;

/**
* @brief Base class for type representations in the IR
*
Expand Down Expand Up @@ -500,32 +503,53 @@ using TensorTypePtr = std::shared_ptr<const TensorType>;
*/
class DistributedTensorType : public TensorType {
public:
DistributedTensorType(std::vector<ExprPtr> shape, DataType dtype) : TensorType(std::move(shape), dtype) {}
/// Optional back-reference to the :class:`WindowBuffer` whose allocation this
/// tensor is a view of. Populated by ``pld.window``'s type deducer;
/// ``std::nullopt`` for user-declared parameter annotations like
/// ``pld.DistributedTensor[[shape], dtype]``. Two DistributedTensorTypes with
/// the same shape / dtype but different ``window_buffer_`` values are
/// structurally distinct, so passes can tell apart slices of different
/// CommGroup window buffers.
std::optional<WindowBufferPtr> window_buffer_;

DistributedTensorType(std::vector<ExprPtr> shape, DataType dtype)
: TensorType(std::move(shape), dtype), window_buffer_(std::nullopt) {}

DistributedTensorType(std::vector<ExprPtr> shape, DataType dtype, MemRefPtr memref)
: TensorType(std::move(shape), dtype, std::move(memref)) {}
: TensorType(std::move(shape), dtype, std::move(memref)), window_buffer_(std::nullopt) {}

DistributedTensorType(std::vector<ExprPtr> shape, DataType dtype, std::optional<MemRefPtr> memref)
: TensorType(std::move(shape), dtype, std::move(memref)) {}
: TensorType(std::move(shape), dtype, std::move(memref)), window_buffer_(std::nullopt) {}

DistributedTensorType(std::vector<ExprPtr> shape, DataType dtype, std::optional<MemRefPtr> memref,
std::optional<TensorView> tensor_view)
: TensorType(std::move(shape), dtype, std::move(memref), std::move(tensor_view)) {}
: TensorType(std::move(shape), dtype, std::move(memref), std::move(tensor_view)),
window_buffer_(std::nullopt) {}

DistributedTensorType(const std::vector<int64_t>& shape, DataType dtype)
: TensorType(shape, dtype, std::nullopt) {}
: TensorType(shape, dtype, std::nullopt), window_buffer_(std::nullopt) {}

DistributedTensorType(const std::vector<int64_t>& shape, DataType dtype, std::optional<MemRefPtr> memref)
: TensorType(shape, dtype, std::move(memref)) {}
: TensorType(shape, dtype, std::move(memref)), window_buffer_(std::nullopt) {}

DistributedTensorType(const std::vector<int64_t>& shape, DataType dtype, std::optional<MemRefPtr> memref,
std::optional<TensorView> tensor_view)
: TensorType(shape, dtype, std::move(memref), std::move(tensor_view)) {}
: TensorType(shape, dtype, std::move(memref), std::move(tensor_view)), window_buffer_(std::nullopt) {}

/// Construct a DistributedTensorType produced by ``pld.window``: the result
/// is paired with the originating :class:`WindowBuffer` so passes can recover
/// the comm-group / slot identity later.
DistributedTensorType(std::vector<ExprPtr> shape, DataType dtype, WindowBufferPtr window_buffer)
: TensorType(std::move(shape), dtype), window_buffer_(std::move(window_buffer)) {}

[[nodiscard]] ObjectKind GetKind() const override { return ObjectKind::DistributedTensorType; }
[[nodiscard]] std::string TypeName() const override { return "DistributedTensorType"; }

static constexpr auto GetFieldDescriptors() { return TensorType::GetFieldDescriptors(); }
static constexpr auto GetFieldDescriptors() {
return std::tuple_cat(
TensorType::GetFieldDescriptors(),
std::make_tuple(reflection::UsualField(&DistributedTensorType::window_buffer_, "window_buffer")));
}
};

using DistributedTensorTypePtr = std::shared_ptr<const DistributedTensorType>;
Expand Down Expand Up @@ -711,6 +735,32 @@ inline PtrTypePtr GetPtrType() {
return ptr_type;
}

/**
* @brief Singleton marker type for ``pld.alloc_window_buffer`` results.
*
* Carries no per-instance fields; all allocation metadata (size, host-staging
* flags, etc.) lives on the :class:`WindowBuffer` Var subclass that the alloc
* op binds. Cross-rank op verifiers dispatch on this marker
* (``As<WindowBufferType>``) to reject non-window arguments.
*/
class WindowBufferType : public Type {
public:
WindowBufferType() = default;

[[nodiscard]] ObjectKind GetKind() const override { return ObjectKind::WindowBufferType; }
[[nodiscard]] std::string TypeName() const override { return "WindowBufferType"; }

static constexpr auto GetFieldDescriptors() { return Type::GetFieldDescriptors(); }
};

using WindowBufferTypePtr = std::shared_ptr<const WindowBufferType>;

/// Get the shared singleton WindowBufferType instance.
inline WindowBufferTypePtr GetWindowBufferType() {
static const auto window_buffer_type = std::make_shared<WindowBufferType>();
return window_buffer_type;
}

} // namespace ir
} // namespace pypto

Expand Down
44 changes: 33 additions & 11 deletions python/bindings/modules/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,11 @@ void BindIR(nb::module_& m) {
nb::init<const std::vector<int64_t>&, DataType, std::optional<MemRefPtr>, std::optional<TensorView>>(),
nb::arg("shape"), nb::arg("dtype"), nb::arg("memref") = nb::none(), nb::arg("tensor_view") = nb::none(),
"Create a distributed tensor type with constant shape, optional memref and tensor_view");
dist_tensor_type_class.def(
nb::init<std::vector<ExprPtr>, DataType, WindowBufferPtr>(), nb::arg("shape"), nb::arg("dtype"),
nb::arg("window_buffer"),
"Create a distributed tensor type produced by pld.window; window_buffer is the back-"
"reference to the source WindowBuffer allocation.");
BindFields<DistributedTensorType>(dist_tensor_type_class);

// TileType - const shared_ptr
Expand Down Expand Up @@ -507,6 +512,18 @@ void BindIR(nb::module_& m) {
ptr_type_class.def_static("get", &GetPtrType, "Get the singleton PtrType instance");
BindFields<PtrType>(ptr_type_class);

// WindowBufferType - singleton marker type for pld.alloc_window_buffer outputs.
// Mirrors MemRefType: no per-instance fields. The WindowBuffer Var subclass
// (bindings below) carries the allocation metadata.
auto window_buffer_type_class = nb::class_<WindowBufferType, Type>(
ir, "WindowBufferType",
"Singleton marker type for pld.alloc_window_buffer outputs. The companion WindowBuffer "
"Var subclass carries the allocation metadata (name, size, dtype, host flags).");
window_buffer_type_class.def(nb::init<>(), "Create the singleton WindowBufferType instance.");
window_buffer_type_class.def_static("get", &GetWindowBufferType,
"Get the shared singleton WindowBufferType instance.");
BindFields<WindowBufferType>(window_buffer_type_class);

// MemorySpace enum
nb::enum_<MemorySpace>(ir, "MemorySpace", "Memory space enumeration")
.value("DDR", MemorySpace::DDR, "DDR memory (off-chip)")
Expand Down Expand Up @@ -1487,18 +1504,23 @@ void BindIR(nb::module_& m) {
"List of CommGroups declared on the program. CommGroups participate "
"in structural equality / hashing through reflection.");

// CommGroup / WindowBuffer — IRNode-typed host-side metadata attached to a Program.
auto window_buffer_class = nb::class_<WindowBuffer, IRNode>(
// WindowBuffer — a specialised Var subclass that carries CommGroup window-buffer
// allocation metadata. Mirrors MemRef's Var-subclass shape; the inherited
// ``name_hint`` is mirrored from ``name_`` (UsualField, unique-id role).
auto window_buffer_class = nb::class_<WindowBuffer, Var>(
ir, "WindowBuffer",
"Per-rank allocation spec for a named CommGroup HCCL window buffer. "
"Maps 1:1 to simpler.task_interface.ChipBufferSpec at submit-time. "
"size_ is the per-rank element count; load_from_host_ / store_to_host_ "
"are bool flags marking host-staging participation (the actual host "
"tensor binding is recorded on the alloc op, not here).");
window_buffer_class.def(nb::init<std::string, ExprPtr, DataType, bool, bool, Span>(), nb::arg("name"),
nb::arg("size"), nb::arg("dtype"), nb::arg("load_from_host") = false,
nb::arg("store_to_host") = false, nb::arg("span") = Span::unknown(),
"Create a WindowBuffer.");
"Per-rank CommGroup window-buffer allocation, modelled as a specialised Var. "
"Its SSA-edge type is the singleton WindowBufferType; the allocation metadata "
"(name, size, dtype, host-staging flags) lives on the Var subclass directly — "
"the exact mirror of how MemRef carries (base, byte_offset, size) under MemRefType. "
"Constructed by the comm-collection pass; the alloc op's LHS at parse time is a "
"plain Var(PtrType).");
window_buffer_class.def(nb::init<VarPtr, ExprPtr, bool, bool, Span>(), nb::arg("base"), nb::arg("size"),
nb::arg("load_from_host") = false, nb::arg("store_to_host") = false,
nb::arg("span") = Span::unknown(),
"Create a WindowBuffer wrapping the given Ptr Var. The buffer's "
"runtime-unique identifier flows through the inherited "
"Var.name_hint (taken from base.name_hint).");
BindFields<WindowBuffer>(window_buffer_class);

auto comm_group_class =
Expand Down
Loading
Loading