Skip to content

[RFC] L3 Distributed Programming Interface Design #1127

@YunjiQin

Description

@YunjiQin

Summary

Propose the L3 (HOST-level) distributed programming interface for PyPTO. This RFC defines a three-level execution hierarchy (HOST → CHIP → CORE_GROUP) with two complementary writing styles: functional (separate @pl.function methods) and inline (pl.at() context manager). The goal is to collect feedback on API ergonomics, naming, and design trade-offs before the implementation lands.

Motivation

PyPTO currently supports single-device programming (InCore kernels + Chip Orchestration). To enable multi-device and HOST-level workflows — such as dispatching chip tasks, verifying results on CPU, or aggregating outputs from multiple devices — we need a distributed programming interface.

Key requirements:

  • Express multi-level hierarchies (HOST → CHIP → CORE_GROUP) in the DSL
  • Support SubWorkers: HOST-level pure-Python tasks that run after chip tasks complete (verification, reduction, post-processing)
  • Provide both a modular (functional) and a compact (inline) writing style
  • Integrate with the existing ir.compile() pipeline and simpler's Worker(level=3) runtime

Design

Hierarchy Model

Three execution levels, each with an optional role:

Level Keyword Role Description
L0–L1 pl.Level.CORE_GROUP Tile-level vector operations (InCore)
L2 pl.Level.CHIP Orchestrator Device-level task dispatch
L3 pl.Level.HOST Orchestrator / Worker HOST-level distributed scheduling

Roles (L3+ only):

  • pl.Role.Orchestrator — Builds DAG, dispatches chip tasks and SubWorkers
  • pl.Role.Worker — Executes pure-Python compute/validation (SubWorker)

Style 1: Functional (Separate Functions)

Each hierarchy level is a separate @pl.function-decorated method. Explicit, modular, reusable.

import pypto.language as pl
import torch
from pypto import ir
from pypto.ir.distributed_compiled_program import DistributedConfig

@pl.program
class L3DependencyProgram:
    """L3: HOST orch → CHIP worker (a + b) → SubWorker (verify)."""

    @pl.function(type=pl.FunctionType.InCore)
    def tile_add(
        self,
        a: pl.Tensor[[128, 128], pl.FP32],
        b: pl.Tensor[[128, 128], pl.FP32],
        f: pl.Out[pl.Tensor[[128, 128], pl.FP32]],
    ) -> pl.Tensor[[128, 128], pl.FP32]:
        tile_a = pl.load(a, [0, 0], [128, 128])
        tile_b = pl.load(b, [0, 0], [128, 128])
        tile_f = pl.add(tile_a, tile_b)
        out_f = pl.store(tile_f, [0, 0], f)
        return out_f

    @pl.function(type=pl.FunctionType.Orchestration)
    def chip_orch(
        self,
        a: pl.Tensor[[128, 128], pl.FP32],
        b: pl.Tensor[[128, 128], pl.FP32],
        f: pl.Out[pl.Tensor[[128, 128], pl.FP32]],
    ) -> pl.Tensor[[128, 128], pl.FP32]:
        out_f = self.tile_add(a, b, f)
        return out_f

    @pl.function(level=pl.Level.HOST, role=pl.Role.Worker)
    def verify(self, f: pl.Tensor[[128, 128], pl.FP32]):
        expected = torch.full((128, 128), 5.0, dtype=torch.float32)
        if not torch.allclose(f, expected, rtol=1e-5, atol=1e-5):
            raise AssertionError(
                f"SubWorker verify failed: expected 5.0, "
                f"got max={f.max().item()}, min={f.min().item()}"
            )

    @pl.function(level=pl.Level.HOST, role=pl.Role.Orchestrator)
    def host_orch(
        self,
        a: pl.Tensor[[128, 128], pl.FP32],
        b: pl.Tensor[[128, 128], pl.FP32],
        f: pl.Out[pl.Tensor[[128, 128], pl.FP32]],
    ) -> pl.Tensor[[128, 128], pl.FP32]:
        out_f: pl.Tensor[[128, 128], pl.FP32] = self.chip_orch(a, b, f)
        self.verify(out_f)
        return out_f

Key characteristics:

  • Each level is a separate method — clear, testable in isolation
  • level= / role= on @pl.function specifies the hierarchy position
  • SubWorker body (verify) is pure Python — can use torch, numpy, etc.
  • SubWorker inputs/outputs defined via standard function signature with pl.Tensor type annotations
  • HOST Orchestrator (host_orch) dispatches chip work via self.chip_orch(...) and SubWorker via self.verify(...)

Style 2: Inline (pl.at() Context Manager)

All hierarchy levels inlined into a single @pl.function via with pl.at(). Compact, self-contained.

@pl.program
class L3DependencyInlineProgram:
    """L3: all levels inlined into host_orch via pl.at() using tensor API."""

    @pl.function(level=pl.Level.HOST, role=pl.Role.Orchestrator)
    def host_orch(
        self,
        a: pl.Tensor[[128, 128], pl.FP32],
        b: pl.Tensor[[128, 128], pl.FP32],
        f: pl.Out[pl.Tensor[[128, 128], pl.FP32]],
    ) -> pl.Tensor[[128, 128], pl.FP32]:
        # CHIP orchestration scope
        with pl.at(level=pl.Level.CHIP, role=pl.Role.Orchestrator):
            # CORE_GROUP scope (tile-level computation)
            with pl.at(level=pl.Level.CORE_GROUP):
                out = pl.add(a, b)
                out_f = pl.assemble(f, out, [0, 0])

        # SubWorker scope (verification on HOST)
        with pl.at(level=pl.Level.HOST, role=pl.Role.Worker):
            expected = torch.full((128, 128), 5.0, dtype=torch.float32)
            if not torch.allclose(out_f, expected, rtol=1e-5, atol=1e-5):
                raise AssertionError(
                    f"SubWorker verify failed: expected 5.0, "
                    f"got max={out_f.max().item()}, min={out_f.min().item()}"
                )

        return out_f

Key characteristics:

  • Single method — all hierarchy levels visible at a glance
  • Nesting reflects the execution hierarchy: HOST → CHIP → CORE_GROUP
  • SubWorker body via with pl.at(level=pl.Level.HOST, role=pl.Role.Worker) — pure Python, source captured from AST
  • SubWorker inputs: variables referenced from parent scope are auto-detected as inputs
  • SubWorker outputs: declared via annotated assignments with pl.Tensor type (e.g., result: pl.Tensor[[128, 128], pl.FP32] = ...). This is how the parser identifies output variables from the SubWorker scope
  • Variables flow across scope boundaries (e.g., out_f used in SubWorker block)

Comparison: Functional vs Inline

Aspect Functional Inline (pl.at())
Structure Separate methods Nested with blocks
Reusability Methods reusable across programs Scopes are local
Readability Modular, each level self-contained Compact, full picture in one function
SubWorker declaration @pl.function(level=HOST, role=Worker) with pl.at(level=HOST, role=Worker)
SubWorker inputs Function parameters with pl.Tensor types Auto-detected from parent scope variable references
SubWorker outputs Function parameters with pl.Out[pl.Tensor[...]] Annotated assignments: var: pl.Tensor[...] = expr
Best for Complex programs with shared kernels Simple pipelines, prototyping

Both styles produce identical compiled output and runtime behavior.


Advanced Pattern: Parallel Chip Tasks + SubWorker Reduce

Multiple independent chip tasks dispatched by the HOST orchestrator, with a SubWorker aggregating results.

@pl.program
class L3ParallelReduceProgram:
    """L3: HOST orch → 2 independent chip workers → SubWorker (reduce)."""

    @pl.function(type=pl.FunctionType.InCore)
    def tile_add(
        self,
        a: pl.Tensor[[128, 128], pl.FP32],
        b: pl.Tensor[[128, 128], pl.FP32],
        f: pl.Out[pl.Tensor[[128, 128], pl.FP32]],
    ) -> pl.Tensor[[128, 128], pl.FP32]:
        tile_a = pl.load(a, [0, 0], [128, 128])
        tile_b = pl.load(b, [0, 0], [128, 128])
        tile_f = pl.add(tile_a, tile_b)
        return pl.store(tile_f, [0, 0], f)

    @pl.function(type=pl.FunctionType.InCore)
    def tile_sub(
        self,
        a: pl.Tensor[[128, 128], pl.FP32],
        b: pl.Tensor[[128, 128], pl.FP32],
        f: pl.Out[pl.Tensor[[128, 128], pl.FP32]],
    ) -> pl.Tensor[[128, 128], pl.FP32]:
        tile_a = pl.load(a, [0, 0], [128, 128])
        tile_b = pl.load(b, [0, 0], [128, 128])
        tile_f = pl.sub(tile_a, tile_b)
        return pl.store(tile_f, [0, 0], f)

    @pl.function(type=pl.FunctionType.Orchestration)
    def chip_orch_add(self, a, b, f):
        return self.tile_add(a, b, f)

    @pl.function(type=pl.FunctionType.Orchestration)
    def chip_orch_sub(self, a, b, f):
        return self.tile_sub(a, b, f)

    @pl.function(level=pl.Level.HOST, role=pl.Role.Worker)
    def reduce_sum(
        self,
        sum_ab: pl.Tensor[[128, 128], pl.FP32],
        diff_ab: pl.Tensor[[128, 128], pl.FP32],
        f: pl.Out[pl.Tensor[[128, 128], pl.FP32]],
    ):
        result = sum_ab + diff_ab  # f = (a+b) + (a-b) = 2a
        f[:] = result

    @pl.function(level=pl.Level.HOST, role=pl.Role.Orchestrator)
    def host_orch(self, a, b, f):
        sum_ab = pl.create_tensor([128, 128], dtype=pl.FP32)
        diff_ab = pl.create_tensor([128, 128], dtype=pl.FP32)
        out_sum = self.chip_orch_add(a, b, sum_ab)
        out_diff = self.chip_orch_sub(a, b, diff_ab)
        self.reduce_sum(out_sum, out_diff, f)
        return f

Key points:

  • chip_orch_add and chip_orch_sub are independent → runtime can parallelize them
  • reduce_sum (SubWorker) runs after both chip tasks complete
  • pl.create_tensor() allocates intermediate HOST tensors
  • SubWorker receives multiple chip outputs and writes the final result

API Surface

@pl.function Decorator Parameters

@pl.function(
    type: pl.FunctionType = ...,      # InCore, Orchestration
    level: pl.Level | None = None,    # HOST, CHIP, CORE_GROUP, ...
    role: pl.Role | None = None,      # Orchestrator, Worker
)

For L3 distributed, level= and role= replace type=:

  • level=pl.Level.HOST, role=pl.Role.Orchestrator → HOST entry point
  • level=pl.Level.HOST, role=pl.Role.Worker → SubWorker

pl.at() Context Manager

pl.at(
    level: pl.Level,                          # Target hierarchy level
    role: pl.Role | None = None,              # Function role (L3+ only)
    *,
    optimizations: list[Optimization] | None = None,
    name_hint: str = "",                      # Outlined function name
) -> AtContext

DistributedConfig

from pypto.ir.distributed_compiled_program import DistributedConfig

DistributedConfig(
    device_ids: list[int] = [0],        # Device IDs to use
    num_sub_workers: int = 0,           # SubWorker thread count
    runtime: str = "tensormap_and_ringbuffer",
    block_dim: int = 1,
    aicpu_thread_num: int = 1,
)

Compilation and Execution

compiled = ir.compile(
    L3DependencyProgram,
    platform="a2a3sim",
    distributed_config=DistributedConfig(
        device_ids=[7],
        num_sub_workers=1,
        block_dim=3,
        aicpu_thread_num=4,
    ),
)

a = torch.full((128, 128), 2.0, dtype=torch.float32)
b = torch.full((128, 128), 3.0, dtype=torch.float32)
f = torch.zeros((128, 128), dtype=torch.float32)

compiled(a, b, f)          # In-place style
# or
f = compiled(a, b)         # Return style (auto-allocates outputs)

SubWorker Mechanism

SubWorkers (level >= HOST, role = Worker) are pure-Python functions executed on the HOST after chip tasks complete.

Source capture:

  • Functional style: inspect.getsource() at @pl.program creation time
  • Inline style: AST-level source extraction from with pl.at() body

Input/output detection in inline style:

  • Inputs: AST-analyzed references to parent-scope variables
  • Outputs: Annotated assignments with pl.Tensor type (e.g., result: pl.Tensor[[128, 128], pl.FP32] = expr)

Generated output structure:

output_dir/
├── orchestration/
│   └── host_orch.py           # Generated HOST orchestrator
├── next_levels/
│   └── chip_orch/             # Compiled chip-level task
│       ├── kernels/
│       └── kernel_config.py
└── sub_workers/
    └── verify.py              # Generated SubWorker callable

Runtime execution:

  1. Compile chip-level tasks → simpler.ChipCallable
  2. Load generated SubWorker Python modules
  3. Create simpler.Worker(level=3), register SubWorkers
  4. Execute HOST orchestration closure via worker.run(Task(...))

Affected Files

Key implementation files:

File Purpose
python/pypto/language/dsl_api.py pl.at() context manager
python/pypto/language/parser/decorator.py @pl.function, @pl.program, SubWorker registry
python/pypto/language/parser/ast_parser.py AST parsing, _parse_sub_worker_scope()
python/pypto/ir/compile.py Distributed program detection, compilation entry
python/pypto/ir/distributed_compiled_program.py DistributedConfig, DistributedCompiledProgram
python/pypto/backend/pto_backend.py Distributed code generation
python/pypto/runtime/distributed_runner.py Distributed execution via simpler Worker

Open Questions

  1. Naming: Is pl.Role.Worker clear enough for SubWorker semantics? Should it be pl.Role.SubWorker to avoid confusion with the runtime Worker class?

  2. SubWorker constraints: Should SubWorkers be restricted to pure-Python (no DSL ops), or should they support a subset of DSL operations for HOST-level tensor manipulation?

  3. Variable flow in inline style: Variables cross with pl.at() boundaries (e.g., out_f computed inside CHIP scope, used in SubWorker scope). Is implicit data flow clear enough, or should we require explicit declarations?

  4. Output annotation convention: In the inline style, SubWorker outputs use annotated assignments (var: pl.Tensor[...] = expr). Is this convention intuitive, or should there be a more explicit pl.output(...) marker?

  5. Error handling: How should errors in SubWorkers be surfaced? Currently they raise Python exceptions — should there be a structured error reporting mechanism?

  6. Multi-device: The current design uses a single device_ids list. How should we express per-chip-task device affinity for multi-device workflows?

  7. Inline style limitations: Should the inline style support the advanced pattern (multiple chip tasks + SubWorker reduce), or is the functional style sufficient for complex workflows?

Metadata

Metadata

Assignees

Labels

rfcDesign proposal / request for comments

Type

No type

Projects

Status

No status

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions