diff --git a/ptodsl/api/pto.py b/ptodsl/api/pto.py index 0b90f0b6..d9964b43 100644 --- a/ptodsl/api/pto.py +++ b/ptodsl/api/pto.py @@ -5,6 +5,7 @@ add_ptr, alloc_tile, aic_initialize_pipe, + bitcast, aiv_initialize_pipe, as_tensor, call, @@ -22,6 +23,8 @@ reserve_buffer, slice_view, store, + talloc_to_aic, + talloc_to_aiv, tfree_from_aic, tfree_from_aiv, tpop_from_aic, @@ -71,6 +74,7 @@ "set_ffts", "add_ptr", "as_tensor", + "bitcast", "slice_view", "vector_section", "cube_section", @@ -87,6 +91,8 @@ "load_scalar", "load", "store", + "talloc_to_aic", + "talloc_to_aiv", "tpush_to_aiv", "tpush_to_aic", "tpop_from_aic", diff --git a/ptodsl/api/pto_general.py b/ptodsl/api/pto_general.py index a7f25593..6d5e57c8 100644 --- a/ptodsl/api/pto_general.py +++ b/ptodsl/api/pto_general.py @@ -1,7 +1,15 @@ from contextlib import contextmanager from mlir.dialects import pto as _pto -from mlir.ir import FlatSymbolRefAttr, InsertionPoint, Operation +from mlir.ir import ( + Attribute, + DenseI32ArrayAttr, + FlatSymbolRefAttr, + InsertionPoint, + IntegerAttr, + IntegerType, + Operation, +) from .scalar import Value, _unwrap @@ -44,6 +52,27 @@ def _resolve_peer_func_attr(peer_func): return peer_func +def _int_attr(value, bits): + if isinstance(value, Attribute): + return value + return IntegerAttr.get(IntegerType.get_signless(bits), value) + + +def _maybe_int_attr(attrs, name, value, bits): + if value is not None: + attrs[name] = _int_attr(value, bits) + + +def _bool_attr(value): + if isinstance(value, Attribute): + return value + return Attribute.parse("true" if value else "false") + + +def _operand_segment_attr(*sizes): + return DenseI32ArrayAttr.get(sizes) + + def call(callee, *args): return Operation.create( "func.call", @@ -64,6 +93,11 @@ def add_ptr(ptr, offset): return _pto.AddPtrOp(_unwrap(ptr), _unwrap(offset)).result +def bitcast(result_type, src): + """Reinterpret-cast a value, e.g. `!pto.ptr` -> `!pto.ptr`.""" + return _pto.BitcastOp(result_type, _unwrap(src)).result + + def as_tensor(tensor_type, *, ptr, shape, strides, layout=None): shape_vals = [_unwrap(v) for v in shape] stride_vals = [_unwrap(v) for v in strides] @@ -150,19 +184,40 @@ def aic_initialize_pipe( *, dir_mask, slot_size, - gm_slot_buffer=None, # only needed on a2/a3? - c2v_consumer_buf, - v2c_consumer_buf, + gm_slot_buffer=None, # !pto.ptr<...> address-only slot + gm_slot_tensor=None, # !pto.tensor_view<...> for address-based slot model (PR #606) + c2v_consumer_buf=None, + v2c_consumer_buf=None, id=None, + local_slot_num=None, # TPipe template arg (default 2 in C++); ptoas expands + # to SlotNum if absent, which inflates the FFTS event + # multiplex (see manual fa_performance_kernel.cpp). nosplit=None, ): + if gm_slot_tensor is not None: + attrs = { + "dir_mask": _int_attr(dir_mask, 8), + "operandSegmentSizes": _operand_segment_attr(1), + "slot_size": _int_attr(slot_size, 32), + } + _maybe_int_attr(attrs, "id", id, 32) + _maybe_int_attr(attrs, "local_slot_num", local_slot_num, 32) + if nosplit is not None: + attrs["nosplit"] = _bool_attr(nosplit) + return Operation.create( + "pto.aic_initialize_pipe", + operands=[_unwrap(gm_slot_tensor)], + attributes=attrs, + ) return _pto.AicInitializePipeOp( dir_mask, slot_size, - c2v_consumer_buf=_unwrap(c2v_consumer_buf), - v2c_consumer_buf=_unwrap(v2c_consumer_buf), - gm_slot_buffer=_unwrap(gm_slot_buffer), + c2v_consumer_buf=_unwrap(c2v_consumer_buf) if c2v_consumer_buf is not None else None, + v2c_consumer_buf=_unwrap(v2c_consumer_buf) if v2c_consumer_buf is not None else None, + gm_slot_buffer=_unwrap(gm_slot_buffer) if gm_slot_buffer is not None else None, + gm_slot_tensor=_unwrap(gm_slot_tensor) if gm_slot_tensor is not None else None, id=id, + local_slot_num=local_slot_num, nosplit=nosplit, ) @@ -176,19 +231,38 @@ def aiv_initialize_pipe( *, dir_mask, slot_size, - gm_slot_buffer=None, # only needed on a2/a3 - c2v_consumer_buf, - v2c_consumer_buf, + gm_slot_buffer=None, # !pto.ptr<...> address-only slot + gm_slot_tensor=None, # !pto.tensor_view<...> for address-based slot model (PR #606) + c2v_consumer_buf=None, + v2c_consumer_buf=None, id=None, + local_slot_num=None, # mirrors aic_initialize_pipe; see comment there. nosplit=None, ): + if gm_slot_tensor is not None: + attrs = { + "dir_mask": _int_attr(dir_mask, 8), + "operandSegmentSizes": _operand_segment_attr(1), + "slot_size": _int_attr(slot_size, 32), + } + _maybe_int_attr(attrs, "id", id, 32) + _maybe_int_attr(attrs, "local_slot_num", local_slot_num, 32) + if nosplit is not None: + attrs["nosplit"] = _bool_attr(nosplit) + return Operation.create( + "pto.aiv_initialize_pipe", + operands=[_unwrap(gm_slot_tensor)], + attributes=attrs, + ) return _pto.AivInitializePipeOp( dir_mask, slot_size, - c2v_consumer_buf=_unwrap(c2v_consumer_buf), - v2c_consumer_buf=_unwrap(v2c_consumer_buf), - gm_slot_buffer=_unwrap(gm_slot_buffer), + c2v_consumer_buf=_unwrap(c2v_consumer_buf) if c2v_consumer_buf is not None else None, + v2c_consumer_buf=_unwrap(v2c_consumer_buf) if v2c_consumer_buf is not None else None, + gm_slot_buffer=_unwrap(gm_slot_buffer) if gm_slot_buffer is not None else None, + gm_slot_tensor=_unwrap(gm_slot_tensor) if gm_slot_tensor is not None else None, id=id, + local_slot_num=local_slot_num, nosplit=nosplit, ) @@ -217,7 +291,7 @@ def initialize_l2g2l_pipe( slot_size, slot_num, _unwrap(gm_addr), - _unwrap(local_addr), + local_addr=_unwrap(local_addr), local_slot_num=local_slot_num, flag_base=flag_base, peer_local_addr=( @@ -283,13 +357,50 @@ def tpop_from_aiv(tile_type, split, *, id=None): return _pto.TPopFromAivOp(tile_type, split, id=id).result +# %entry = pto.talloc_to_aiv {split = 1} -> !pto.tensor_view<...> +def talloc_to_aiv(view_type, split, *, id=None): + attrs = {"split": _int_attr(split, 8)} + _maybe_int_attr(attrs, "id", id, 32) + return Operation.create( + "pto.talloc_to_aiv", + results=[view_type], + attributes=attrs, + ).results[0] + + +def talloc_to_aic(view_type, split, *, id=None): + attrs = {"split": _int_attr(split, 8)} + _maybe_int_attr(attrs, "id", id, 32) + return Operation.create( + "pto.talloc_to_aic", + results=[view_type], + attributes=attrs, + ).results[0] + + # pto.tfree_from_aic {split = 0} -def tfree_from_aic(split, *, id=None): - return _pto.TFreeFromAicOp(split, id=id) +def tfree_from_aic(entry_or_split, split=None, *, id=None): + if split is None: + return _pto.TFreeFromAicOp(entry_or_split, id=id) + attrs = {"split": _int_attr(split, 8)} + _maybe_int_attr(attrs, "id", id, 32) + return Operation.create( + "pto.tfree_from_aic", + operands=[_unwrap(entry_or_split)], + attributes=attrs, + ) -def tfree_from_aiv(split, *, id=None): - return _pto.TFreeFromAivOp(split, id=id) +def tfree_from_aiv(entry_or_split, split=None, *, id=None): + if split is None: + return _pto.TFreeFromAivOp(entry_or_split, id=id) + attrs = {"split": _int_attr(split, 8)} + _maybe_int_attr(attrs, "id", id, 32) + return Operation.create( + "pto.tfree_from_aiv", + operands=[_unwrap(entry_or_split)], + attributes=attrs, + ) def load_scalar(result_type, ptr, offset): diff --git a/ptodsl/api/type_def.py b/ptodsl/api/type_def.py index a379e952..cfa4da38 100644 --- a/ptodsl/api/type_def.py +++ b/ptodsl/api/type_def.py @@ -28,7 +28,20 @@ def PtrType(dtype): return _pto.PtrType.get(dtype) -def TensorType(*, rank, dtype): +def TensorType(*, rank=None, shape=None, dtype): + """Build a `!pto.tensor_view`. + + Pass ``rank`` for a dynamic-shape view (``?x?xfp32``) or ``shape`` for a + statically-shaped one (``128x32xfp32``). The static form is required when + the lowered C++ runtime needs concrete `pto::Shape<...>` template params, + e.g. for the address-based slot model added by PR #606. + """ + if shape is not None and rank is not None: + raise ValueError("TensorType: pass either rank or shape, not both") + if shape is not None: + return _pto.TensorViewType.get(list(shape), dtype) + if rank is None: + raise ValueError("TensorType: pass rank or shape") return _pto.TensorViewType.get(rank, dtype) diff --git a/ptodsl/compiler/ir.py b/ptodsl/compiler/ir.py index 0820e9f5..4e7a500c 100644 --- a/ptodsl/compiler/ir.py +++ b/ptodsl/compiler/ir.py @@ -1,4 +1,5 @@ import inspect +import os from mlir.dialects import func, pto as _pto from mlir.ir import Attribute, Context, InsertionPoint, Location, Module, UnitAttr @@ -160,7 +161,8 @@ def decorator(fn): else: _define(ir_module, ctx, meta_map, fn) - ir_module.operation.verify() + if os.environ.get("PTODSL_SKIP_VERIFY") not in ("1", "true", "TRUE", "yes", "YES"): + ir_module.operation.verify() return ir_module return decorator