diff --git a/ptodsl/api/scalar.py b/ptodsl/api/scalar.py index 7f4e9d54..93b2770a 100644 --- a/ptodsl/api/scalar.py +++ b/ptodsl/api/scalar.py @@ -3,7 +3,7 @@ def _unwrap(value): - if isinstance(value, Value): + if isinstance(value, Value) or hasattr(value, "raw"): return value.raw return value @@ -79,6 +79,8 @@ def __getattr__(self, item): def wrap_value(value): if isinstance(value, Value): return value + if hasattr(value, "raw"): + return Value(value.raw) return Value(value) diff --git a/ptodsl/language.py b/ptodsl/language.py index de97b725..5e9cab1d 100644 --- a/ptodsl/language.py +++ b/ptodsl/language.py @@ -8,11 +8,18 @@ def _unwrap(value): - if isinstance(value, Value): + if isinstance(value, Value) or hasattr(value, "raw"): return value.raw return value +def _unwrap_index(value): + value = _unwrap(value) + if isinstance(value, int): + return arith.ConstantOp(IndexType.get(), value).result + return value + + class Value: # TODO: generalize to more comprehensive wrappers like https://github.com/makslevental/mlir-python-extras/blob/0.0.8.2/mlir/extras/dialects/ext/arith.py def __init__(self, raw): @@ -83,6 +90,8 @@ def __getattr__(self, item): def wrap_value(value): if isinstance(value, Value): return value + if hasattr(value, "raw"): + return Value(value.raw) return Value(value) @@ -322,16 +331,16 @@ def index_cast(value, index_type=IndexType): def as_tensor(tensor_type, *, ptr, shape, strides): - shape_vals = [_unwrap(v) for v in shape] - stride_vals = [_unwrap(v) for v in strides] + shape_vals = [_unwrap_index(v) for v in shape] + stride_vals = [_unwrap_index(v) for v in strides] return pto.MakeTensorViewOp( tensor_type, _unwrap(ptr), shape_vals, stride_vals ).result def slice_view(subtensor_type, *, source, offsets, sizes): - offset_vals = [_unwrap(v) for v in offsets] - size_vals = [_unwrap(v) for v in sizes] + offset_vals = [_unwrap_index(v) for v in offsets] + size_vals = [_unwrap_index(v) for v in sizes] return pto.PartitionViewOp( subtensor_type, source, offsets=offset_vals, sizes=size_vals ).result diff --git a/ptodsl/lib/a5/README.md b/ptodsl/lib/a5/README.md index baa2c9ee..204bb8e7 100644 --- a/ptodsl/lib/a5/README.md +++ b/ptodsl/lib/a5/README.md @@ -1,18 +1,28 @@ # A5 Library Layer -This directory contains a first PTODSL library-style translation layer for the -`pto-isa/include/pto/npu/a5` surface. +This directory contains a PTODSL library-style translation layer for the +`pto-isa/include/pto/npu/a5` surface, organized around the PTO tile opcode that +each file is re-expressing with PTO micro instructions. -The scope of this pass is: +The scope of this layout is: -- Pythonic wrappers over PTO tile ops and selected micro instructions -- A5-flavored compatibility aliases such as `TLoad`, `TAdd`, `TMatmul`, and `TStore` -- Translated builder kernels that emit `.pto` through PTODSL +- Small, readable files that show how a tile helper is written from PTO micro + opcodes such as `pto.vlds`, `pto.vadd`, and `pto.vsts` +- A5-flavored aliases such as `TLoad`, `TAdd`, `TMatmul`, and `TStore` +- Example builder kernels that emit `.pto` through PTODSL - A checked-in generation flow for reproducible `.pto` artifacts Entry points: -- [`ops.py`](./ops.py): reusable A5-style helpers built on PTODSL and PTO dialect ops +- [`tbinary.py`](./tbinary.py): tile binary helpers such as `tadd`, `tsub`, `tmul`, + `tdiv`, and `tor_`, written with PTO vector micro ops +- [`tunary.py`](./tunary.py): tile unary helpers such as `texp`, `tlog`, `trelu`, + `tsqrt`, `trsqrt`, and `trecip` +- [`texpand.py`](./texpand.py): row and column broadcast helpers +- [`treduce.py`](./treduce.py): row and column reduction helpers +- [`tsort.py`](./tsort.py): gather and sort helpers +- [`native.py`](./native.py): helpers that still map directly to tile/cube ops +- [`ops.py`](./ops.py): the public A5 surface that re-exports the split helpers - [`kernels.py`](./kernels.py): translated example kernels - [`generated`](./generated): emitted `.pto` artifacts from `scripts/generate_a5_pto.py` diff --git a/ptodsl/lib/a5/TILE_MICRO_CHECKLIST.md b/ptodsl/lib/a5/TILE_MICRO_CHECKLIST.md index 42f4175f..cd7cc994 100644 --- a/ptodsl/lib/a5/TILE_MICRO_CHECKLIST.md +++ b/ptodsl/lib/a5/TILE_MICRO_CHECKLIST.md @@ -1,43 +1,45 @@ # Tile Micro Coverage -- Total public tile ops: `32` +- Total public tile ops: `34` - Implemented: `26` - Partial: `1` - Pending: `0` -- Blocked: `4` +- Blocked: `6` - Not applicable: `1` -| tile op | status | helper | note | -| --- | --- | --- | --- | -| `mov` | `implemented` | `mov_micro` | UB stage + vlds/vsts copy loop. | -| `add` | `implemented` | `add_micro` | UB stage + constexpr-specialized TBinOp-style vlds/vadd/vsts lowering. | -| `sub` | `implemented` | `sub_micro` | UB stage + constexpr-specialized TBinOp-style vlds/vsub/vsts lowering. | -| `div` | `implemented` | `div_micro` | UB stage + constexpr-specialized TBinOp-style vlds/vdiv/vsts lowering. | -| `mul` | `implemented` | `mul_micro` | UB stage + constexpr-specialized TBinOp-style vlds/vmul/vsts lowering. | -| `or_` | `implemented` | `or_micro` | UB stage + constexpr-specialized TBinOp-style vlds/vor/vsts lowering. | -| `gather` | `partial` | `gather_micro` | Indexed gather is implemented via vgather2 for same-width source/index pairs; mask-pattern gather still needs unsupported vsqz-style micro support. | -| `exp` | `implemented` | `exp_micro` | UB stage + vlds/vexp/vsts loop. | -| `log` | `implemented` | `log_micro` | UB stage + vlds/vln/vsts loop. | -| `relu` | `implemented` | `relu_micro` | UB stage + vlds/vrelu/vsts loop. | -| `abs` | `implemented` | `abs_micro` | UB stage + vlds/vabs/vsts loop. | -| `sqrt` | `implemented` | `sqrt_micro` | UB stage + vlds/vsqrt/vsts loop. | -| `rsqrt` | `implemented` | `rsqrt_micro` | UB stage + vsqrt/vrec micro sequence. | -| `reciprocal` | `implemented` | `reciprocal_micro` | UB stage + vlds/vrec/vsts loop. | -| `matmul` | `blocked` | `-` | Cube/L0 path is not a pure vector-micro rewrite target. | -| `matmul_bias` | `blocked` | `-` | Cube/L0 path is not a pure vector-micro rewrite target. | -| `matmul_acc` | `blocked` | `-` | Cube/L0 path is not a pure vector-micro rewrite target. | -| `extract` | `blocked` | `-` | Layout/L0 extraction op, not a vector-micro compute rewrite. | -| `row_sum` | `implemented` | `row_sum_micro` | Static-shape row reduction via vcadd + point-store. | -| `row_min` | `implemented` | `row_min_micro` | Static-shape row reduction via vcmin + point-store. | -| `row_max` | `implemented` | `row_max_micro` | Static-shape row reduction via vcmax + point-store. | -| `row_expand` | `implemented` | `row_expand_micro` | Static-shape canonical broadcast via vldas/vldus/vdup/vsts. | -| `row_expand_sub` | `implemented` | `row_expand_sub_micro` | Static-shape canonical broadcast via vldas/vldus/vdup/vsub/vsts. | -| `row_expand_div` | `implemented` | `row_expand_div_micro` | Static-shape canonical broadcast via vldas/vldus/vdup/vdiv/vsts. | -| `row_expand_mul` | `implemented` | `row_expand_mul_micro` | Static-shape canonical broadcast via vldas/vldus/vdup/vmul/vsts. | -| `col_sum` | `implemented` | `col_sum_micro` | Static-shape TColReduceOps-style column reduction via vadd. | -| `col_min` | `implemented` | `col_min_micro` | Static-shape TColReduceOps-style column reduction via vmin. | -| `col_max` | `implemented` | `col_max_micro` | Static-shape TColReduceOps-style column reduction via vmax. | -| `col_expand` | `implemented` | `col_expand_micro` | Static-shape canonical broadcast via vlds/vsts replication. | -| `mrgsort` | `implemented` | `mrgsort_micro` | Single-list row-major merge sort via vmrgsort4. | -| `sort32` | `implemented` | `sort32_micro` | Static-shape block sort via vbitsort. | -| `subset` | `not_applicable` | `-` | View helper only, not a tile compute op. | +| tile op | helper | note | +| --- | --- | --- | +| `mov` | `tmov` | UB stage + vlds/vsts copy loop. | +| `add` | `tadd` | UB stage + constexpr-specialized TBinOp-style vlds/vadd/vsts lowering. | +| `sub` | `tsub` | UB stage + constexpr-specialized TBinOp-style vlds/vsub/vsts lowering. | +| `div` | `tdiv` | UB stage + constexpr-specialized TBinOp-style vlds/vdiv/vsts lowering. | +| `mul` | `tmul` | UB stage + constexpr-specialized TBinOp-style vlds/vmul/vsts lowering. | +| `or_` | `tor_` | UB stage + constexpr-specialized TBinOp-style vlds/vor/vsts lowering. | +| `gather` | `tgather` | Indexed gather is implemented via vgather2 for same-width source/index pairs; mask-pattern gather still needs unsupported vsqz-style micro support. | +| `exp` | `texp` | UB stage + vlds/vexp/vsts loop. | +| `log` | `tlog` | UB stage + vlds/vln/vsts loop. | +| `relu` | `trelu` | UB stage + vlds/vrelu/vsts loop. | +| `abs` | `tabs` | UB stage + vlds/vabs/vsts loop. | +| `sqrt` | `tsqrt` | UB stage + vlds/vsqrt/vsts loop. | +| `rsqrt` | `trsqrt` | UB stage + vsqrt/vrec sequence. | +| `reciprocal` | `trecip` | UB stage + vlds/vrec/vsts loop. | +| `matmul` | `-` | Cube/L0 path is not a pure vector-micro rewrite target. | +| `matmul_bias` | `-` | Cube/L0 path is not a pure vector-micro rewrite target. | +| `matmul_acc` | `-` | Cube/L0 path is not a pure vector-micro rewrite target. | +| `extract` | `-` | Layout/L0 extraction op, not a vector-micro compute rewrite. | +| `row_sum` | `trow_sum` | Static-shape row reduction via vcadd + point-store. | +| `row_min` | `trow_min` | Static-shape row reduction via vcmin + point-store. | +| `row_max` | `trow_max` | Static-shape row reduction via vcmax + point-store. | +| `row_prod` | `-` | No row-product micro lowering is wired yet. | +| `row_expand` | `trow_expand` | Static-shape canonical broadcast via vldas/vldus/vdup/vsts. | +| `row_expand_sub` | `trow_expand_sub` | Static-shape canonical broadcast via vldas/vldus/vdup/vsub/vsts. | +| `row_expand_div` | `trow_expand_div` | Static-shape canonical broadcast via vldas/vldus/vdup/vdiv/vsts. | +| `row_expand_mul` | `trow_expand_mul` | Static-shape canonical broadcast via vldas/vldus/vdup/vmul/vsts. | +| `col_sum` | `tcol_sum` | Static-shape TColReduceOps-style column reduction via vadd. | +| `col_min` | `tcol_min` | Static-shape TColReduceOps-style column reduction via vmin. | +| `col_max` | `tcol_max` | Static-shape TColReduceOps-style column reduction via vmax. | +| `col_prod` | `-` | No column-product micro lowering is wired yet. | +| `col_expand` | `tcol_expand` | Static-shape canonical broadcast via vlds/vsts replication. | +| `mrgsort` | `tmrgsort` | Single-list row-major merge sort via vmrgsort4. | +| `sort32` | `tsort32` | Static-shape block sort via vbitsort. | +| `subset` | `-` | View helper only, not a tile compute op. | diff --git a/ptodsl/lib/a5/__init__.py b/ptodsl/lib/a5/__init__.py index 61670f55..3b273e59 100644 --- a/ptodsl/lib/a5/__init__.py +++ b/ptodsl/lib/a5/__init__.py @@ -1,11 +1,11 @@ -from . import ops +from . import native, ops, tbinary, texpand, treduce, tsort, tunary from .kernels import ( KERNEL_BUILDERS, build_cube_matmul, build_elementwise_add, - build_micro_vector_copy, build_mxfp8_matmul, build_templated_elementwise_add, + build_vector_copy, ) from .ops import * from .tile_micro_coverage import ( @@ -19,9 +19,15 @@ "TILE_MICRO_COVERAGE", "build_cube_matmul", "build_elementwise_add", - "build_micro_vector_copy", "build_mxfp8_matmul", "build_templated_elementwise_add", + "build_vector_copy", "coverage_markdown", "coverage_summary", + "native", + "tbinary", + "texpand", + "treduce", + "tsort", + "tunary", ] diff --git a/ptodsl/lib/a5/_common.py b/ptodsl/lib/a5/_common.py new file mode 100644 index 00000000..9114639f --- /dev/null +++ b/ptodsl/lib/a5/_common.py @@ -0,0 +1,628 @@ +"""Shared A5 helpers for writing tile-style kernels with PTO micro instructions.""" + +import builtins +import re + +from mlir.dialects import arith, pto +from mlir.ir import IntegerAttr, IntegerType + +from ... import const_expr, language as dsl, range_constexpr, scalar as s +from ...api.scalar import _unwrap + +VF_IMPL_DEFAULT = "default" +VF_IMPL_1D_NO_POST_UPDATE = "1d_no_post_update" +VF_IMPL_1D_POST_UPDATE = "1d_post_update" +VF_IMPL_2D_NO_POST_UPDATE = "2d_no_post_update" +VF_IMPL_2D_POST_UPDATE = "2d_post_update" + +_DTYPE_ALIAS_GROUPS = { + "f32": {"f32", "float32"}, + "f16": {"f16", "float16", "half"}, + "bf16": {"bf16", "bfloat16"}, + "u32": {"u32", "ui32", "uint32"}, + "u16": {"u16", "uint16"}, + "u8": {"u8", "uint8"}, + "i32": {"i32", "int32"}, + "i16": {"i16", "int16"}, + "i8": {"i8", "int8"}, +} + + +def _call(op, *args, **kwargs): + return op( + *(_unwrap(arg) for arg in args), + **{name: _unwrap(value) for name, value in kwargs.items()}, + ) + + +def raw(value): + return _unwrap(value) + + +def _space_enum(space): + return getattr(pto.AddressSpace, str(space).upper()) + + +def ptr(dtype, *, space="GM"): + return pto.PtrType.get(dtype, _space_enum(space)) + + +def vreg_type(lanes, dtype): + return pto.VRegType.get(lanes, dtype) + + +def mask_type(): + return pto.MaskType.get() + + +def align_type(): + return pto.AlignType.get() + + +def uint32_type(): + return IntegerType.get_unsigned(32) + + +def const_i64(value): + i64 = IntegerType.get_signless(64) + return arith.ConstantOp(i64, IntegerAttr.get(i64, value)).result + + +def const_i32(value): + i32 = IntegerType.get_signless(32) + return arith.ConstantOp(i32, IntegerAttr.get(i32, value)).result + + +def const_float(dtype, value): + return arith.ConstantOp(dtype, value).result + + +def row_major_strides(shape): + strides = [None] * len(shape) + stride = s.const(1) + for index in range(len(shape) - 1, -1, -1): + strides[index] = stride + dim = shape[index] + stride = stride * (s.const(dim) if isinstance(dim, int) else dim) + return strides + + +def _index_value(value): + return s.const(value) if isinstance(value, int) else value + + +def make_tensor(ptr_value, *, shape, dtype): + tensor_type = dsl.TensorType(rank=len(shape), dtype=dtype) + return dsl.as_tensor( + tensor_type, + ptr=_unwrap(ptr_value), + shape=[_index_value(dim) for dim in shape], + strides=row_major_strides(shape), + ) + + +def slice_tensor(source, *, offsets, sizes, dtype): + subtensor_type = dsl.SubTensorType(shape=sizes, dtype=dtype) + return dsl.slice_view( + subtensor_type, + source=_unwrap(source), + offsets=[_index_value(offset) for offset in offsets], + sizes=[_index_value(size) for size in sizes], + ) + + +def alloc_tile_buffer( + dtype, + shape, + *, + space="VEC", + valid_shape=None, + config=None, + addr=None, + valid_row=None, + valid_col=None, +): + tile_type = dsl.TileBufType( + shape=shape, + dtype=dtype, + memory_space=space, + valid_shape=valid_shape, + config=config, + ) + kwargs = {} + if addr is not None: + kwargs["addr"] = _unwrap(addr) + if valid_row is not None: + kwargs["valid_row"] = _unwrap(valid_row) + if valid_col is not None: + kwargs["valid_col"] = _unwrap(valid_col) + return pto.AllocTileOp(tile_type, **kwargs).result + + +def load_view(source, dest): + pto.TLoadOp(None, source, dest) + return dest + + +def store_view(source, dest): + pto.TStoreOp(None, source, dest) + return dest + + +def move_tile(source, dest): + pto.TMovOp(None, source, dest) + return dest + + +def load_tile( + view, + tile_buffer=None, + *, + dtype=None, + shape=None, + space="VEC", + valid_shape=None, + config=None, +): + if tile_buffer is None: + if dtype is None or shape is None: + raise ValueError( + "`load_tile(...)` requires either `tile_buffer=` or both `dtype=` and `shape=`." + ) + tile_buffer = alloc_tile_buffer( + dtype, + shape, + space=space, + valid_shape=valid_shape, + config=config, + ) + load_view(view, tile_buffer) + return tile_buffer + + +def store_tile(tile_buffer, view): + store_view(tile_buffer, view) + return view + + +def dtype_token(dtype): + text = str(dtype).lower() + for canonical, aliases in _DTYPE_ALIAS_GROUPS.items(): + if any(alias in text for alias in aliases): + return canonical + raise ValueError(f"Unsupported dtype token for '{dtype}'.") + + +def dtype_byte_width(dtype): + token = dtype_token(dtype) + if token in {"f32", "i32", "u32"}: + return 4 + if token in {"f16", "bf16", "i16", "u16"}: + return 2 + if token in {"i8", "u8"}: + return 1 + raise ValueError(f"Unsupported dtype byte width for '{dtype}'.") + + +def micro_lane_count(dtype): + return 256 // dtype_byte_width(dtype) + + +def resolve_lanes(dtype, lanes): + return micro_lane_count(dtype) if lanes is None else lanes + + +def extract_static_tensor_shape(value): + raw = _unwrap(value) + type_obj = getattr(raw, "type", None) + if type_obj is None: + return None + text = str(type_obj) + match = re.search( + r"!pto\.(?:partition_)?tensor_view<(?P[^>]+)>|!pto\.tile_buf<[^,]+,\s*(?P[^>]+)>", + text, + ) + if not match: + return None + payload = match.group("payload") or match.group("tile_payload") + dims = re.findall(r"(\?|\d+)x", payload) + if not dims: + return None + shape = [] + for dim in dims: + if dim == "?": + return None + shape.append(int(dim)) + return shape + + +def extract_tensor_dtype_token(value): + raw = _unwrap(value) + type_obj = getattr(raw, "type", None) + if type_obj is None: + return None + text = str(type_obj).lower() + for canonical, aliases in _DTYPE_ALIAS_GROUPS.items(): + if any(alias in text for alias in aliases): + return canonical + return None + + +def require_supported_dtype(dtype, *, allowed, message): + try: + token = dtype_token(dtype) + except ValueError as exc: + raise ValueError(message) from exc + if token not in allowed: + raise ValueError(message) + return token + + +def require_view_shape(view, expected_shape, *, message): + actual_shape = extract_static_tensor_shape(view) + if actual_shape is None: + return + if list(actual_shape) != list(expected_shape): + raise ValueError(f"{message} Expected {expected_shape}, got {actual_shape}.") + + +def require_view_dtype(view, dtype, *, message): + actual_token = extract_tensor_dtype_token(view) + if actual_token is None: + return + if actual_token != dtype_token(dtype): + raise ValueError(message) + + +def require_static_matrix_shape(shape, *, context): + if len(shape) != 2 or any(not isinstance(dim, int) for dim in shape): + raise ValueError(f"{context} currently requires a static rank-2 integer shape.") + rows, cols = shape + if rows <= 0 or cols <= 0: + raise ValueError(f"{context} requires positive row/column sizes.") + return rows, cols + + +def full_mask(dtype): + width = dtype_byte_width(dtype) + if width == 4: + return pto.pset_b32(mask_type(), "PAT_ALL") + if width == 2: + return pto.pset_b16(mask_type(), "PAT_ALL") + if width == 1: + return pto.pset_b8(mask_type(), "PAT_ALL") + raise ValueError(f"Unsupported dtype mask width for '{dtype}'.") + + +def tail_mask(dtype, active_lanes): + i32 = IntegerType.get_signless(32) + active = const_i32(active_lanes) + width = dtype_byte_width(dtype) + if width == 4: + mask, _ = pto.plt_b32(mask_type(), i32, active) + return mask + if width == 2: + mask, _ = pto.plt_b16(mask_type(), i32, active) + return mask + if width == 1: + mask, _ = pto.plt_b8(mask_type(), i32, active) + return mask + raise ValueError(f"Unsupported dtype tail mask width for '{dtype}'.") + + +def mask_for_chunk(dtype, active_lanes): + lanes = micro_lane_count(dtype) + if active_lanes == lanes: + return full_mask(dtype) + return tail_mask(dtype, active_lanes) + + +def onept_dist(dtype): + width = dtype_byte_width(dtype) + if width == 4: + return "ONEPT_B32" + if width == 2: + return "ONEPT_B16" + if width == 1: + return "ONEPT_B8" + raise ValueError(f"Unsupported dtype point-store width for '{dtype}'.") + + +def normalize_vf_impl_kind(impl): + if impl is None: + return VF_IMPL_DEFAULT + + normalized = str(impl).strip().lower() + aliases = { + "default": VF_IMPL_DEFAULT, + "vfimpl_default": VF_IMPL_DEFAULT, + "1d_no_post_update": VF_IMPL_1D_NO_POST_UPDATE, + "vfimpl_1d_no_post_update": VF_IMPL_1D_NO_POST_UPDATE, + "1d_post_update": VF_IMPL_1D_POST_UPDATE, + "vfimpl_1d_post_update": VF_IMPL_1D_POST_UPDATE, + "2d_no_post_update": VF_IMPL_2D_NO_POST_UPDATE, + "vfimpl_2d_no_post_update": VF_IMPL_2D_NO_POST_UPDATE, + "2d_post_update": VF_IMPL_2D_POST_UPDATE, + "vfimpl_2d_post_update": VF_IMPL_2D_POST_UPDATE, + } + if normalized not in aliases: + supported = ", ".join(sorted(aliases)) + raise ValueError( + f"Unsupported VF impl kind '{impl}'. Expected one of: {supported}." + ) + return aliases[normalized] + + +def check_tbinop_operands(lhs_view, rhs_view, out_view, *, dtype, shape, context): + rows, cols = require_static_matrix_shape(shape, context=context) + require_supported_dtype( + dtype, + allowed={"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"}, + message=f"Fix: {context} has invalid data type.", + ) + for view, label in ((lhs_view, "src0"), (rhs_view, "src1"), (out_view, "dst")): + require_view_shape( + view, + [rows, cols], + message=f"Fix: {context} input tile {label} valid shape mismatch.", + ) + require_view_dtype( + view, + dtype, + message=f"Fix: {context} input tile src0, src1 and dst tile data type mismatch.", + ) + return rows, cols + + +def check_row_expand_operands(src_view, out_view, *, dtype, shape, context): + rows, cols = require_static_matrix_shape(shape, context=context) + require_supported_dtype( + dtype, + allowed={"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"}, + message=f"Fix: {context} input data type is not supported.", + ) + require_view_shape( + src_view, + [rows, 1], + message=f"Fix: {context} source valid shape must be [rows, 1].", + ) + require_view_shape( + out_view, + [rows, cols], + message=f"Fix: {context} destination valid shape mismatch.", + ) + require_view_dtype( + src_view, + dtype, + message=f"Fix: {context} input data type must be consistent with the output data type.", + ) + require_view_dtype( + out_view, + dtype, + message=f"Fix: {context} input data type must be consistent with the output data type.", + ) + return rows, cols + + +def check_col_expand_operands(src_view, out_view, *, dtype, shape, context): + rows, cols = require_static_matrix_shape(shape, context=context) + require_supported_dtype( + dtype, + allowed={"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"}, + message=f"Fix: {context} input data type is not supported.", + ) + require_view_shape( + src_view, + [1, cols], + message=f"Fix: {context} source valid shape must be [1, cols].", + ) + require_view_shape( + out_view, + [rows, cols], + message=f"Fix: {context} destination valid shape mismatch.", + ) + require_view_dtype( + src_view, + dtype, + message=f"Fix: {context} input data type must be consistent with the output data type.", + ) + require_view_dtype( + out_view, + dtype, + message=f"Fix: {context} input data type must be consistent with the output data type.", + ) + return rows, cols + + +def check_row_reduce_operands(src_view, out_view, *, dtype, shape, context): + rows, cols = require_static_matrix_shape(shape, context=context) + require_supported_dtype( + dtype, + allowed={"f32", "f16"}, + message=f"Fix: {context} input data type is not supported.", + ) + require_view_shape( + src_view, + [rows, cols], + message=f"Fix: {context} source valid shape mismatch.", + ) + require_view_shape( + out_view, + [rows, 1], + message=f"Fix: {context} use a single-column output tile.", + ) + require_view_dtype( + src_view, + dtype, + message=f"Fix: {context} input and output data type mismatch.", + ) + require_view_dtype( + out_view, + dtype, + message=f"Fix: {context} input and output data type mismatch.", + ) + return rows, cols + + +def check_col_reduce_operands(src_view, out_view, *, dtype, shape, context): + rows, cols = require_static_matrix_shape(shape, context=context) + require_supported_dtype( + dtype, + allowed={"f32", "f16"}, + message=f"Fix: {context} input data type is not supported.", + ) + require_view_shape( + src_view, + [rows, cols], + message=f"Fix: {context} source valid shape mismatch.", + ) + require_view_shape( + out_view, + [1, cols], + message=f"Fix: {context} use a single-row output tile.", + ) + require_view_dtype( + src_view, + dtype, + message=f"Fix: {context} input and output data type mismatch.", + ) + require_view_dtype( + out_view, + dtype, + message=f"Fix: {context} input and output data type mismatch.", + ) + return rows, cols + + +def check_gather_operands( + src_view, indices_view, out_view, *, dtype, index_dtype, shape +): + rows, cols = require_static_matrix_shape(shape, context="TGATHER") + require_supported_dtype( + dtype, + allowed={"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"}, + message="Fix: TGATHER source data type is not supported.", + ) + require_supported_dtype( + index_dtype, + allowed={"u32"}, + message="Fix: TGATHER index data type must be uint32.", + ) + for view, label, view_dtype in ( + (src_view, "src", dtype), + (indices_view, "indices", index_dtype), + (out_view, "dst", dtype), + ): + require_view_shape( + view, + [rows, cols], + message=f"Fix: TGATHER {label} valid shape mismatch.", + ) + require_view_dtype( + view, + view_dtype, + message=f"Fix: TGATHER {label} data type mismatch.", + ) + return rows, cols + + +def check_mrgsort_operands(src_view, out_view, *, dtype, shape, block_len): + rows, cols = require_static_matrix_shape(shape, context="TMRGSORT") + if rows != 1: + raise ValueError( + "TMRGSORT micro lowering currently requires a single input row." + ) + if cols != block_len * 4: + raise ValueError( + "TMRGSORT micro lowering currently requires shape[1] == block_len * 4." + ) + require_view_shape( + src_view, + [rows, cols], + message="Fix: TMRGSORT source valid shape mismatch.", + ) + require_view_shape( + out_view, + [rows, cols], + message="Fix: TMRGSORT destination valid shape mismatch.", + ) + require_view_dtype( + src_view, + dtype, + message="Fix: TMRGSORT input and output data type mismatch.", + ) + require_view_dtype( + out_view, + dtype, + message="Fix: TMRGSORT input and output data type mismatch.", + ) + return rows, cols + + +def check_sort32_operands(src_view, idx_view, out_view, *, dtype, shape): + rows, cols = require_static_matrix_shape(shape, context="TSORT32") + out_cols = cols * 4 if dtype_token(dtype) == "f16" else cols * 2 + for view, label, expected_shape in ( + (src_view, "src", [rows, cols]), + (idx_view, "idx", [rows, cols]), + (out_view, "dst", [rows, out_cols]), + ): + require_view_shape( + view, + expected_shape, + message=f"TSORT32 {label} shape mismatch.", + ) + require_view_dtype(src_view, dtype, message="Dst and src mube be same.") + require_view_dtype(out_view, dtype, message="Dst and src mube be same.") + require_view_dtype(idx_view, uint32_type(), message="Idx must be uint32_t.") + if cols % 32 != 0: + raise ValueError( + "TSORT32 micro lowering currently requires column count divisible by 32." + ) + return rows, cols, out_cols + + +__all__ = [ + "VF_IMPL_DEFAULT", + "VF_IMPL_1D_NO_POST_UPDATE", + "VF_IMPL_1D_POST_UPDATE", + "VF_IMPL_2D_NO_POST_UPDATE", + "VF_IMPL_2D_POST_UPDATE", + "_call", + "align_type", + "alloc_tile_buffer", + "check_col_expand_operands", + "check_col_reduce_operands", + "check_gather_operands", + "check_mrgsort_operands", + "check_row_expand_operands", + "check_row_reduce_operands", + "check_sort32_operands", + "check_tbinop_operands", + "const_expr", + "const_float", + "const_i32", + "const_i64", + "dtype_byte_width", + "dtype_token", + "full_mask", + "load_tile", + "load_view", + "make_tensor", + "mask_for_chunk", + "mask_type", + "micro_lane_count", + "move_tile", + "normalize_vf_impl_kind", + "onept_dist", + "ptr", + "range_constexpr", + "resolve_lanes", + "s", + "slice_tensor", + "store_tile", + "store_view", + "tail_mask", + "uint32_type", + "vreg_type", +] diff --git a/ptodsl/lib/a5/generated/a5_cube_matmul.pto b/ptodsl/lib/a5/generated/a5_cube_matmul.pto index 7f52f654..64bb0003 100644 --- a/ptodsl/lib/a5/generated/a5_cube_matmul.pto +++ b/ptodsl/lib/a5/generated/a5_cube_matmul.pto @@ -4,42 +4,54 @@ module { %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1 = arith.constant 1 : index - %0 = pto.make_tensor_view %arg0, shape = [%c16, %c32], strides = [%c32, %c1] : !pto.tensor_view %c32_0 = arith.constant 32 : index + %0 = arith.muli %c1, %c32_0 : index %c16_1 = arith.constant 16 : index - %c1_2 = arith.constant 1 : index - %1 = pto.make_tensor_view %arg1, shape = [%c32_0, %c16_1], strides = [%c16_1, %c1_2] : !pto.tensor_view + %1 = arith.muli %0, %c16_1 : index + %2 = pto.make_tensor_view %arg0, shape = [%c16, %c32], strides = [%0, %c1] : !pto.tensor_view + %c32_2 = arith.constant 32 : index %c16_3 = arith.constant 16 : index - %c16_4 = arith.constant 16 : index - %c1_5 = arith.constant 1 : index - %2 = pto.make_tensor_view %arg2, shape = [%c16_3, %c16_4], strides = [%c16_4, %c1_5] : !pto.tensor_view + %c1_4 = arith.constant 1 : index + %c16_5 = arith.constant 16 : index + %3 = arith.muli %c1_4, %c16_5 : index + %c32_6 = arith.constant 32 : index + %4 = arith.muli %3, %c32_6 : index + %5 = pto.make_tensor_view %arg1, shape = [%c32_2, %c16_3], strides = [%3, %c1_4] : !pto.tensor_view + %c16_7 = arith.constant 16 : index + %c16_8 = arith.constant 16 : index + %c1_9 = arith.constant 1 : index + %c16_10 = arith.constant 16 : index + %6 = arith.muli %c1_9, %c16_10 : index + %c16_11 = arith.constant 16 : index + %7 = arith.muli %6, %c16_11 : index + %8 = pto.make_tensor_view %arg2, shape = [%c16_7, %c16_8], strides = [%6, %c1_9] : !pto.tensor_view pto.section.cube { - %c0_6 = arith.constant 0 : index - %c0_7 = arith.constant 0 : index - %c16_8 = arith.constant 16 : index - %c32_9 = arith.constant 32 : index - %3 = pto.partition_view %0, offsets = [%c0_6, %c0_7], sizes = [%c16_8, %c32_9] : !pto.tensor_view -> !pto.partition_tensor_view<16x32xf16> - %4 = pto.alloc_tile : !pto.tile_buf - pto.tload ins(%3 : !pto.partition_tensor_view<16x32xf16>) outs(%4 : !pto.tile_buf) - %c0_10 = arith.constant 0 : index - %c0_11 = arith.constant 0 : index - %c32_12 = arith.constant 32 : index - %c16_13 = arith.constant 16 : index - %5 = pto.partition_view %1, offsets = [%c0_10, %c0_11], sizes = [%c32_12, %c16_13] : !pto.tensor_view -> !pto.partition_tensor_view<32x16xf16> - %6 = pto.alloc_tile : !pto.tile_buf - pto.tload ins(%5 : !pto.partition_tensor_view<32x16xf16>) outs(%6 : !pto.tile_buf) - %7 = pto.alloc_tile : !pto.tile_buf - %8 = pto.alloc_tile : !pto.tile_buf - %9 = pto.alloc_tile : !pto.tile_buf - pto.textract ins(%4, %c0, %c0 : !pto.tile_buf, index, index) outs(%7 : !pto.tile_buf) - pto.tmov ins(%6 : !pto.tile_buf) outs(%8 : !pto.tile_buf) - pto.tmatmul ins(%7, %8 : !pto.tile_buf, !pto.tile_buf) outs(%9 : !pto.tile_buf) - %c0_14 = arith.constant 0 : index - %c0_15 = arith.constant 0 : index - %c16_16 = arith.constant 16 : index - %c16_17 = arith.constant 16 : index - %10 = pto.partition_view %2, offsets = [%c0_14, %c0_15], sizes = [%c16_16, %c16_17] : !pto.tensor_view -> !pto.partition_tensor_view<16x16xf32> - pto.tstore ins(%9 : !pto.tile_buf) outs(%10 : !pto.partition_tensor_view<16x16xf32>) + %c0_12 = arith.constant 0 : index + %c0_13 = arith.constant 0 : index + %c16_14 = arith.constant 16 : index + %c32_15 = arith.constant 32 : index + %9 = pto.partition_view %2, offsets = [%c0_12, %c0_13], sizes = [%c16_14, %c32_15] : !pto.tensor_view -> !pto.partition_tensor_view<16x32xf16> + %10 = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%9 : !pto.partition_tensor_view<16x32xf16>) outs(%10 : !pto.tile_buf) + %c0_16 = arith.constant 0 : index + %c0_17 = arith.constant 0 : index + %c32_18 = arith.constant 32 : index + %c16_19 = arith.constant 16 : index + %11 = pto.partition_view %5, offsets = [%c0_16, %c0_17], sizes = [%c32_18, %c16_19] : !pto.tensor_view -> !pto.partition_tensor_view<32x16xf16> + %12 = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%11 : !pto.partition_tensor_view<32x16xf16>) outs(%12 : !pto.tile_buf) + %13 = pto.alloc_tile : !pto.tile_buf + %14 = pto.alloc_tile : !pto.tile_buf + %15 = pto.alloc_tile : !pto.tile_buf + pto.textract ins(%10, %c0, %c0 : !pto.tile_buf, index, index) outs(%13 : !pto.tile_buf) + pto.tmov ins(%12 : !pto.tile_buf) outs(%14 : !pto.tile_buf) + pto.tmatmul ins(%13, %14 : !pto.tile_buf, !pto.tile_buf) outs(%15 : !pto.tile_buf) + %c0_20 = arith.constant 0 : index + %c0_21 = arith.constant 0 : index + %c16_22 = arith.constant 16 : index + %c16_23 = arith.constant 16 : index + %16 = pto.partition_view %8, offsets = [%c0_20, %c0_21], sizes = [%c16_22, %c16_23] : !pto.tensor_view -> !pto.partition_tensor_view<16x16xf32> + pto.tstore ins(%15 : !pto.tile_buf) outs(%16 : !pto.partition_tensor_view<16x16xf32>) } return } diff --git a/ptodsl/lib/a5/generated/a5_elementwise_add.pto b/ptodsl/lib/a5/generated/a5_elementwise_add.pto index 598b5bfb..e25d712a 100644 --- a/ptodsl/lib/a5/generated/a5_elementwise_add.pto +++ b/ptodsl/lib/a5/generated/a5_elementwise_add.pto @@ -1,49 +1,269 @@ module { func.func @a5_elementwise_add(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: index, %arg4: index) { %c1 = arith.constant 1 : index - %0 = pto.make_tensor_view %arg0, shape = [%arg3, %arg4], strides = [%arg4, %c1] : !pto.tensor_view + %0 = arith.muli %c1, %arg4 : index + %1 = arith.muli %0, %arg3 : index + %2 = pto.make_tensor_view %arg0, shape = [%arg3, %arg4], strides = [%0, %c1] : !pto.tensor_view %c1_0 = arith.constant 1 : index - %1 = pto.make_tensor_view %arg1, shape = [%arg3, %arg4], strides = [%arg4, %c1_0] : !pto.tensor_view + %3 = arith.muli %c1_0, %arg4 : index + %4 = arith.muli %3, %arg3 : index + %5 = pto.make_tensor_view %arg1, shape = [%arg3, %arg4], strides = [%3, %c1_0] : !pto.tensor_view %c1_1 = arith.constant 1 : index - %2 = pto.make_tensor_view %arg2, shape = [%arg3, %arg4], strides = [%arg4, %c1_1] : !pto.tensor_view + %6 = arith.muli %c1_1, %arg4 : index + %7 = arith.muli %6, %arg3 : index + %8 = pto.make_tensor_view %arg2, shape = [%arg3, %arg4], strides = [%6, %c1_1] : !pto.tensor_view %c0 = arith.constant 0 : index %c0_2 = arith.constant 0 : index %c32 = arith.constant 32 : index %c32_3 = arith.constant 32 : index - %3 = pto.partition_view %0, offsets = [%c0, %c0_2], sizes = [%c32, %c32_3] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + %9 = pto.partition_view %2, offsets = [%c0, %c0_2], sizes = [%c32, %c32_3] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> %c0_4 = arith.constant 0 : index %c0_5 = arith.constant 0 : index %c32_6 = arith.constant 32 : index %c32_7 = arith.constant 32 : index - %4 = pto.partition_view %1, offsets = [%c0_4, %c0_5], sizes = [%c32_6, %c32_7] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + %10 = pto.partition_view %5, offsets = [%c0_4, %c0_5], sizes = [%c32_6, %c32_7] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> %c0_8 = arith.constant 0 : index %c0_9 = arith.constant 0 : index %c32_10 = arith.constant 32 : index %c32_11 = arith.constant 32 : index - %5 = pto.partition_view %2, offsets = [%c0_8, %c0_9], sizes = [%c32_10, %c32_11] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + %11 = pto.partition_view %8, offsets = [%c0_8, %c0_9], sizes = [%c32_10, %c32_11] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> pto.section.vector { %c0_i64 = arith.constant 0 : i64 %c4096_i64 = arith.constant 4096 : i64 %c8192_i64 = arith.constant 8192 : i64 - %6 = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf - %7 = pto.alloc_tile addr = %c4096_i64 : !pto.tile_buf - %8 = pto.alloc_tile addr = %c8192_i64 : !pto.tile_buf - pto.tload ins(%3 : !pto.partition_tensor_view<32x32xf32>) outs(%6 : !pto.tile_buf) - pto.tload ins(%4 : !pto.partition_tensor_view<32x32xf32>) outs(%7 : !pto.tile_buf) - %9 = pto.castptr %c0_i64 : i64 -> !pto.ptr - %10 = pto.castptr %c4096_i64 : i64 -> !pto.ptr - %11 = pto.castptr %c8192_i64 : i64 -> !pto.ptr - %12 = pto.pset_b32 "PAT_ALL" : !pto.mask + %12 = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + %13 = pto.alloc_tile addr = %c4096_i64 : !pto.tile_buf + %14 = pto.alloc_tile addr = %c8192_i64 : !pto.tile_buf + pto.tload ins(%9 : !pto.partition_tensor_view<32x32xf32>) outs(%12 : !pto.tile_buf) + pto.tload ins(%10 : !pto.partition_tensor_view<32x32xf32>) outs(%13 : !pto.tile_buf) + %15 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %16 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %17 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %c32_i32 = arith.constant 32 : i32 + %mask, %scalar_out = pto.plt_b32 %c32_i32 : i32 -> !pto.mask, i32 %c0_12 = arith.constant 0 : index - %c1024 = arith.constant 1024 : index + %18 = pto.vlds %15[%c0_12] : !pto.ptr -> !pto.vreg<64xf32> + %19 = pto.vlds %16[%c0_12] : !pto.ptr -> !pto.vreg<64xf32> + %20 = pto.vadd %18, %19, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %20, %17[%c0_12], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_13 = arith.constant 32 : i32 + %mask_14, %scalar_out_15 = pto.plt_b32 %c32_i32_13 : i32 -> !pto.mask, i32 + %c32_16 = arith.constant 32 : index + %21 = pto.vlds %15[%c32_16] : !pto.ptr -> !pto.vreg<64xf32> + %22 = pto.vlds %16[%c32_16] : !pto.ptr -> !pto.vreg<64xf32> + %23 = pto.vadd %21, %22, %mask_14 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %23, %17[%c32_16], %mask_14 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_17 = arith.constant 32 : i32 + %mask_18, %scalar_out_19 = pto.plt_b32 %c32_i32_17 : i32 -> !pto.mask, i32 %c64 = arith.constant 64 : index - scf.for %arg5 = %c0_12 to %c1024 step %c64 { - %13 = pto.vlds %9[%arg5] : !pto.ptr -> !pto.vreg<64xf32> - %14 = pto.vlds %10[%arg5] : !pto.ptr -> !pto.vreg<64xf32> - %15 = pto.vadd %13, %14, %12 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> - pto.vsts %15, %11[%arg5], %12 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask - } - pto.tstore ins(%8 : !pto.tile_buf) outs(%5 : !pto.partition_tensor_view<32x32xf32>) + %24 = pto.vlds %15[%c64] : !pto.ptr -> !pto.vreg<64xf32> + %25 = pto.vlds %16[%c64] : !pto.ptr -> !pto.vreg<64xf32> + %26 = pto.vadd %24, %25, %mask_18 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %26, %17[%c64], %mask_18 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_20 = arith.constant 32 : i32 + %mask_21, %scalar_out_22 = pto.plt_b32 %c32_i32_20 : i32 -> !pto.mask, i32 + %c96 = arith.constant 96 : index + %27 = pto.vlds %15[%c96] : !pto.ptr -> !pto.vreg<64xf32> + %28 = pto.vlds %16[%c96] : !pto.ptr -> !pto.vreg<64xf32> + %29 = pto.vadd %27, %28, %mask_21 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %29, %17[%c96], %mask_21 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_23 = arith.constant 32 : i32 + %mask_24, %scalar_out_25 = pto.plt_b32 %c32_i32_23 : i32 -> !pto.mask, i32 + %c128 = arith.constant 128 : index + %30 = pto.vlds %15[%c128] : !pto.ptr -> !pto.vreg<64xf32> + %31 = pto.vlds %16[%c128] : !pto.ptr -> !pto.vreg<64xf32> + %32 = pto.vadd %30, %31, %mask_24 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %32, %17[%c128], %mask_24 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_26 = arith.constant 32 : i32 + %mask_27, %scalar_out_28 = pto.plt_b32 %c32_i32_26 : i32 -> !pto.mask, i32 + %c160 = arith.constant 160 : index + %33 = pto.vlds %15[%c160] : !pto.ptr -> !pto.vreg<64xf32> + %34 = pto.vlds %16[%c160] : !pto.ptr -> !pto.vreg<64xf32> + %35 = pto.vadd %33, %34, %mask_27 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %35, %17[%c160], %mask_27 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_29 = arith.constant 32 : i32 + %mask_30, %scalar_out_31 = pto.plt_b32 %c32_i32_29 : i32 -> !pto.mask, i32 + %c192 = arith.constant 192 : index + %36 = pto.vlds %15[%c192] : !pto.ptr -> !pto.vreg<64xf32> + %37 = pto.vlds %16[%c192] : !pto.ptr -> !pto.vreg<64xf32> + %38 = pto.vadd %36, %37, %mask_30 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %38, %17[%c192], %mask_30 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_32 = arith.constant 32 : i32 + %mask_33, %scalar_out_34 = pto.plt_b32 %c32_i32_32 : i32 -> !pto.mask, i32 + %c224 = arith.constant 224 : index + %39 = pto.vlds %15[%c224] : !pto.ptr -> !pto.vreg<64xf32> + %40 = pto.vlds %16[%c224] : !pto.ptr -> !pto.vreg<64xf32> + %41 = pto.vadd %39, %40, %mask_33 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %41, %17[%c224], %mask_33 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_35 = arith.constant 32 : i32 + %mask_36, %scalar_out_37 = pto.plt_b32 %c32_i32_35 : i32 -> !pto.mask, i32 + %c256 = arith.constant 256 : index + %42 = pto.vlds %15[%c256] : !pto.ptr -> !pto.vreg<64xf32> + %43 = pto.vlds %16[%c256] : !pto.ptr -> !pto.vreg<64xf32> + %44 = pto.vadd %42, %43, %mask_36 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %44, %17[%c256], %mask_36 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_38 = arith.constant 32 : i32 + %mask_39, %scalar_out_40 = pto.plt_b32 %c32_i32_38 : i32 -> !pto.mask, i32 + %c288 = arith.constant 288 : index + %45 = pto.vlds %15[%c288] : !pto.ptr -> !pto.vreg<64xf32> + %46 = pto.vlds %16[%c288] : !pto.ptr -> !pto.vreg<64xf32> + %47 = pto.vadd %45, %46, %mask_39 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %47, %17[%c288], %mask_39 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_41 = arith.constant 32 : i32 + %mask_42, %scalar_out_43 = pto.plt_b32 %c32_i32_41 : i32 -> !pto.mask, i32 + %c320 = arith.constant 320 : index + %48 = pto.vlds %15[%c320] : !pto.ptr -> !pto.vreg<64xf32> + %49 = pto.vlds %16[%c320] : !pto.ptr -> !pto.vreg<64xf32> + %50 = pto.vadd %48, %49, %mask_42 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %50, %17[%c320], %mask_42 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_44 = arith.constant 32 : i32 + %mask_45, %scalar_out_46 = pto.plt_b32 %c32_i32_44 : i32 -> !pto.mask, i32 + %c352 = arith.constant 352 : index + %51 = pto.vlds %15[%c352] : !pto.ptr -> !pto.vreg<64xf32> + %52 = pto.vlds %16[%c352] : !pto.ptr -> !pto.vreg<64xf32> + %53 = pto.vadd %51, %52, %mask_45 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %53, %17[%c352], %mask_45 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_47 = arith.constant 32 : i32 + %mask_48, %scalar_out_49 = pto.plt_b32 %c32_i32_47 : i32 -> !pto.mask, i32 + %c384 = arith.constant 384 : index + %54 = pto.vlds %15[%c384] : !pto.ptr -> !pto.vreg<64xf32> + %55 = pto.vlds %16[%c384] : !pto.ptr -> !pto.vreg<64xf32> + %56 = pto.vadd %54, %55, %mask_48 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %56, %17[%c384], %mask_48 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_50 = arith.constant 32 : i32 + %mask_51, %scalar_out_52 = pto.plt_b32 %c32_i32_50 : i32 -> !pto.mask, i32 + %c416 = arith.constant 416 : index + %57 = pto.vlds %15[%c416] : !pto.ptr -> !pto.vreg<64xf32> + %58 = pto.vlds %16[%c416] : !pto.ptr -> !pto.vreg<64xf32> + %59 = pto.vadd %57, %58, %mask_51 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %59, %17[%c416], %mask_51 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_53 = arith.constant 32 : i32 + %mask_54, %scalar_out_55 = pto.plt_b32 %c32_i32_53 : i32 -> !pto.mask, i32 + %c448 = arith.constant 448 : index + %60 = pto.vlds %15[%c448] : !pto.ptr -> !pto.vreg<64xf32> + %61 = pto.vlds %16[%c448] : !pto.ptr -> !pto.vreg<64xf32> + %62 = pto.vadd %60, %61, %mask_54 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %62, %17[%c448], %mask_54 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_56 = arith.constant 32 : i32 + %mask_57, %scalar_out_58 = pto.plt_b32 %c32_i32_56 : i32 -> !pto.mask, i32 + %c480 = arith.constant 480 : index + %63 = pto.vlds %15[%c480] : !pto.ptr -> !pto.vreg<64xf32> + %64 = pto.vlds %16[%c480] : !pto.ptr -> !pto.vreg<64xf32> + %65 = pto.vadd %63, %64, %mask_57 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %65, %17[%c480], %mask_57 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_59 = arith.constant 32 : i32 + %mask_60, %scalar_out_61 = pto.plt_b32 %c32_i32_59 : i32 -> !pto.mask, i32 + %c512 = arith.constant 512 : index + %66 = pto.vlds %15[%c512] : !pto.ptr -> !pto.vreg<64xf32> + %67 = pto.vlds %16[%c512] : !pto.ptr -> !pto.vreg<64xf32> + %68 = pto.vadd %66, %67, %mask_60 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %68, %17[%c512], %mask_60 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_62 = arith.constant 32 : i32 + %mask_63, %scalar_out_64 = pto.plt_b32 %c32_i32_62 : i32 -> !pto.mask, i32 + %c544 = arith.constant 544 : index + %69 = pto.vlds %15[%c544] : !pto.ptr -> !pto.vreg<64xf32> + %70 = pto.vlds %16[%c544] : !pto.ptr -> !pto.vreg<64xf32> + %71 = pto.vadd %69, %70, %mask_63 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %71, %17[%c544], %mask_63 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_65 = arith.constant 32 : i32 + %mask_66, %scalar_out_67 = pto.plt_b32 %c32_i32_65 : i32 -> !pto.mask, i32 + %c576 = arith.constant 576 : index + %72 = pto.vlds %15[%c576] : !pto.ptr -> !pto.vreg<64xf32> + %73 = pto.vlds %16[%c576] : !pto.ptr -> !pto.vreg<64xf32> + %74 = pto.vadd %72, %73, %mask_66 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %74, %17[%c576], %mask_66 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_68 = arith.constant 32 : i32 + %mask_69, %scalar_out_70 = pto.plt_b32 %c32_i32_68 : i32 -> !pto.mask, i32 + %c608 = arith.constant 608 : index + %75 = pto.vlds %15[%c608] : !pto.ptr -> !pto.vreg<64xf32> + %76 = pto.vlds %16[%c608] : !pto.ptr -> !pto.vreg<64xf32> + %77 = pto.vadd %75, %76, %mask_69 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %77, %17[%c608], %mask_69 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_71 = arith.constant 32 : i32 + %mask_72, %scalar_out_73 = pto.plt_b32 %c32_i32_71 : i32 -> !pto.mask, i32 + %c640 = arith.constant 640 : index + %78 = pto.vlds %15[%c640] : !pto.ptr -> !pto.vreg<64xf32> + %79 = pto.vlds %16[%c640] : !pto.ptr -> !pto.vreg<64xf32> + %80 = pto.vadd %78, %79, %mask_72 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %80, %17[%c640], %mask_72 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_74 = arith.constant 32 : i32 + %mask_75, %scalar_out_76 = pto.plt_b32 %c32_i32_74 : i32 -> !pto.mask, i32 + %c672 = arith.constant 672 : index + %81 = pto.vlds %15[%c672] : !pto.ptr -> !pto.vreg<64xf32> + %82 = pto.vlds %16[%c672] : !pto.ptr -> !pto.vreg<64xf32> + %83 = pto.vadd %81, %82, %mask_75 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %83, %17[%c672], %mask_75 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_77 = arith.constant 32 : i32 + %mask_78, %scalar_out_79 = pto.plt_b32 %c32_i32_77 : i32 -> !pto.mask, i32 + %c704 = arith.constant 704 : index + %84 = pto.vlds %15[%c704] : !pto.ptr -> !pto.vreg<64xf32> + %85 = pto.vlds %16[%c704] : !pto.ptr -> !pto.vreg<64xf32> + %86 = pto.vadd %84, %85, %mask_78 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %86, %17[%c704], %mask_78 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_80 = arith.constant 32 : i32 + %mask_81, %scalar_out_82 = pto.plt_b32 %c32_i32_80 : i32 -> !pto.mask, i32 + %c736 = arith.constant 736 : index + %87 = pto.vlds %15[%c736] : !pto.ptr -> !pto.vreg<64xf32> + %88 = pto.vlds %16[%c736] : !pto.ptr -> !pto.vreg<64xf32> + %89 = pto.vadd %87, %88, %mask_81 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %89, %17[%c736], %mask_81 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_83 = arith.constant 32 : i32 + %mask_84, %scalar_out_85 = pto.plt_b32 %c32_i32_83 : i32 -> !pto.mask, i32 + %c768 = arith.constant 768 : index + %90 = pto.vlds %15[%c768] : !pto.ptr -> !pto.vreg<64xf32> + %91 = pto.vlds %16[%c768] : !pto.ptr -> !pto.vreg<64xf32> + %92 = pto.vadd %90, %91, %mask_84 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %92, %17[%c768], %mask_84 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_86 = arith.constant 32 : i32 + %mask_87, %scalar_out_88 = pto.plt_b32 %c32_i32_86 : i32 -> !pto.mask, i32 + %c800 = arith.constant 800 : index + %93 = pto.vlds %15[%c800] : !pto.ptr -> !pto.vreg<64xf32> + %94 = pto.vlds %16[%c800] : !pto.ptr -> !pto.vreg<64xf32> + %95 = pto.vadd %93, %94, %mask_87 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %95, %17[%c800], %mask_87 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_89 = arith.constant 32 : i32 + %mask_90, %scalar_out_91 = pto.plt_b32 %c32_i32_89 : i32 -> !pto.mask, i32 + %c832 = arith.constant 832 : index + %96 = pto.vlds %15[%c832] : !pto.ptr -> !pto.vreg<64xf32> + %97 = pto.vlds %16[%c832] : !pto.ptr -> !pto.vreg<64xf32> + %98 = pto.vadd %96, %97, %mask_90 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %98, %17[%c832], %mask_90 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_92 = arith.constant 32 : i32 + %mask_93, %scalar_out_94 = pto.plt_b32 %c32_i32_92 : i32 -> !pto.mask, i32 + %c864 = arith.constant 864 : index + %99 = pto.vlds %15[%c864] : !pto.ptr -> !pto.vreg<64xf32> + %100 = pto.vlds %16[%c864] : !pto.ptr -> !pto.vreg<64xf32> + %101 = pto.vadd %99, %100, %mask_93 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %101, %17[%c864], %mask_93 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_95 = arith.constant 32 : i32 + %mask_96, %scalar_out_97 = pto.plt_b32 %c32_i32_95 : i32 -> !pto.mask, i32 + %c896 = arith.constant 896 : index + %102 = pto.vlds %15[%c896] : !pto.ptr -> !pto.vreg<64xf32> + %103 = pto.vlds %16[%c896] : !pto.ptr -> !pto.vreg<64xf32> + %104 = pto.vadd %102, %103, %mask_96 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %104, %17[%c896], %mask_96 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_98 = arith.constant 32 : i32 + %mask_99, %scalar_out_100 = pto.plt_b32 %c32_i32_98 : i32 -> !pto.mask, i32 + %c928 = arith.constant 928 : index + %105 = pto.vlds %15[%c928] : !pto.ptr -> !pto.vreg<64xf32> + %106 = pto.vlds %16[%c928] : !pto.ptr -> !pto.vreg<64xf32> + %107 = pto.vadd %105, %106, %mask_99 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %107, %17[%c928], %mask_99 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_101 = arith.constant 32 : i32 + %mask_102, %scalar_out_103 = pto.plt_b32 %c32_i32_101 : i32 -> !pto.mask, i32 + %c960 = arith.constant 960 : index + %108 = pto.vlds %15[%c960] : !pto.ptr -> !pto.vreg<64xf32> + %109 = pto.vlds %16[%c960] : !pto.ptr -> !pto.vreg<64xf32> + %110 = pto.vadd %108, %109, %mask_102 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %110, %17[%c960], %mask_102 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %c32_i32_104 = arith.constant 32 : i32 + %mask_105, %scalar_out_106 = pto.plt_b32 %c32_i32_104 : i32 -> !pto.mask, i32 + %c992 = arith.constant 992 : index + %111 = pto.vlds %15[%c992] : !pto.ptr -> !pto.vreg<64xf32> + %112 = pto.vlds %16[%c992] : !pto.ptr -> !pto.vreg<64xf32> + %113 = pto.vadd %111, %112, %mask_105 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %113, %17[%c992], %mask_105 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.tstore ins(%14 : !pto.tile_buf) outs(%11 : !pto.partition_tensor_view<32x32xf32>) } return } diff --git a/ptodsl/lib/a5/generated/a5_micro_vector_copy.pto b/ptodsl/lib/a5/generated/a5_vector_copy.pto similarity index 72% rename from ptodsl/lib/a5/generated/a5_micro_vector_copy.pto rename to ptodsl/lib/a5/generated/a5_vector_copy.pto index 7cda605b..7558636d 100644 --- a/ptodsl/lib/a5/generated/a5_micro_vector_copy.pto +++ b/ptodsl/lib/a5/generated/a5_vector_copy.pto @@ -1,5 +1,5 @@ module { - func.func @a5_micro_vector_copy(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: index) { + func.func @a5_vector_copy(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: index) { pto.section.vector { %0 = pto.vlds %arg0[%arg2] : !pto.ptr -> !pto.vreg<64xf32> %1 = pto.pset_b32 "PAT_ALL" : !pto.mask diff --git a/ptodsl/lib/a5/kernels.py b/ptodsl/lib/a5/kernels.py index 4bd7e531..b66e5370 100644 --- a/ptodsl/lib/a5/kernels.py +++ b/ptodsl/lib/a5/kernels.py @@ -1,17 +1,33 @@ -from mlir.dialects import pto as _raw_pto +from mlir.dialects import pto as raw_pto from mlir.ir import IndexType -from ... import Constexpr, pto, scalar as s, to_ir_module +from ... import Constexpr, language as dsl, scalar as s, to_ir_module from ...language import make_mxfp8 +from ._common import ( + VF_IMPL_DEFAULT, + alloc_tile_buffer, + load_tile, + make_tensor, + ptr, + slice_tensor, + store_tile, +) from . import ops -def build_elementwise_add(*, rows=32, cols=32, tile_rows=32, tile_cols=32, dtype=None): - dtype = pto.float32 if dtype is None else dtype +def _resolve_dtype(dtype, default_name): + if dtype is None: + return getattr(dsl, default_name) + if isinstance(dtype, str): + return getattr(dsl, dtype) + return dtype + +def build_elementwise_add(*, rows=32, cols=32, tile_rows=32, tile_cols=32, dtype=None): def meta_data(): + element_dtype = _resolve_dtype(dtype, "float32") return { - "ptr_t": pto.ptr(dtype), + "ptr_t": ptr(element_dtype), "index_t": IndexType.get(), } @@ -23,20 +39,36 @@ def a5_elementwise_add( n_rows: "index_t", n_cols: "index_t", ) -> None: - lhs = pto.make_tensor(src0, shape=[n_rows, n_cols], dtype=dtype) - rhs = pto.make_tensor(src1, shape=[n_rows, n_cols], dtype=dtype) - out = pto.make_tensor(dst, shape=[n_rows, n_cols], dtype=dtype) - - lhs_tile = lhs.slice([0, 0], [tile_rows, tile_cols]) - rhs_tile = rhs.slice([0, 0], [tile_rows, tile_cols]) - out_tile = out.slice([0, 0], [tile_rows, tile_cols]) - - with pto.vector_section(): - ops.add_micro( + element_dtype = _resolve_dtype(dtype, "float32") + lhs = make_tensor(src0, shape=[n_rows, n_cols], dtype=element_dtype) + rhs = make_tensor(src1, shape=[n_rows, n_cols], dtype=element_dtype) + out = make_tensor(dst, shape=[n_rows, n_cols], dtype=element_dtype) + + lhs_tile = slice_tensor( + lhs, + offsets=[0, 0], + sizes=[tile_rows, tile_cols], + dtype=element_dtype, + ) + rhs_tile = slice_tensor( + rhs, + offsets=[0, 0], + sizes=[tile_rows, tile_cols], + dtype=element_dtype, + ) + out_tile = slice_tensor( + out, + offsets=[0, 0], + sizes=[tile_rows, tile_cols], + dtype=element_dtype, + ) + + with dsl.vector_section(): + ops.tadd( lhs_tile, rhs_tile, out_tile, - dtype=dtype, + dtype=element_dtype, shape=[tile_rows, tile_cols], ) @@ -44,68 +76,71 @@ def a5_elementwise_add( def build_templated_elementwise_add(*, dtype=None): - dtype = pto.float32 if dtype is None else dtype - - def meta_data(ROWS=32, COLS=32): - return { - "ptr_t": pto.ptr(dtype), - "shape": [ROWS, COLS], - } - - @to_ir_module(meta_data=meta_data) - def a5_templated_elementwise_add( - src0: "ptr_t", - src1: "ptr_t", - dst: "ptr_t", + def specialize( + *, ROWS: Constexpr[int] = 32, COLS: Constexpr[int] = 32, - VF_IMPL: Constexpr[str] = ops.VF_IMPL_DEFAULT, - ) -> None: - lhs = pto.make_tensor(src0, shape=shape, dtype=dtype) - rhs = pto.make_tensor(src1, shape=shape, dtype=dtype) - out = pto.make_tensor(dst, shape=shape, dtype=dtype) - - with pto.vector_section(): - ops.add_micro( - lhs.slice([0, 0], shape), - rhs.slice([0, 0], shape), - out.slice([0, 0], shape), - dtype=dtype, - shape=shape, - impl=VF_IMPL, - ) - - return a5_templated_elementwise_add - - -def build_micro_vector_copy(*, lanes=64, dtype=None): - dtype = pto.float32 if dtype is None else dtype - + VF_IMPL: Constexpr[str] = VF_IMPL_DEFAULT, + ): + def meta_data(): + element_dtype = _resolve_dtype(dtype, "float32") + return { + "ptr_t": ptr(element_dtype), + "shape": [ROWS, COLS], + } + + @to_ir_module(meta_data=meta_data) + def a5_templated_elementwise_add( + src0: "ptr_t", + src1: "ptr_t", + dst: "ptr_t", + ) -> None: + element_dtype = _resolve_dtype(dtype, "float32") + lhs = make_tensor(src0, shape=shape, dtype=element_dtype) + rhs = make_tensor(src1, shape=shape, dtype=element_dtype) + out = make_tensor(dst, shape=shape, dtype=element_dtype) + + with dsl.vector_section(): + ops.tadd( + slice_tensor(lhs, offsets=[0, 0], sizes=shape, dtype=element_dtype), + slice_tensor(rhs, offsets=[0, 0], sizes=shape, dtype=element_dtype), + slice_tensor(out, offsets=[0, 0], sizes=shape, dtype=element_dtype), + dtype=element_dtype, + shape=shape, + impl=VF_IMPL, + ) + + return a5_templated_elementwise_add + + return specialize + + +def build_vector_copy(*, lanes=64, dtype=None): def meta_data(): + element_dtype = _resolve_dtype(dtype, "float32") return { - "ptr_t": pto.ptr(dtype, space="VEC"), + "ptr_t": ptr(element_dtype, space="VEC"), "index_t": IndexType.get(), } @to_ir_module(meta_data=meta_data) - def a5_micro_vector_copy(src: "ptr_t", dst: "ptr_t", offset: "index_t") -> None: - with pto.vector_section(): - ops.vector_copy(src, dst, offset, lanes=lanes, dtype=dtype) + def a5_vector_copy(src: "ptr_t", dst: "ptr_t", offset: "index_t") -> None: + element_dtype = _resolve_dtype(dtype, "float32") + with dsl.vector_section(): + ops.vector_copy(src, dst, offset, lanes=lanes, dtype=element_dtype) - return a5_micro_vector_copy + return a5_vector_copy def build_mxfp8_matmul(*, m=16, k=64, n=32, lhs_variant="e5m2", rhs_variant="e5m2"): - mx = make_mxfp8(lhs=lhs_variant, rhs=rhs_variant) - scale_k = mx.scale_k(k) - def meta_data(): + mx = make_mxfp8(lhs=lhs_variant, rhs=rhs_variant) return { - "ptr_lhs": pto.ptr(mx.lhs), - "ptr_rhs": pto.ptr(mx.rhs), - "ptr_scale": pto.ptr(mx.scale), - "ptr_bias": pto.ptr(mx.acc), - "ptr_out": pto.ptr(mx.acc), + "ptr_lhs": ptr(mx.lhs), + "ptr_rhs": ptr(mx.rhs), + "ptr_scale": ptr(mx.scale), + "ptr_bias": ptr(mx.acc), + "ptr_out": ptr(mx.acc), } @to_ir_module(meta_data=meta_data) @@ -117,46 +152,67 @@ def a5_mxfp8_matmul( bias_ptr: "ptr_bias", out_ptr: "ptr_out", ) -> None: - lhs = pto.make_tensor(lhs_ptr, shape=[m, k], dtype=mx.lhs) - rhs = pto.make_tensor(rhs_ptr, shape=[k, n], dtype=mx.rhs) - lhs_scale = pto.make_tensor(lhs_scale_ptr, shape=[m, scale_k], dtype=mx.scale) - rhs_scale = pto.make_tensor(rhs_scale_ptr, shape=[scale_k, n], dtype=mx.scale) - bias = pto.make_tensor(bias_ptr, shape=[1, n], dtype=mx.acc) - out = pto.make_tensor(out_ptr, shape=[m, n], dtype=mx.acc) - - with pto.cube_section(): - lhs_tile = ops.load_tile( - lhs.slice([0, 0], [m, k]), dtype=mx.lhs, shape=[m, k], space="LEFT" + mx = make_mxfp8(lhs=lhs_variant, rhs=rhs_variant) + scale_k = mx.scale_k(k) + lhs = make_tensor(lhs_ptr, shape=[m, k], dtype=mx.lhs) + rhs = make_tensor(rhs_ptr, shape=[k, n], dtype=mx.rhs) + lhs_scale = make_tensor(lhs_scale_ptr, shape=[m, scale_k], dtype=mx.scale) + rhs_scale = make_tensor(rhs_scale_ptr, shape=[scale_k, n], dtype=mx.scale) + bias = make_tensor(bias_ptr, shape=[1, n], dtype=mx.acc) + out = make_tensor(out_ptr, shape=[m, n], dtype=mx.acc) + + with dsl.cube_section(): + lhs_tile = load_tile( + slice_tensor(lhs, offsets=[0, 0], sizes=[m, k], dtype=mx.lhs), + dtype=mx.lhs, + shape=[m, k], + space="LEFT", ) - rhs_tile = ops.load_tile( - rhs.slice([0, 0], [k, n]), dtype=mx.rhs, shape=[k, n], space="RIGHT" + rhs_tile = load_tile( + slice_tensor(rhs, offsets=[0, 0], sizes=[k, n], dtype=mx.rhs), + dtype=mx.rhs, + shape=[k, n], + space="RIGHT", ) - lhs_scale_tile = ops.load_tile( - lhs_scale.slice([0, 0], [m, scale_k]), + lhs_scale_tile = load_tile( + slice_tensor( + lhs_scale, + offsets=[0, 0], + sizes=[m, scale_k], + dtype=mx.scale, + ), dtype=mx.scale, shape=[m, scale_k], space="SCALING", - config=pto.TileBufConfig( + config=dsl.TileBufConfig( blayout="RowMajor", slayout="RowMajor", - s_fractal_size=_raw_pto.TileConfig.fractalMxSize, + s_fractal_size=raw_pto.TileConfig.fractalMxSize, ), ) - rhs_scale_tile = ops.load_tile( - rhs_scale.slice([0, 0], [scale_k, n]), + rhs_scale_tile = load_tile( + slice_tensor( + rhs_scale, + offsets=[0, 0], + sizes=[scale_k, n], + dtype=mx.scale, + ), dtype=mx.scale, shape=[scale_k, n], space="SCALING", - config=pto.TileBufConfig( + config=dsl.TileBufConfig( blayout="ColMajor", slayout="ColMajor", - s_fractal_size=_raw_pto.TileConfig.fractalMxSize, + s_fractal_size=raw_pto.TileConfig.fractalMxSize, ), ) - bias_tile = ops.load_tile( - bias.slice([0, 0], [1, n]), dtype=mx.acc, shape=[1, n], space="BIAS" + bias_tile = load_tile( + slice_tensor(bias, offsets=[0, 0], sizes=[1, n], dtype=mx.acc), + dtype=mx.acc, + shape=[1, n], + space="BIAS", ) - acc_tile = pto.make_tile_buffer(mx.acc, [m, n], space="ACC").alloc() + acc_tile = alloc_tile_buffer(mx.acc, [m, n], space="ACC") ops.matmul_mx_bias( lhs_tile, lhs_scale_tile, @@ -165,7 +221,9 @@ def a5_mxfp8_matmul( bias_tile, acc_tile, ) - ops.store_tile(acc_tile, out.slice([0, 0], [m, n])) + store_tile( + acc_tile, slice_tensor(out, offsets=[0, 0], sizes=[m, n], dtype=mx.acc) + ) return a5_mxfp8_matmul @@ -173,47 +231,64 @@ def a5_mxfp8_matmul( def build_cube_matmul( *, m=16, k=32, n=16, lhs_dtype=None, rhs_dtype=None, acc_dtype=None ): - lhs_dtype = pto.float16 if lhs_dtype is None else lhs_dtype - rhs_dtype = pto.float16 if rhs_dtype is None else rhs_dtype - acc_dtype = pto.float32 if acc_dtype is None else acc_dtype - def meta_data(): + lhs_element_dtype = _resolve_dtype(lhs_dtype, "float16") + rhs_element_dtype = _resolve_dtype(rhs_dtype, "float16") + acc_element_dtype = _resolve_dtype(acc_dtype, "float32") return { - "ptr_lhs": pto.ptr(lhs_dtype), - "ptr_rhs": pto.ptr(rhs_dtype), - "ptr_out": pto.ptr(acc_dtype), + "ptr_lhs": ptr(lhs_element_dtype), + "ptr_rhs": ptr(rhs_element_dtype), + "ptr_out": ptr(acc_element_dtype), } @to_ir_module(meta_data=meta_data) def a5_cube_matmul( lhs_ptr: "ptr_lhs", rhs_ptr: "ptr_rhs", out_ptr: "ptr_out" ) -> None: + lhs_element_dtype = _resolve_dtype(lhs_dtype, "float16") + rhs_element_dtype = _resolve_dtype(rhs_dtype, "float16") + acc_element_dtype = _resolve_dtype(acc_dtype, "float32") c0 = s.const(0) - lhs = pto.make_tensor(lhs_ptr, shape=[m, k], dtype=lhs_dtype) - rhs = pto.make_tensor(rhs_ptr, shape=[k, n], dtype=rhs_dtype) - out = pto.make_tensor(out_ptr, shape=[m, n], dtype=acc_dtype) - - with pto.cube_section(): - lhs_mat = ops.load_tile( - lhs.slice([0, 0], [m, k]), dtype=lhs_dtype, shape=[m, k], space="MAT" + lhs = make_tensor(lhs_ptr, shape=[m, k], dtype=lhs_element_dtype) + rhs = make_tensor(rhs_ptr, shape=[k, n], dtype=rhs_element_dtype) + out = make_tensor(out_ptr, shape=[m, n], dtype=acc_element_dtype) + + with dsl.cube_section(): + lhs_mat = load_tile( + slice_tensor( + lhs, offsets=[0, 0], sizes=[m, k], dtype=lhs_element_dtype + ), + dtype=lhs_element_dtype, + shape=[m, k], + space="MAT", ) - rhs_mat = ops.load_tile( - rhs.slice([0, 0], [k, n]), dtype=rhs_dtype, shape=[k, n], space="MAT" + rhs_mat = load_tile( + slice_tensor( + rhs, offsets=[0, 0], sizes=[k, n], dtype=rhs_element_dtype + ), + dtype=rhs_element_dtype, + shape=[k, n], + space="MAT", ) - lhs_tile = pto.make_tile_buffer(lhs_dtype, [m, k], space="LEFT").alloc() - rhs_tile = pto.make_tile_buffer(rhs_dtype, [k, n], space="RIGHT").alloc() - acc_tile = pto.make_tile_buffer(acc_dtype, [m, n], space="ACC").alloc() + lhs_tile = alloc_tile_buffer(lhs_element_dtype, [m, k], space="LEFT") + rhs_tile = alloc_tile_buffer(rhs_element_dtype, [k, n], space="RIGHT") + acc_tile = alloc_tile_buffer(acc_element_dtype, [m, n], space="ACC") ops.extract(lhs_mat, c0, c0, lhs_tile) ops.move_tile(rhs_mat, rhs_tile) ops.matmul(lhs_tile, rhs_tile, acc_tile) - ops.store_tile(acc_tile, out.slice([0, 0], [m, n])) + store_tile( + acc_tile, + slice_tensor( + out, offsets=[0, 0], sizes=[m, n], dtype=acc_element_dtype + ), + ) return a5_cube_matmul KERNEL_BUILDERS = { "a5_elementwise_add": build_elementwise_add, - "a5_micro_vector_copy": build_micro_vector_copy, + "a5_vector_copy": build_vector_copy, "a5_cube_matmul": build_cube_matmul, } @@ -222,7 +297,7 @@ def a5_cube_matmul( "KERNEL_BUILDERS", "build_cube_matmul", "build_elementwise_add", - "build_micro_vector_copy", + "build_vector_copy", "build_mxfp8_matmul", "build_templated_elementwise_add", ] diff --git a/ptodsl/lib/a5/native.py b/ptodsl/lib/a5/native.py new file mode 100644 index 00000000..1fd2e937 --- /dev/null +++ b/ptodsl/lib/a5/native.py @@ -0,0 +1,264 @@ +"""Helpers that still map directly to PTO tile/cube ops or simple micro utilities.""" + +from mlir.dialects import arith, pto +from mlir.ir import BoolAttr, IntegerType + +from ... import language as dsl +from ...api.scalar import _unwrap +from ._common import ( + _call, + full_mask, + load_tile, + load_view, + mask_type, + move_tile, + s, + store_tile, + store_view, + vreg_type, +) + + +def adds(src, scalar, out): + _call(pto.TAddSOp, src, scalar, out) + return out + + +def subs(src, scalar, out): + _call(pto.TSubSOp, src, scalar, out) + return out + + +def muls(src, scalar, out): + _call(pto.TMulSOp, src, scalar, out) + return out + + +def divs(src, scalar, out): + _call(pto.TDivSOp, src, scalar, out) + return out + + +def max(lhs, rhs, out): + _call(pto.TMaxOp, lhs, rhs, out) + return out + + +def maxs(src, scalar, out): + _call(pto.TMaxSOp, src, scalar, out) + return out + + +def min(lhs, rhs, out): + _call(pto.TMinOp, lhs, rhs, out) + return out + + +def mins(src, scalar, out): + _call(pto.TMinSOp, src, scalar, out) + return out + + +def and_(lhs, rhs, out): + _call(pto.TAndOp, lhs, rhs, out) + return out + + +def xor(lhs, rhs, out): + _call(pto.TXorOp, lhs, rhs, out) + return out + + +def shl(lhs, rhs, out): + _call(pto.TShlOp, lhs, rhs, out) + return out + + +def shls(src, scalar, out): + _call(pto.TShlSOp, src, scalar, out) + return out + + +def shr(lhs, rhs, out): + _call(pto.TShrOp, lhs, rhs, out) + return out + + +def shrs(src, scalar, out): + _call(pto.TShrSOp, src, scalar, out) + return out + + +def compare(src0, src1, out, *, mode): + cmp_mode = ( + pto.CmpModeAttr.get(getattr(pto.CmpMode, mode.upper())) + if isinstance(mode, str) + else mode + ) + _call(pto.TCmpOp, src0, src1, out, cmpMode=cmp_mode) + return out + + +def scatter(src, indices, dst): + _call(pto.TScatterOp, src, indices, dst) + return dst + + +def select(mask, lhs, rhs, out): + _call(pto.TSelOp, mask, lhs, rhs, out) + return out + + +def concat(lhs, rhs, dst): + _call(pto.TConcatOp, lhs, rhs, dst) + return dst + + +def extract(source, index_row, index_col, out): + _call( + pto.TExtractOp, + src=source, + indexRow=_unwrap(index_row), + indexCol=_unwrap(index_col), + dst=out, + ) + return out + + +def insert(source, index_row, index_col, out): + _call( + pto.TInsertOp, + src=source, + indexRow=_unwrap(index_row), + indexCol=_unwrap(index_col), + dst=out, + ) + return out + + +def row_prod(src, tmp, dst): + _call(pto.TRowProdOp, src=src, tmp=tmp, dst=dst) + return dst + + +def col_prod(src, tmp, dst, *, is_binary=True): + _call(pto.TColProdOp, src=src, dst=dst, tmp=tmp, isBinary=BoolAttr.get(is_binary)) + return dst + + +def col_expand_mul(src0, src1, dst): + _call(pto.TColExpandMulOp, src0=src0, src1=src1, dst=dst) + return dst + + +def col_expand_max(src0, src1, dst): + _call(pto.TColExpandMaxOp, src0=src0, src1=src1, dst=dst) + return dst + + +def col_expand_min(src0, src1, dst): + _call(pto.TColExpandMinOp, src0=src0, src1=src1, dst=dst) + return dst + + +def trans(src, dst): + _call(pto.TTransOp, src, dst) + return dst + + +def matmul(lhs, rhs, out): + _call(pto.TMatmulOp, None, lhs, rhs, out) + return out + + +def matmul_acc(acc, lhs, rhs, out): + _call(pto.TMatmulAccOp, None, acc, lhs, rhs, out) + return out + + +def matmul_bias(lhs, rhs, bias, out): + _call(pto.TMatmulBiasOp, None, lhs, rhs, bias, out) + return out + + +def matmul_mx(lhs, lhs_scale, rhs, rhs_scale, out): + _call(pto.TMatmulMxOp, None, lhs, lhs_scale, rhs, rhs_scale, out) + return out + + +def matmul_mx_acc(acc, lhs, lhs_scale, rhs, rhs_scale, out): + _call(pto.TMatmulMxAccOp, None, acc, lhs, lhs_scale, rhs, rhs_scale, out) + return out + + +def matmul_mx_bias(lhs, lhs_scale, rhs, rhs_scale, bias, out): + _call(pto.TMatmulMxBiasOp, None, lhs, lhs_scale, rhs, rhs_scale, bias, out) + return out + + +def full_mask_b32(): + return pto.pset_b32(mask_type(), "PAT_ALL") + + +def vload(ptr_value, offset, *, lanes=64, dtype=None): + dtype = dsl.float32 if dtype is None else dtype + return pto.vlds(vreg_type(lanes, dtype), _unwrap(ptr_value), _unwrap(offset)) + + +def vstore(vector, ptr_value, offset, *, mask=None): + if mask is None: + mask = full_mask(dsl.float32) + pto.vsts(_unwrap(vector), _unwrap(ptr_value), _unwrap(offset), _unwrap(mask)) + return ptr_value + + +def vector_copy(src_ptr, dst_ptr, offset, *, lanes=64, dtype=None): + dtype = dsl.float32 if dtype is None else dtype + vec = vload(src_ptr, offset, lanes=lanes, dtype=dtype) + pto.vsts(vec, _unwrap(dst_ptr), _unwrap(offset), full_mask(dtype)) + return vec + + +__all__ = [ + "adds", + "and_", + "col_expand_max", + "col_expand_min", + "col_expand_mul", + "col_prod", + "compare", + "concat", + "divs", + "extract", + "full_mask_b32", + "insert", + "load_tile", + "load_view", + "matmul", + "matmul_acc", + "matmul_bias", + "matmul_mx", + "matmul_mx_acc", + "matmul_mx_bias", + "max", + "maxs", + "min", + "mins", + "move_tile", + "muls", + "row_prod", + "scatter", + "select", + "shl", + "shls", + "shr", + "shrs", + "store_tile", + "store_view", + "subs", + "trans", + "vector_copy", + "vload", + "vstore", + "xor", +] diff --git a/ptodsl/lib/a5/ops.py b/ptodsl/lib/a5/ops.py index 996c6723..043e240f 100644 --- a/ptodsl/lib/a5/ops.py +++ b/ptodsl/lib/a5/ops.py @@ -1,2029 +1,155 @@ -import builtins -import re - -from mlir.dialects import arith as _arith -from mlir.dialects import pto as _pto -from mlir.ir import IntegerAttr, IntegerType - -from ... import pto as _dsl_pto -from ... import scalar as _scalar -from ... import const_expr, range_constexpr -from ...api.scalar import _unwrap - -VF_IMPL_DEFAULT = "default" -VF_IMPL_1D_NO_POST_UPDATE = "1d_no_post_update" -VF_IMPL_1D_POST_UPDATE = "1d_post_update" -VF_IMPL_2D_NO_POST_UPDATE = "2d_no_post_update" -VF_IMPL_2D_POST_UPDATE = "2d_post_update" - - -_DTYPE_ALIAS_GROUPS = { - "f32": {"f32", "float32"}, - "f16": {"f16", "float16", "half"}, - "bf16": {"bf16", "bfloat16"}, - "i32": {"i32", "int32"}, - "u32": {"u32", "uint32"}, - "i16": {"i16", "int16"}, - "u16": {"u16", "uint16"}, - "i8": {"i8", "int8"}, - "u8": {"u8", "uint8"}, -} - - -def _call(op, *args, **kwargs): - return op( - *(_unwrap(arg) for arg in args), - **{name: _unwrap(value) for name, value in kwargs.items()}, - ) - - -def _cmp_mode_attr(mode): - if mode is None: - return None - if isinstance(mode, str): - return _pto.CmpModeAttr.get(getattr(_pto.CmpMode, mode.upper())) - return mode - - -def _const_i64(value): - i64 = IntegerType.get_signless(64) - return _arith.ConstantOp(i64, IntegerAttr.get(i64, value)).result - - -def _const_i32(value): - i32 = IntegerType.get_signless(32) - return _arith.ConstantOp(i32, IntegerAttr.get(i32, value)).result - - -def _const_float(dtype, value): - return _arith.ConstantOp(_scalar.resolve_type(dtype), value).result - - -def _dtype_token(dtype): - text = str(_scalar.resolve_type(dtype)).lower() - for canonical, aliases in _DTYPE_ALIAS_GROUPS.items(): - if any(alias in text for alias in aliases): - return canonical - raise ValueError(f"Unsupported dtype token for '{dtype}'.") - - -def _dtype_byte_width(dtype): - text = str(dtype) - if ( - "float32" in text - or "f32" in text - or "int32" in text - or "i32" in text - or "uint32" in text - or "u32" in text - ): - return 4 - if ( - "float16" in text - or "f16" in text - or "bfloat16" in text - or "bf16" in text - or "int16" in text - or "i16" in text - or "u16" in text - ): - return 2 - if "i8" in text or "u8" in text: - return 1 - raise ValueError(f"Unsupported dtype byte width for '{dtype}'.") - - -def _extract_static_tensor_shape(value): - raw = _unwrap(value) - type_obj = getattr(raw, "type", None) - if type_obj is None: - return None - text = str(type_obj) - match = re.search( - r"!pto\.(?:partition_)?tensor_view<(?P[^>]+)>|!pto\.tile_buf<[^,]+,\s*(?P[^>]+)>", - text, - ) - if not match: - return None - payload = match.group("payload") or match.group("tile_payload") - dims = re.findall(r"(\?|\d+)x", payload) - if not dims: - return None - shape = [] - for dim in dims: - if dim == "?": - return None - shape.append(int(dim)) - return shape - - -def _extract_tensor_dtype_token(value): - raw = _unwrap(value) - type_obj = getattr(raw, "type", None) - if type_obj is None: - return None - text = str(type_obj).lower() - for canonical, aliases in _DTYPE_ALIAS_GROUPS.items(): - if any(alias in text for alias in aliases): - return canonical - return None - - -def _require_supported_dtype(dtype, *, allowed, message): - try: - token = _dtype_token(dtype) - except ValueError as exc: - raise ValueError(message) from exc - if token not in allowed: - raise ValueError(message) - return token - - -def _require_view_shape(view, expected_shape, *, context, message): - actual_shape = _extract_static_tensor_shape(view) - if actual_shape is None: - return - if list(actual_shape) != list(expected_shape): - raise ValueError(f"{message} Expected {expected_shape}, got {actual_shape}.") - - -def _require_view_dtype(view, dtype, *, message): - actual_token = _extract_tensor_dtype_token(view) - if actual_token is None: - return - if actual_token != _dtype_token(dtype): - raise ValueError(message) - - -def _micro_lane_count(dtype): - return 256 // _dtype_byte_width(dtype) - - -def _resolve_lanes(dtype, lanes): - if lanes is None: - return _micro_lane_count(dtype) - return lanes - - -def _full_mask(dtype): - width = _dtype_byte_width(dtype) - if width == 4: - return _dsl_pto.pset_b32(_dsl_pto.MaskType(), "PAT_ALL") - if width == 2: - return _dsl_pto.pset_b16(_dsl_pto.MaskType(), "PAT_ALL") - if width == 1: - return _dsl_pto.pset_b8(_dsl_pto.MaskType(), "PAT_ALL") - raise ValueError(f"Unsupported dtype mask width for '{dtype}'.") - - -def _tail_mask(dtype, active_lanes): - i32 = IntegerType.get_signless(32) - width = _dtype_byte_width(dtype) - active = _const_i32(active_lanes) - if width == 4: - mask, _ = _dsl_pto.plt_b32(_dsl_pto.MaskType(), i32, active) - return mask - if width == 2: - mask, _ = _dsl_pto.plt_b16(_dsl_pto.MaskType(), i32, active) - return mask - if width == 1: - mask, _ = _dsl_pto.plt_b8(_dsl_pto.MaskType(), i32, active) - return mask - raise ValueError(f"Unsupported dtype tail mask width for '{dtype}'.") - - -def _mask_for_chunk(dtype, active_lanes): - lanes = _micro_lane_count(dtype) - if active_lanes == lanes: - return _full_mask(dtype) - return _tail_mask(dtype, active_lanes) - - -def _onept_dist(dtype): - width = _dtype_byte_width(dtype) - if width == 4: - return "ONEPT_B32" - if width == 2: - return "ONEPT_B16" - if width == 1: - return "ONEPT_B8" - raise ValueError(f"Unsupported dtype point-store width for '{dtype}'.") - - -def _normalize_vf_impl_kind(impl): - if impl is None: - return VF_IMPL_DEFAULT - - normalized = str(impl).strip().lower() - aliases = { - "default": VF_IMPL_DEFAULT, - "vfimpl_default": VF_IMPL_DEFAULT, - "1d_no_post_update": VF_IMPL_1D_NO_POST_UPDATE, - "vfimpl_1d_no_post_update": VF_IMPL_1D_NO_POST_UPDATE, - "1d_post_update": VF_IMPL_1D_POST_UPDATE, - "vfimpl_1d_post_update": VF_IMPL_1D_POST_UPDATE, - "2d_no_post_update": VF_IMPL_2D_NO_POST_UPDATE, - "vfimpl_2d_no_post_update": VF_IMPL_2D_NO_POST_UPDATE, - "2d_post_update": VF_IMPL_2D_POST_UPDATE, - "vfimpl_2d_post_update": VF_IMPL_2D_POST_UPDATE, - } - if normalized not in aliases: - supported = ", ".join(sorted(aliases)) - raise ValueError( - f"Unsupported VF impl kind '{impl}'. Expected one of: {supported}." - ) - return aliases[normalized] - - -def _alloc_like_view(view, *, dtype, shape, space, valid_shape=None, config=None): - return _dsl_pto.make_tile_buffer( - dtype, - shape, - space=space, - valid_shape=valid_shape, - config=config, - ).alloc() - - -def load_tile( - view, - tile_buffer=None, - *, - dtype=None, - shape=None, - space="VEC", - valid_shape=None, - config=None, -): - if tile_buffer is None: - if dtype is None or shape is None: - raise ValueError( - "`load_tile(...)` requires either `tile_buffer=` or both `dtype=` and `shape=`." - ) - tile_buffer = _alloc_like_view( - view, - dtype=dtype, - shape=shape, - space=space, - valid_shape=valid_shape, - config=config, - ) - _dsl_pto.load(view, tile_buffer) - return tile_buffer - - -def store_tile(tile_buffer, view): - _dsl_pto.store(tile_buffer, view) - return view - - -def move_tile(source, dest): - _call(_pto.TMovOp, None, source, dest) - return dest - - -def add(lhs, rhs, out): - _call(_pto.TAddOp, lhs, rhs, out) - return out - - -def add_micro( - lhs_view, - rhs_view, - out_view, - *, - dtype, - shape, - lanes=None, - base_addr=0, - impl=VF_IMPL_DEFAULT, -): - return _binary_micro( - lhs_view, - rhs_view, - out_view, - dtype=dtype, - shape=shape, - lanes=lanes, - base_addr=base_addr, - op_name="vadd", - impl=impl, - ) - - -def sub_micro( - lhs_view, - rhs_view, - out_view, - *, - dtype, - shape, - lanes=None, - base_addr=0, - impl=VF_IMPL_DEFAULT, -): - return _binary_micro( - lhs_view, - rhs_view, - out_view, - dtype=dtype, - shape=shape, - lanes=lanes, - base_addr=base_addr, - op_name="vsub", - impl=impl, - ) - - -def mul_micro( - lhs_view, - rhs_view, - out_view, - *, - dtype, - shape, - lanes=None, - base_addr=0, - impl=VF_IMPL_DEFAULT, -): - return _binary_micro( - lhs_view, - rhs_view, - out_view, - dtype=dtype, - shape=shape, - lanes=lanes, - base_addr=base_addr, - op_name="vmul", - impl=impl, - ) - - -def div_micro( - lhs_view, - rhs_view, - out_view, - *, - dtype, - shape, - lanes=None, - base_addr=0, - impl=VF_IMPL_DEFAULT, -): - return _binary_micro( - lhs_view, - rhs_view, - out_view, - dtype=dtype, - shape=shape, - lanes=lanes, - base_addr=base_addr, - op_name="vdiv", - impl=impl, - ) - - -def or_micro( - lhs_view, - rhs_view, - out_view, - *, - dtype, - shape, - lanes=None, - base_addr=0, - impl=VF_IMPL_DEFAULT, -): - return _binary_micro( - lhs_view, - rhs_view, - out_view, - dtype=dtype, - shape=shape, - lanes=lanes, - base_addr=base_addr, - op_name="vor", - impl=impl, - ) - - -def mov_micro(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): - return _unary_micro( - src_view, - out_view, - dtype=dtype, - shape=shape, - lanes=lanes, - base_addr=base_addr, - op_name=None, - ) - - -def exp_micro(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): - return _unary_micro( - src_view, - out_view, - dtype=dtype, - shape=shape, - lanes=lanes, - base_addr=base_addr, - op_name="vexp", - ) - - -def log_micro(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): - return _unary_micro( - src_view, - out_view, - dtype=dtype, - shape=shape, - lanes=lanes, - base_addr=base_addr, - op_name="vln", - ) - - -def relu_micro(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): - return _unary_micro( - src_view, - out_view, - dtype=dtype, - shape=shape, - lanes=lanes, - base_addr=base_addr, - op_name="vrelu", - ) - - -def abs_micro(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): - return _unary_micro( - src_view, - out_view, - dtype=dtype, - shape=shape, - lanes=lanes, - base_addr=base_addr, - op_name="vabs", - ) - - -def sqrt_micro(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): - return _unary_micro( - src_view, - out_view, - dtype=dtype, - shape=shape, - lanes=lanes, - base_addr=base_addr, - op_name="vsqrt", - ) - - -def rsqrt_micro(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): - return _rsqrt_micro( - src_view, - out_view, - dtype=dtype, - shape=shape, - lanes=lanes, - base_addr=base_addr, - ) - - -def reciprocal_micro(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): - return _unary_micro( - src_view, - out_view, - dtype=dtype, - shape=shape, - lanes=lanes, - base_addr=base_addr, - op_name="vrec", - ) - - -def gather_micro( - src_view, - indices_view, - out_view, - *, - dtype, - index_dtype, - shape, - base_addr=0, -): - return _gather_micro( - src_view, - indices_view, - out_view, - dtype=dtype, - index_dtype=index_dtype, - shape=shape, - base_addr=base_addr, - ) - - -def col_expand_micro(src_view, out_view, *, dtype, shape, base_addr=0): - rows, cols = _check_col_expand_operands( - src_view, out_view, dtype=dtype, shape=shape, context="TCOLEXPAND" - ) - lanes = _micro_lane_count(dtype) - vreg_type = _dsl_pto.VRegType(lanes, dtype) - buf_bytes = rows * cols * _dtype_byte_width(dtype) - - src_addr = _const_i64(base_addr) - out_addr = _const_i64(base_addr + buf_bytes) - - src_tile = _dsl_pto.make_tile_buffer( - dtype, shape, space="VEC", valid_shape=[1, cols] - ).alloc(addr=src_addr) - out_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=out_addr) - - _dsl_pto.load(src_view, src_tile) - - src_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), src_addr) - out_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), out_addr) - - for col in range(0, cols, lanes): - active = builtins.min(lanes, cols - col) - mask = _mask_for_chunk(dtype, active) - col_offset = _scalar.const(col) - vec = _dsl_pto.vlds(vreg_type, src_ptr, col_offset) - for row in range(rows): - dst_offset = _scalar.const(row * cols + col) - _dsl_pto.vsts(vec, out_ptr, dst_offset, mask) - - _dsl_pto.store(out_tile, out_view) - return out_view - - -def row_expand_micro(src_view, out_view, *, dtype, shape, base_addr=0): - rows, cols = _check_row_expand_operands( - src_view, out_view, dtype=dtype, shape=shape, context="TROWEXPAND" - ) - lanes = _micro_lane_count(dtype) - vreg_type = _dsl_pto.VRegType(lanes, dtype) - buf_bytes = rows * cols * _dtype_byte_width(dtype) - - src_addr = _const_i64(base_addr) - out_addr = _const_i64(base_addr + buf_bytes) - - src_tile = _dsl_pto.make_tile_buffer( - dtype, shape, space="VEC", valid_shape=[rows, 1] - ).alloc(addr=src_addr) - out_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=out_addr) - - _dsl_pto.load(src_view, src_tile) - - src_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), src_addr) - out_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), out_addr) - - for row in range(rows): - scalar_offset = _scalar.const(row * cols) - align = _dsl_pto.vldas(_dsl_pto.AlignType(), src_ptr, scalar_offset) - scalar_vec, _, _ = _dsl_pto.vldus( - vreg_type, - _dsl_pto.AlignType(), - _dsl_pto.ptr(dtype, space="VEC"), - src_ptr, - scalar_offset, - align, - ) - broadcast = _dsl_pto.vdup(vreg_type, scalar_vec, position="POS_LOWEST") - for col in range(0, cols, lanes): - active = builtins.min(lanes, cols - col) - mask = _mask_for_chunk(dtype, active) - dst_offset = _scalar.const(row * cols + col) - _dsl_pto.vsts(broadcast, out_ptr, dst_offset, mask) - - _dsl_pto.store(out_tile, out_view) - return out_view - - -def row_expand_sub_micro( - base_view, expand_view, out_view, *, dtype, shape, base_addr=0 -): - return _row_expand_binary_micro( - base_view, - expand_view, - out_view, - dtype=dtype, - shape=shape, - base_addr=base_addr, - op_name="vsub", - ) - - -def row_expand_mul_micro( - base_view, expand_view, out_view, *, dtype, shape, base_addr=0 -): - return _row_expand_binary_micro( - base_view, - expand_view, - out_view, - dtype=dtype, - shape=shape, - base_addr=base_addr, - op_name="vmul", - ) - - -def row_expand_div_micro( - base_view, expand_view, out_view, *, dtype, shape, base_addr=0 -): - return _row_expand_binary_micro( - base_view, - expand_view, - out_view, - dtype=dtype, - shape=shape, - base_addr=base_addr, - op_name="vdiv", - ) - - -def row_sum_micro(src_view, out_view, *, dtype, shape, base_addr=0): - return _row_reduce_micro( - src_view, - out_view, - dtype=dtype, - shape=shape, - base_addr=base_addr, - reduce_op_name="vcadd", - combine_op_name="vadd", - init_value=0.0, - ) - - -def row_max_micro(src_view, out_view, *, dtype, shape, base_addr=0): - return _row_reduce_micro( - src_view, - out_view, - dtype=dtype, - shape=shape, - base_addr=base_addr, - reduce_op_name="vcmax", - combine_op_name="vmax", - init_value=float("-inf"), - ) - - -def row_min_micro(src_view, out_view, *, dtype, shape, base_addr=0): - return _row_reduce_micro( - src_view, - out_view, - dtype=dtype, - shape=shape, - base_addr=base_addr, - reduce_op_name="vcmin", - combine_op_name="vmin", - init_value=float("inf"), - ) - - -def col_sum_micro( - src_view, out_view, *, dtype, shape, base_addr=0, impl=VF_IMPL_DEFAULT -): - return _col_reduce_micro( - src_view, - out_view, - dtype=dtype, - shape=shape, - base_addr=base_addr, - reduce_op_name="vadd", - impl=impl, - ) - - -def col_max_micro( - src_view, out_view, *, dtype, shape, base_addr=0, impl=VF_IMPL_DEFAULT -): - return _col_reduce_micro( - src_view, - out_view, - dtype=dtype, - shape=shape, - base_addr=base_addr, - reduce_op_name="vmax", - impl=impl, - ) - - -def col_min_micro( - src_view, out_view, *, dtype, shape, base_addr=0, impl=VF_IMPL_DEFAULT -): - return _col_reduce_micro( - src_view, - out_view, - dtype=dtype, - shape=shape, - base_addr=base_addr, - reduce_op_name="vmin", - impl=impl, - ) - - -def mrgsort_micro(src_view, out_view, *, dtype, shape, block_len, base_addr=0): - return _mrgsort_micro( - src_view, - out_view, - dtype=dtype, - shape=shape, - block_len=block_len, - base_addr=base_addr, - ) - - -def sort32_micro(src_view, idx_view, out_view, *, dtype, shape, base_addr=0): - return _sort32_micro( - src_view, - idx_view, - out_view, - dtype=dtype, - shape=shape, - base_addr=base_addr, - ) - - -def _require_static_matrix_shape(shape, *, context): - if len(shape) != 2 or any(not isinstance(dim, int) for dim in shape): - raise ValueError(f"{context} currently requires a static rank-2 integer shape.") - rows, cols = shape - if rows <= 0 or cols <= 0: - raise ValueError(f"{context} requires positive row/column sizes.") - return rows, cols - - -def _check_tbinop_operands(lhs_view, rhs_view, out_view, *, dtype, shape, context): - rows, cols = _require_static_matrix_shape(shape, context=context) - _require_supported_dtype( - dtype, - allowed={"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"}, - message=f"Fix: {context} has invalid data type.", - ) - for view, label in ((lhs_view, "src0"), (rhs_view, "src1"), (out_view, "dst")): - _require_view_shape( - view, - [rows, cols], - context=context, - message=f"Fix: {context} input tile {label} valid shape mismatch with output tile dst shape.", - ) - _require_view_dtype( - view, - dtype, - message=f"Fix: {context} input tile src0, src1 and dst tile data type mismatch.", - ) - return rows, cols - - -def _check_row_expand_operands(src_view, out_view, *, dtype, shape, context): - rows, cols = _require_static_matrix_shape(shape, context=context) - _require_supported_dtype( - dtype, - allowed={"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"}, - message=f"Fix: {context} data type must be b8/b16/b32", - ) - _require_view_shape( - src_view, - [rows, 1], - context=context, - message=f"Fix: {context} source valid shape must be [rows, 1].", - ) - _require_view_shape( - out_view, - [rows, cols], - context=context, - message=f"Fix: {context} output valid shape mismatch.", - ) - _require_view_dtype( - src_view, - dtype, - message=f"Fix: {context} input data type must be consistent with the output data type.", - ) - _require_view_dtype( - out_view, - dtype, - message=f"Fix: {context} input data type must be consistent with the output data type.", - ) - return rows, cols - - -def _check_col_expand_operands(src_view, out_view, *, dtype, shape, context): - rows, cols = _require_static_matrix_shape(shape, context=context) - _require_supported_dtype( - dtype, - allowed={"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"}, - message=f"Fix: {context} data type must be b8/b16/b32", - ) - _require_view_shape( - src_view, - [1, cols], - context=context, - message=f"Fix: {context} input valid col must be consistent with output valid col.", - ) - _require_view_shape( - out_view, - [rows, cols], - context=context, - message=f"Fix: {context} output valid shape mismatch.", - ) - _require_view_dtype( - src_view, - dtype, - message=f"Fix: {context} input data type must be consistent with the output data type.", - ) - _require_view_dtype( - out_view, - dtype, - message=f"Fix: {context} input data type must be consistent with the output data type.", - ) - return rows, cols - - -def _check_row_reduce_operands(src_view, out_view, *, dtype, shape, context): - rows, cols = _require_static_matrix_shape(shape, context=context) - _require_supported_dtype( - dtype, - allowed={"f32", "f16", "i32", "i16"}, - message=( - "Row reduction only supports 'half', 'float', 'int32', or 'int16' data types. " - "Fix: Define TileDataIn with DType = half, float, int32, or int16." - ), - ) - _require_view_shape( - src_view, - [rows, cols], - context=context, - message="Fix: Ensure src valid shape matches [rows, cols].", - ) - _require_view_shape( - out_view, - [rows, 1], - context=context, - message="Fix: Pass dstValidRow = srcValidRows and use a single-column output tile.", - ) - _require_view_dtype( - src_view, - dtype, - message="Fix: Ensure TileDataOut uses the same DType as TileDataIn.", - ) - _require_view_dtype( - out_view, - dtype, - message="Fix: Ensure TileDataOut uses the same DType as TileDataIn.", - ) - return rows, cols - - -def _check_col_reduce_operands(src_view, out_view, *, dtype, shape, context): - rows, cols = _require_static_matrix_shape(shape, context=context) - _require_supported_dtype( - dtype, - allowed={"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"}, - message=f"Fix: {context} input data type is not supported by this instruction.", - ) - _require_view_shape( - src_view, - [rows, cols], - context=context, - message=f"Fix: {context} input shape mismatch.", - ) - _require_view_shape( - out_view, - [1, cols], - context=context, - message=f"Fix: {context} input valid row must be consistent with the output valid row.", - ) - _require_view_dtype( - src_view, - dtype, - message=f"Fix: {context} input data type must be consistent with the output data type.", - ) - _require_view_dtype( - out_view, - dtype, - message=f"Fix: {context} input data type must be consistent with the output data type.", - ) - return rows, cols - - -def _check_gather_operands( - src_view, indices_view, out_view, *, dtype, index_dtype, shape -): - rows, cols = _require_static_matrix_shape(shape, context="TGATHER") - dtype_token = _require_supported_dtype( - dtype, - allowed={"f32", "f16", "i32", "u32", "i16", "u16"}, - message="Fix: TGATHER Src data type must be int16_t/uint16_t/int32_t/uint32_t/half/float.", - ) - index_token = _require_supported_dtype( - index_dtype, - allowed={"i32", "u32", "i16", "u16"}, - message="Fix: TGATHER expect b16/b32", - ) - if _dtype_byte_width(dtype) != _dtype_byte_width(index_dtype): - raise ValueError( - "Fix: TGATHER micro lowering currently supports same-width source/index pairs only." - ) - for view, expected_shape, label in ( - (src_view, [rows, cols], "src"), - (indices_view, [rows, cols], "indices"), - (out_view, [rows, cols], "dst"), - ): - _require_view_shape( - view, - expected_shape, - context="TGATHER", - message=f"Fix: TGATHER {label} shape mismatch.", - ) - _require_view_dtype( - src_view, - dtype, - message="Fix: TGATHER expect same type size for dst and src", - ) - _require_view_dtype( - out_view, - dtype, - message="Fix: TGATHER expect same type size for dst and src", - ) - _require_view_dtype( - indices_view, - index_dtype, - message="Fix: TGATHER expect b16/b32", - ) - return rows, cols, dtype_token, index_token - - -def _check_mrgsort_operands(src_view, out_view, *, dtype, shape, block_len): - rows, cols = _require_static_matrix_shape(shape, context="TMRGSORT") - _require_supported_dtype( - dtype, - allowed={"f32", "f16"}, - message="TMrgsort: Unsupported data type! Supported types is half/float", - ) - if rows != 1: - raise ValueError("TMrgsort: the row of Destination and Source tile must be 1.") - if block_len <= 0 or cols % (block_len * 4) != 0: - raise ValueError("TMrgsort: src columns must be divisible by blockLen * 4.") - _require_view_shape( - src_view, - [rows, cols], - context="TMRGSORT", - message="TMrgsort: source tile shape mismatch.", - ) - _require_view_shape( - out_view, - [rows, cols], - context="TMRGSORT", - message="TMrgsort: destination tile shape mismatch.", - ) - _require_view_dtype( - src_view, - dtype, - message="TMrgsort: Destination and Source tile data types must be the same.", - ) - _require_view_dtype( - out_view, - dtype, - message="TMrgsort: Destination and Source tile data types must be the same.", - ) - return rows, cols - - -def _check_sort32_operands(src_view, idx_view, out_view, *, dtype, shape): - rows, cols = _require_static_matrix_shape(shape, context="TSORT32") - _require_supported_dtype( - dtype, - allowed={"f32", "f16"}, - message="Dst and src must be float or half.", - ) - out_cols = cols * (2 if _dtype_token(dtype) == "f32" else 4) - for view, expected_shape, label in ( - (src_view, [rows, cols], "src"), - (idx_view, [rows, cols], "idx"), - (out_view, [rows, out_cols], "dst"), - ): - _require_view_shape( - view, - expected_shape, - context="TSORT32", - message=f"TSORT32 {label} shape mismatch.", - ) - _require_view_dtype( - src_view, - dtype, - message="Dst and src mube be same.", - ) - _require_view_dtype( - out_view, - dtype, - message="Dst and src mube be same.", - ) - _require_view_dtype( - idx_view, - _dsl_pto.uint32, - message="Idx must be uint32_t.", - ) - if cols % 32 != 0: - raise ValueError( - "TSORT32 micro lowering currently requires column count divisible by 32." - ) - return rows, cols, out_cols - - -def _row_expand_binary_micro( - base_view, expand_view, out_view, *, dtype, shape, base_addr, op_name -): - rows, cols = _check_row_expand_operands( - expand_view, - out_view, - dtype=dtype, - shape=shape, - context=f"TROWEXPAND_{op_name[1:].upper()}", - ) - _require_view_shape( - base_view, - [rows, cols], - context=op_name, - message=f"Fix: TROWEXPAND_{op_name[1:].upper()} base input valid shape mismatch with output tile dst shape.", - ) - _require_view_dtype( - base_view, - dtype, - message=f"Fix: TROWEXPAND_{op_name[1:].upper()} input data type must be consistent with the output data type.", - ) - lanes = _micro_lane_count(dtype) - vreg_type = _dsl_pto.VRegType(lanes, dtype) - buf_bytes = rows * cols * _dtype_byte_width(dtype) - - base_addr_value = _const_i64(base_addr) - expand_addr_value = _const_i64(base_addr + buf_bytes) - out_addr_value = _const_i64(base_addr + buf_bytes * 2) - - base_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc( - addr=base_addr_value - ) - expand_tile = _dsl_pto.make_tile_buffer( - dtype, shape, space="VEC", valid_shape=[rows, 1] - ).alloc(addr=expand_addr_value) - out_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc( - addr=out_addr_value - ) - - _dsl_pto.load(base_view, base_tile) - _dsl_pto.load(expand_view, expand_tile) - - base_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), base_addr_value) - expand_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), expand_addr_value) - out_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), out_addr_value) - micro_op = getattr(_dsl_pto, op_name) - - for row in range(rows): - scalar_offset = _scalar.const(row * cols) - align = _dsl_pto.vldas(_dsl_pto.AlignType(), expand_ptr, scalar_offset) - scalar_vec, _, _ = _dsl_pto.vldus( - vreg_type, - _dsl_pto.AlignType(), - _dsl_pto.ptr(dtype, space="VEC"), - expand_ptr, - scalar_offset, - align, - ) - broadcast = _dsl_pto.vdup(vreg_type, scalar_vec, position="POS_LOWEST") - for col in range(0, cols, lanes): - active = builtins.min(lanes, cols - col) - mask = _mask_for_chunk(dtype, active) - row_offset = _scalar.const(row * cols + col) - base_vec = _dsl_pto.vlds(vreg_type, base_ptr, row_offset) - out_vec = micro_op(vreg_type, base_vec, broadcast, mask) - _dsl_pto.vsts(out_vec, out_ptr, row_offset, mask) - - _dsl_pto.store(out_tile, out_view) - return out_view - - -def _row_reduce_micro( - src_view, - out_view, - *, - dtype, - shape, - base_addr, - reduce_op_name, - combine_op_name, - init_value, -): - rows, cols = _check_row_reduce_operands( - src_view, out_view, dtype=dtype, shape=shape, context="TROWREDUCE" - ) - width = _dtype_byte_width(dtype) - if width not in {2, 4}: - raise ValueError(f"{reduce_op_name} currently supports only float16/float32.") - - lanes = _micro_lane_count(dtype) - vreg_type = _dsl_pto.VRegType(lanes, dtype) - buf_bytes = rows * cols * width - - src_addr = _const_i64(base_addr) - out_addr = _const_i64(base_addr + buf_bytes) - - src_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=src_addr) - out_tile = _dsl_pto.make_tile_buffer( - dtype, shape, space="VEC", valid_shape=[rows, 1] - ).alloc(addr=out_addr) - - _dsl_pto.load(src_view, src_tile) - - src_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), src_addr) - out_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), out_addr) - reduce_op = getattr(_dsl_pto, reduce_op_name) - combine_op = getattr(_dsl_pto, combine_op_name) - full_mask = _full_mask(dtype) - point_mask = _tail_mask(dtype, 1) - init_scalar = _const_float(dtype, init_value) - - for row in range(rows): - accum = _dsl_pto.vbr(vreg_type, init_scalar) - for col in range(0, cols, lanes): - active = builtins.min(lanes, cols - col) - mask = _mask_for_chunk(dtype, active) - offset = _scalar.const(row * cols + col) - vec = _dsl_pto.vlds(vreg_type, src_ptr, offset) - reduced = reduce_op(vreg_type, vec, mask) - accum = combine_op(vreg_type, accum, reduced, full_mask) - out_offset = _scalar.const(row * cols) - _dsl_pto.vsts(accum, out_ptr, out_offset, point_mask, dist=_onept_dist(dtype)) - - _dsl_pto.store(out_tile, out_view) - return out_view - - -def _col_reduce_micro( - src_view, - out_view, - *, - dtype, - shape, - base_addr, - reduce_op_name, - impl, -): - rows, cols = _check_col_reduce_operands( - src_view, out_view, dtype=dtype, shape=shape, context="TCOLREDUCE" - ) - lanes = _micro_lane_count(dtype) - buf_bytes = rows * cols * _dtype_byte_width(dtype) - - src_addr = _const_i64(base_addr) - out_addr = _const_i64(base_addr + buf_bytes) - - src_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=src_addr) - out_tile = _dsl_pto.make_tile_buffer( - dtype, [1, cols], space="VEC", valid_shape=[1, cols] - ).alloc(addr=out_addr) - - _dsl_pto.load(src_view, src_tile) - - ptr_type = _dsl_pto.ptr(dtype, space="VEC") - vreg_type = _dsl_pto.VRegType(lanes, dtype) - src_ptr = _dsl_pto.castptr(ptr_type, src_addr) - out_ptr = _dsl_pto.castptr(ptr_type, out_addr) - reduce_op = getattr(_dsl_pto, reduce_op_name) - impl_kind = _normalize_vf_impl_kind(impl) - if const_expr(impl_kind == VF_IMPL_DEFAULT): - impl_kind = VF_IMPL_1D_POST_UPDATE - - if const_expr(impl_kind in {VF_IMPL_1D_NO_POST_UPDATE, VF_IMPL_2D_NO_POST_UPDATE}): - _col_reduce_micro_no_post_update( - src_ptr, - out_ptr, - dtype=dtype, - rows=rows, - cols=cols, - lanes=lanes, - vreg_type=vreg_type, - reduce_op=reduce_op, - ) - elif const_expr(impl_kind in {VF_IMPL_1D_POST_UPDATE, VF_IMPL_2D_POST_UPDATE}): - _col_reduce_micro_post_update( - src_ptr, - out_ptr, - ptr_type=ptr_type, - dtype=dtype, - rows=rows, - cols=cols, - lanes=lanes, - vreg_type=vreg_type, - reduce_op=reduce_op, - ) - else: - raise ValueError(f"Unexpected normalized VF impl kind '{impl_kind}'.") - - _dsl_pto.store(out_tile, out_view) - return out_view - - -def _col_reduce_micro_no_post_update( - src_ptr, out_ptr, *, dtype, rows, cols, lanes, vreg_type, reduce_op -): - loop_pairs = (rows - 1) // 2 - remain = (rows - 1) % 2 - for col in range_constexpr(0, cols, lanes): - active = builtins.min(lanes, cols - col) - mask = _mask_for_chunk(dtype, active) - accum = _dsl_pto.vlds(vreg_type, src_ptr, _scalar.const(col)) - for pair in range_constexpr(loop_pairs): - row0 = 2 * pair + 1 - row1 = 2 * pair + 2 - src0 = _dsl_pto.vlds(vreg_type, src_ptr, _scalar.const(col + row0 * cols)) - src1 = _dsl_pto.vlds(vreg_type, src_ptr, _scalar.const(col + row1 * cols)) - tmp = reduce_op(vreg_type, src0, src1, mask) - accum = reduce_op(vreg_type, accum, tmp, mask) - if const_expr(remain): - tail_row = 2 * loop_pairs + 1 - src_tail = _dsl_pto.vlds( - vreg_type, src_ptr, _scalar.const(col + tail_row * cols) - ) - accum = reduce_op(vreg_type, accum, src_tail, mask) - _dsl_pto.vsts(accum, out_ptr, _scalar.const(col), mask) - - -def _col_reduce_micro_post_update( - src_ptr, out_ptr, *, ptr_type, dtype, rows, cols, lanes, vreg_type, reduce_op -): - src_cursor = src_ptr - out_cursor = out_ptr - loop_pairs = (rows - 1) // 2 - remain = (rows - 1) % 2 - lane_step = _scalar.const(lanes) - pair_stride = _scalar.const(cols * 2) - for col in range_constexpr(0, cols, lanes): - active = builtins.min(lanes, cols - col) - mask = _mask_for_chunk(dtype, active) - chunk_base = src_cursor - accum, src_cursor = _dsl_pto.vlds_post( - vreg_type, ptr_type, src_cursor, lane_step - ) - row0_ptr = _dsl_pto.addptr(chunk_base, _scalar.const(cols)) - row1_ptr = _dsl_pto.addptr(chunk_base, _scalar.const(cols * 2)) - for _ in range_constexpr(loop_pairs): - src0, row0_ptr = _dsl_pto.vlds_post( - vreg_type, ptr_type, row0_ptr, pair_stride - ) - src1, row1_ptr = _dsl_pto.vlds_post( - vreg_type, ptr_type, row1_ptr, pair_stride - ) - tmp = reduce_op(vreg_type, src0, src1, mask) - accum = reduce_op(vreg_type, accum, tmp, mask) - if const_expr(remain): - src_tail = _dsl_pto.vlds(vreg_type, row0_ptr, _scalar.const(0)) - accum = reduce_op(vreg_type, accum, src_tail, mask) - out_cursor = _dsl_pto.vsts_post(ptr_type, accum, out_cursor, lane_step, mask) - - -def _gather_micro( - src_view, - indices_view, - out_view, - *, - dtype, - index_dtype, - shape, - base_addr, -): - rows, cols, _, _ = _check_gather_operands( - src_view, - indices_view, - out_view, - dtype=dtype, - index_dtype=index_dtype, - shape=shape, - ) - src_bytes = rows * cols * _dtype_byte_width(dtype) - idx_bytes = rows * cols * _dtype_byte_width(index_dtype) - - src_addr = _const_i64(base_addr) - idx_addr = _const_i64(base_addr + src_bytes) - out_addr = _const_i64(base_addr + src_bytes + idx_bytes) - - src_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=src_addr) - idx_tile = _dsl_pto.make_tile_buffer(index_dtype, shape, space="VEC").alloc( - addr=idx_addr - ) - out_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=out_addr) - - _dsl_pto.load(src_view, src_tile) - _dsl_pto.load(indices_view, idx_tile) - - src_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), src_addr) - idx_ptr = _dsl_pto.castptr(_dsl_pto.ptr(index_dtype, space="VEC"), idx_addr) - out_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), out_addr) - lanes = _micro_lane_count(dtype) - vreg_type = _dsl_pto.VRegType(lanes, dtype) - index_vreg_type = _dsl_pto.VRegType(_micro_lane_count(index_dtype), index_dtype) - - for row in range_constexpr(rows): - row_base = row * cols - for col in range_constexpr(0, cols, lanes): - active = builtins.min(lanes, cols - col) - offset = _scalar.const(row_base + col) - mask = _mask_for_chunk(dtype, active) - idx_vec = _dsl_pto.vlds(index_vreg_type, idx_ptr, offset) - out_vec = _dsl_pto.vgather2( - vreg_type, src_ptr, idx_vec, _scalar.const(active) - ) - _dsl_pto.vsts(out_vec, out_ptr, offset, mask) - - _dsl_pto.store(out_tile, out_view) - return out_view - - -def _mrgsort_micro(src_view, out_view, *, dtype, shape, block_len, base_addr): - _, cols = _check_mrgsort_operands( - src_view, out_view, dtype=dtype, shape=shape, block_len=block_len - ) - src_addr = _const_i64(base_addr) - out_addr = _const_i64(base_addr + cols * _dtype_byte_width(dtype)) - - src_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=src_addr) - out_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=out_addr) - _dsl_pto.load(src_view, src_tile) - - ptr_type = _dsl_pto.ptr(dtype, space="VEC") - src_ptr = _dsl_pto.castptr(ptr_type, src_addr) - out_ptr = _dsl_pto.castptr(ptr_type, out_addr) - - src1_ptr = _dsl_pto.addptr(src_ptr, _scalar.const(block_len)) - src2_ptr = _dsl_pto.addptr(src_ptr, _scalar.const(block_len * 2)) - src3_ptr = _dsl_pto.addptr(src_ptr, _scalar.const(block_len * 3)) - - num_structures = (block_len * _dtype_byte_width(dtype)) >> 3 - count_value = ( - num_structures - | (num_structures << 16) - | (num_structures << 32) - | (num_structures << 48) - ) - repeat_times = cols // (block_len * 4) - config_value = repeat_times | (0b1111 << 8) - - _dsl_pto.vmrgsort4( - out_ptr, - src_ptr, - src1_ptr, - src2_ptr, - src3_ptr, - _const_i64(count_value), - _const_i64(config_value), - ) - _dsl_pto.store(out_tile, out_view) - return out_view - - -def _sort32_micro(src_view, idx_view, out_view, *, dtype, shape, base_addr): - rows, cols, out_cols = _check_sort32_operands( - src_view, idx_view, out_view, dtype=dtype, shape=shape - ) - src_bytes = rows * cols * _dtype_byte_width(dtype) - idx_bytes = rows * cols * 4 - - src_addr = _const_i64(base_addr) - idx_addr = _const_i64(base_addr + src_bytes) - out_addr = _const_i64(base_addr + src_bytes + idx_bytes) - - src_tile = _dsl_pto.make_tile_buffer(dtype, [rows, cols], space="VEC").alloc( - addr=src_addr - ) - idx_tile = _dsl_pto.make_tile_buffer( - _dsl_pto.uint32, [rows, cols], space="VEC" - ).alloc(addr=idx_addr) - out_tile = _dsl_pto.make_tile_buffer(dtype, [rows, out_cols], space="VEC").alloc( - addr=out_addr - ) - - _dsl_pto.load(src_view, src_tile) - _dsl_pto.load(idx_view, idx_tile) - - src_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), src_addr) - idx_ptr = _dsl_pto.castptr(_dsl_pto.ptr(_dsl_pto.uint32, space="VEC"), idx_addr) - out_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), out_addr) - repeat_times = _scalar.const(cols // 32) - - for row in range_constexpr(rows): - src_row = _dsl_pto.addptr(src_ptr, _scalar.const(row * cols)) - idx_row = _dsl_pto.addptr(idx_ptr, _scalar.const(row * cols)) - out_row = _dsl_pto.addptr(out_ptr, _scalar.const(row * out_cols)) - _dsl_pto.vbitsort(out_row, src_row, idx_row, repeat_times) - - _dsl_pto.store(out_tile, out_view) - return out_view - - -def _binary_micro( - lhs_view, rhs_view, out_view, *, dtype, shape, lanes, base_addr, op_name, impl -): - rows, cols = _check_tbinop_operands( - lhs_view, - rhs_view, - out_view, - dtype=dtype, - shape=shape, - context=op_name.upper().replace("V", "T", 1), - ) - lanes = _resolve_lanes(dtype, lanes) - element_count = rows * cols - buf_bytes = element_count * _dtype_byte_width(dtype) - lhs_addr = _const_i64(base_addr) - rhs_addr = _const_i64(base_addr + buf_bytes) - out_addr = _const_i64(base_addr + buf_bytes * 2) - - lhs_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=lhs_addr) - rhs_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=rhs_addr) - out_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=out_addr) - - _dsl_pto.load(lhs_view, lhs_tile) - _dsl_pto.load(rhs_view, rhs_tile) - - ptr_type = _dsl_pto.ptr(dtype, space="VEC") - vreg_type = _dsl_pto.VRegType(lanes, dtype) - lhs_ptr = _dsl_pto.castptr(ptr_type, lhs_addr) - rhs_ptr = _dsl_pto.castptr(ptr_type, rhs_addr) - out_ptr = _dsl_pto.castptr(ptr_type, out_addr) - micro_op = getattr(_dsl_pto, op_name) - impl_kind = _normalize_vf_impl_kind(impl) - is_contiguous = rows == 1 or cols == element_count - if const_expr(impl_kind == VF_IMPL_DEFAULT): - impl_kind = ( - VF_IMPL_1D_POST_UPDATE if is_contiguous else VF_IMPL_2D_NO_POST_UPDATE - ) - - if const_expr(impl_kind == VF_IMPL_1D_NO_POST_UPDATE): - _binary_micro_1d_no_post_update( - lhs_ptr, - rhs_ptr, - out_ptr, - dtype=dtype, - lanes=lanes, - element_count=element_count, - vreg_type=vreg_type, - micro_op=micro_op, - ) - elif const_expr(impl_kind == VF_IMPL_1D_POST_UPDATE): - _binary_micro_1d_post_update( - lhs_ptr, - rhs_ptr, - out_ptr, - ptr_type=ptr_type, - dtype=dtype, - lanes=lanes, - element_count=element_count, - vreg_type=vreg_type, - micro_op=micro_op, - ) - elif const_expr(impl_kind == VF_IMPL_2D_NO_POST_UPDATE): - _binary_micro_2d_no_post_update( - lhs_ptr, - rhs_ptr, - out_ptr, - dtype=dtype, - rows=rows, - cols=cols, - lanes=lanes, - vreg_type=vreg_type, - micro_op=micro_op, - ) - elif const_expr(impl_kind == VF_IMPL_2D_POST_UPDATE): - _binary_micro_2d_post_update( - lhs_ptr, - rhs_ptr, - out_ptr, - dtype=dtype, - rows=rows, - cols=cols, - lanes=lanes, - vreg_type=vreg_type, - micro_op=micro_op, - ) - else: - raise ValueError(f"Unexpected normalized VF impl kind '{impl_kind}'.") - - _dsl_pto.store(out_tile, out_view) - return out_view - - -def _binary_micro_1d_no_post_update( - lhs_ptr, rhs_ptr, out_ptr, *, dtype, lanes, element_count, vreg_type, micro_op -): - for offset in range_constexpr(0, element_count, lanes): - active = builtins.min(lanes, element_count - offset) - mask = _mask_for_chunk(dtype, active) - index = _scalar.const(offset) - lhs_vec = _dsl_pto.vlds(vreg_type, lhs_ptr, index) - rhs_vec = _dsl_pto.vlds(vreg_type, rhs_ptr, index) - out_vec = micro_op(vreg_type, lhs_vec, rhs_vec, mask) - _dsl_pto.vsts(out_vec, out_ptr, index, mask) - - -def _binary_micro_1d_post_update( - lhs_ptr, - rhs_ptr, - out_ptr, - *, - ptr_type, - dtype, - lanes, - element_count, - vreg_type, - micro_op, -): - lhs_cursor = lhs_ptr - rhs_cursor = rhs_ptr - out_cursor = out_ptr - lane_step = _scalar.const(lanes) - for offset in range_constexpr(0, element_count, lanes): - active = builtins.min(lanes, element_count - offset) - mask = _mask_for_chunk(dtype, active) - lhs_vec, lhs_cursor = _dsl_pto.vlds_post( - vreg_type, ptr_type, lhs_cursor, lane_step - ) - rhs_vec, rhs_cursor = _dsl_pto.vlds_post( - vreg_type, ptr_type, rhs_cursor, lane_step - ) - out_vec = micro_op(vreg_type, lhs_vec, rhs_vec, mask) - out_cursor = _dsl_pto.vsts_post(ptr_type, out_vec, out_cursor, lane_step, mask) - - -def _binary_micro_2d_no_post_update( - lhs_ptr, rhs_ptr, out_ptr, *, dtype, rows, cols, lanes, vreg_type, micro_op -): - for row in range_constexpr(rows): - row_base = row * cols - for col in range_constexpr(0, cols, lanes): - active = builtins.min(lanes, cols - col) - mask = _mask_for_chunk(dtype, active) - index = _scalar.const(row_base + col) - lhs_vec = _dsl_pto.vlds(vreg_type, lhs_ptr, index) - rhs_vec = _dsl_pto.vlds(vreg_type, rhs_ptr, index) - out_vec = micro_op(vreg_type, lhs_vec, rhs_vec, mask) - _dsl_pto.vsts(out_vec, out_ptr, index, mask) - - -def _binary_micro_2d_post_update( - lhs_ptr, rhs_ptr, out_ptr, *, dtype, rows, cols, lanes, vreg_type, micro_op -): - _binary_micro_2d_no_post_update( - lhs_ptr, - rhs_ptr, - out_ptr, - dtype=dtype, - rows=rows, - cols=cols, - lanes=lanes, - vreg_type=vreg_type, - micro_op=micro_op, - ) - - -def _rsqrt_micro(src_view, out_view, *, dtype, shape, lanes, base_addr): - if any(not isinstance(dim, int) for dim in shape): - raise ValueError( - "micro tile lowering currently requires a static integer shape." - ) - - lanes = _resolve_lanes(dtype, lanes) - element_count = 1 - for dim in shape: - element_count *= dim - - buf_bytes = element_count * _dtype_byte_width(dtype) - src_addr = _const_i64(base_addr) - out_addr = _const_i64(base_addr + buf_bytes) - - src_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=src_addr) - out_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=out_addr) - - _dsl_pto.load(src_view, src_tile) - - vreg_type = _dsl_pto.VRegType(lanes, dtype) - src_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), src_addr) - out_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), out_addr) - - for offset in range_constexpr(0, element_count, lanes): - active = builtins.min(lanes, element_count - offset) - mask = _mask_for_chunk(dtype, active) - index = _scalar.const(offset) - src_vec = _dsl_pto.vlds(vreg_type, src_ptr, index) - sqrt_vec = _dsl_pto.vsqrt(vreg_type, src_vec, mask) - out_vec = _dsl_pto.vrec(vreg_type, sqrt_vec, mask) - _dsl_pto.vsts(out_vec, out_ptr, index, mask) - - _dsl_pto.store(out_tile, out_view) - return out_view - - -def _unary_micro(src_view, out_view, *, dtype, shape, lanes, base_addr, op_name): - if any(not isinstance(dim, int) for dim in shape): - raise ValueError( - "micro tile lowering currently requires a static integer shape." - ) - - lanes = _resolve_lanes(dtype, lanes) - element_count = 1 - for dim in shape: - element_count *= dim - - buf_bytes = element_count * _dtype_byte_width(dtype) - src_addr = _const_i64(base_addr) - out_addr = _const_i64(base_addr + buf_bytes) - - src_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=src_addr) - out_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=out_addr) - - _dsl_pto.load(src_view, src_tile) - - src_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), src_addr) - out_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), out_addr) - micro_op = getattr(_dsl_pto, op_name) if op_name is not None else None - - for offset in range_constexpr(0, element_count, lanes): - active = builtins.min(lanes, element_count - offset) - mask = _mask_for_chunk(dtype, active) - index = _scalar.const(offset) - src_vec = _dsl_pto.vlds(_dsl_pto.VRegType(lanes, dtype), src_ptr, index) - out_vec = ( - src_vec - if micro_op is None - else micro_op(_dsl_pto.VRegType(lanes, dtype), src_vec, mask) - ) - _dsl_pto.vsts(out_vec, out_ptr, index, mask) - - _dsl_pto.store(out_tile, out_view) - return out_view - - -def adds(src, scalar, out): - _call(_pto.TAddSOp, src, scalar, out) - return out - - -def sub(lhs, rhs, out): - _call(_pto.TSubOp, lhs, rhs, out) - return out - - -def subs(src, scalar, out): - _call(_pto.TSubSOp, src, scalar, out) - return out - - -def mul(lhs, rhs, out): - _call(_pto.TMulOp, lhs, rhs, out) - return out - - -def muls(src, scalar, out): - _call(_pto.TMulSOp, src, scalar, out) - return out - - -def div(lhs, rhs, out): - _call(_pto.TDivOp, lhs, rhs, out) - return out - - -def divs(src, scalar, out): - _call(_pto.TDivSOp, src, scalar, out) - return out - - -def max(lhs, rhs, out): - _call(_pto.TMaxOp, lhs, rhs, out) - return out - - -def maxs(src, scalar, out): - _call(_pto.TMaxSOp, src, scalar, out) - return out - - -def min(lhs, rhs, out): - _call(_pto.TMinOp, lhs, rhs, out) - return out - - -def mins(src, scalar, out): - _call(_pto.TMinSOp, src, scalar, out) - return out - - -def and_(lhs, rhs, out): - _call(_pto.TAndOp, lhs, rhs, out) - return out - - -def or_(lhs, rhs, out): - _call(_pto.TOrOp, lhs, rhs, out) - return out - - -def xor(lhs, rhs, out): - _call(_pto.TXorOp, lhs, rhs, out) - return out - - -def shl(lhs, rhs, out): - _call(_pto.TShlOp, lhs, rhs, out) - return out - - -def shls(src, scalar, out): - _call(_pto.TShlSOp, src, scalar, out) - return out - - -def shr(lhs, rhs, out): - _call(_pto.TShrOp, lhs, rhs, out) - return out - - -def shrs(src, scalar, out): - _call(_pto.TShrSOp, src, scalar, out) - return out - - -def compare(src0, src1, out, *, mode): - _call(_pto.TCmpOp, src0, src1, out, cmpMode=_cmp_mode_attr(mode)) - return out - - -def exp(src, out): - _call(_pto.TExpOp, src, out) - return out - - -def log(src, out): - _call(_pto.TLogOp, src, out) - return out - - -def relu(src, out): - _call(_pto.TReluOp, src, out) - return out - - -def abs(src, out): - _call(_pto.TAbsOp, src, out) - return out - - -def sqrt(src, out): - _call(_pto.TSqrtOp, src, out) - return out - - -def rsqrt(src, out): - _call(_pto.TRsqrtOp, src, out) - return out - - -def reciprocal(src, out): - _call(_pto.TRecipOp, src, out) - return out - - -def lrelu(src, slope, out): - _call(_pto.TLReluOp, src, slope, out) - return out - - -def gather(src, out, *, indices=None, mask_pattern=None): - kwargs = {} - if indices is not None: - kwargs["indices"] = indices - if mask_pattern is not None: - kwargs["maskPattern"] = _pto.MaskPatternAttr.get( - getattr(_pto.MaskPattern, mask_pattern) - ) - _call(_pto.TGatherOp, src, out, **kwargs) - return out - - -def scatter(src, indices, out): - _call(_pto.TScatterOp, src, indices, out) - return out - - -def select(mask, src0, src1, tmp, out): - _call(_pto.TSelOp, mask, src0, src1, tmp, out) - return out - - -def concat(src0, src1, out): - _call(_pto.TConcatOp, src0, src1, out) - return out - - -def extract(source, index_row, index_col, out): - _call(_pto.TExtractOp, source, index_row, index_col, out) - return out - - -def insert(source, index_row, index_col, out): - _call(_pto.TInsertOp, source, index_row, index_col, out) - return out - - -def row_sum(src, tmp, dst): - _call(_pto.TRowSumOp, src=src, tmp=tmp, dst=dst) - return dst - - -def row_min(src, tmp, dst): - _call(_pto.TRowMinOp, src=src, tmp=tmp, dst=dst) - return dst - - -def row_max(src, tmp, dst): - _call(_pto.TRowMaxOp, src=src, tmp=tmp, dst=dst) - return dst - - -def col_sum(src, tmp, dst, *, is_binary=True): - _call(_pto.TColSumOp, src=src, tmp=tmp, dst=dst, isBinary=is_binary) - return dst - - -def col_min(src, dst): - _call(_pto.TColMinOp, src=src, dst=dst) - return dst - - -def col_max(src, dst): - _call(_pto.TColMaxOp, src=src, dst=dst) - return dst - - -def row_expand(src, dst): - _call(_pto.TRowExpandOp, src=src, dst=dst) - return dst - - -def row_expand_sub(src0, src1, dst): - _call(_pto.TRowExpandSubOp, src0=src0, src1=src1, dst=dst) - return dst - - -def row_expand_mul(src0, src1, dst): - _call(_pto.TRowExpandMulOp, src0=src0, src1=src1, dst=dst) - return dst - - -def row_expand_div(src0, src1, dst): - _call(_pto.TRowExpandDivOp, src0=src0, src1=src1, dst=dst) - return dst - - -def col_expand(src, dst): - _call(_pto.TColExpandOp, src=src, dst=dst) - return dst - - -def col_expand_mul(src0, src1, dst): - _call(_pto.TColExpandMulOp, src0=src0, src1=src1, dst=dst) - return dst - - -def col_expand_max(src0, src1, dst): - _call(_pto.TColExpandMaxOp, src0=src0, src1=src1, dst=dst) - return dst - - -def col_expand_min(src0, src1, dst): - _call(_pto.TColExpandMinOp, src0=src0, src1=src1, dst=dst) - return dst - - -def trans(src, dst): - _call(_pto.TTransOp, src, dst) - return dst - - -def mrgsort(src, dst, block_len): - _call(_pto.TMrgSortOp, srcs=[src], dsts=[dst], blockLen=block_len) - return dst - - -def sort32(src, dst, idx): - _call(_pto.TSort32Op, src, dst, idx) - return dst - - -def matmul(lhs, rhs, out): - _call(_pto.TMatmulOp, None, lhs, rhs, out) - return out - - -def matmul_acc(acc, lhs, rhs, out): - _call(_pto.TMatmulAccOp, None, acc, lhs, rhs, out) - return out - - -def matmul_bias(lhs, rhs, bias, out): - _call(_pto.TMatmulBiasOp, None, lhs, rhs, bias, out) - return out - - -def matmul_mx(lhs, lhs_scale, rhs, rhs_scale, out): - _call(_pto.TMatmulMxOp, None, lhs, lhs_scale, rhs, rhs_scale, out) - return out - - -def matmul_mx_acc(acc, lhs, lhs_scale, rhs, rhs_scale, out): - _call(_pto.TMatmulMxAccOp, None, acc, lhs, lhs_scale, rhs, rhs_scale, out) - return out - - -def matmul_mx_bias(lhs, lhs_scale, rhs, rhs_scale, bias, out): - _call(_pto.TMatmulMxBiasOp, None, lhs, lhs_scale, rhs, rhs_scale, bias, out) - return out - - -def full_mask_b32(): - return _dsl_pto.pset_b32(_dsl_pto.MaskType(), "PAT_ALL") - - -def vload(ptr, offset, *, lanes=64, dtype=None): - dtype = _dsl_pto.float32 if dtype is None else dtype - return _dsl_pto.vlds(_dsl_pto.VRegType(lanes, dtype), ptr, offset) - - -def vstore(vector, ptr, offset, *, mask=None): - if mask is None: - mask = full_mask_b32() - _dsl_pto.vsts(vector, ptr, offset, mask) - return ptr - - -def vector_copy(src_ptr, dst_ptr, offset, *, lanes=64, dtype=None): - vec = vload(src_ptr, offset, lanes=lanes, dtype=dtype) - vstore(vec, dst_ptr, offset) - return vec - - +"""Public A5 op surface split into small, opcode-focused implementation files.""" + +from ._common import ( + VF_IMPL_1D_NO_POST_UPDATE, + VF_IMPL_1D_POST_UPDATE, + VF_IMPL_2D_NO_POST_UPDATE, + VF_IMPL_2D_POST_UPDATE, + VF_IMPL_DEFAULT, +) +from .native import ( + adds, + and_, + col_expand_max, + col_expand_min, + col_expand_mul, + col_prod, + compare, + concat, + divs, + extract, + full_mask_b32, + insert, + load_tile, + matmul, + matmul_acc, + matmul_bias, + matmul_mx, + matmul_mx_acc, + matmul_mx_bias, + max, + maxs, + min, + mins, + move_tile, + muls, + row_prod, + scatter, + select, + shl, + shls, + shr, + shrs, + store_tile, + subs, + trans, + vector_copy, + vload, + vstore, + xor, +) +from .tbinary import tadd, tdiv, tmov, tmul, tor_, tsub +from .texpand import ( + tcol_expand, + trow_expand, + trow_expand_div, + trow_expand_mul, + trow_expand_sub, +) +from .treduce import ( + tcol_max, + tcol_min, + tcol_sum, + trow_max, + trow_min, + trow_sum, +) +from .tsort import tgather, tmrgsort, tsort32 +from .tunary import tabs, texp, tlog, trecip, trelu, trsqrt, tsqrt + +# Readable aliases that match the public tile op names. +mov = tmov +add = tadd +sub = tsub +mul = tmul +div = tdiv +or_ = tor_ +gather = tgather +exp = texp +log = tlog +relu = trelu +abs = tabs +sqrt = tsqrt +rsqrt = trsqrt +reciprocal = trecip +row_sum = trow_sum +row_min = trow_min +row_max = trow_max +row_expand = trow_expand +row_expand_sub = trow_expand_sub +row_expand_mul = trow_expand_mul +row_expand_div = trow_expand_div +col_sum = tcol_sum +col_min = tcol_min +col_max = tcol_max +col_expand = tcol_expand +mrgsort = tmrgsort +sort32 = tsort32 + +# A5-style aliases. TLoad = load_tile TStore = store_tile -TMov = move_tile -TAdd = add +TMov = tmov +TAdd = tadd TAddS = adds -TSub = sub +TSub = tsub TSubS = subs -TMul = mul +TMul = tmul TMulS = muls -TDiv = div +TDiv = tdiv TDivS = divs TMax = max TMaxS = maxs TMin = min TMinS = mins TAnd = and_ -TOr = or_ +TOr = tor_ TXor = xor TShl = shl TShlS = shls TShr = shr TShrS = shrs TCmp = compare -TExp = exp -TLog = log -TRelu = relu -TAbs = abs -TSqrt = sqrt -TRsqrt = rsqrt -TRecip = reciprocal -TLRelu = lrelu -TGather = gather +TExp = texp +TLog = tlog +TRelu = trelu +TAbs = tabs +TSqrt = tsqrt +TRsqrt = trsqrt +TRecip = trecip +TGather = tgather TScatter = scatter TSel = select TConcat = concat TExtract = extract TInsert = insert -TRowSum = row_sum -TRowMin = row_min -TRowMax = row_max -TColSum = col_sum -TColMin = col_min -TColMax = col_max -TRowExpand = row_expand -TRowExpandSub = row_expand_sub -TRowExpandMul = row_expand_mul -TRowExpandDiv = row_expand_div -TColExpand = col_expand +TRowSum = trow_sum +TRowMin = trow_min +TRowMax = trow_max +TRowExpand = trow_expand +TRowExpandSub = trow_expand_sub +TRowExpandMul = trow_expand_mul +TRowExpandDiv = trow_expand_div +TColSum = tcol_sum +TColMin = tcol_min +TColMax = tcol_max +TColExpand = tcol_expand TColExpandMul = col_expand_mul TColExpandMax = col_expand_max TColExpandMin = col_expand_min +TMrgSort = tmrgsort +TSort32 = tsort32 TTrans = trans -TMrgSort = mrgsort -TSort32 = sort32 TMatmul = matmul TMatmulAcc = matmul_acc TMatmulBias = matmul_bias @@ -2031,7 +157,6 @@ def vector_copy(src_ptr, dst_ptr, offset, *, lanes=64, dtype=None): TMatmulMxAcc = matmul_mx_acc TMatmulMxBias = matmul_mx_bias - __all__ = [ "VF_IMPL_DEFAULT", "VF_IMPL_1D_NO_POST_UPDATE", @@ -2057,7 +182,6 @@ def vector_copy(src_ptr, dst_ptr, offset, *, lanes=64, dtype=None): "TExtract", "TGather", "TInsert", - "TLRelu", "TLoad", "TLog", "TMatmul", @@ -2099,36 +223,27 @@ def vector_copy(src_ptr, dst_ptr, offset, *, lanes=64, dtype=None): "TTrans", "TXor", "add", - "add_micro", - "abs_micro", "adds", "and_", "col_expand", - "col_expand_micro", "col_expand_max", "col_expand_min", "col_expand_mul", "col_max", - "col_max_micro", "col_min", - "col_min_micro", + "col_prod", "col_sum", - "col_sum_micro", "compare", "concat", "div", "divs", "exp", - "exp_micro", "extract", "full_mask_b32", "gather", - "gather_micro", "insert", "load_tile", "log", - "log_micro", - "lrelu", "matmul", "matmul_acc", "matmul_bias", @@ -2140,32 +255,22 @@ def vector_copy(src_ptr, dst_ptr, offset, *, lanes=64, dtype=None): "min", "mins", "move_tile", - "mov_micro", + "mov", "mrgsort", - "mrgsort_micro", "mul", "muls", "or_", "reciprocal", - "reciprocal_micro", "relu", - "relu_micro", "row_expand", - "row_expand_div_micro", - "row_expand_micro", - "row_expand_mul_micro", "row_expand_div", - "row_expand_sub_micro", "row_expand_mul", "row_expand_sub", "row_max", - "row_max_micro", "row_min", - "row_min_micro", + "row_prod", "row_sum", - "row_sum_micro", "rsqrt", - "rsqrt_micro", "scatter", "select", "shl", @@ -2173,17 +278,38 @@ def vector_copy(src_ptr, dst_ptr, offset, *, lanes=64, dtype=None): "shr", "shrs", "sort32", - "sort32_micro", "sqrt", - "sqrt_micro", "store_tile", "sub", - "sub_micro", "subs", - "div_micro", - "mul_micro", - "or_micro", + "tabs", + "tadd", + "tcol_expand", + "tcol_max", + "tcol_min", + "tcol_sum", + "tdiv", + "texp", + "tgather", + "tlog", + "tmov", + "tmrgsort", + "tmul", + "tor_", "trans", + "trecip", + "trelu", + "trow_expand", + "trow_expand_div", + "trow_expand_mul", + "trow_expand_sub", + "trow_max", + "trow_min", + "trow_sum", + "trsqrt", + "tsort32", + "tsqrt", + "tsub", "vector_copy", "vload", "vstore", diff --git a/ptodsl/lib/a5/tbinary.py b/ptodsl/lib/a5/tbinary.py new file mode 100644 index 00000000..6eb9ce6a --- /dev/null +++ b/ptodsl/lib/a5/tbinary.py @@ -0,0 +1,397 @@ +"""Implement tile binary ops with PTO vector micro instructions. + +This file demonstrates how to write tile-style helpers such as `pto.tadd` +directly in terms of `pto.vlds`, `pto.vadd`, and `pto.vsts`. +""" + +import builtins + +from mlir.dialects import pto + +from ._common import ( + VF_IMPL_1D_NO_POST_UPDATE, + VF_IMPL_1D_POST_UPDATE, + VF_IMPL_2D_NO_POST_UPDATE, + VF_IMPL_2D_POST_UPDATE, + VF_IMPL_DEFAULT, + alloc_tile_buffer, + check_tbinop_operands, + const_expr, + const_i64, + dtype_byte_width, + mask_for_chunk, + normalize_vf_impl_kind, + ptr, + raw, + range_constexpr, + resolve_lanes, + s, + store_view, + load_view, + vreg_type, +) + + +def tadd( + lhs_view, + rhs_view, + out_view, + *, + dtype, + shape, + lanes=None, + base_addr=0, + impl=VF_IMPL_DEFAULT, +): + return _binary_tile_vop( + lhs_view, + rhs_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + micro_op_name="vadd", + impl=impl, + ) + + +def tsub( + lhs_view, + rhs_view, + out_view, + *, + dtype, + shape, + lanes=None, + base_addr=0, + impl=VF_IMPL_DEFAULT, +): + return _binary_tile_vop( + lhs_view, + rhs_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + micro_op_name="vsub", + impl=impl, + ) + + +def tmul( + lhs_view, + rhs_view, + out_view, + *, + dtype, + shape, + lanes=None, + base_addr=0, + impl=VF_IMPL_DEFAULT, +): + return _binary_tile_vop( + lhs_view, + rhs_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + micro_op_name="vmul", + impl=impl, + ) + + +def tdiv( + lhs_view, + rhs_view, + out_view, + *, + dtype, + shape, + lanes=None, + base_addr=0, + impl=VF_IMPL_DEFAULT, +): + return _binary_tile_vop( + lhs_view, + rhs_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + micro_op_name="vdiv", + impl=impl, + ) + + +def tor_( + lhs_view, + rhs_view, + out_view, + *, + dtype, + shape, + lanes=None, + base_addr=0, + impl=VF_IMPL_DEFAULT, +): + return _binary_tile_vop( + lhs_view, + rhs_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + micro_op_name="vor", + impl=impl, + ) + + +def tmov(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): + rows, cols = check_tbinop_operands( + src_view, + src_view, + out_view, + dtype=dtype, + shape=shape, + context="TMOV", + ) + lanes = resolve_lanes(dtype, lanes) + element_count = rows * cols + buf_bytes = element_count * dtype_byte_width(dtype) + src_addr = const_i64(base_addr) + out_addr = const_i64(base_addr + buf_bytes) + + src_tile = alloc_tile_buffer(dtype, shape, space="VEC", addr=src_addr) + out_tile = alloc_tile_buffer(dtype, shape, space="VEC", addr=out_addr) + load_view(src_view, src_tile) + + ptr_type = ptr(dtype, space="VEC") + src_ptr = pto.castptr(ptr_type, src_addr) + out_ptr = pto.castptr(ptr_type, out_addr) + vector_type = vreg_type(lanes, dtype) + + for offset in range_constexpr(0, element_count, lanes): + active = builtins.min(lanes, element_count - offset) + mask = mask_for_chunk(dtype, active) + index = s.const(offset) + src_vec = pto.vlds(vector_type, src_ptr, raw(index)) + pto.vsts(src_vec, out_ptr, raw(index), mask) + + store_view(out_tile, out_view) + return out_view + + +def _binary_tile_vop( + lhs_view, + rhs_view, + out_view, + *, + dtype, + shape, + lanes, + base_addr, + micro_op_name, + impl, +): + rows, cols = check_tbinop_operands( + lhs_view, + rhs_view, + out_view, + dtype=dtype, + shape=shape, + context=micro_op_name.upper().replace("V", "T", 1), + ) + lanes = resolve_lanes(dtype, lanes) + element_count = rows * cols + buf_bytes = element_count * dtype_byte_width(dtype) + lhs_addr = const_i64(base_addr) + rhs_addr = const_i64(base_addr + buf_bytes) + out_addr = const_i64(base_addr + buf_bytes * 2) + + lhs_tile = alloc_tile_buffer(dtype, shape, space="VEC", addr=lhs_addr) + rhs_tile = alloc_tile_buffer(dtype, shape, space="VEC", addr=rhs_addr) + out_tile = alloc_tile_buffer(dtype, shape, space="VEC", addr=out_addr) + load_view(lhs_view, lhs_tile) + load_view(rhs_view, rhs_tile) + + ptr_type = ptr(dtype, space="VEC") + vector_type = vreg_type(lanes, dtype) + lhs_ptr = pto.castptr(ptr_type, lhs_addr) + rhs_ptr = pto.castptr(ptr_type, rhs_addr) + out_ptr = pto.castptr(ptr_type, out_addr) + micro_op = getattr(pto, micro_op_name) + + impl_kind = normalize_vf_impl_kind(impl) + is_contiguous = rows == 1 or cols == element_count + if const_expr(impl_kind == VF_IMPL_DEFAULT): + impl_kind = ( + VF_IMPL_1D_POST_UPDATE if is_contiguous else VF_IMPL_2D_NO_POST_UPDATE + ) + + if const_expr(impl_kind == VF_IMPL_1D_NO_POST_UPDATE): + _binary_1d_no_post_update( + lhs_ptr, + rhs_ptr, + out_ptr, + dtype=dtype, + lanes=lanes, + element_count=element_count, + vector_type=vector_type, + micro_op=micro_op, + ) + elif const_expr(impl_kind == VF_IMPL_1D_POST_UPDATE): + _binary_1d_post_update( + lhs_ptr, + rhs_ptr, + out_ptr, + ptr_type=ptr_type, + dtype=dtype, + lanes=lanes, + element_count=element_count, + vector_type=vector_type, + micro_op=micro_op, + ) + elif const_expr(impl_kind == VF_IMPL_2D_NO_POST_UPDATE): + _binary_2d_no_post_update( + lhs_ptr, + rhs_ptr, + out_ptr, + dtype=dtype, + rows=rows, + cols=cols, + lanes=lanes, + vector_type=vector_type, + micro_op=micro_op, + ) + elif const_expr(impl_kind == VF_IMPL_2D_POST_UPDATE): + _binary_2d_post_update( + lhs_ptr, + rhs_ptr, + out_ptr, + ptr_type=ptr_type, + dtype=dtype, + rows=rows, + cols=cols, + lanes=lanes, + vector_type=vector_type, + micro_op=micro_op, + ) + else: + raise ValueError(f"Unexpected normalized VF impl kind '{impl_kind}'.") + + store_view(out_tile, out_view) + return out_view + + +def _binary_1d_no_post_update( + lhs_ptr, rhs_ptr, out_ptr, *, dtype, lanes, element_count, vector_type, micro_op +): + for offset in range_constexpr(0, element_count, lanes): + active = builtins.min(lanes, element_count - offset) + mask = mask_for_chunk(dtype, active) + index = s.const(offset) + lhs_vec = pto.vlds(vector_type, lhs_ptr, raw(index)) + rhs_vec = pto.vlds(vector_type, rhs_ptr, raw(index)) + out_vec = micro_op(vector_type, lhs_vec, rhs_vec, mask) + pto.vsts(out_vec, out_ptr, raw(index), mask) + + +def _binary_1d_post_update( + lhs_ptr, + rhs_ptr, + out_ptr, + *, + ptr_type, + dtype, + lanes, + element_count, + vector_type, + micro_op, +): + lhs_cursor = lhs_ptr + rhs_cursor = rhs_ptr + out_cursor = out_ptr + lane_step = s.const(lanes) + for offset in range_constexpr(0, element_count, lanes): + active = builtins.min(lanes, element_count - offset) + mask = mask_for_chunk(dtype, active) + lhs_vec, lhs_cursor = pto.vlds_post( + vector_type, ptr_type, lhs_cursor, raw(lane_step) + ) + rhs_vec, rhs_cursor = pto.vlds_post( + vector_type, ptr_type, rhs_cursor, raw(lane_step) + ) + out_vec = micro_op(vector_type, lhs_vec, rhs_vec, mask) + out_cursor = pto.vsts_post(ptr_type, out_vec, out_cursor, raw(lane_step), mask) + + +def _binary_2d_no_post_update( + lhs_ptr, rhs_ptr, out_ptr, *, dtype, rows, cols, lanes, vector_type, micro_op +): + for row in range_constexpr(rows): + for col in range_constexpr(0, cols, lanes): + active = builtins.min(lanes, cols - col) + mask = mask_for_chunk(dtype, active) + index = s.const(row * cols + col) + lhs_vec = pto.vlds(vector_type, lhs_ptr, raw(index)) + rhs_vec = pto.vlds(vector_type, rhs_ptr, raw(index)) + out_vec = micro_op(vector_type, lhs_vec, rhs_vec, mask) + pto.vsts(out_vec, out_ptr, raw(index), mask) + + +def _binary_2d_post_update( + lhs_ptr, + rhs_ptr, + out_ptr, + *, + ptr_type, + dtype, + rows, + cols, + lanes, + vector_type, + micro_op, +): + lane_step = s.const(lanes) + for row in range_constexpr(rows): + row_base = row * cols + lhs_row = pto.addptr(lhs_ptr, raw(s.const(row_base))) + rhs_row = pto.addptr(rhs_ptr, raw(s.const(row_base))) + out_row = pto.addptr(out_ptr, raw(s.const(row_base))) + lhs_cursor = lhs_row + rhs_cursor = rhs_row + out_cursor = out_row + for col in range_constexpr(0, cols, lanes): + active = builtins.min(lanes, cols - col) + mask = mask_for_chunk(dtype, active) + lhs_vec, lhs_cursor = pto.vlds_post( + vector_type, ptr_type, lhs_cursor, raw(lane_step) + ) + rhs_vec, rhs_cursor = pto.vlds_post( + vector_type, ptr_type, rhs_cursor, raw(lane_step) + ) + out_vec = micro_op(vector_type, lhs_vec, rhs_vec, mask) + out_cursor = pto.vsts_post( + ptr_type, out_vec, out_cursor, raw(lane_step), mask + ) + + +__all__ = [ + "VF_IMPL_DEFAULT", + "VF_IMPL_1D_NO_POST_UPDATE", + "VF_IMPL_1D_POST_UPDATE", + "VF_IMPL_2D_NO_POST_UPDATE", + "VF_IMPL_2D_POST_UPDATE", + "tadd", + "tdiv", + "tmov", + "tmul", + "tor_", + "tsub", +] diff --git a/ptodsl/lib/a5/texpand.py b/ptodsl/lib/a5/texpand.py new file mode 100644 index 00000000..f2c841dd --- /dev/null +++ b/ptodsl/lib/a5/texpand.py @@ -0,0 +1,214 @@ +"""Implement tile broadcast/expand ops with PTO vector micro instructions.""" + +import builtins + +from mlir.dialects import pto + +from ._common import ( + alloc_tile_buffer, + check_col_expand_operands, + check_row_expand_operands, + const_i64, + dtype_byte_width, + mask_for_chunk, + micro_lane_count, + ptr, + raw, + range_constexpr, + s, + store_view, + load_view, + vreg_type, + require_view_dtype, + require_view_shape, +) + + +def tcol_expand(src_view, out_view, *, dtype, shape, base_addr=0): + rows, cols = check_col_expand_operands( + src_view, out_view, dtype=dtype, shape=shape, context="TCOLEXPAND" + ) + lanes = micro_lane_count(dtype) + vector_type = vreg_type(lanes, dtype) + buf_bytes = rows * cols * dtype_byte_width(dtype) + + src_addr = const_i64(base_addr) + out_addr = const_i64(base_addr + buf_bytes) + src_tile = alloc_tile_buffer( + dtype, shape, space="VEC", valid_shape=[1, cols], addr=src_addr + ) + out_tile = alloc_tile_buffer(dtype, shape, space="VEC", addr=out_addr) + load_view(src_view, src_tile) + + src_ptr = pto.castptr(ptr(dtype, space="VEC"), src_addr) + out_ptr = pto.castptr(ptr(dtype, space="VEC"), out_addr) + + for col in range_constexpr(0, cols, lanes): + active = builtins.min(lanes, cols - col) + mask = mask_for_chunk(dtype, active) + col_offset = s.const(col) + vec = pto.vlds(vector_type, src_ptr, raw(col_offset)) + for row in range_constexpr(rows): + dst_offset = s.const(row * cols + col) + pto.vsts(vec, out_ptr, raw(dst_offset), mask) + + store_view(out_tile, out_view) + return out_view + + +def trow_expand(src_view, out_view, *, dtype, shape, base_addr=0): + rows, cols = check_row_expand_operands( + src_view, out_view, dtype=dtype, shape=shape, context="TROWEXPAND" + ) + lanes = micro_lane_count(dtype) + vector_type = vreg_type(lanes, dtype) + buf_bytes = rows * cols * dtype_byte_width(dtype) + + src_addr = const_i64(base_addr) + out_addr = const_i64(base_addr + buf_bytes) + src_tile = alloc_tile_buffer( + dtype, shape, space="VEC", valid_shape=[rows, 1], addr=src_addr + ) + out_tile = alloc_tile_buffer(dtype, shape, space="VEC", addr=out_addr) + load_view(src_view, src_tile) + + src_ptr = pto.castptr(ptr(dtype, space="VEC"), src_addr) + out_ptr = pto.castptr(ptr(dtype, space="VEC"), out_addr) + + for row in range_constexpr(rows): + scalar_offset = s.const(row * cols) + align = pto.vldas(pto.AlignType.get(), src_ptr, raw(scalar_offset)) + scalar_vec, _, _ = pto.vldus( + vector_type, + pto.AlignType.get(), + ptr(dtype, space="VEC"), + src_ptr, + raw(scalar_offset), + align, + ) + broadcast = pto.vdup(vector_type, scalar_vec, position="POS_LOWEST") + for col in range_constexpr(0, cols, lanes): + active = builtins.min(lanes, cols - col) + mask = mask_for_chunk(dtype, active) + dst_offset = s.const(row * cols + col) + pto.vsts(broadcast, out_ptr, raw(dst_offset), mask) + + store_view(out_tile, out_view) + return out_view + + +def trow_expand_sub(base_view, expand_view, out_view, *, dtype, shape, base_addr=0): + return _trow_expand_binary( + base_view, + expand_view, + out_view, + dtype=dtype, + shape=shape, + base_addr=base_addr, + micro_op_name="vsub", + ) + + +def trow_expand_mul(base_view, expand_view, out_view, *, dtype, shape, base_addr=0): + return _trow_expand_binary( + base_view, + expand_view, + out_view, + dtype=dtype, + shape=shape, + base_addr=base_addr, + micro_op_name="vmul", + ) + + +def trow_expand_div(base_view, expand_view, out_view, *, dtype, shape, base_addr=0): + return _trow_expand_binary( + base_view, + expand_view, + out_view, + dtype=dtype, + shape=shape, + base_addr=base_addr, + micro_op_name="vdiv", + ) + + +def _trow_expand_binary( + base_view, + expand_view, + out_view, + *, + dtype, + shape, + base_addr, + micro_op_name, +): + rows, cols = check_row_expand_operands( + expand_view, + out_view, + dtype=dtype, + shape=shape, + context=f"TROWEXPAND_{micro_op_name[1:].upper()}", + ) + require_view_shape( + base_view, + [rows, cols], + message=f"Fix: TROWEXPAND_{micro_op_name[1:].upper()} base input valid shape mismatch with output tile dst shape.", + ) + require_view_dtype( + base_view, + dtype, + message=f"Fix: TROWEXPAND_{micro_op_name[1:].upper()} input data type must be consistent with the output data type.", + ) + + lanes = micro_lane_count(dtype) + vector_type = vreg_type(lanes, dtype) + buf_bytes = rows * cols * dtype_byte_width(dtype) + base_addr_value = const_i64(base_addr) + expand_addr_value = const_i64(base_addr + buf_bytes) + out_addr_value = const_i64(base_addr + buf_bytes * 2) + + base_tile = alloc_tile_buffer(dtype, shape, space="VEC", addr=base_addr_value) + expand_tile = alloc_tile_buffer( + dtype, shape, space="VEC", valid_shape=[rows, 1], addr=expand_addr_value + ) + out_tile = alloc_tile_buffer(dtype, shape, space="VEC", addr=out_addr_value) + load_view(base_view, base_tile) + load_view(expand_view, expand_tile) + + base_ptr = pto.castptr(ptr(dtype, space="VEC"), base_addr_value) + expand_ptr = pto.castptr(ptr(dtype, space="VEC"), expand_addr_value) + out_ptr = pto.castptr(ptr(dtype, space="VEC"), out_addr_value) + micro_op = getattr(pto, micro_op_name) + + for row in range_constexpr(rows): + scalar_offset = s.const(row * cols) + align = pto.vldas(pto.AlignType.get(), expand_ptr, raw(scalar_offset)) + scalar_vec, _, _ = pto.vldus( + vector_type, + pto.AlignType.get(), + ptr(dtype, space="VEC"), + expand_ptr, + raw(scalar_offset), + align, + ) + broadcast = pto.vdup(vector_type, scalar_vec, position="POS_LOWEST") + for col in range_constexpr(0, cols, lanes): + active = builtins.min(lanes, cols - col) + mask = mask_for_chunk(dtype, active) + row_offset = s.const(row * cols + col) + base_vec = pto.vlds(vector_type, base_ptr, raw(row_offset)) + out_vec = micro_op(vector_type, base_vec, broadcast, mask) + pto.vsts(out_vec, out_ptr, raw(row_offset), mask) + + store_view(out_tile, out_view) + return out_view + + +__all__ = [ + "tcol_expand", + "trow_expand", + "trow_expand_div", + "trow_expand_mul", + "trow_expand_sub", +] diff --git a/ptodsl/lib/a5/tile_micro_coverage.py b/ptodsl/lib/a5/tile_micro_coverage.py index c3b98677..9e249eeb 100644 --- a/ptodsl/lib/a5/tile_micro_coverage.py +++ b/ptodsl/lib/a5/tile_micro_coverage.py @@ -3,72 +3,72 @@ TILE_MICRO_COVERAGE = { "mov": { "status": "implemented", - "helper": "mov_micro", + "helper": "tmov", "note": "UB stage + vlds/vsts copy loop.", }, "add": { "status": "implemented", - "helper": "add_micro", + "helper": "tadd", "note": "UB stage + constexpr-specialized TBinOp-style vlds/vadd/vsts lowering.", }, "sub": { "status": "implemented", - "helper": "sub_micro", + "helper": "tsub", "note": "UB stage + constexpr-specialized TBinOp-style vlds/vsub/vsts lowering.", }, "div": { "status": "implemented", - "helper": "div_micro", + "helper": "tdiv", "note": "UB stage + constexpr-specialized TBinOp-style vlds/vdiv/vsts lowering.", }, "mul": { "status": "implemented", - "helper": "mul_micro", + "helper": "tmul", "note": "UB stage + constexpr-specialized TBinOp-style vlds/vmul/vsts lowering.", }, "or_": { "status": "implemented", - "helper": "or_micro", + "helper": "tor_", "note": "UB stage + constexpr-specialized TBinOp-style vlds/vor/vsts lowering.", }, "gather": { "status": "partial", - "helper": "gather_micro", + "helper": "tgather", "note": "Indexed gather is implemented via vgather2 for same-width source/index pairs; mask-pattern gather still needs unsupported vsqz-style micro support.", }, "exp": { "status": "implemented", - "helper": "exp_micro", + "helper": "texp", "note": "UB stage + vlds/vexp/vsts loop.", }, "log": { "status": "implemented", - "helper": "log_micro", + "helper": "tlog", "note": "UB stage + vlds/vln/vsts loop.", }, "relu": { "status": "implemented", - "helper": "relu_micro", + "helper": "trelu", "note": "UB stage + vlds/vrelu/vsts loop.", }, "abs": { "status": "implemented", - "helper": "abs_micro", + "helper": "tabs", "note": "UB stage + vlds/vabs/vsts loop.", }, "sqrt": { "status": "implemented", - "helper": "sqrt_micro", + "helper": "tsqrt", "note": "UB stage + vlds/vsqrt/vsts loop.", }, "rsqrt": { "status": "implemented", - "helper": "rsqrt_micro", - "note": "UB stage + vsqrt/vrec micro sequence.", + "helper": "trsqrt", + "note": "UB stage + vsqrt/vrec sequence.", }, "reciprocal": { "status": "implemented", - "helper": "reciprocal_micro", + "helper": "trecip", "note": "UB stage + vlds/vrec/vsts loop.", }, "matmul": { @@ -93,67 +93,77 @@ }, "row_sum": { "status": "implemented", - "helper": "row_sum_micro", + "helper": "trow_sum", "note": "Static-shape row reduction via vcadd + point-store.", }, "row_min": { "status": "implemented", - "helper": "row_min_micro", + "helper": "trow_min", "note": "Static-shape row reduction via vcmin + point-store.", }, "row_max": { "status": "implemented", - "helper": "row_max_micro", + "helper": "trow_max", "note": "Static-shape row reduction via vcmax + point-store.", }, + "row_prod": { + "status": "blocked", + "helper": None, + "note": "No row-product micro lowering is wired yet.", + }, "row_expand": { "status": "implemented", - "helper": "row_expand_micro", + "helper": "trow_expand", "note": "Static-shape canonical broadcast via vldas/vldus/vdup/vsts.", }, "row_expand_sub": { "status": "implemented", - "helper": "row_expand_sub_micro", + "helper": "trow_expand_sub", "note": "Static-shape canonical broadcast via vldas/vldus/vdup/vsub/vsts.", }, "row_expand_div": { "status": "implemented", - "helper": "row_expand_div_micro", + "helper": "trow_expand_div", "note": "Static-shape canonical broadcast via vldas/vldus/vdup/vdiv/vsts.", }, "row_expand_mul": { "status": "implemented", - "helper": "row_expand_mul_micro", + "helper": "trow_expand_mul", "note": "Static-shape canonical broadcast via vldas/vldus/vdup/vmul/vsts.", }, "col_sum": { "status": "implemented", - "helper": "col_sum_micro", + "helper": "tcol_sum", "note": "Static-shape TColReduceOps-style column reduction via vadd.", }, "col_min": { "status": "implemented", - "helper": "col_min_micro", + "helper": "tcol_min", "note": "Static-shape TColReduceOps-style column reduction via vmin.", }, "col_max": { "status": "implemented", - "helper": "col_max_micro", + "helper": "tcol_max", "note": "Static-shape TColReduceOps-style column reduction via vmax.", }, + "col_prod": { + "status": "blocked", + "helper": None, + "note": "No column-product micro lowering is wired yet.", + }, "col_expand": { "status": "implemented", - "helper": "col_expand_micro", + "helper": "tcol_expand", "note": "Static-shape canonical broadcast via vlds/vsts replication.", }, "mrgsort": { "status": "implemented", - "helper": "mrgsort_micro", + "helper": "tmrgsort", "note": "Single-list row-major merge sort via vmrgsort4.", }, "sort32": { "status": "implemented", - "helper": "sort32_micro", + "helper": "tsort32", "note": "Static-shape block sort via vbitsort.", }, "subset": { @@ -184,15 +194,13 @@ def coverage_markdown(): f"- Blocked: `{counts.get('blocked', 0)}`", f"- Not applicable: `{counts.get('not_applicable', 0)}`", "", - "| tile op | status | helper | note |", - "| --- | --- | --- | --- |", + "| tile op | helper | note |", + "| --- | --- | --- |", ] for name in tile.__all__: entry = TILE_MICRO_COVERAGE[name] helper = entry["helper"] or "-" - lines.append( - f"| `{name}` | `{entry['status']}` | `{helper}` | {entry['note']} |" - ) + lines.append(f"| `{name}` | `{helper}` | {entry['note']} |") return "\n".join(lines) + "\n" diff --git a/ptodsl/lib/a5/treduce.py b/ptodsl/lib/a5/treduce.py new file mode 100644 index 00000000..3419e0d6 --- /dev/null +++ b/ptodsl/lib/a5/treduce.py @@ -0,0 +1,283 @@ +"""Implement tile reduce ops with PTO vector micro instructions.""" + +import builtins + +from mlir.dialects import pto + +from ._common import ( + VF_IMPL_1D_NO_POST_UPDATE, + VF_IMPL_1D_POST_UPDATE, + VF_IMPL_2D_NO_POST_UPDATE, + VF_IMPL_2D_POST_UPDATE, + VF_IMPL_DEFAULT, + alloc_tile_buffer, + check_col_reduce_operands, + check_row_reduce_operands, + const_expr, + const_float, + const_i64, + dtype_byte_width, + full_mask, + mask_for_chunk, + micro_lane_count, + normalize_vf_impl_kind, + onept_dist, + ptr, + raw, + range_constexpr, + s, + store_view, + tail_mask, + load_view, + vreg_type, +) + + +def trow_sum(src_view, out_view, *, dtype, shape, base_addr=0): + return _trow_reduce( + src_view, + out_view, + dtype=dtype, + shape=shape, + base_addr=base_addr, + reduce_op_name="vcadd", + combine_op_name="vadd", + init_value=0.0, + ) + + +def trow_max(src_view, out_view, *, dtype, shape, base_addr=0): + return _trow_reduce( + src_view, + out_view, + dtype=dtype, + shape=shape, + base_addr=base_addr, + reduce_op_name="vcmax", + combine_op_name="vmax", + init_value=float("-inf"), + ) + + +def trow_min(src_view, out_view, *, dtype, shape, base_addr=0): + return _trow_reduce( + src_view, + out_view, + dtype=dtype, + shape=shape, + base_addr=base_addr, + reduce_op_name="vcmin", + combine_op_name="vmin", + init_value=float("inf"), + ) + + +def tcol_sum(src_view, out_view, *, dtype, shape, base_addr=0, impl=VF_IMPL_DEFAULT): + return _tcol_reduce( + src_view, + out_view, + dtype=dtype, + shape=shape, + base_addr=base_addr, + micro_op_name="vadd", + impl=impl, + ) + + +def tcol_max(src_view, out_view, *, dtype, shape, base_addr=0, impl=VF_IMPL_DEFAULT): + return _tcol_reduce( + src_view, + out_view, + dtype=dtype, + shape=shape, + base_addr=base_addr, + micro_op_name="vmax", + impl=impl, + ) + + +def tcol_min(src_view, out_view, *, dtype, shape, base_addr=0, impl=VF_IMPL_DEFAULT): + return _tcol_reduce( + src_view, + out_view, + dtype=dtype, + shape=shape, + base_addr=base_addr, + micro_op_name="vmin", + impl=impl, + ) + + +def _trow_reduce( + src_view, + out_view, + *, + dtype, + shape, + base_addr, + reduce_op_name, + combine_op_name, + init_value, +): + rows, cols = check_row_reduce_operands( + src_view, out_view, dtype=dtype, shape=shape, context="TROWREDUCE" + ) + width = dtype_byte_width(dtype) + if width not in {2, 4}: + raise ValueError(f"{reduce_op_name} currently supports only float16/float32.") + + lanes = micro_lane_count(dtype) + vector_type = vreg_type(lanes, dtype) + buf_bytes = rows * cols * width + src_addr = const_i64(base_addr) + out_addr = const_i64(base_addr + buf_bytes) + + src_tile = alloc_tile_buffer(dtype, shape, space="VEC", addr=src_addr) + out_tile = alloc_tile_buffer( + dtype, shape, space="VEC", valid_shape=[rows, 1], addr=out_addr + ) + load_view(src_view, src_tile) + + src_ptr = pto.castptr(ptr(dtype, space="VEC"), src_addr) + out_ptr = pto.castptr(ptr(dtype, space="VEC"), out_addr) + reduce_op = getattr(pto, reduce_op_name) + combine_op = getattr(pto, combine_op_name) + row_mask = full_mask(dtype) + point_mask = tail_mask(dtype, 1) + init_scalar = const_float(dtype, init_value) + + for row in range_constexpr(rows): + accum = pto.vbr(vector_type, init_scalar) + for col in range_constexpr(0, cols, lanes): + active = builtins.min(lanes, cols - col) + mask = mask_for_chunk(dtype, active) + offset = s.const(row * cols + col) + vec = pto.vlds(vector_type, src_ptr, raw(offset)) + reduced = reduce_op(vector_type, vec, mask) + accum = combine_op(vector_type, accum, reduced, row_mask) + out_offset = s.const(row * cols) + pto.vsts(accum, out_ptr, raw(out_offset), point_mask, dist=onept_dist(dtype)) + + store_view(out_tile, out_view) + return out_view + + +def _tcol_reduce( + src_view, + out_view, + *, + dtype, + shape, + base_addr, + micro_op_name, + impl, +): + rows, cols = check_col_reduce_operands( + src_view, out_view, dtype=dtype, shape=shape, context="TCOLREDUCE" + ) + lanes = micro_lane_count(dtype) + buf_bytes = rows * cols * dtype_byte_width(dtype) + src_addr = const_i64(base_addr) + out_addr = const_i64(base_addr + buf_bytes) + + src_tile = alloc_tile_buffer(dtype, shape, space="VEC", addr=src_addr) + out_tile = alloc_tile_buffer( + dtype, [1, cols], space="VEC", valid_shape=[1, cols], addr=out_addr + ) + load_view(src_view, src_tile) + + ptr_type = ptr(dtype, space="VEC") + vector_type = vreg_type(lanes, dtype) + src_ptr = pto.castptr(ptr_type, src_addr) + out_ptr = pto.castptr(ptr_type, out_addr) + reduce_op = getattr(pto, micro_op_name) + impl_kind = normalize_vf_impl_kind(impl) + if const_expr(impl_kind == VF_IMPL_DEFAULT): + impl_kind = VF_IMPL_1D_POST_UPDATE + + if const_expr(impl_kind in {VF_IMPL_1D_NO_POST_UPDATE, VF_IMPL_2D_NO_POST_UPDATE}): + _tcol_reduce_no_post_update( + src_ptr, + out_ptr, + dtype=dtype, + rows=rows, + cols=cols, + lanes=lanes, + vector_type=vector_type, + reduce_op=reduce_op, + ) + elif const_expr(impl_kind in {VF_IMPL_1D_POST_UPDATE, VF_IMPL_2D_POST_UPDATE}): + _tcol_reduce_post_update( + src_ptr, + out_ptr, + ptr_type=ptr_type, + dtype=dtype, + rows=rows, + cols=cols, + lanes=lanes, + vector_type=vector_type, + reduce_op=reduce_op, + ) + else: + raise ValueError(f"Unexpected normalized VF impl kind '{impl_kind}'.") + + store_view(out_tile, out_view) + return out_view + + +def _tcol_reduce_no_post_update( + src_ptr, out_ptr, *, dtype, rows, cols, lanes, vector_type, reduce_op +): + loop_pairs = (rows - 1) // 2 + remain = (rows - 1) % 2 + for col in range_constexpr(0, cols, lanes): + active = builtins.min(lanes, cols - col) + mask = mask_for_chunk(dtype, active) + accum = pto.vlds(vector_type, src_ptr, raw(s.const(col))) + for pair in range_constexpr(loop_pairs): + row0 = 2 * pair + 1 + row1 = 2 * pair + 2 + src0 = pto.vlds(vector_type, src_ptr, raw(s.const(col + row0 * cols))) + src1 = pto.vlds(vector_type, src_ptr, raw(s.const(col + row1 * cols))) + pair_sum = reduce_op(vector_type, src0, src1, mask) + accum = reduce_op(vector_type, accum, pair_sum, mask) + if const_expr(remain): + src_tail = pto.vlds( + vector_type, src_ptr, raw(s.const(col + (rows - 1) * cols)) + ) + accum = reduce_op(vector_type, accum, src_tail, mask) + pto.vsts(accum, out_ptr, raw(s.const(col)), mask) + + +def _tcol_reduce_post_update( + src_ptr, out_ptr, *, ptr_type, dtype, rows, cols, lanes, vector_type, reduce_op +): + lane_step = s.const(lanes) + for col in range_constexpr(0, cols, lanes): + active = builtins.min(lanes, cols - col) + mask = mask_for_chunk(dtype, active) + row0_ptr = pto.addptr(src_ptr, raw(s.const(col))) + accum, _ = pto.vlds_post(vector_type, ptr_type, row0_ptr, raw(lane_step)) + row_ptr = pto.addptr(row0_ptr, raw(s.const(cols))) + for _ in range_constexpr(rows - 1): + src_tail, row_ptr = pto.vlds_post( + vector_type, ptr_type, row_ptr, raw(lane_step) + ) + accum = reduce_op(vector_type, accum, src_tail, mask) + out_cursor = pto.addptr(out_ptr, raw(s.const(col))) + pto.vsts_post(ptr_type, accum, out_cursor, raw(lane_step), mask) + + +__all__ = [ + "VF_IMPL_DEFAULT", + "VF_IMPL_1D_NO_POST_UPDATE", + "VF_IMPL_1D_POST_UPDATE", + "VF_IMPL_2D_NO_POST_UPDATE", + "VF_IMPL_2D_POST_UPDATE", + "tcol_max", + "tcol_min", + "tcol_sum", + "trow_max", + "trow_min", + "trow_sum", +] diff --git a/ptodsl/lib/a5/tsort.py b/ptodsl/lib/a5/tsort.py new file mode 100644 index 00000000..72884ded --- /dev/null +++ b/ptodsl/lib/a5/tsort.py @@ -0,0 +1,155 @@ +"""Implement gather/sort tile ops with PTO vector micro instructions.""" + +import builtins + +from mlir.dialects import pto + +from ._common import ( + alloc_tile_buffer, + check_gather_operands, + check_mrgsort_operands, + check_sort32_operands, + const_i64, + dtype_byte_width, + mask_for_chunk, + micro_lane_count, + ptr, + raw, + range_constexpr, + s, + store_view, + uint32_type, + load_view, + vreg_type, +) + + +def tgather( + src_view, + indices_view, + out_view, + *, + dtype, + index_dtype=None, + shape, + base_addr=0, +): + index_dtype = uint32_type() if index_dtype is None else index_dtype + rows, cols = check_gather_operands( + src_view, + indices_view, + out_view, + dtype=dtype, + index_dtype=index_dtype, + shape=shape, + ) + src_bytes = rows * cols * dtype_byte_width(dtype) + idx_bytes = rows * cols * dtype_byte_width(index_dtype) + + src_addr = const_i64(base_addr) + idx_addr = const_i64(base_addr + src_bytes) + out_addr = const_i64(base_addr + src_bytes + idx_bytes) + + src_tile = alloc_tile_buffer(dtype, shape, space="VEC", addr=src_addr) + idx_tile = alloc_tile_buffer(index_dtype, shape, space="VEC", addr=idx_addr) + out_tile = alloc_tile_buffer(dtype, shape, space="VEC", addr=out_addr) + load_view(src_view, src_tile) + load_view(indices_view, idx_tile) + + src_ptr = pto.castptr(ptr(dtype, space="VEC"), src_addr) + idx_ptr = pto.castptr(ptr(index_dtype, space="VEC"), idx_addr) + out_ptr = pto.castptr(ptr(dtype, space="VEC"), out_addr) + lanes = micro_lane_count(dtype) + vector_type = vreg_type(lanes, dtype) + index_vector_type = vreg_type(micro_lane_count(index_dtype), index_dtype) + + for row in range_constexpr(rows): + row_base = row * cols + for col in range_constexpr(0, cols, lanes): + active = builtins.min(lanes, cols - col) + offset = s.const(row_base + col) + mask = mask_for_chunk(dtype, active) + idx_vec = pto.vlds(index_vector_type, idx_ptr, raw(offset)) + out_vec = pto.vgather2(vector_type, src_ptr, idx_vec, raw(s.const(active))) + pto.vsts(out_vec, out_ptr, raw(offset), mask) + + store_view(out_tile, out_view) + return out_view + + +def tmrgsort(src_view, out_view, *, dtype, shape, block_len, base_addr=0): + _, cols = check_mrgsort_operands( + src_view, out_view, dtype=dtype, shape=shape, block_len=block_len + ) + src_addr = const_i64(base_addr) + out_addr = const_i64(base_addr + cols * dtype_byte_width(dtype)) + + src_tile = alloc_tile_buffer(dtype, shape, space="VEC", addr=src_addr) + out_tile = alloc_tile_buffer(dtype, shape, space="VEC", addr=out_addr) + load_view(src_view, src_tile) + + ptr_type = ptr(dtype, space="VEC") + src_ptr = pto.castptr(ptr_type, src_addr) + out_ptr = pto.castptr(ptr_type, out_addr) + + src1_ptr = pto.addptr(src_ptr, raw(s.const(block_len))) + src2_ptr = pto.addptr(src_ptr, raw(s.const(block_len * 2))) + src3_ptr = pto.addptr(src_ptr, raw(s.const(block_len * 3))) + num_structures = (block_len * dtype_byte_width(dtype)) >> 3 + count_value = ( + num_structures + | (num_structures << 16) + | (num_structures << 32) + | (num_structures << 48) + ) + repeat_times = cols // (block_len * 4) + config_value = repeat_times | (0b1111 << 8) + + pto.vmrgsort4( + out_ptr, + src_ptr, + src1_ptr, + src2_ptr, + src3_ptr, + const_i64(count_value), + const_i64(config_value), + ) + store_view(out_tile, out_view) + return out_view + + +def tsort32(src_view, idx_view, out_view, *, dtype, shape, base_addr=0): + rows, cols, out_cols = check_sort32_operands( + src_view, idx_view, out_view, dtype=dtype, shape=shape + ) + src_bytes = rows * cols * dtype_byte_width(dtype) + idx_bytes = rows * cols * 4 + + src_addr = const_i64(base_addr) + idx_addr = const_i64(base_addr + src_bytes) + out_addr = const_i64(base_addr + src_bytes + idx_bytes) + + src_tile = alloc_tile_buffer(dtype, [rows, cols], space="VEC", addr=src_addr) + idx_tile = alloc_tile_buffer( + uint32_type(), [rows, cols], space="VEC", addr=idx_addr + ) + out_tile = alloc_tile_buffer(dtype, [rows, out_cols], space="VEC", addr=out_addr) + load_view(src_view, src_tile) + load_view(idx_view, idx_tile) + + src_ptr = pto.castptr(ptr(dtype, space="VEC"), src_addr) + idx_ptr = pto.castptr(ptr(uint32_type(), space="VEC"), idx_addr) + out_ptr = pto.castptr(ptr(dtype, space="VEC"), out_addr) + repeat_times = s.const(cols // 32) + + for row in range_constexpr(rows): + src_row = pto.addptr(src_ptr, raw(s.const(row * cols))) + idx_row = pto.addptr(idx_ptr, raw(s.const(row * cols))) + out_row = pto.addptr(out_ptr, raw(s.const(row * out_cols))) + pto.vbitsort(out_row, src_row, idx_row, raw(repeat_times)) + + store_view(out_tile, out_view) + return out_view + + +__all__ = ["tgather", "tmrgsort", "tsort32"] diff --git a/ptodsl/lib/a5/tunary.py b/ptodsl/lib/a5/tunary.py new file mode 100644 index 00000000..3eee318a --- /dev/null +++ b/ptodsl/lib/a5/tunary.py @@ -0,0 +1,190 @@ +"""Implement tile unary ops with PTO vector micro instructions. + +This file demonstrates how to write tile-style helpers such as `pto.texp` +directly in terms of `pto.vlds`, a unary vector opcode, and `pto.vsts`. +""" + +import builtins + +from mlir.dialects import pto + +from ._common import ( + alloc_tile_buffer, + check_tbinop_operands, + const_i64, + dtype_byte_width, + mask_for_chunk, + ptr, + raw, + range_constexpr, + resolve_lanes, + s, + store_view, + load_view, + vreg_type, +) + + +def texp(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): + return _unary_tile_vop( + src_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + micro_op_name="vexp", + context="TEXP", + ) + + +def tlog(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): + return _unary_tile_vop( + src_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + micro_op_name="vln", + context="TLOG", + ) + + +def trelu(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): + return _unary_tile_vop( + src_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + micro_op_name="vrelu", + context="TRELU", + ) + + +def tabs(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): + return _unary_tile_vop( + src_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + micro_op_name="vabs", + context="TABS", + ) + + +def tsqrt(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): + return _unary_tile_vop( + src_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + micro_op_name="vsqrt", + context="TSQRT", + ) + + +def trecip(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): + return _unary_tile_vop( + src_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + micro_op_name="vrec", + context="TRECIP", + ) + + +def trsqrt(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): + rows, cols = check_tbinop_operands( + src_view, + src_view, + out_view, + dtype=dtype, + shape=shape, + context="TRSQRT", + ) + lanes = resolve_lanes(dtype, lanes) + element_count = rows * cols + buf_bytes = element_count * dtype_byte_width(dtype) + src_addr = const_i64(base_addr) + out_addr = const_i64(base_addr + buf_bytes) + + src_tile = alloc_tile_buffer(dtype, shape, space="VEC", addr=src_addr) + out_tile = alloc_tile_buffer(dtype, shape, space="VEC", addr=out_addr) + load_view(src_view, src_tile) + + ptr_type = ptr(dtype, space="VEC") + vector_type = vreg_type(lanes, dtype) + src_ptr = pto.castptr(ptr_type, src_addr) + out_ptr = pto.castptr(ptr_type, out_addr) + + for offset in range_constexpr(0, element_count, lanes): + active = builtins.min(lanes, element_count - offset) + mask = mask_for_chunk(dtype, active) + index = s.const(offset) + src_vec = pto.vlds(vector_type, src_ptr, raw(index)) + sqrt_vec = pto.vsqrt(vector_type, src_vec, mask) + out_vec = pto.vrec(vector_type, sqrt_vec, mask) + pto.vsts(out_vec, out_ptr, raw(index), mask) + + store_view(out_tile, out_view) + return out_view + + +def _unary_tile_vop( + src_view, + out_view, + *, + dtype, + shape, + lanes, + base_addr, + micro_op_name, + context, +): + rows, cols = check_tbinop_operands( + src_view, + src_view, + out_view, + dtype=dtype, + shape=shape, + context=context, + ) + lanes = resolve_lanes(dtype, lanes) + element_count = rows * cols + buf_bytes = element_count * dtype_byte_width(dtype) + src_addr = const_i64(base_addr) + out_addr = const_i64(base_addr + buf_bytes) + + src_tile = alloc_tile_buffer(dtype, shape, space="VEC", addr=src_addr) + out_tile = alloc_tile_buffer(dtype, shape, space="VEC", addr=out_addr) + load_view(src_view, src_tile) + + ptr_type = ptr(dtype, space="VEC") + vector_type = vreg_type(lanes, dtype) + src_ptr = pto.castptr(ptr_type, src_addr) + out_ptr = pto.castptr(ptr_type, out_addr) + micro_op = getattr(pto, micro_op_name) + + for offset in range_constexpr(0, element_count, lanes): + active = builtins.min(lanes, element_count - offset) + mask = mask_for_chunk(dtype, active) + index = s.const(offset) + src_vec = pto.vlds(vector_type, src_ptr, raw(index)) + out_vec = micro_op(vector_type, src_vec, mask) + pto.vsts(out_vec, out_ptr, raw(index), mask) + + store_view(out_tile, out_view) + return out_view + + +__all__ = ["tabs", "texp", "tlog", "trecip", "trelu", "trsqrt", "tsqrt"] diff --git a/tests/regression/test_a5_lib_regression.py b/tests/regression/test_a5_lib_regression.py index d9b23c7d..981c5eae 100644 --- a/tests/regression/test_a5_lib_regression.py +++ b/tests/regression/test_a5_lib_regression.py @@ -1,11 +1,52 @@ import pytest -from mlir.ir import IndexType +from mlir.ir import IndexType, IntegerType -from ptodsl import pto, to_ir_module +import ptodsl.language as pto +from ptodsl import to_ir_module from ptodsl.lib import a5 from scripts.generate_a5_pto import emit_kernels +def _index(value): + return pto.const(value) if isinstance(value, int) else value + + +def _row_major_strides(shape): + strides = [None] * len(shape) + stride = pto.const(1) + for index in range(len(shape) - 1, -1, -1): + strides[index] = stride + dim = shape[index] + stride = stride * _index(dim) + return strides + + +def _make_tensor(ptr_value, *, shape, dtype): + return pto.as_tensor( + pto.TensorType(rank=len(shape), dtype=dtype), + ptr=ptr_value, + shape=[_index(dim) for dim in shape], + strides=_row_major_strides(shape), + ) + + +def _slice_tensor(source, *, offsets, sizes, dtype): + return pto.slice_view( + pto.SubTensorType(shape=sizes, dtype=dtype), + source=source, + offsets=[_index(offset) for offset in offsets], + sizes=[_index(size) for size in sizes], + ) + + +def test_a5_split_modules_are_publicly_exposed(): + assert a5.tbinary.tadd is a5.tadd + assert a5.tunary.trsqrt is a5.trsqrt + assert a5.texpand.trow_expand is a5.trow_expand + assert a5.treduce.trow_sum is a5.trow_sum + assert a5.tsort.tgather is a5.tgather + + def test_a5_elementwise_add_kernel_emits_tile_flow(): text = str(a5.build_elementwise_add()) @@ -39,94 +80,115 @@ def test_a5_templated_elementwise_add_specializes_constexpr_impl(): assert "pto.tadd" not in text -def test_a5_micro_vector_copy_emits_micro_ops(): - text = str(a5.build_micro_vector_copy()) +def test_a5_vector_copy_emits_vector_opcodes(): + text = str(a5.build_vector_copy()) - assert "func.func @a5_micro_vector_copy" in text + assert "func.func @a5_vector_copy" in text assert "pto.pset_b32" in text assert "pto.vlds" in text assert "pto.vsts" in text -def test_a5_col_expand_micro_emits_broadcast_micro_ops(): +def test_a5_tcol_expand_emits_broadcast_micro_ops(): def meta_data(): return { - "ptr_t": pto.ptr(pto.float32), + "ptr_t": pto.PtrType(pto.float32), "index_t": IndexType.get(), } @to_ir_module(meta_data=meta_data) - def a5_col_expand_micro(src: "ptr_t", dst: "ptr_t") -> None: - src_view = pto.make_tensor(src, shape=[1, 32], dtype=pto.float32) - dst_view = pto.make_tensor(dst, shape=[32, 32], dtype=pto.float32) + def a5_tcol_expand(src: "ptr_t", dst: "ptr_t") -> None: + src_view = _make_tensor(src, shape=[1, 32], dtype=pto.float32) + dst_view = _make_tensor(dst, shape=[32, 32], dtype=pto.float32) with pto.vector_section(): - a5.col_expand_micro( - src_view.slice([0, 0], [1, 32]), - dst_view.slice([0, 0], [32, 32]), + a5.tcol_expand( + _slice_tensor( + src_view, offsets=[0, 0], sizes=[1, 32], dtype=pto.float32 + ), + _slice_tensor( + dst_view, + offsets=[0, 0], + sizes=[32, 32], + dtype=pto.float32, + ), dtype=pto.float32, shape=[32, 32], ) - text = str(a5_col_expand_micro) + text = str(a5_tcol_expand) - assert "func.func @a5_col_expand_micro" in text + assert "func.func @a5_tcol_expand" in text assert "pto.vlds" in text assert "pto.vsts" in text assert "pto.tcolexpand" not in text -def test_a5_gather_micro_emits_indexed_gather_micro_ops(): +def test_a5_tgather_emits_indexed_gather_opcodes(): + def uint32(): + return IntegerType.get_unsigned(32) + def meta_data(): return { - "ptr_src": pto.ptr(pto.float32), - "ptr_idx": pto.ptr(pto.uint32), + "ptr_src": pto.PtrType(pto.float32), + "ptr_idx": pto.PtrType(uint32()), } @to_ir_module(meta_data=meta_data) - def a5_gather_micro(src: "ptr_src", idx: "ptr_idx", dst: "ptr_src") -> None: - src_view = pto.make_tensor(src, shape=[1, 64], dtype=pto.float32) - idx_view = pto.make_tensor(idx, shape=[1, 64], dtype=pto.uint32) - dst_view = pto.make_tensor(dst, shape=[1, 64], dtype=pto.float32) + def a5_tgather(src: "ptr_src", idx: "ptr_idx", dst: "ptr_src") -> None: + src_view = _make_tensor(src, shape=[1, 64], dtype=pto.float32) + idx_view = _make_tensor(idx, shape=[1, 64], dtype=uint32()) + dst_view = _make_tensor(dst, shape=[1, 64], dtype=pto.float32) with pto.vector_section(): - a5.gather_micro( - src_view.slice([0, 0], [1, 64]), - idx_view.slice([0, 0], [1, 64]), - dst_view.slice([0, 0], [1, 64]), + a5.tgather( + _slice_tensor( + src_view, offsets=[0, 0], sizes=[1, 64], dtype=pto.float32 + ), + _slice_tensor(idx_view, offsets=[0, 0], sizes=[1, 64], dtype=uint32()), + _slice_tensor( + dst_view, offsets=[0, 0], sizes=[1, 64], dtype=pto.float32 + ), dtype=pto.float32, - index_dtype=pto.uint32, + index_dtype=uint32(), shape=[1, 64], ) - text = str(a5_gather_micro) + text = str(a5_tgather) - assert "func.func @a5_gather_micro" in text + assert "func.func @a5_tgather" in text assert "pto.vgather2" in text assert "pto.vsts" in text assert "pto.tgather" not in text -def test_a5_row_expand_micro_emits_broadcast_micro_ops(): +def test_a5_trow_expand_emits_broadcast_micro_ops(): def meta_data(): return { - "ptr_t": pto.ptr(pto.float32), + "ptr_t": pto.PtrType(pto.float32), "index_t": IndexType.get(), } @to_ir_module(meta_data=meta_data) - def a5_row_expand_micro(src: "ptr_t", dst: "ptr_t") -> None: - src_view = pto.make_tensor(src, shape=[32, 1], dtype=pto.float32) - dst_view = pto.make_tensor(dst, shape=[32, 32], dtype=pto.float32) + def a5_trow_expand(src: "ptr_t", dst: "ptr_t") -> None: + src_view = _make_tensor(src, shape=[32, 1], dtype=pto.float32) + dst_view = _make_tensor(dst, shape=[32, 32], dtype=pto.float32) with pto.vector_section(): - a5.row_expand_micro( - src_view.slice([0, 0], [32, 1]), - dst_view.slice([0, 0], [32, 32]), + a5.trow_expand( + _slice_tensor( + src_view, offsets=[0, 0], sizes=[32, 1], dtype=pto.float32 + ), + _slice_tensor( + dst_view, + offsets=[0, 0], + sizes=[32, 32], + dtype=pto.float32, + ), dtype=pto.float32, shape=[32, 32], ) - text = str(a5_row_expand_micro) + text = str(a5_trow_expand) - assert "func.func @a5_row_expand_micro" in text + assert "func.func @a5_trow_expand" in text assert "pto.vldas" in text assert "pto.vldus" in text assert "pto.vdup" in text @@ -134,30 +196,36 @@ def a5_row_expand_micro(src: "ptr_t", dst: "ptr_t") -> None: assert "pto.trowexpand" not in text -def test_a5_row_expand_mul_micro_emits_broadcast_compute_micro_ops(): +def test_a5_trow_expand_mul_emits_broadcast_compute_micro_ops(): def meta_data(): return { - "ptr_t": pto.ptr(pto.float32), + "ptr_t": pto.PtrType(pto.float32), "index_t": IndexType.get(), } @to_ir_module(meta_data=meta_data) - def a5_row_expand_mul_micro(base: "ptr_t", scale: "ptr_t", dst: "ptr_t") -> None: - base_view = pto.make_tensor(base, shape=[32, 32], dtype=pto.float32) - scale_view = pto.make_tensor(scale, shape=[32, 1], dtype=pto.float32) - dst_view = pto.make_tensor(dst, shape=[32, 32], dtype=pto.float32) + def a5_trow_expand_mul(base: "ptr_t", scale: "ptr_t", dst: "ptr_t") -> None: + base_view = _make_tensor(base, shape=[32, 32], dtype=pto.float32) + scale_view = _make_tensor(scale, shape=[32, 1], dtype=pto.float32) + dst_view = _make_tensor(dst, shape=[32, 32], dtype=pto.float32) with pto.vector_section(): - a5.row_expand_mul_micro( - base_view.slice([0, 0], [32, 32]), - scale_view.slice([0, 0], [32, 1]), - dst_view.slice([0, 0], [32, 32]), + a5.trow_expand_mul( + _slice_tensor( + base_view, offsets=[0, 0], sizes=[32, 32], dtype=pto.float32 + ), + _slice_tensor( + scale_view, offsets=[0, 0], sizes=[32, 1], dtype=pto.float32 + ), + _slice_tensor( + dst_view, offsets=[0, 0], sizes=[32, 32], dtype=pto.float32 + ), dtype=pto.float32, shape=[32, 32], ) - text = str(a5_row_expand_mul_micro) + text = str(a5_trow_expand_mul) - assert "func.func @a5_row_expand_mul_micro" in text + assert "func.func @a5_trow_expand_mul" in text assert "pto.vldas" in text assert "pto.vldus" in text assert "pto.vdup" in text @@ -166,27 +234,31 @@ def a5_row_expand_mul_micro(base: "ptr_t", scale: "ptr_t", dst: "ptr_t") -> None assert "pto.trowexpandmul" not in text -def test_a5_rsqrt_micro_emits_vsqrt_then_vrec(): +def test_a5_trsqrt_emits_vsqrt_then_vrec(): def meta_data(): return { - "ptr_t": pto.ptr(pto.float32), + "ptr_t": pto.PtrType(pto.float32), } @to_ir_module(meta_data=meta_data) - def a5_rsqrt_micro(src: "ptr_t", dst: "ptr_t") -> None: - src_view = pto.make_tensor(src, shape=[1, 64], dtype=pto.float32) - dst_view = pto.make_tensor(dst, shape=[1, 64], dtype=pto.float32) + def a5_trsqrt(src: "ptr_t", dst: "ptr_t") -> None: + src_view = _make_tensor(src, shape=[1, 64], dtype=pto.float32) + dst_view = _make_tensor(dst, shape=[1, 64], dtype=pto.float32) with pto.vector_section(): - a5.rsqrt_micro( - src_view.slice([0, 0], [1, 64]), - dst_view.slice([0, 0], [1, 64]), + a5.trsqrt( + _slice_tensor( + src_view, offsets=[0, 0], sizes=[1, 64], dtype=pto.float32 + ), + _slice_tensor( + dst_view, offsets=[0, 0], sizes=[1, 64], dtype=pto.float32 + ), dtype=pto.float32, shape=[1, 64], ) - text = str(a5_rsqrt_micro) + text = str(a5_trsqrt) - assert "func.func @a5_rsqrt_micro" in text + assert "func.func @a5_trsqrt" in text assert "pto.vsqrt" in text assert "pto.vrec" in text assert "pto.trsqrt" not in text @@ -195,35 +267,39 @@ def a5_rsqrt_micro(src: "ptr_t", dst: "ptr_t") -> None: @pytest.mark.parametrize( ("helper_name", "reduce_op", "combine_op", "tile_op"), [ - ("row_sum_micro", "pto.vcadd", "pto.vadd", "pto.trowsum"), - ("row_max_micro", "pto.vcmax", "pto.vmax", "pto.trowmax"), - ("row_min_micro", "pto.vcmin", "pto.vmin", "pto.trowmin"), + ("trow_sum", "pto.vcadd", "pto.vadd", "pto.trowsum"), + ("trow_max", "pto.vcmax", "pto.vmax", "pto.trowmax"), + ("trow_min", "pto.vcmin", "pto.vmin", "pto.trowmin"), ], ) -def test_a5_row_reduce_micro_emits_reduction_micro_ops( +def test_a5_trow_reduce_emits_reduction_micro_ops( helper_name, reduce_op, combine_op, tile_op ): def meta_data(): return { - "ptr_t": pto.ptr(pto.float32), + "ptr_t": pto.PtrType(pto.float32), "index_t": IndexType.get(), } helper = getattr(a5, helper_name) @to_ir_module(meta_data=meta_data) - def a5_row_reduce_micro(src: "ptr_t", dst: "ptr_t") -> None: - src_view = pto.make_tensor(src, shape=[32, 32], dtype=pto.float32) - dst_view = pto.make_tensor(dst, shape=[32, 1], dtype=pto.float32) + def a5_trow_reduce(src: "ptr_t", dst: "ptr_t") -> None: + src_view = _make_tensor(src, shape=[32, 32], dtype=pto.float32) + dst_view = _make_tensor(dst, shape=[32, 1], dtype=pto.float32) with pto.vector_section(): helper( - src_view.slice([0, 0], [32, 32]), - dst_view.slice([0, 0], [32, 1]), + _slice_tensor( + src_view, offsets=[0, 0], sizes=[32, 32], dtype=pto.float32 + ), + _slice_tensor( + dst_view, offsets=[0, 0], sizes=[32, 1], dtype=pto.float32 + ), dtype=pto.float32, shape=[32, 32], ) - text = str(a5_row_reduce_micro) + text = str(a5_trow_reduce) assert reduce_op in text assert combine_op in text @@ -234,35 +310,37 @@ def a5_row_reduce_micro(src: "ptr_t", dst: "ptr_t") -> None: @pytest.mark.parametrize( ("helper_name", "reduce_op", "tile_op", "impl"), [ - ("col_sum_micro", "pto.vadd", "pto.tcolsum", a5.VF_IMPL_1D_POST_UPDATE), - ("col_max_micro", "pto.vmax", "pto.tcolmax", a5.VF_IMPL_1D_NO_POST_UPDATE), - ("col_min_micro", "pto.vmin", "pto.tcolmin", a5.VF_IMPL_1D_POST_UPDATE), + ("tcol_sum", "pto.vadd", "pto.tcolsum", a5.VF_IMPL_1D_POST_UPDATE), + ("tcol_max", "pto.vmax", "pto.tcolmax", a5.VF_IMPL_1D_NO_POST_UPDATE), + ("tcol_min", "pto.vmin", "pto.tcolmin", a5.VF_IMPL_1D_POST_UPDATE), ], ) -def test_a5_col_reduce_micro_emits_template_lowering( - helper_name, reduce_op, tile_op, impl -): +def test_a5_tcol_reduce_emits_template_lowering(helper_name, reduce_op, tile_op, impl): def meta_data(): return { - "ptr_t": pto.ptr(pto.float32), + "ptr_t": pto.PtrType(pto.float32), } helper = getattr(a5, helper_name) @to_ir_module(meta_data=meta_data) - def a5_col_reduce_micro(src: "ptr_t", dst: "ptr_t") -> None: - src_view = pto.make_tensor(src, shape=[32, 32], dtype=pto.float32) - dst_view = pto.make_tensor(dst, shape=[1, 32], dtype=pto.float32) + def a5_tcol_reduce(src: "ptr_t", dst: "ptr_t") -> None: + src_view = _make_tensor(src, shape=[32, 32], dtype=pto.float32) + dst_view = _make_tensor(dst, shape=[1, 32], dtype=pto.float32) with pto.vector_section(): helper( - src_view.slice([0, 0], [32, 32]), - dst_view.slice([0, 0], [1, 32]), + _slice_tensor( + src_view, offsets=[0, 0], sizes=[32, 32], dtype=pto.float32 + ), + _slice_tensor( + dst_view, offsets=[0, 0], sizes=[1, 32], dtype=pto.float32 + ), dtype=pto.float32, shape=[32, 32], impl=impl, ) - text = str(a5_col_reduce_micro) + text = str(a5_tcol_reduce) assert reduce_op in text assert tile_op not in text @@ -271,54 +349,65 @@ def a5_col_reduce_micro(src: "ptr_t", dst: "ptr_t") -> None: assert "pto.vsts_post" in text -def test_a5_sort32_micro_emits_vbitsort(): +def test_a5_tsort32_emits_vbitsort(): + def uint32(): + return IntegerType.get_unsigned(32) + def meta_data(): return { - "ptr_src": pto.ptr(pto.float32), - "ptr_idx": pto.ptr(pto.uint32), + "ptr_src": pto.PtrType(pto.float32), + "ptr_idx": pto.PtrType(uint32()), } @to_ir_module(meta_data=meta_data) - def a5_sort32_micro(src: "ptr_src", idx: "ptr_idx", dst: "ptr_src") -> None: - src_view = pto.make_tensor(src, shape=[1, 64], dtype=pto.float32) - idx_view = pto.make_tensor(idx, shape=[1, 64], dtype=pto.uint32) - dst_view = pto.make_tensor(dst, shape=[1, 128], dtype=pto.float32) + def a5_tsort32(src: "ptr_src", idx: "ptr_idx", dst: "ptr_src") -> None: + src_view = _make_tensor(src, shape=[1, 64], dtype=pto.float32) + idx_view = _make_tensor(idx, shape=[1, 64], dtype=uint32()) + dst_view = _make_tensor(dst, shape=[1, 128], dtype=pto.float32) with pto.vector_section(): - a5.sort32_micro( - src_view.slice([0, 0], [1, 64]), - idx_view.slice([0, 0], [1, 64]), - dst_view.slice([0, 0], [1, 128]), + a5.tsort32( + _slice_tensor( + src_view, offsets=[0, 0], sizes=[1, 64], dtype=pto.float32 + ), + _slice_tensor(idx_view, offsets=[0, 0], sizes=[1, 64], dtype=uint32()), + _slice_tensor( + dst_view, offsets=[0, 0], sizes=[1, 128], dtype=pto.float32 + ), dtype=pto.float32, shape=[1, 64], ) - text = str(a5_sort32_micro) + text = str(a5_tsort32) - assert "func.func @a5_sort32_micro" in text + assert "func.func @a5_tsort32" in text assert "pto.vbitsort" in text assert "pto.tsort32" not in text -def test_a5_mrgsort_micro_emits_vmrgsort4(): +def test_a5_tmrgsort_emits_vmrgsort4(): def meta_data(): - return {"ptr_t": pto.ptr(pto.float32)} + return {"ptr_t": pto.PtrType(pto.float32)} @to_ir_module(meta_data=meta_data) - def a5_mrgsort_micro(src: "ptr_t", dst: "ptr_t") -> None: - src_view = pto.make_tensor(src, shape=[1, 256], dtype=pto.float32) - dst_view = pto.make_tensor(dst, shape=[1, 256], dtype=pto.float32) + def a5_tmrgsort(src: "ptr_t", dst: "ptr_t") -> None: + src_view = _make_tensor(src, shape=[1, 256], dtype=pto.float32) + dst_view = _make_tensor(dst, shape=[1, 256], dtype=pto.float32) with pto.vector_section(): - a5.mrgsort_micro( - src_view.slice([0, 0], [1, 256]), - dst_view.slice([0, 0], [1, 256]), + a5.tmrgsort( + _slice_tensor( + src_view, offsets=[0, 0], sizes=[1, 256], dtype=pto.float32 + ), + _slice_tensor( + dst_view, offsets=[0, 0], sizes=[1, 256], dtype=pto.float32 + ), dtype=pto.float32, shape=[1, 256], block_len=64, ) - text = str(a5_mrgsort_micro) + text = str(a5_tmrgsort) - assert "func.func @a5_mrgsort_micro" in text + assert "func.func @a5_tmrgsort" in text assert "pto.vmrgsort4" in text assert "pto.tmrgsort" not in text @@ -330,7 +419,7 @@ def test_a5_generation_script_emits_pto_files(tmp_path): assert generated_names == [ "a5_cube_matmul.pto", "a5_elementwise_add.pto", - "a5_micro_vector_copy.pto", + "a5_vector_copy.pto", ] for path in generated: @@ -338,9 +427,9 @@ def test_a5_generation_script_emits_pto_files(tmp_path): assert "func.func @" in text -def test_a5_add_micro_rejects_view_dtype_mismatch(): +def test_a5_tadd_rejects_view_dtype_mismatch(): def meta_data(): - return {"ptr_t": pto.ptr(pto.float16)} + return {"ptr_t": pto.PtrType(pto.float16)} with pytest.raises( ValueError, match="TADD input tile src0, src1 and dst tile data type mismatch" @@ -348,22 +437,28 @@ def meta_data(): @to_ir_module(meta_data=meta_data) def invalid_add(src0: "ptr_t", src1: "ptr_t", dst: "ptr_t") -> None: - lhs = pto.make_tensor(src0, shape=[32, 32], dtype=pto.float16) - rhs = pto.make_tensor(src1, shape=[32, 32], dtype=pto.float16) - out = pto.make_tensor(dst, shape=[32, 32], dtype=pto.float16) + lhs = _make_tensor(src0, shape=[32, 32], dtype=pto.float16) + rhs = _make_tensor(src1, shape=[32, 32], dtype=pto.float16) + out = _make_tensor(dst, shape=[32, 32], dtype=pto.float16) with pto.vector_section(): - a5.add_micro( - lhs.slice([0, 0], [32, 32]), - rhs.slice([0, 0], [32, 32]), - out.slice([0, 0], [32, 32]), + a5.tadd( + _slice_tensor( + lhs, offsets=[0, 0], sizes=[32, 32], dtype=pto.float16 + ), + _slice_tensor( + rhs, offsets=[0, 0], sizes=[32, 32], dtype=pto.float16 + ), + _slice_tensor( + out, offsets=[0, 0], sizes=[32, 32], dtype=pto.float16 + ), dtype=pto.float32, shape=[32, 32], ) -def test_a5_row_expand_micro_rejects_non_column_source(): +def test_a5_trow_expand_rejects_non_column_source(): def meta_data(): - return {"ptr_t": pto.ptr(pto.float32)} + return {"ptr_t": pto.PtrType(pto.float32)} with pytest.raises( ValueError, match="TROWEXPAND source valid shape must be \\[rows, 1\\]" @@ -371,50 +466,65 @@ def meta_data(): @to_ir_module(meta_data=meta_data) def invalid_row_expand(src: "ptr_t", dst: "ptr_t") -> None: - src_view = pto.make_tensor(src, shape=[1, 32], dtype=pto.float32) - dst_view = pto.make_tensor(dst, shape=[32, 32], dtype=pto.float32) + src_view = _make_tensor(src, shape=[1, 32], dtype=pto.float32) + dst_view = _make_tensor(dst, shape=[32, 32], dtype=pto.float32) with pto.vector_section(): - a5.row_expand_micro( - src_view.slice([0, 0], [1, 32]), - dst_view.slice([0, 0], [32, 32]), + a5.trow_expand( + _slice_tensor( + src_view, offsets=[0, 0], sizes=[1, 32], dtype=pto.float32 + ), + _slice_tensor( + dst_view, + offsets=[0, 0], + sizes=[32, 32], + dtype=pto.float32, + ), dtype=pto.float32, shape=[32, 32], ) -def test_a5_row_reduce_micro_rejects_non_single_column_output(): +def test_a5_trow_sum_rejects_non_single_column_output(): def meta_data(): - return {"ptr_t": pto.ptr(pto.float32)} + return {"ptr_t": pto.PtrType(pto.float32)} with pytest.raises(ValueError, match="use a single-column output tile"): @to_ir_module(meta_data=meta_data) def invalid_row_reduce(src: "ptr_t", dst: "ptr_t") -> None: - src_view = pto.make_tensor(src, shape=[32, 32], dtype=pto.float32) - dst_view = pto.make_tensor(dst, shape=[1, 32], dtype=pto.float32) + src_view = _make_tensor(src, shape=[32, 32], dtype=pto.float32) + dst_view = _make_tensor(dst, shape=[1, 32], dtype=pto.float32) with pto.vector_section(): - a5.row_sum_micro( - src_view.slice([0, 0], [32, 32]), - dst_view.slice([0, 0], [1, 32]), + a5.trow_sum( + _slice_tensor( + src_view, offsets=[0, 0], sizes=[32, 32], dtype=pto.float32 + ), + _slice_tensor( + dst_view, offsets=[0, 0], sizes=[1, 32], dtype=pto.float32 + ), dtype=pto.float32, shape=[32, 32], ) -def test_a5_col_reduce_micro_rejects_unsupported_dtype(): +def test_a5_tcol_sum_rejects_unsupported_dtype(): def meta_data(): - return {"ptr_t": pto.ptr(pto.bool)} + return {"ptr_t": pto.PtrType(pto.bool)} with pytest.raises(ValueError, match="TCOLREDUCE input data type is not supported"): @to_ir_module(meta_data=meta_data) def invalid_col_reduce(src: "ptr_t", dst: "ptr_t") -> None: - src_view = pto.make_tensor(src, shape=[32, 32], dtype=pto.bool) - dst_view = pto.make_tensor(dst, shape=[1, 32], dtype=pto.bool) + src_view = _make_tensor(src, shape=[32, 32], dtype=pto.bool) + dst_view = _make_tensor(dst, shape=[1, 32], dtype=pto.bool) with pto.vector_section(): - a5.col_sum_micro( - src_view.slice([0, 0], [32, 32]), - dst_view.slice([0, 0], [1, 32]), + a5.tcol_sum( + _slice_tensor( + src_view, offsets=[0, 0], sizes=[32, 32], dtype=pto.bool + ), + _slice_tensor( + dst_view, offsets=[0, 0], sizes=[1, 32], dtype=pto.bool + ), dtype=pto.bool, shape=[32, 32], )