Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ptodsl/api/pto.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
add_ptr,
alloc_tile,
aic_initialize_pipe,
bitcast,
aiv_initialize_pipe,
as_tensor,
call,
Expand All @@ -22,6 +23,8 @@
reserve_buffer,
slice_view,
store,
talloc_to_aic,
talloc_to_aiv,
tfree_from_aic,
tfree_from_aiv,
tpop_from_aic,
Expand Down Expand Up @@ -71,6 +74,7 @@
"set_ffts",
"add_ptr",
"as_tensor",
"bitcast",
"slice_view",
"vector_section",
"cube_section",
Expand All @@ -87,6 +91,8 @@
"load_scalar",
"load",
"store",
"talloc_to_aic",
"talloc_to_aiv",
"tpush_to_aiv",
"tpush_to_aic",
"tpop_from_aic",
Expand Down
147 changes: 129 additions & 18 deletions ptodsl/api/pto_general.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from contextlib import contextmanager

from mlir.dialects import pto as _pto
from mlir.ir import FlatSymbolRefAttr, InsertionPoint, Operation
from mlir.ir import (
Attribute,
DenseI32ArrayAttr,
FlatSymbolRefAttr,
InsertionPoint,
IntegerAttr,
IntegerType,
Operation,
)

from .scalar import Value, _unwrap

Expand Down Expand Up @@ -44,6 +52,27 @@ def _resolve_peer_func_attr(peer_func):
return peer_func


def _int_attr(value, bits):
if isinstance(value, Attribute):
return value
return IntegerAttr.get(IntegerType.get_signless(bits), value)


def _maybe_int_attr(attrs, name, value, bits):
if value is not None:
attrs[name] = _int_attr(value, bits)


def _bool_attr(value):
if isinstance(value, Attribute):
return value
return Attribute.parse("true" if value else "false")


def _operand_segment_attr(*sizes):
return DenseI32ArrayAttr.get(sizes)


def call(callee, *args):
return Operation.create(
"func.call",
Expand All @@ -64,6 +93,11 @@ def add_ptr(ptr, offset):
return _pto.AddPtrOp(_unwrap(ptr), _unwrap(offset)).result


def bitcast(result_type, src):
"""Reinterpret-cast a value, e.g. `!pto.ptr<f32>` -> `!pto.ptr<f16>`."""
return _pto.BitcastOp(result_type, _unwrap(src)).result


def as_tensor(tensor_type, *, ptr, shape, strides, layout=None):
shape_vals = [_unwrap(v) for v in shape]
stride_vals = [_unwrap(v) for v in strides]
Expand Down Expand Up @@ -150,19 +184,40 @@ def aic_initialize_pipe(
*,
dir_mask,
slot_size,
gm_slot_buffer=None, # only needed on a2/a3?
c2v_consumer_buf,
v2c_consumer_buf,
gm_slot_buffer=None, # !pto.ptr<...> address-only slot
gm_slot_tensor=None, # !pto.tensor_view<...> for address-based slot model (PR #606)
c2v_consumer_buf=None,
v2c_consumer_buf=None,
id=None,
local_slot_num=None, # TPipe template arg (default 2 in C++); ptoas expands
# to SlotNum if absent, which inflates the FFTS event
# multiplex (see manual fa_performance_kernel.cpp).
nosplit=None,
):
if gm_slot_tensor is not None:
attrs = {
"dir_mask": _int_attr(dir_mask, 8),
"operandSegmentSizes": _operand_segment_attr(1),
"slot_size": _int_attr(slot_size, 32),
}
_maybe_int_attr(attrs, "id", id, 32)
_maybe_int_attr(attrs, "local_slot_num", local_slot_num, 32)
if nosplit is not None:
attrs["nosplit"] = _bool_attr(nosplit)
return Operation.create(
"pto.aic_initialize_pipe",
operands=[_unwrap(gm_slot_tensor)],
attributes=attrs,
)
return _pto.AicInitializePipeOp(
dir_mask,
slot_size,
c2v_consumer_buf=_unwrap(c2v_consumer_buf),
v2c_consumer_buf=_unwrap(v2c_consumer_buf),
gm_slot_buffer=_unwrap(gm_slot_buffer),
c2v_consumer_buf=_unwrap(c2v_consumer_buf) if c2v_consumer_buf is not None else None,
v2c_consumer_buf=_unwrap(v2c_consumer_buf) if v2c_consumer_buf is not None else None,
gm_slot_buffer=_unwrap(gm_slot_buffer) if gm_slot_buffer is not None else None,
gm_slot_tensor=_unwrap(gm_slot_tensor) if gm_slot_tensor is not None else None,
id=id,
local_slot_num=local_slot_num,
nosplit=nosplit,
)

Expand All @@ -176,19 +231,38 @@ def aiv_initialize_pipe(
*,
dir_mask,
slot_size,
gm_slot_buffer=None, # only needed on a2/a3
c2v_consumer_buf,
v2c_consumer_buf,
gm_slot_buffer=None, # !pto.ptr<...> address-only slot
gm_slot_tensor=None, # !pto.tensor_view<...> for address-based slot model (PR #606)
c2v_consumer_buf=None,
v2c_consumer_buf=None,
id=None,
local_slot_num=None, # mirrors aic_initialize_pipe; see comment there.
nosplit=None,
):
if gm_slot_tensor is not None:
attrs = {
"dir_mask": _int_attr(dir_mask, 8),
"operandSegmentSizes": _operand_segment_attr(1),
"slot_size": _int_attr(slot_size, 32),
}
_maybe_int_attr(attrs, "id", id, 32)
_maybe_int_attr(attrs, "local_slot_num", local_slot_num, 32)
if nosplit is not None:
attrs["nosplit"] = _bool_attr(nosplit)
return Operation.create(
"pto.aiv_initialize_pipe",
operands=[_unwrap(gm_slot_tensor)],
attributes=attrs,
)
return _pto.AivInitializePipeOp(
dir_mask,
slot_size,
c2v_consumer_buf=_unwrap(c2v_consumer_buf),
v2c_consumer_buf=_unwrap(v2c_consumer_buf),
gm_slot_buffer=_unwrap(gm_slot_buffer),
c2v_consumer_buf=_unwrap(c2v_consumer_buf) if c2v_consumer_buf is not None else None,
v2c_consumer_buf=_unwrap(v2c_consumer_buf) if v2c_consumer_buf is not None else None,
gm_slot_buffer=_unwrap(gm_slot_buffer) if gm_slot_buffer is not None else None,
gm_slot_tensor=_unwrap(gm_slot_tensor) if gm_slot_tensor is not None else None,
id=id,
local_slot_num=local_slot_num,
nosplit=nosplit,
)

Expand Down Expand Up @@ -217,7 +291,7 @@ def initialize_l2g2l_pipe(
slot_size,
slot_num,
_unwrap(gm_addr),
_unwrap(local_addr),
local_addr=_unwrap(local_addr),
local_slot_num=local_slot_num,
flag_base=flag_base,
peer_local_addr=(
Expand Down Expand Up @@ -283,13 +357,50 @@ def tpop_from_aiv(tile_type, split, *, id=None):
return _pto.TPopFromAivOp(tile_type, split, id=id).result


# %entry = pto.talloc_to_aiv {split = 1} -> !pto.tensor_view<...>
def talloc_to_aiv(view_type, split, *, id=None):
attrs = {"split": _int_attr(split, 8)}
_maybe_int_attr(attrs, "id", id, 32)
return Operation.create(
"pto.talloc_to_aiv",
results=[view_type],
attributes=attrs,
).results[0]


def talloc_to_aic(view_type, split, *, id=None):
attrs = {"split": _int_attr(split, 8)}
_maybe_int_attr(attrs, "id", id, 32)
return Operation.create(
"pto.talloc_to_aic",
results=[view_type],
attributes=attrs,
).results[0]


# pto.tfree_from_aic {split = 0}
def tfree_from_aic(split, *, id=None):
return _pto.TFreeFromAicOp(split, id=id)
def tfree_from_aic(entry_or_split, split=None, *, id=None):
if split is None:
return _pto.TFreeFromAicOp(entry_or_split, id=id)
attrs = {"split": _int_attr(split, 8)}
_maybe_int_attr(attrs, "id", id, 32)
return Operation.create(
"pto.tfree_from_aic",
operands=[_unwrap(entry_or_split)],
attributes=attrs,
)


def tfree_from_aiv(split, *, id=None):
return _pto.TFreeFromAivOp(split, id=id)
def tfree_from_aiv(entry_or_split, split=None, *, id=None):
if split is None:
return _pto.TFreeFromAivOp(entry_or_split, id=id)
attrs = {"split": _int_attr(split, 8)}
_maybe_int_attr(attrs, "id", id, 32)
return Operation.create(
"pto.tfree_from_aiv",
operands=[_unwrap(entry_or_split)],
attributes=attrs,
)


def load_scalar(result_type, ptr, offset):
Expand Down
15 changes: 14 additions & 1 deletion ptodsl/api/type_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,20 @@ def PtrType(dtype):
return _pto.PtrType.get(dtype)


def TensorType(*, rank, dtype):
def TensorType(*, rank=None, shape=None, dtype):
"""Build a `!pto.tensor_view`.

Pass ``rank`` for a dynamic-shape view (``?x?xfp32``) or ``shape`` for a
statically-shaped one (``128x32xfp32``). The static form is required when
the lowered C++ runtime needs concrete `pto::Shape<...>` template params,
e.g. for the address-based slot model added by PR #606.
"""
if shape is not None and rank is not None:
raise ValueError("TensorType: pass either rank or shape, not both")
if shape is not None:
return _pto.TensorViewType.get(list(shape), dtype)
if rank is None:
raise ValueError("TensorType: pass rank or shape")
return _pto.TensorViewType.get(rank, dtype)


Expand Down
4 changes: 3 additions & 1 deletion ptodsl/compiler/ir.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import os

from mlir.dialects import func, pto as _pto
from mlir.ir import Attribute, Context, InsertionPoint, Location, Module, UnitAttr
Expand Down Expand Up @@ -160,7 +161,8 @@ def decorator(fn):
else:
_define(ir_module, ctx, meta_map, fn)

ir_module.operation.verify()
if os.environ.get("PTODSL_SKIP_VERIFY") not in ("1", "true", "TRUE", "yes", "YES"):
ir_module.operation.verify()
return ir_module

return decorator
Expand Down