Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ptodsl/api/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


def _unwrap(value):
if isinstance(value, Value):
if isinstance(value, Value) or hasattr(value, "raw"):
return value.raw
return value

Expand Down Expand Up @@ -79,6 +79,8 @@ def __getattr__(self, item):
def wrap_value(value):
if isinstance(value, Value):
return value
if hasattr(value, "raw"):
return Value(value.raw)
return Value(value)


Expand Down
19 changes: 14 additions & 5 deletions ptodsl/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,18 @@


def _unwrap(value):
if isinstance(value, Value):
if isinstance(value, Value) or hasattr(value, "raw"):
return value.raw
return value


def _unwrap_index(value):
value = _unwrap(value)
if isinstance(value, int):
return arith.ConstantOp(IndexType.get(), value).result
return value


class Value:
# TODO: generalize to more comprehensive wrappers like https://github.com/makslevental/mlir-python-extras/blob/0.0.8.2/mlir/extras/dialects/ext/arith.py
def __init__(self, raw):
Expand Down Expand Up @@ -83,6 +90,8 @@ def __getattr__(self, item):
def wrap_value(value):
if isinstance(value, Value):
return value
if hasattr(value, "raw"):
return Value(value.raw)
return Value(value)


Expand Down Expand Up @@ -322,16 +331,16 @@ def index_cast(value, index_type=IndexType):


def as_tensor(tensor_type, *, ptr, shape, strides):
shape_vals = [_unwrap(v) for v in shape]
stride_vals = [_unwrap(v) for v in strides]
shape_vals = [_unwrap_index(v) for v in shape]
stride_vals = [_unwrap_index(v) for v in strides]
return pto.MakeTensorViewOp(
tensor_type, _unwrap(ptr), shape_vals, stride_vals
).result


def slice_view(subtensor_type, *, source, offsets, sizes):
offset_vals = [_unwrap(v) for v in offsets]
size_vals = [_unwrap(v) for v in sizes]
offset_vals = [_unwrap_index(v) for v in offsets]
size_vals = [_unwrap_index(v) for v in sizes]
return pto.PartitionViewOp(
subtensor_type, source, offsets=offset_vals, sizes=size_vals
).result
Expand Down
24 changes: 17 additions & 7 deletions ptodsl/lib/a5/README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
# A5 Library Layer

This directory contains a first PTODSL library-style translation layer for the
`pto-isa/include/pto/npu/a5` surface.
This directory contains a PTODSL library-style translation layer for the
`pto-isa/include/pto/npu/a5` surface, organized around the PTO tile opcode that
each file is re-expressing with PTO micro instructions.

The scope of this pass is:
The scope of this layout is:

- Pythonic wrappers over PTO tile ops and selected micro instructions
- A5-flavored compatibility aliases such as `TLoad`, `TAdd`, `TMatmul`, and `TStore`
- Translated builder kernels that emit `.pto` through PTODSL
- Small, readable files that show how a tile helper is written from PTO micro
opcodes such as `pto.vlds`, `pto.vadd`, and `pto.vsts`
- A5-flavored aliases such as `TLoad`, `TAdd`, `TMatmul`, and `TStore`
- Example builder kernels that emit `.pto` through PTODSL
- A checked-in generation flow for reproducible `.pto` artifacts

Entry points:

- [`ops.py`](./ops.py): reusable A5-style helpers built on PTODSL and PTO dialect ops
- [`tbinary.py`](./tbinary.py): tile binary helpers such as `tadd`, `tsub`, `tmul`,
`tdiv`, and `tor_`, written with PTO vector micro ops
- [`tunary.py`](./tunary.py): tile unary helpers such as `texp`, `tlog`, `trelu`,
`tsqrt`, `trsqrt`, and `trecip`
- [`texpand.py`](./texpand.py): row and column broadcast helpers
- [`treduce.py`](./treduce.py): row and column reduction helpers
- [`tsort.py`](./tsort.py): gather and sort helpers
- [`native.py`](./native.py): helpers that still map directly to tile/cube ops
- [`ops.py`](./ops.py): the public A5 surface that re-exports the split helpers
- [`kernels.py`](./kernels.py): translated example kernels
- [`generated`](./generated): emitted `.pto` artifacts from `scripts/generate_a5_pto.py`

Expand Down
74 changes: 38 additions & 36 deletions ptodsl/lib/a5/TILE_MICRO_CHECKLIST.md
Original file line number Diff line number Diff line change
@@ -1,43 +1,45 @@
# Tile Micro Coverage

- Total public tile ops: `32`
- Total public tile ops: `34`
- Implemented: `26`
- Partial: `1`
- Pending: `0`
- Blocked: `4`
- Blocked: `6`
- Not applicable: `1`

| tile op | status | helper | note |
| --- | --- | --- | --- |
| `mov` | `implemented` | `mov_micro` | UB stage + vlds/vsts copy loop. |
| `add` | `implemented` | `add_micro` | UB stage + constexpr-specialized TBinOp-style vlds/vadd/vsts lowering. |
| `sub` | `implemented` | `sub_micro` | UB stage + constexpr-specialized TBinOp-style vlds/vsub/vsts lowering. |
| `div` | `implemented` | `div_micro` | UB stage + constexpr-specialized TBinOp-style vlds/vdiv/vsts lowering. |
| `mul` | `implemented` | `mul_micro` | UB stage + constexpr-specialized TBinOp-style vlds/vmul/vsts lowering. |
| `or_` | `implemented` | `or_micro` | UB stage + constexpr-specialized TBinOp-style vlds/vor/vsts lowering. |
| `gather` | `partial` | `gather_micro` | Indexed gather is implemented via vgather2 for same-width source/index pairs; mask-pattern gather still needs unsupported vsqz-style micro support. |
| `exp` | `implemented` | `exp_micro` | UB stage + vlds/vexp/vsts loop. |
| `log` | `implemented` | `log_micro` | UB stage + vlds/vln/vsts loop. |
| `relu` | `implemented` | `relu_micro` | UB stage + vlds/vrelu/vsts loop. |
| `abs` | `implemented` | `abs_micro` | UB stage + vlds/vabs/vsts loop. |
| `sqrt` | `implemented` | `sqrt_micro` | UB stage + vlds/vsqrt/vsts loop. |
| `rsqrt` | `implemented` | `rsqrt_micro` | UB stage + vsqrt/vrec micro sequence. |
| `reciprocal` | `implemented` | `reciprocal_micro` | UB stage + vlds/vrec/vsts loop. |
| `matmul` | `blocked` | `-` | Cube/L0 path is not a pure vector-micro rewrite target. |
| `matmul_bias` | `blocked` | `-` | Cube/L0 path is not a pure vector-micro rewrite target. |
| `matmul_acc` | `blocked` | `-` | Cube/L0 path is not a pure vector-micro rewrite target. |
| `extract` | `blocked` | `-` | Layout/L0 extraction op, not a vector-micro compute rewrite. |
| `row_sum` | `implemented` | `row_sum_micro` | Static-shape row reduction via vcadd + point-store. |
| `row_min` | `implemented` | `row_min_micro` | Static-shape row reduction via vcmin + point-store. |
| `row_max` | `implemented` | `row_max_micro` | Static-shape row reduction via vcmax + point-store. |
| `row_expand` | `implemented` | `row_expand_micro` | Static-shape canonical broadcast via vldas/vldus/vdup/vsts. |
| `row_expand_sub` | `implemented` | `row_expand_sub_micro` | Static-shape canonical broadcast via vldas/vldus/vdup/vsub/vsts. |
| `row_expand_div` | `implemented` | `row_expand_div_micro` | Static-shape canonical broadcast via vldas/vldus/vdup/vdiv/vsts. |
| `row_expand_mul` | `implemented` | `row_expand_mul_micro` | Static-shape canonical broadcast via vldas/vldus/vdup/vmul/vsts. |
| `col_sum` | `implemented` | `col_sum_micro` | Static-shape TColReduceOps-style column reduction via vadd. |
| `col_min` | `implemented` | `col_min_micro` | Static-shape TColReduceOps-style column reduction via vmin. |
| `col_max` | `implemented` | `col_max_micro` | Static-shape TColReduceOps-style column reduction via vmax. |
| `col_expand` | `implemented` | `col_expand_micro` | Static-shape canonical broadcast via vlds/vsts replication. |
| `mrgsort` | `implemented` | `mrgsort_micro` | Single-list row-major merge sort via vmrgsort4. |
| `sort32` | `implemented` | `sort32_micro` | Static-shape block sort via vbitsort. |
| `subset` | `not_applicable` | `-` | View helper only, not a tile compute op. |
| tile op | helper | note |
| --- | --- | --- |
| `mov` | `tmov` | UB stage + vlds/vsts copy loop. |
| `add` | `tadd` | UB stage + constexpr-specialized TBinOp-style vlds/vadd/vsts lowering. |
| `sub` | `tsub` | UB stage + constexpr-specialized TBinOp-style vlds/vsub/vsts lowering. |
| `div` | `tdiv` | UB stage + constexpr-specialized TBinOp-style vlds/vdiv/vsts lowering. |
| `mul` | `tmul` | UB stage + constexpr-specialized TBinOp-style vlds/vmul/vsts lowering. |
| `or_` | `tor_` | UB stage + constexpr-specialized TBinOp-style vlds/vor/vsts lowering. |
| `gather` | `tgather` | Indexed gather is implemented via vgather2 for same-width source/index pairs; mask-pattern gather still needs unsupported vsqz-style micro support. |
| `exp` | `texp` | UB stage + vlds/vexp/vsts loop. |
| `log` | `tlog` | UB stage + vlds/vln/vsts loop. |
| `relu` | `trelu` | UB stage + vlds/vrelu/vsts loop. |
| `abs` | `tabs` | UB stage + vlds/vabs/vsts loop. |
| `sqrt` | `tsqrt` | UB stage + vlds/vsqrt/vsts loop. |
| `rsqrt` | `trsqrt` | UB stage + vsqrt/vrec sequence. |
| `reciprocal` | `trecip` | UB stage + vlds/vrec/vsts loop. |
| `matmul` | `-` | Cube/L0 path is not a pure vector-micro rewrite target. |
| `matmul_bias` | `-` | Cube/L0 path is not a pure vector-micro rewrite target. |
| `matmul_acc` | `-` | Cube/L0 path is not a pure vector-micro rewrite target. |
| `extract` | `-` | Layout/L0 extraction op, not a vector-micro compute rewrite. |
| `row_sum` | `trow_sum` | Static-shape row reduction via vcadd + point-store. |
| `row_min` | `trow_min` | Static-shape row reduction via vcmin + point-store. |
| `row_max` | `trow_max` | Static-shape row reduction via vcmax + point-store. |
| `row_prod` | `-` | No row-product micro lowering is wired yet. |
| `row_expand` | `trow_expand` | Static-shape canonical broadcast via vldas/vldus/vdup/vsts. |
| `row_expand_sub` | `trow_expand_sub` | Static-shape canonical broadcast via vldas/vldus/vdup/vsub/vsts. |
| `row_expand_div` | `trow_expand_div` | Static-shape canonical broadcast via vldas/vldus/vdup/vdiv/vsts. |
| `row_expand_mul` | `trow_expand_mul` | Static-shape canonical broadcast via vldas/vldus/vdup/vmul/vsts. |
| `col_sum` | `tcol_sum` | Static-shape TColReduceOps-style column reduction via vadd. |
| `col_min` | `tcol_min` | Static-shape TColReduceOps-style column reduction via vmin. |
| `col_max` | `tcol_max` | Static-shape TColReduceOps-style column reduction via vmax. |
| `col_prod` | `-` | No column-product micro lowering is wired yet. |
| `col_expand` | `tcol_expand` | Static-shape canonical broadcast via vlds/vsts replication. |
| `mrgsort` | `tmrgsort` | Single-list row-major merge sort via vmrgsort4. |
| `sort32` | `tsort32` | Static-shape block sort via vbitsort. |
| `subset` | `-` | View helper only, not a tile compute op. |
12 changes: 9 additions & 3 deletions ptodsl/lib/a5/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from . import ops
from . import native, ops, tbinary, texpand, treduce, tsort, tunary
from .kernels import (
KERNEL_BUILDERS,
build_cube_matmul,
build_elementwise_add,
build_micro_vector_copy,
build_mxfp8_matmul,
build_templated_elementwise_add,
build_vector_copy,
)
from .ops import *
from .tile_micro_coverage import (
Expand All @@ -19,9 +19,15 @@
"TILE_MICRO_COVERAGE",
"build_cube_matmul",
"build_elementwise_add",
"build_micro_vector_copy",
"build_mxfp8_matmul",
"build_templated_elementwise_add",
"build_vector_copy",
"coverage_markdown",
"coverage_summary",
"native",
"tbinary",
"texpand",
"treduce",
"tsort",
"tunary",
]
Loading
Loading