Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions docs/en/dev/ir/02-types.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,41 @@ tensor_with_memref = ir.TensorType(shape, DataType.FP32, memref)
`TensorType.memory_space` is always `ir.Mem.DDR`. `MemRef` carries address,
size, and id; memory space is not stored on `MemRef` itself.

### DistributedTensorType

`DistributedTensorType` is a precise-`ObjectKind` subclass of `TensorType`
used as the function-signature type for chip orchestrator / InCore parameters
that slice a CommGroup HCCL window buffer. It exists so cross-rank op
verifiers (introduced in later milestones) can reject plain `Tensor`
arguments — `As<TensorType>` does NOT match a `DistributedTensorType`
(precise `ObjectKind` semantics; see
[ir-kind-traits.md](../../../../.claude/rules/ir-kind-traits.md)). Use
`As<DistributedTensorType>` to dispatch on the distributed variant.

The DSL surface is `pld.DistributedTensor[[shape], dtype]`:

```python
import pypto.language.distributed as pld
import pypto.language as pl

@pl.function(type=pl.FunctionType.InCore)
def kernel(self, data: pld.DistributedTensor[[256], pl.FP32]): ...
```

At the IR level:

```python
t = ir.DistributedTensorType([64], DataType.FP32)
assert isinstance(t, ir.TensorType) # C++ inheritance preserved
# As<TensorType>(t) → null; As<DistributedTensorType>(t) → cast.
```

Allocation-side metadata (the buffer name, host-staging flags) lives on the
allocation op (`pld.alloc_window_buffer`, added in a later milestone) and on
the `ir.WindowBuffer` slot in `program.comm_groups`, not on the type itself.
Tile types do not have a distributed variant; cross-rank ops always operate
on `DistributedTensor`.

### TensorType with TensorView

Tensor with layout and stride information for optimized memory access.
Expand Down
33 changes: 33 additions & 0 deletions docs/zh-cn/dev/ir/02-types.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,39 @@ tensor_with_memref = ir.TensorType(shape, DataType.FP32, memref)
`TensorType.memory_space` 始终是 `ir.Mem.DDR`。`MemRef` 只保存地址、大小和
id;内存空间不再存储在 `MemRef` 本身上。

### DistributedTensorType

`DistributedTensorType` 是 `TensorType` 的精确 `ObjectKind` 子类,作为 chip
orchestrator / InCore 形参的类型注解,用来切片 CommGroup HCCL window buffer。
它的存在让跨 rank op 的 verifier(后续 milestone 引入)可以静态拒绝普通的
`Tensor` 实参 —— `As<TensorType>` **不会**匹配 `DistributedTensorType`
(精确 `ObjectKind` 匹配语义,见
[ir-kind-traits.md](../../../../.claude/rules/ir-kind-traits.md)),跨 rank op 用
`As<DistributedTensorType>` 派生。

DSL 形式是 `pld.DistributedTensor[[shape], dtype]`:

```python
import pypto.language.distributed as pld
import pypto.language as pl

@pl.function(type=pl.FunctionType.InCore)
def kernel(self, data: pld.DistributedTensor[[256], pl.FP32]): ...
```

IR 层:

```python
t = ir.DistributedTensorType([64], DataType.FP32)
assert isinstance(t, ir.TensorType) # C++ 继承关系保留
# As<TensorType>(t) → null;As<DistributedTensorType>(t) → 转型成功
```

分配侧的元数据(buffer 名字、host staging 标志)挂在 alloc op
(`pld.alloc_window_buffer`,后续 milestone 引入)和 `program.comm_groups`
中的 `ir.WindowBuffer` slot 上,**不**在类型本身。Tile 类型没有 distributed
变体;跨 rank op 始终作用在 `DistributedTensor` 上。

### 带 TensorView 的 TensorType

带有布局和步长信息的张量,用于优化内存访问。
Expand Down
3 changes: 3 additions & 0 deletions include/pypto/ir/core.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,15 @@ enum class ObjectKind {
ScalarType,
ShapedType,
TensorType,
DistributedTensorType,
TileType,
TupleType,

// Other IR node kinds
Function,
Program,
WindowBuffer,
CommGroup,

// Op kinds
Op,
Expand Down
21 changes: 15 additions & 6 deletions include/pypto/ir/kind_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,11 @@ DEFINE_KIND_TRAIT(InlineStmt, ObjectKind::InlineStmt)
DEFINE_KIND_TRAIT(UnknownType, ObjectKind::UnknownType)
DEFINE_KIND_TRAIT(ScalarType, ObjectKind::ScalarType)
// ShapedType is both a concrete type and a base class - handled separately below
// TensorType: precise-match (DistributedTensorType is a subclass with its own
// ObjectKind, so As<TensorType>(dt) returns nullptr by design — see
// .claude/rules/ir-kind-traits.md and the L3 distributed plan).
DEFINE_KIND_TRAIT(TensorType, ObjectKind::TensorType)
DEFINE_KIND_TRAIT(DistributedTensorType, ObjectKind::DistributedTensorType)
DEFINE_KIND_TRAIT(TileType, ObjectKind::TileType)
DEFINE_KIND_TRAIT(TupleType, ObjectKind::TupleType)
DEFINE_KIND_TRAIT(MemRefType, ObjectKind::MemRefType)
Expand All @@ -113,6 +117,8 @@ DEFINE_KIND_TRAIT(PtrType, ObjectKind::PtrType)
// Other IR node types
DEFINE_KIND_TRAIT(Function, ObjectKind::Function)
DEFINE_KIND_TRAIT(Program, ObjectKind::Program)
DEFINE_KIND_TRAIT(WindowBuffer, ObjectKind::WindowBuffer)
DEFINE_KIND_TRAIT(CommGroup, ObjectKind::CommGroup)

// Op kinds
DEFINE_KIND_TRAIT(Op, ObjectKind::Op)
Expand Down Expand Up @@ -189,19 +195,22 @@ struct KindTrait<UnaryExpr> {
// Type base class - matches any type kind
template <>
struct KindTrait<Type> {
static constexpr ObjectKind kinds[] = {ObjectKind::UnknownType, ObjectKind::ScalarType,
ObjectKind::ShapedType, ObjectKind::TensorType,
ObjectKind::TileType, ObjectKind::TupleType};
static constexpr ObjectKind kinds[] = {ObjectKind::UnknownType,
ObjectKind::ScalarType,
ObjectKind::ShapedType,
ObjectKind::TensorType,
ObjectKind::DistributedTensorType,
ObjectKind::TileType,
ObjectKind::TupleType};
static constexpr size_t count = sizeof(kinds) / sizeof(ObjectKind);
};

// ShapedType can be used as both a concrete type and a base class
// It matches itself, TensorType, and TileType
// It matches itself, TensorType, DistributedTensorType, and TileType
template <>
struct KindTrait<ShapedType> {
// For base class matching: includes ShapedType, TensorType, TileType
static constexpr ObjectKind kinds[] = {ObjectKind::ShapedType, ObjectKind::TensorType,
ObjectKind::TileType};
ObjectKind::DistributedTensorType, ObjectKind::TileType};
static constexpr size_t count = sizeof(kinds) / sizeof(ObjectKind);
};

Expand Down
117 changes: 114 additions & 3 deletions include/pypto/ir/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <utility>
#include <vector>

#include "pypto/core/dtype.h"
#include "pypto/ir/core.h"
#include "pypto/ir/expr.h"
#include "pypto/ir/function.h"
Expand All @@ -28,6 +29,91 @@
namespace pypto {
namespace ir {

/**
* @brief Per-rank allocation spec for one named CommGroup HCCL window buffer.
*
* Maps 1:1 to ``simpler.task_interface.ChipBufferSpec`` at submit-time. Pure
* allocation metadata: does NOT describe how the buffer is used in code.
* Code-level use is expressed in the function signature via
* ``pld.DistributedTensor[[shape], dtype]``; the alloc op
* (``pld.alloc_window_buffer``, added in N2) materialises one of these slots
* into the program's :class:`CommGroup`.
*
* ``size_`` is the **element count** of one rank's slice (a single scalar; this
* struct is allocation-only and intentionally does not carry a multi-dim
* shape). ``size_`` may be a ``ConstInt`` (compile-time known) or a symbolic
* expression referring to the world size.
*
* ``load_from_host_`` / ``store_to_host_`` are simple boolean flags marking
* whether the slot participates in pre-fork H2D / post-task D2H staging. The
* specific host tensor that supplies / receives the staged data is recorded
* on the alloc op, not on this allocation spec.
*/
class WindowBuffer : public IRNode {
public:
std::string name_; ///< Buffer name (parser-extracted from alloc-op LHS)
ExprPtr size_; ///< Per-rank element count (ConstInt or symbolic Expr)
DataType dtype_; ///< Element data type
bool load_from_host_ = false; ///< Pre-fork H2D copy from a host staging tensor
bool store_to_host_ = false; ///< Post-task D2H copy back into a host staging tensor

WindowBuffer(std::string name, ExprPtr size, DataType dtype, bool load_from_host = false,
bool store_to_host = false, Span span = Span::unknown())
: IRNode(std::move(span)),
name_(std::move(name)),
size_(std::move(size)),
dtype_(dtype),
load_from_host_(load_from_host),
store_to_host_(store_to_host) {}

[[nodiscard]] ObjectKind GetKind() const override { return ObjectKind::WindowBuffer; }
[[nodiscard]] std::string TypeName() const override { return "WindowBuffer"; }

static constexpr auto GetFieldDescriptors() {
return std::tuple_cat(
IRNode::GetFieldDescriptors(),
std::make_tuple(reflection::UsualField(&WindowBuffer::name_, "name"),
reflection::UsualField(&WindowBuffer::size_, "size"),
reflection::UsualField(&WindowBuffer::dtype_, "dtype"),
reflection::UsualField(&WindowBuffer::load_from_host_, "load_from_host"),
reflection::UsualField(&WindowBuffer::store_to_host_, "store_to_host")));
}
};

using WindowBufferPtr = std::shared_ptr<const WindowBuffer>;

/**
* @brief A communication group inferred for a ``@pl.program``.
*
* The ``CollectCommGroups`` pass (N4) builds these from
* ``pld.alloc_window_buffer`` ops and their dispatch coverage. The runtime
* (``distributed_runner``) uses this to compose ``ChipBootstrapConfig`` before
* bringing the workers up.
*
* ``devices_`` is the ascending-sorted set of physical device ids covered by
* the group. **An empty vector means "all devices"** (every entry of
* ``DistributedConfig.device_ids``, resolved by the driver at submit-time).
*/
class CommGroup : public IRNode {
public:
std::vector<int64_t> devices_; ///< Covered device ids (ascending); empty = all devices
std::vector<WindowBufferPtr> slots_; ///< Allocation slots in this group (alloc-order)

CommGroup(std::vector<int64_t> devices, std::vector<WindowBufferPtr> slots, Span span = Span::unknown())
: IRNode(std::move(span)), devices_(std::move(devices)), slots_(std::move(slots)) {}

[[nodiscard]] ObjectKind GetKind() const override { return ObjectKind::CommGroup; }
[[nodiscard]] std::string TypeName() const override { return "CommGroup"; }

static constexpr auto GetFieldDescriptors() {
return std::tuple_cat(IRNode::GetFieldDescriptors(),
std::make_tuple(reflection::UsualField(&CommGroup::devices_, "devices"),
reflection::UsualField(&CommGroup::slots_, "slots")));
}
};

using CommGroupPtr = std::shared_ptr<const CommGroup>;

/**
* @brief Program definition
*
Expand All @@ -52,6 +138,16 @@ class Program : public IRNode {
Program(std::map<GlobalVarPtr, FunctionPtr, GlobalVarPtrLess> functions, std::string name, Span span)
: IRNode(std::move(span)), functions_(std::move(functions)), name_(std::move(name)) {}

/**
* @brief Map-based ctor with CommGroup metadata (used by the deserializer).
*/
Program(std::map<GlobalVarPtr, FunctionPtr, GlobalVarPtrLess> functions,
std::vector<CommGroupPtr> comm_groups, std::string name, Span span)
: IRNode(std::move(span)),
functions_(std::move(functions)),
name_(std::move(name)),
comm_groups_(std::move(comm_groups)) {}

/**
* @brief Create a program from a list of functions
*
Expand All @@ -64,6 +160,17 @@ class Program : public IRNode {
*/
Program(const std::vector<FunctionPtr>& functions, std::string name, Span span);

/**
* @brief Create a program from a list of functions and CommGroup metadata.
*
* @param functions List of functions
* @param comm_groups List of CommGroups declared on the program
* @param name Program name (optional)
* @param span Source location
*/
Program(const std::vector<FunctionPtr>& functions, std::vector<CommGroupPtr> comm_groups, std::string name,
Span span);

[[nodiscard]] ObjectKind GetKind() const override { return ObjectKind::Program; }
[[nodiscard]] std::string TypeName() const override { return "Program"; }

Expand All @@ -84,19 +191,23 @@ class Program : public IRNode {
[[nodiscard]] GlobalVarPtr GetGlobalVar(const std::string& name) const;

/**
* @brief Get field descriptors for reflection-based visitation
* @brief Get field descriptors for reflection-based visitation.
*
* @return Tuple of field descriptors (name as IGNORE field, functions as USUAL field)
* ``comm_groups_`` participates in structural equality / hashing via
* ``UsualField``: two programs declaring the same CommGroups (same names,
* sizes, dtypes, host-staging flags) are structurally equivalent.
*/
static constexpr auto GetFieldDescriptors() {
return std::tuple_cat(IRNode::GetFieldDescriptors(),
std::make_tuple(reflection::IgnoreField(&Program::name_, "name"),
reflection::UsualField(&Program::functions_, "functions")));
reflection::UsualField(&Program::functions_, "functions"),
reflection::UsualField(&Program::comm_groups_, "comm_groups")));
}

public:
std::string name_; // Program name
std::map<GlobalVarPtr, FunctionPtr, GlobalVarPtrLess> functions_; // Map of GlobalVars to Functions
std::vector<CommGroupPtr> comm_groups_; // CommGroups (host-side metadata)
};

using ProgramPtr = std::shared_ptr<const Program>;
Expand Down
8 changes: 8 additions & 0 deletions include/pypto/ir/serialization/type_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ class DeserializerContext {
virtual IRNodePtr DeserializeNode(const msgpack::object& obj, msgpack::zone& zone) = 0;
virtual msgpack::object GetFieldObj(const msgpack::object& fields_obj, const std::string& field_name) = 0;

/**
* @brief Whether ``field_name`` is present in ``fields_obj``.
*
* Lets deserializers handle optional / forward-compatible fields without
* tripping ``GetFieldObj``'s missing-field exception.
*/
virtual bool HasField(const msgpack::object& fields_obj, const std::string& field_name) = 0;

template <typename T>
T GetField(const msgpack::object& fields_obj, const std::string& field_name) {
msgpack::object field_obj = GetFieldObj(fields_obj, field_name);
Expand Down
44 changes: 44 additions & 0 deletions include/pypto/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,50 @@ class TensorType : public ShapedType {

using TensorTypePtr = std::shared_ptr<const TensorType>;

/**
* @brief Distributed tensor type — a per-rank slice of a CommGroup HCCL window buffer.
*
* Subclass of :class:`TensorType` distinguished only by ``ObjectKind`` so that
* verifiers can reject plain ``TensorType`` arguments to cross-rank ops
* (``pld.tile.remote_load`` / ``pld.system.notify`` / ``pld.system.wait``).
*
* Note ``As<TensorType>`` does NOT match ``DistributedTensorType`` (precise
* ObjectKind match). This is intentional — the cross-rank ops use
* ``As<DistributedTensorType>`` to enforce that only window-bound tensors flow
* through them.
*/
class DistributedTensorType : public TensorType {
public:
DistributedTensorType(std::vector<ExprPtr> shape, DataType dtype) : TensorType(std::move(shape), dtype) {}

DistributedTensorType(std::vector<ExprPtr> shape, DataType dtype, MemRefPtr memref)
: TensorType(std::move(shape), dtype, std::move(memref)) {}

DistributedTensorType(std::vector<ExprPtr> shape, DataType dtype, std::optional<MemRefPtr> memref)
: TensorType(std::move(shape), dtype, std::move(memref)) {}

DistributedTensorType(std::vector<ExprPtr> shape, DataType dtype, std::optional<MemRefPtr> memref,
std::optional<TensorView> tensor_view)
: TensorType(std::move(shape), dtype, std::move(memref), std::move(tensor_view)) {}

DistributedTensorType(const std::vector<int64_t>& shape, DataType dtype)
: TensorType(shape, dtype, std::nullopt) {}

DistributedTensorType(const std::vector<int64_t>& shape, DataType dtype, std::optional<MemRefPtr> memref)
: TensorType(shape, dtype, std::move(memref)) {}

DistributedTensorType(const std::vector<int64_t>& shape, DataType dtype, std::optional<MemRefPtr> memref,
std::optional<TensorView> tensor_view)
: TensorType(shape, dtype, std::move(memref), std::move(tensor_view)) {}

[[nodiscard]] ObjectKind GetKind() const override { return ObjectKind::DistributedTensorType; }
[[nodiscard]] std::string TypeName() const override { return "DistributedTensorType"; }

static constexpr auto GetFieldDescriptors() { return TensorType::GetFieldDescriptors(); }
};

using DistributedTensorTypePtr = std::shared_ptr<const DistributedTensorType>;

/**
* @brief Tile type representation
*
Expand Down
Loading
Loading