From 2edfc698b44bb94ffd721a7513fa343fa5c56540 Mon Sep 17 00:00:00 2001 From: chenshengxin Date: Sat, 9 May 2026 19:20:53 +0800 Subject: [PATCH] feat(api): support address-based pipe slot model + tfree(entry, split) Extend the Python frontend so kernels can express the address-based pipe shape that ptoas and pto-isa already accept (PR #606 / hw-native-sys/pto-isa "update TALLOC/TPUSH/TPOP/TFREE to support push or pop GlobalTensor"): API additions (ptodsl/api/pto_general.py, ptodsl/api/pto.py): - aic_initialize_pipe / aiv_initialize_pipe: accept gm_slot_tensor (!pto.tensor_view<...>) instead of the legacy gm_slot_buffer + c2v/v2c_consumer_buf triplet, plus a local_slot_num attribute that mirrors the C++ TPipe template arg. The c2v_consumer_buf and v2c_consumer_buf operands become optional so kernels using the address-based form do not need to supply them. Falls back to the generic Operation.create path while the installed mlir.dialects.pto._pto_ops_gen binding still predates the new operand shape. - talloc_to_aic / talloc_to_aiv: emit the address-based slot allocation ops (returns a tensor_view that subsequent tpush/tfree consume). - tfree_from_aic(entry, split=...) / tfree_from_aiv(entry, split=...): new entry+split overload that carries the popped tensor_view to the free op. Existing tfree(split, id=...) callers keep working unchanged. This is the missing operand that turned address-based TFREE into a pipe-only no-op in ptoas-generated C++; carrying the entry restores real free notifications and unblocks long sequences (S1>=4096) that previously hung on slot exhaustion. - bitcast: thin wrapper around _pto.BitcastOp (e.g. for cross-cast reinterpreting a !pto.ptr region as !pto.ptr). Type system (ptodsl/api/type_def.py): - TensorType now accepts shape=[d0, d1, ...] for static-shape views, alongside the existing rank=N dynamic form. Static shape is required by gm_slot_tensor pipe init because the lowered C++ runtime templates on a concrete pto::Shape<...>. Compiler (ptodsl/compiler/ir.py): - New PTODSL_SKIP_VERIFY env knob that suppresses the post-build ir_module.operation.verify(). This is a transitional escape hatch while the installed mlir-dialect verifier still rejects the address-based gm_slot_tensor init shape; it is meant to be removed once the dialect verifier/binding catches up to ptoas. Backwards compatibility: - All legacy callers (gm_slot_buffer + c2v/v2c, split-only tfree, rank-only TensorType) keep working without modification. The new API surface is purely additive. Motivating downstream: - hw-native-sys/pto-isa#117 (PTO-DSL Flash Attention performance kernel) consumes this API to align with the manual fa_performance_kernel.cpp. --- ptodsl/api/pto.py | 6 ++ ptodsl/api/pto_general.py | 147 +++++++++++++++++++++++++++++++++----- ptodsl/api/type_def.py | 15 +++- ptodsl/compiler/ir.py | 4 +- 4 files changed, 152 insertions(+), 20 deletions(-) diff --git a/ptodsl/api/pto.py b/ptodsl/api/pto.py index 0b90f0b6..d9964b43 100644 --- a/ptodsl/api/pto.py +++ b/ptodsl/api/pto.py @@ -5,6 +5,7 @@ add_ptr, alloc_tile, aic_initialize_pipe, + bitcast, aiv_initialize_pipe, as_tensor, call, @@ -22,6 +23,8 @@ reserve_buffer, slice_view, store, + talloc_to_aic, + talloc_to_aiv, tfree_from_aic, tfree_from_aiv, tpop_from_aic, @@ -71,6 +74,7 @@ "set_ffts", "add_ptr", "as_tensor", + "bitcast", "slice_view", "vector_section", "cube_section", @@ -87,6 +91,8 @@ "load_scalar", "load", "store", + "talloc_to_aic", + "talloc_to_aiv", "tpush_to_aiv", "tpush_to_aic", "tpop_from_aic", diff --git a/ptodsl/api/pto_general.py b/ptodsl/api/pto_general.py index a7f25593..6d5e57c8 100644 --- a/ptodsl/api/pto_general.py +++ b/ptodsl/api/pto_general.py @@ -1,7 +1,15 @@ from contextlib import contextmanager from mlir.dialects import pto as _pto -from mlir.ir import FlatSymbolRefAttr, InsertionPoint, Operation +from mlir.ir import ( + Attribute, + DenseI32ArrayAttr, + FlatSymbolRefAttr, + InsertionPoint, + IntegerAttr, + IntegerType, + Operation, +) from .scalar import Value, _unwrap @@ -44,6 +52,27 @@ def _resolve_peer_func_attr(peer_func): return peer_func +def _int_attr(value, bits): + if isinstance(value, Attribute): + return value + return IntegerAttr.get(IntegerType.get_signless(bits), value) + + +def _maybe_int_attr(attrs, name, value, bits): + if value is not None: + attrs[name] = _int_attr(value, bits) + + +def _bool_attr(value): + if isinstance(value, Attribute): + return value + return Attribute.parse("true" if value else "false") + + +def _operand_segment_attr(*sizes): + return DenseI32ArrayAttr.get(sizes) + + def call(callee, *args): return Operation.create( "func.call", @@ -64,6 +93,11 @@ def add_ptr(ptr, offset): return _pto.AddPtrOp(_unwrap(ptr), _unwrap(offset)).result +def bitcast(result_type, src): + """Reinterpret-cast a value, e.g. `!pto.ptr` -> `!pto.ptr`.""" + return _pto.BitcastOp(result_type, _unwrap(src)).result + + def as_tensor(tensor_type, *, ptr, shape, strides, layout=None): shape_vals = [_unwrap(v) for v in shape] stride_vals = [_unwrap(v) for v in strides] @@ -150,19 +184,40 @@ def aic_initialize_pipe( *, dir_mask, slot_size, - gm_slot_buffer=None, # only needed on a2/a3? - c2v_consumer_buf, - v2c_consumer_buf, + gm_slot_buffer=None, # !pto.ptr<...> address-only slot + gm_slot_tensor=None, # !pto.tensor_view<...> for address-based slot model (PR #606) + c2v_consumer_buf=None, + v2c_consumer_buf=None, id=None, + local_slot_num=None, # TPipe template arg (default 2 in C++); ptoas expands + # to SlotNum if absent, which inflates the FFTS event + # multiplex (see manual fa_performance_kernel.cpp). nosplit=None, ): + if gm_slot_tensor is not None: + attrs = { + "dir_mask": _int_attr(dir_mask, 8), + "operandSegmentSizes": _operand_segment_attr(1), + "slot_size": _int_attr(slot_size, 32), + } + _maybe_int_attr(attrs, "id", id, 32) + _maybe_int_attr(attrs, "local_slot_num", local_slot_num, 32) + if nosplit is not None: + attrs["nosplit"] = _bool_attr(nosplit) + return Operation.create( + "pto.aic_initialize_pipe", + operands=[_unwrap(gm_slot_tensor)], + attributes=attrs, + ) return _pto.AicInitializePipeOp( dir_mask, slot_size, - c2v_consumer_buf=_unwrap(c2v_consumer_buf), - v2c_consumer_buf=_unwrap(v2c_consumer_buf), - gm_slot_buffer=_unwrap(gm_slot_buffer), + c2v_consumer_buf=_unwrap(c2v_consumer_buf) if c2v_consumer_buf is not None else None, + v2c_consumer_buf=_unwrap(v2c_consumer_buf) if v2c_consumer_buf is not None else None, + gm_slot_buffer=_unwrap(gm_slot_buffer) if gm_slot_buffer is not None else None, + gm_slot_tensor=_unwrap(gm_slot_tensor) if gm_slot_tensor is not None else None, id=id, + local_slot_num=local_slot_num, nosplit=nosplit, ) @@ -176,19 +231,38 @@ def aiv_initialize_pipe( *, dir_mask, slot_size, - gm_slot_buffer=None, # only needed on a2/a3 - c2v_consumer_buf, - v2c_consumer_buf, + gm_slot_buffer=None, # !pto.ptr<...> address-only slot + gm_slot_tensor=None, # !pto.tensor_view<...> for address-based slot model (PR #606) + c2v_consumer_buf=None, + v2c_consumer_buf=None, id=None, + local_slot_num=None, # mirrors aic_initialize_pipe; see comment there. nosplit=None, ): + if gm_slot_tensor is not None: + attrs = { + "dir_mask": _int_attr(dir_mask, 8), + "operandSegmentSizes": _operand_segment_attr(1), + "slot_size": _int_attr(slot_size, 32), + } + _maybe_int_attr(attrs, "id", id, 32) + _maybe_int_attr(attrs, "local_slot_num", local_slot_num, 32) + if nosplit is not None: + attrs["nosplit"] = _bool_attr(nosplit) + return Operation.create( + "pto.aiv_initialize_pipe", + operands=[_unwrap(gm_slot_tensor)], + attributes=attrs, + ) return _pto.AivInitializePipeOp( dir_mask, slot_size, - c2v_consumer_buf=_unwrap(c2v_consumer_buf), - v2c_consumer_buf=_unwrap(v2c_consumer_buf), - gm_slot_buffer=_unwrap(gm_slot_buffer), + c2v_consumer_buf=_unwrap(c2v_consumer_buf) if c2v_consumer_buf is not None else None, + v2c_consumer_buf=_unwrap(v2c_consumer_buf) if v2c_consumer_buf is not None else None, + gm_slot_buffer=_unwrap(gm_slot_buffer) if gm_slot_buffer is not None else None, + gm_slot_tensor=_unwrap(gm_slot_tensor) if gm_slot_tensor is not None else None, id=id, + local_slot_num=local_slot_num, nosplit=nosplit, ) @@ -217,7 +291,7 @@ def initialize_l2g2l_pipe( slot_size, slot_num, _unwrap(gm_addr), - _unwrap(local_addr), + local_addr=_unwrap(local_addr), local_slot_num=local_slot_num, flag_base=flag_base, peer_local_addr=( @@ -283,13 +357,50 @@ def tpop_from_aiv(tile_type, split, *, id=None): return _pto.TPopFromAivOp(tile_type, split, id=id).result +# %entry = pto.talloc_to_aiv {split = 1} -> !pto.tensor_view<...> +def talloc_to_aiv(view_type, split, *, id=None): + attrs = {"split": _int_attr(split, 8)} + _maybe_int_attr(attrs, "id", id, 32) + return Operation.create( + "pto.talloc_to_aiv", + results=[view_type], + attributes=attrs, + ).results[0] + + +def talloc_to_aic(view_type, split, *, id=None): + attrs = {"split": _int_attr(split, 8)} + _maybe_int_attr(attrs, "id", id, 32) + return Operation.create( + "pto.talloc_to_aic", + results=[view_type], + attributes=attrs, + ).results[0] + + # pto.tfree_from_aic {split = 0} -def tfree_from_aic(split, *, id=None): - return _pto.TFreeFromAicOp(split, id=id) +def tfree_from_aic(entry_or_split, split=None, *, id=None): + if split is None: + return _pto.TFreeFromAicOp(entry_or_split, id=id) + attrs = {"split": _int_attr(split, 8)} + _maybe_int_attr(attrs, "id", id, 32) + return Operation.create( + "pto.tfree_from_aic", + operands=[_unwrap(entry_or_split)], + attributes=attrs, + ) -def tfree_from_aiv(split, *, id=None): - return _pto.TFreeFromAivOp(split, id=id) +def tfree_from_aiv(entry_or_split, split=None, *, id=None): + if split is None: + return _pto.TFreeFromAivOp(entry_or_split, id=id) + attrs = {"split": _int_attr(split, 8)} + _maybe_int_attr(attrs, "id", id, 32) + return Operation.create( + "pto.tfree_from_aiv", + operands=[_unwrap(entry_or_split)], + attributes=attrs, + ) def load_scalar(result_type, ptr, offset): diff --git a/ptodsl/api/type_def.py b/ptodsl/api/type_def.py index a379e952..cfa4da38 100644 --- a/ptodsl/api/type_def.py +++ b/ptodsl/api/type_def.py @@ -28,7 +28,20 @@ def PtrType(dtype): return _pto.PtrType.get(dtype) -def TensorType(*, rank, dtype): +def TensorType(*, rank=None, shape=None, dtype): + """Build a `!pto.tensor_view`. + + Pass ``rank`` for a dynamic-shape view (``?x?xfp32``) or ``shape`` for a + statically-shaped one (``128x32xfp32``). The static form is required when + the lowered C++ runtime needs concrete `pto::Shape<...>` template params, + e.g. for the address-based slot model added by PR #606. + """ + if shape is not None and rank is not None: + raise ValueError("TensorType: pass either rank or shape, not both") + if shape is not None: + return _pto.TensorViewType.get(list(shape), dtype) + if rank is None: + raise ValueError("TensorType: pass rank or shape") return _pto.TensorViewType.get(rank, dtype) diff --git a/ptodsl/compiler/ir.py b/ptodsl/compiler/ir.py index 0820e9f5..4e7a500c 100644 --- a/ptodsl/compiler/ir.py +++ b/ptodsl/compiler/ir.py @@ -1,4 +1,5 @@ import inspect +import os from mlir.dialects import func, pto as _pto from mlir.ir import Attribute, Context, InsertionPoint, Location, Module, UnitAttr @@ -160,7 +161,8 @@ def decorator(fn): else: _define(ir_module, ctx, meta_map, fn) - ir_module.operation.verify() + if os.environ.get("PTODSL_SKIP_VERIFY") not in ("1", "true", "TRUE", "yes", "YES"): + ir_module.operation.verify() return ir_module return decorator