Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions ptodsl/lib/a5/A5_HEADER_COVERAGE.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# A5 Header Coverage

- Total A5 headers tracked: `116`
- Implemented: `49`
- Implemented: `57`
- Partial: `2`
- Native only: `11`
- Pending: `35`
- Blocked/meta: `19`
- Pending: `29`
- Blocked/meta: `17`

| header | status | helper | note |
| --- | --- | --- | --- |
Expand Down Expand Up @@ -35,7 +35,7 @@
| `TColExpandSub` | `implemented` | `tcol_expand_sub` | Rewritten with PTO micro instructions in the PTODSL A5 layer. |
| `TColMax` | `implemented` | `tcol_max` | Rewritten with PTO micro instructions in the PTODSL A5 layer. |
| `TColMin` | `implemented` | `tcol_min` | Rewritten with PTO micro instructions in the PTODSL A5 layer. |
| `TColProd` | `blocked` | `-` | No column-product micro lowering is wired yet. |
| `TColProd` | `implemented` | `tcol_prod` | Rewritten with PTO micro instructions in the PTODSL A5 layer. |
| `TColReduceIdx` | `pending` | `-` | Indexed column reduction is not implemented yet. |
| `TColReduceOps` | `implemented` | `treduce._tcol_reduce` | Rewritten with PTO micro instructions in the PTODSL A5 layer. |
| `TColSum` | `implemented` | `tcol_sum` | Rewritten with PTO micro instructions in the PTODSL A5 layer. |
Expand All @@ -44,13 +44,13 @@
| `TDeQuant` | `pending` | `-` | Quantization/dequantization path is not implemented yet. |
| `TDiv` | `implemented` | `tdiv` | Rewritten with PTO micro instructions in the PTODSL A5 layer. |
| `TDivS` | `implemented` | `tdivs` | Rewritten with PTO micro instructions in the PTODSL A5 layer. |
| `TExpandS` | `pending` | `-` | Scalar expand helper is not implemented yet. |
| `TExpandS` | `implemented` | `texpands` | Rewritten with PTO micro instructions in the PTODSL A5 layer. |
| `TExtract` | `native` | `extract` | Still forwarded to the native PTO tile builder. |
| `TFMod` | `pending` | `-` | Fmod lowering is not implemented yet. |
| `TFModS` | `pending` | `-` | Scalar fmod lowering is not implemented yet. |
| `TFillPad` | `pending` | `-` | Pad/fill helper is not implemented yet. |
| `TGather` | `partial` | `tgather` | Indexed gather is implemented via vgather2; mask-pattern gather still needs missing vsqz-style micro support. |
| `TGatherB` | `pending` | `-` | GatherB lowering is not implemented yet, even though vgatherb exists in the micro surface. |
| `TGatherB` | `implemented` | `tgatherb` | Rewritten with PTO micro instructions in the PTODSL A5 layer. |
| `TGetScaleAddr` | `pending` | `-` | Scale-address helper is not represented in the PTODSL A5 layer. |
| `THistogram` | `pending` | `-` | Histogram lowering is not implemented yet. |
| `TImg2col` | `blocked` | `-` | Hardware layout/state programming path, not a straightforward vector-micro rewrite target. |
Expand All @@ -76,7 +76,7 @@
| `TPartMul` | `pending` | `-` | Part-op lowering is not implemented yet. |
| `TPop` | `blocked` | `-` | Runtime buffer stack/state helper, not a direct vector tile rewrite target. |
| `TPrefetch` | `blocked` | `-` | Prefetch/runtime helper, not a direct vector tile rewrite target. |
| `TPrelu` | `pending` | `-` | PReLU lowering is not implemented yet. |
| `TPrelu` | `implemented` | `tprelu` | Rewritten with PTO micro instructions in the PTODSL A5 layer. |
| `TPrint` | `native` | `native print` | Still forwarded to the native PTO tile builder. |
| `TPush` | `blocked` | `-` | Runtime buffer stack/state helper, not a direct vector tile rewrite target. |
| `TQuant` | `pending` | `-` | Quantization path is not implemented yet. |
Expand All @@ -93,13 +93,13 @@
| `TRowExpandMin` | `implemented` | `trow_expand_min` | Rewritten with PTO micro instructions in the PTODSL A5 layer. |
| `TRowExpandMul` | `implemented` | `trow_expand_mul` | Rewritten with PTO micro instructions in the PTODSL A5 layer. |
| `TRowExpandSub` | `implemented` | `trow_expand_sub` | Rewritten with PTO micro instructions in the PTODSL A5 layer. |
| `TRowProd` | `blocked` | `-` | No row-product micro lowering is wired yet. |
| `TRowProd` | `implemented` | `trow_prod` | Rewritten with PTO micro instructions in the PTODSL A5 layer. |
| `TRowReduce` | `implemented` | `treduce._trow_reduce` | Rewritten with PTO micro instructions in the PTODSL A5 layer. |
| `TRowReduceIdx` | `pending` | `-` | Indexed row reduction is not implemented yet. |
| `TRsqrt` | `implemented` | `trsqrt` | Rewritten with PTO micro instructions in the PTODSL A5 layer. |
| `TScatter` | `pending` | `-` | Scatter lowering is not implemented yet, even though vscatter exists in the micro surface. |
| `TSel` | `pending` | `-` | Packed-mask select lowering is not implemented yet. |
| `TSels` | `pending` | `-` | Scalar/mask select lowering is not implemented yet. |
| `TScatter` | `implemented` | `tscatter` | Rewritten with PTO micro instructions in the PTODSL A5 layer. |
| `TSel` | `implemented` | `tsel` | Rewritten with PTO micro instructions in the PTODSL A5 layer. |
| `TSels` | `implemented` | `tsels` | Rewritten with PTO micro instructions in the PTODSL A5 layer. |
| `TSetFmatrix` | `blocked` | `-` | Hardware state setup header, not a straightforward vector-micro rewrite target. |
| `TSetImg2colPadding` | `blocked` | `-` | Hardware state setup header, not a straightforward vector-micro rewrite target. |
| `TSetImg2colRpt` | `blocked` | `-` | Hardware state setup header, not a straightforward vector-micro rewrite target. |
Expand Down
8 changes: 4 additions & 4 deletions ptodsl/lib/a5/TILE_MICRO_CHECKLIST.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Tile Micro Coverage

- Total public tile ops: `34`
- Implemented: `26`
- Implemented: `28`
- Partial: `1`
- Pending: `0`
- Blocked: `6`
- Blocked: `4`
- Not applicable: `1`

| tile op | helper | note |
Expand All @@ -30,15 +30,15 @@
| `row_sum` | `trow_sum` | Static-shape row reduction via vcadd + point-store. |
| `row_min` | `trow_min` | Static-shape row reduction via vcmin + point-store. |
| `row_max` | `trow_max` | Static-shape row reduction via vcmax + point-store. |
| `row_prod` | `-` | No row-product micro lowering is wired yet. |
| `row_prod` | `trow_prod` | Static-shape row reduction via vmul + vintlv tree reduction + point-store. |
| `row_expand` | `trow_expand` | Static-shape canonical broadcast via vldas/vldus/vdup/vsts. |
| `row_expand_sub` | `trow_expand_sub` | Static-shape canonical broadcast via vldas/vldus/vdup/vsub/vsts. |
| `row_expand_div` | `trow_expand_div` | Static-shape canonical broadcast via vldas/vldus/vdup/vdiv/vsts. |
| `row_expand_mul` | `trow_expand_mul` | Static-shape canonical broadcast via vldas/vldus/vdup/vmul/vsts. |
| `col_sum` | `tcol_sum` | Static-shape TColReduceOps-style column reduction via vadd. |
| `col_min` | `tcol_min` | Static-shape TColReduceOps-style column reduction via vmin. |
| `col_max` | `tcol_max` | Static-shape TColReduceOps-style column reduction via vmax. |
| `col_prod` | `-` | No column-product micro lowering is wired yet. |
| `col_prod` | `tcol_prod` | Static-shape TColReduceOps-style column reduction via vmul. |
| `col_expand` | `tcol_expand` | Static-shape canonical broadcast via vlds/vsts replication. |
| `mrgsort` | `tmrgsort` | Single-list row-major merge sort via vmrgsort4. |
| `sort32` | `tsort32` | Static-shape block sort via vbitsort. |
Expand Down
15 changes: 14 additions & 1 deletion ptodsl/lib/a5/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
from . import native, ops, tbinary, texpand, treduce, tscalar, tsort, tunary
from . import (
native,
ops,
tbinary,
texpand,
tindex,
tselect,
treduce,
tscalar,
tsort,
tunary,
)
from .a5_header_coverage import A5_HEADER_COVERAGE, a5_header_coverage_markdown
from .kernels import (
HIVM_LLVM_KERNELS,
Expand Down Expand Up @@ -41,6 +52,8 @@
"native",
"tbinary",
"texpand",
"tindex",
"tselect",
"treduce",
"tscalar",
"tsort",
Expand Down
158 changes: 152 additions & 6 deletions ptodsl/lib/a5/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def uint32_type():
return IntegerType.get_unsigned(32)


def uint16_type():
return IntegerType.get_unsigned(16)


def const_i64(value):
i64 = IntegerType.get_signless(64)
return arith.ConstantOp(i64, IntegerAttr.get(i64, value)).result
Expand Down Expand Up @@ -505,11 +509,14 @@ def check_tscalar_operands(src_view, out_view, *, dtype, shape, context, allowed
return rows, cols


def check_tbinop_operands(lhs_view, rhs_view, out_view, *, dtype, shape, context):
def check_tbinop_operands(
lhs_view, rhs_view, out_view, *, dtype, shape, context, allowed=None
):
rows, cols = require_static_matrix_shape(shape, context=context)
require_supported_dtype(
dtype,
allowed={"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"},
allowed=allowed
or {"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"},
message=f"Fix: {context} has invalid data type.",
)
for view, label in ((lhs_view, "src0"), (rhs_view, "src1"), (out_view, "dst")):
Expand Down Expand Up @@ -586,11 +593,13 @@ def check_col_expand_operands(src_view, out_view, *, dtype, shape, context):
return rows, cols


def check_row_reduce_operands(src_view, out_view, *, dtype, shape, context):
def check_row_reduce_operands(
src_view, out_view, *, dtype, shape, context, allowed=None
):
rows, cols = require_static_matrix_shape(shape, context=context)
require_supported_dtype(
dtype,
allowed={"f32", "f16"},
allowed=allowed or {"f32", "f16"},
message=f"Fix: {context} input data type is not supported.",
)
require_view_shape(
Expand All @@ -616,11 +625,13 @@ def check_row_reduce_operands(src_view, out_view, *, dtype, shape, context):
return rows, cols


def check_col_reduce_operands(src_view, out_view, *, dtype, shape, context):
def check_col_reduce_operands(
src_view, out_view, *, dtype, shape, context, allowed=None
):
rows, cols = require_static_matrix_shape(shape, context=context)
require_supported_dtype(
dtype,
allowed={"f32", "f16"},
allowed=allowed or {"f32", "f16"},
message=f"Fix: {context} input data type is not supported.",
)
require_view_shape(
Expand Down Expand Up @@ -678,6 +689,136 @@ def check_gather_operands(
return rows, cols


def check_gatherb_operands(
src_view, indices_view, out_view, *, dtype, index_dtype, shape
):
rows, cols = require_static_matrix_shape(shape, context="TGATHERB")
require_supported_dtype(
dtype,
allowed={"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"},
message="Fix: TGATHERB source data type is not supported.",
)
require_supported_dtype(
index_dtype,
allowed={"u32"},
message="Fix: TGATHERB index data type must be uint32.",
)
require_view_dtype(
src_view,
dtype,
message="Fix: TGATHERB source data type mismatch.",
)
for view, label, view_dtype in (
(indices_view, "indices", index_dtype),
(out_view, "dst", dtype),
):
require_view_shape(
view,
[rows, cols],
message=f"Fix: TGATHERB {label} valid shape mismatch.",
)
require_view_dtype(
view,
view_dtype,
message=f"Fix: TGATHERB {label} data type mismatch.",
)
return rows, cols


def check_scatter_operands(
src_view, indices_view, out_view, *, dtype, index_dtype, shape
):
rows, cols = require_static_matrix_shape(shape, context="TSCATTER")
dtype_token_value = require_supported_dtype(
dtype,
allowed={"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"},
message="Fix: TSCATTER source data type is not supported.",
)
index_token_value = require_supported_dtype(
index_dtype,
allowed={"u32", "i32", "u16", "i16"},
message="Fix: TSCATTER index data type is not supported.",
)
dtype_width = dtype_byte_width(dtype)
index_width = dtype_byte_width(index_dtype)
if not (
(dtype_width == 4 and index_width == 4)
or (dtype_width == 2 and index_width == 2)
or (dtype_width == 1 and index_width == 2)
):
raise ValueError("Fix: TSCATTER invalid data type of idx.")
for view, label, view_dtype in (
(src_view, "src", dtype),
(indices_view, "idx", index_dtype),
(out_view, "dst", dtype),
):
require_view_shape(
view,
[rows, cols],
message=f"Fix: TSCATTER input tile {label} valid shape mismatch.",
)
require_view_dtype(
view,
view_dtype,
message=f"Fix: TSCATTER {label} data type mismatch.",
)
return rows, cols, dtype_token_value, index_token_value


def check_tsel_operands(mask_view, lhs_view, rhs_view, out_view, *, dtype, shape):
rows, cols = require_static_matrix_shape(shape, context="TSEL")
require_supported_dtype(
dtype,
allowed={"f32"},
message="Fix: TSEL only support 32-bit float data tiles.",
)
require_view_shape(
mask_view,
[rows, cols],
message="Fix: TSEL requires matching source, mask, and destination valid region.",
)
mask_token = extract_tensor_dtype_token(mask_view)
if mask_token not in {"i8", "u8"}:
raise ValueError("Fix: TSEL currently requires i8 or u8 mask tiles.")
for view, label in ((lhs_view, "src0"), (rhs_view, "src1"), (out_view, "dst")):
require_view_shape(
view,
[rows, cols],
message=f"Fix: TSEL input tile {label} valid shape mismatch.",
)
require_view_dtype(
view,
dtype,
message="Fix: TSEL only support same data type between dst, src0, and src1.",
)
return rows, cols


def check_tsels_operands(mask_view, src_view, out_view, *, dtype, shape):
rows, cols = require_static_matrix_shape(shape, context="TSELS")
require_supported_dtype(
dtype,
allowed={"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"},
message="TSELS: Invalid data type",
)
for view, label, view_dtype in (
(mask_view, "mask", dtype),
(src_view, "src", dtype),
(out_view, "dst", dtype),
):
require_view_shape(
view,
[rows, cols],
message=f"Fix: TSELS {label} valid shape mismatch.",
)
require_view_dtype(
view,
view_dtype,
message="TileType of dst and src must be the same.",
)
return rows, cols


def check_mrgsort_operands(src_view, out_view, *, dtype, shape, block_len):
rows, cols = require_static_matrix_shape(shape, context="TMRGSORT")
if rows != 1:
Expand Down Expand Up @@ -747,12 +888,16 @@ def check_sort32_operands(src_view, idx_view, out_view, *, dtype, shape):
"check_col_expand_operands",
"check_col_reduce_operands",
"check_gather_operands",
"check_gatherb_operands",
"check_mrgsort_operands",
"check_row_expand_operands",
"check_row_reduce_operands",
"check_scatter_operands",
"check_sort32_operands",
"check_tscalar_operands",
"check_tbinop_operands",
"check_tsel_operands",
"check_tsels_operands",
"const_expr",
"const_float",
"const_scalar",
Expand Down Expand Up @@ -783,5 +928,6 @@ def check_sort32_operands(src_view, idx_view, out_view, *, dtype, shape):
"store_view",
"tail_mask",
"uint32_type",
"uint16_type",
"vreg_type",
]
Loading
Loading