From e99a0945956e03ac2c93ef4b98ce07db2a6a9337 Mon Sep 17 00:00:00 2001 From: RuoyuZhou Date: Wed, 1 Apr 2026 10:30:19 +0800 Subject: [PATCH 1/2] Extend A5 PTODSL micro coverage --- ptodsl/lib/a5/A5_HEADER_COVERAGE.md | 22 +- ptodsl/lib/a5/TILE_MICRO_CHECKLIST.md | 8 +- ptodsl/lib/a5/__init__.py | 4 +- ptodsl/lib/a5/_common.py | 157 +++++++- ptodsl/lib/a5/a5_header_coverage.py | 16 +- .../tile_ops/TILE_OP_GENERATION_INDEX.md | 4 +- ptodsl/lib/a5/generated/tile_ops/abs.pto | 1 + ptodsl/lib/a5/generated/tile_ops/add.pto | 1 + .../lib/a5/generated/tile_ops/col_expand.pto | 1 + ptodsl/lib/a5/generated/tile_ops/col_max.pto | 1 + ptodsl/lib/a5/generated/tile_ops/col_min.pto | 1 + ptodsl/lib/a5/generated/tile_ops/col_prod.pto | 94 +++++ ptodsl/lib/a5/generated/tile_ops/col_sum.pto | 1 + ptodsl/lib/a5/generated/tile_ops/div.pto | 1 + ptodsl/lib/a5/generated/tile_ops/exp.pto | 1 + ptodsl/lib/a5/generated/tile_ops/gather.pto | 1 + ptodsl/lib/a5/generated/tile_ops/log.pto | 1 + ptodsl/lib/a5/generated/tile_ops/mov.pto | 1 + ptodsl/lib/a5/generated/tile_ops/mrgsort.pto | 1 + ptodsl/lib/a5/generated/tile_ops/mul.pto | 1 + ptodsl/lib/a5/generated/tile_ops/or_.pto | 1 + .../lib/a5/generated/tile_ops/reciprocal.pto | 1 + ptodsl/lib/a5/generated/tile_ops/relu.pto | 1 + .../lib/a5/generated/tile_ops/row_expand.pto | 1 + .../a5/generated/tile_ops/row_expand_div.pto | 1 + .../a5/generated/tile_ops/row_expand_mul.pto | 1 + .../a5/generated/tile_ops/row_expand_sub.pto | 1 + ptodsl/lib/a5/generated/tile_ops/row_max.pto | 1 + ptodsl/lib/a5/generated/tile_ops/row_min.pto | 1 + ptodsl/lib/a5/generated/tile_ops/row_prod.pto | 207 ++++++++++ ptodsl/lib/a5/generated/tile_ops/row_sum.pto | 1 + ptodsl/lib/a5/generated/tile_ops/rsqrt.pto | 1 + ptodsl/lib/a5/generated/tile_ops/sort32.pto | 1 + ptodsl/lib/a5/generated/tile_ops/sqrt.pto | 1 + ptodsl/lib/a5/generated/tile_ops/sub.pto | 1 + ptodsl/lib/a5/ops.py | 48 ++- ptodsl/lib/a5/tbinary.py | 46 +++ ptodsl/lib/a5/tile_micro_coverage.py | 12 +- ptodsl/lib/a5/tile_op_kernels.py | 12 + ptodsl/lib/a5/tindex.py | 356 ++++++++++++++++++ ptodsl/lib/a5/treduce.py | 136 ++++++- ptodsl/lib/a5/tscalar.py | 57 +++ ptodsl/lib/a5/tselect.py | 262 +++++++++++++ ptodsl/lib/a5/tsort.py | 104 +---- tests/regression/test_a5_lib_regression.py | 210 ++++++++++- 45 files changed, 1625 insertions(+), 157 deletions(-) create mode 100644 ptodsl/lib/a5/generated/tile_ops/col_prod.pto create mode 100644 ptodsl/lib/a5/generated/tile_ops/row_prod.pto create mode 100644 ptodsl/lib/a5/tindex.py create mode 100644 ptodsl/lib/a5/tselect.py diff --git a/ptodsl/lib/a5/A5_HEADER_COVERAGE.md b/ptodsl/lib/a5/A5_HEADER_COVERAGE.md index b864afea..1bbf39cd 100644 --- a/ptodsl/lib/a5/A5_HEADER_COVERAGE.md +++ b/ptodsl/lib/a5/A5_HEADER_COVERAGE.md @@ -1,11 +1,11 @@ # A5 Header Coverage - Total A5 headers tracked: `116` -- Implemented: `49` +- Implemented: `57` - Partial: `2` - Native only: `11` -- Pending: `35` -- Blocked/meta: `19` +- Pending: `29` +- Blocked/meta: `17` | header | status | helper | note | | --- | --- | --- | --- | @@ -35,7 +35,7 @@ | `TColExpandSub` | `implemented` | `tcol_expand_sub` | Rewritten with PTO micro instructions in the PTODSL A5 layer. | | `TColMax` | `implemented` | `tcol_max` | Rewritten with PTO micro instructions in the PTODSL A5 layer. | | `TColMin` | `implemented` | `tcol_min` | Rewritten with PTO micro instructions in the PTODSL A5 layer. | -| `TColProd` | `blocked` | `-` | No column-product micro lowering is wired yet. | +| `TColProd` | `implemented` | `tcol_prod` | Rewritten with PTO micro instructions in the PTODSL A5 layer. | | `TColReduceIdx` | `pending` | `-` | Indexed column reduction is not implemented yet. | | `TColReduceOps` | `implemented` | `treduce._tcol_reduce` | Rewritten with PTO micro instructions in the PTODSL A5 layer. | | `TColSum` | `implemented` | `tcol_sum` | Rewritten with PTO micro instructions in the PTODSL A5 layer. | @@ -44,13 +44,13 @@ | `TDeQuant` | `pending` | `-` | Quantization/dequantization path is not implemented yet. | | `TDiv` | `implemented` | `tdiv` | Rewritten with PTO micro instructions in the PTODSL A5 layer. | | `TDivS` | `implemented` | `tdivs` | Rewritten with PTO micro instructions in the PTODSL A5 layer. | -| `TExpandS` | `pending` | `-` | Scalar expand helper is not implemented yet. | +| `TExpandS` | `implemented` | `texpands` | Rewritten with PTO micro instructions in the PTODSL A5 layer. | | `TExtract` | `native` | `extract` | Still forwarded to the native PTO tile builder. | | `TFMod` | `pending` | `-` | Fmod lowering is not implemented yet. | | `TFModS` | `pending` | `-` | Scalar fmod lowering is not implemented yet. | | `TFillPad` | `pending` | `-` | Pad/fill helper is not implemented yet. | | `TGather` | `partial` | `tgather` | Indexed gather is implemented via vgather2; mask-pattern gather still needs missing vsqz-style micro support. | -| `TGatherB` | `pending` | `-` | GatherB lowering is not implemented yet, even though vgatherb exists in the micro surface. | +| `TGatherB` | `implemented` | `tgatherb` | Rewritten with PTO micro instructions in the PTODSL A5 layer. | | `TGetScaleAddr` | `pending` | `-` | Scale-address helper is not represented in the PTODSL A5 layer. | | `THistogram` | `pending` | `-` | Histogram lowering is not implemented yet. | | `TImg2col` | `blocked` | `-` | Hardware layout/state programming path, not a straightforward vector-micro rewrite target. | @@ -76,7 +76,7 @@ | `TPartMul` | `pending` | `-` | Part-op lowering is not implemented yet. | | `TPop` | `blocked` | `-` | Runtime buffer stack/state helper, not a direct vector tile rewrite target. | | `TPrefetch` | `blocked` | `-` | Prefetch/runtime helper, not a direct vector tile rewrite target. | -| `TPrelu` | `pending` | `-` | PReLU lowering is not implemented yet. | +| `TPrelu` | `implemented` | `tprelu` | Rewritten with PTO micro instructions in the PTODSL A5 layer. | | `TPrint` | `native` | `native print` | Still forwarded to the native PTO tile builder. | | `TPush` | `blocked` | `-` | Runtime buffer stack/state helper, not a direct vector tile rewrite target. | | `TQuant` | `pending` | `-` | Quantization path is not implemented yet. | @@ -93,13 +93,13 @@ | `TRowExpandMin` | `implemented` | `trow_expand_min` | Rewritten with PTO micro instructions in the PTODSL A5 layer. | | `TRowExpandMul` | `implemented` | `trow_expand_mul` | Rewritten with PTO micro instructions in the PTODSL A5 layer. | | `TRowExpandSub` | `implemented` | `trow_expand_sub` | Rewritten with PTO micro instructions in the PTODSL A5 layer. | -| `TRowProd` | `blocked` | `-` | No row-product micro lowering is wired yet. | +| `TRowProd` | `implemented` | `trow_prod` | Rewritten with PTO micro instructions in the PTODSL A5 layer. | | `TRowReduce` | `implemented` | `treduce._trow_reduce` | Rewritten with PTO micro instructions in the PTODSL A5 layer. | | `TRowReduceIdx` | `pending` | `-` | Indexed row reduction is not implemented yet. | | `TRsqrt` | `implemented` | `trsqrt` | Rewritten with PTO micro instructions in the PTODSL A5 layer. | -| `TScatter` | `pending` | `-` | Scatter lowering is not implemented yet, even though vscatter exists in the micro surface. | -| `TSel` | `pending` | `-` | Packed-mask select lowering is not implemented yet. | -| `TSels` | `pending` | `-` | Scalar/mask select lowering is not implemented yet. | +| `TScatter` | `implemented` | `tscatter` | Rewritten with PTO micro instructions in the PTODSL A5 layer. | +| `TSel` | `implemented` | `tsel` | Rewritten with PTO micro instructions in the PTODSL A5 layer. | +| `TSels` | `implemented` | `tsels` | Rewritten with PTO micro instructions in the PTODSL A5 layer. | | `TSetFmatrix` | `blocked` | `-` | Hardware state setup header, not a straightforward vector-micro rewrite target. | | `TSetImg2colPadding` | `blocked` | `-` | Hardware state setup header, not a straightforward vector-micro rewrite target. | | `TSetImg2colRpt` | `blocked` | `-` | Hardware state setup header, not a straightforward vector-micro rewrite target. | diff --git a/ptodsl/lib/a5/TILE_MICRO_CHECKLIST.md b/ptodsl/lib/a5/TILE_MICRO_CHECKLIST.md index cd7cc994..1fbd3ee6 100644 --- a/ptodsl/lib/a5/TILE_MICRO_CHECKLIST.md +++ b/ptodsl/lib/a5/TILE_MICRO_CHECKLIST.md @@ -1,10 +1,10 @@ # Tile Micro Coverage - Total public tile ops: `34` -- Implemented: `26` +- Implemented: `28` - Partial: `1` - Pending: `0` -- Blocked: `6` +- Blocked: `4` - Not applicable: `1` | tile op | helper | note | @@ -30,7 +30,7 @@ | `row_sum` | `trow_sum` | Static-shape row reduction via vcadd + point-store. | | `row_min` | `trow_min` | Static-shape row reduction via vcmin + point-store. | | `row_max` | `trow_max` | Static-shape row reduction via vcmax + point-store. | -| `row_prod` | `-` | No row-product micro lowering is wired yet. | +| `row_prod` | `trow_prod` | Static-shape row reduction via vmul + vintlv tree reduction + point-store. | | `row_expand` | `trow_expand` | Static-shape canonical broadcast via vldas/vldus/vdup/vsts. | | `row_expand_sub` | `trow_expand_sub` | Static-shape canonical broadcast via vldas/vldus/vdup/vsub/vsts. | | `row_expand_div` | `trow_expand_div` | Static-shape canonical broadcast via vldas/vldus/vdup/vdiv/vsts. | @@ -38,7 +38,7 @@ | `col_sum` | `tcol_sum` | Static-shape TColReduceOps-style column reduction via vadd. | | `col_min` | `tcol_min` | Static-shape TColReduceOps-style column reduction via vmin. | | `col_max` | `tcol_max` | Static-shape TColReduceOps-style column reduction via vmax. | -| `col_prod` | `-` | No column-product micro lowering is wired yet. | +| `col_prod` | `tcol_prod` | Static-shape TColReduceOps-style column reduction via vmul. | | `col_expand` | `tcol_expand` | Static-shape canonical broadcast via vlds/vsts replication. | | `mrgsort` | `tmrgsort` | Single-list row-major merge sort via vmrgsort4. | | `sort32` | `tsort32` | Static-shape block sort via vbitsort. | diff --git a/ptodsl/lib/a5/__init__.py b/ptodsl/lib/a5/__init__.py index e3fab8a3..31baafd6 100644 --- a/ptodsl/lib/a5/__init__.py +++ b/ptodsl/lib/a5/__init__.py @@ -1,4 +1,4 @@ -from . import native, ops, tbinary, texpand, treduce, tscalar, tsort, tunary +from . import native, ops, tbinary, texpand, tindex, tselect, treduce, tscalar, tsort, tunary from .a5_header_coverage import A5_HEADER_COVERAGE, a5_header_coverage_markdown from .kernels import ( HIVM_LLVM_KERNELS, @@ -41,6 +41,8 @@ "native", "tbinary", "texpand", + "tindex", + "tselect", "treduce", "tscalar", "tsort", diff --git a/ptodsl/lib/a5/_common.py b/ptodsl/lib/a5/_common.py index de80c57b..f78684fb 100644 --- a/ptodsl/lib/a5/_common.py +++ b/ptodsl/lib/a5/_common.py @@ -77,6 +77,10 @@ def uint32_type(): return IntegerType.get_unsigned(32) +def uint16_type(): + return IntegerType.get_unsigned(16) + + def const_i64(value): i64 = IntegerType.get_signless(64) return arith.ConstantOp(i64, IntegerAttr.get(i64, value)).result @@ -505,11 +509,13 @@ def check_tscalar_operands(src_view, out_view, *, dtype, shape, context, allowed return rows, cols -def check_tbinop_operands(lhs_view, rhs_view, out_view, *, dtype, shape, context): +def check_tbinop_operands( + lhs_view, rhs_view, out_view, *, dtype, shape, context, allowed=None +): rows, cols = require_static_matrix_shape(shape, context=context) require_supported_dtype( dtype, - allowed={"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"}, + allowed=allowed or {"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"}, message=f"Fix: {context} has invalid data type.", ) for view, label in ((lhs_view, "src0"), (rhs_view, "src1"), (out_view, "dst")): @@ -586,11 +592,13 @@ def check_col_expand_operands(src_view, out_view, *, dtype, shape, context): return rows, cols -def check_row_reduce_operands(src_view, out_view, *, dtype, shape, context): +def check_row_reduce_operands( + src_view, out_view, *, dtype, shape, context, allowed=None +): rows, cols = require_static_matrix_shape(shape, context=context) require_supported_dtype( dtype, - allowed={"f32", "f16"}, + allowed=allowed or {"f32", "f16"}, message=f"Fix: {context} input data type is not supported.", ) require_view_shape( @@ -616,11 +624,13 @@ def check_row_reduce_operands(src_view, out_view, *, dtype, shape, context): return rows, cols -def check_col_reduce_operands(src_view, out_view, *, dtype, shape, context): +def check_col_reduce_operands( + src_view, out_view, *, dtype, shape, context, allowed=None +): rows, cols = require_static_matrix_shape(shape, context=context) require_supported_dtype( dtype, - allowed={"f32", "f16"}, + allowed=allowed or {"f32", "f16"}, message=f"Fix: {context} input data type is not supported.", ) require_view_shape( @@ -678,6 +688,136 @@ def check_gather_operands( return rows, cols +def check_gatherb_operands( + src_view, indices_view, out_view, *, dtype, index_dtype, shape +): + rows, cols = require_static_matrix_shape(shape, context="TGATHERB") + require_supported_dtype( + dtype, + allowed={"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"}, + message="Fix: TGATHERB source data type is not supported.", + ) + require_supported_dtype( + index_dtype, + allowed={"u32"}, + message="Fix: TGATHERB index data type must be uint32.", + ) + require_view_dtype( + src_view, + dtype, + message="Fix: TGATHERB source data type mismatch.", + ) + for view, label, view_dtype in ( + (indices_view, "indices", index_dtype), + (out_view, "dst", dtype), + ): + require_view_shape( + view, + [rows, cols], + message=f"Fix: TGATHERB {label} valid shape mismatch.", + ) + require_view_dtype( + view, + view_dtype, + message=f"Fix: TGATHERB {label} data type mismatch.", + ) + return rows, cols + + +def check_scatter_operands( + src_view, indices_view, out_view, *, dtype, index_dtype, shape +): + rows, cols = require_static_matrix_shape(shape, context="TSCATTER") + dtype_token_value = require_supported_dtype( + dtype, + allowed={"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"}, + message="Fix: TSCATTER source data type is not supported.", + ) + index_token_value = require_supported_dtype( + index_dtype, + allowed={"u32", "i32", "u16", "i16"}, + message="Fix: TSCATTER index data type is not supported.", + ) + dtype_width = dtype_byte_width(dtype) + index_width = dtype_byte_width(index_dtype) + if not ( + (dtype_width == 4 and index_width == 4) + or (dtype_width == 2 and index_width == 2) + or (dtype_width == 1 and index_width == 2) + ): + raise ValueError("Fix: TSCATTER invalid data type of idx.") + for view, label, view_dtype in ( + (src_view, "src", dtype), + (indices_view, "idx", index_dtype), + (out_view, "dst", dtype), + ): + require_view_shape( + view, + [rows, cols], + message=f"Fix: TSCATTER input tile {label} valid shape mismatch.", + ) + require_view_dtype( + view, + view_dtype, + message=f"Fix: TSCATTER {label} data type mismatch.", + ) + return rows, cols, dtype_token_value, index_token_value + + +def check_tsel_operands(mask_view, lhs_view, rhs_view, out_view, *, dtype, shape): + rows, cols = require_static_matrix_shape(shape, context="TSEL") + require_supported_dtype( + dtype, + allowed={"f32"}, + message="Fix: TSEL only support 32-bit float data tiles.", + ) + require_view_shape( + mask_view, + [rows, cols], + message="Fix: TSEL requires matching source, mask, and destination valid region.", + ) + mask_token = extract_tensor_dtype_token(mask_view) + if mask_token not in {"i8", "u8"}: + raise ValueError("Fix: TSEL currently requires i8 or u8 mask tiles.") + for view, label in ((lhs_view, "src0"), (rhs_view, "src1"), (out_view, "dst")): + require_view_shape( + view, + [rows, cols], + message=f"Fix: TSEL input tile {label} valid shape mismatch.", + ) + require_view_dtype( + view, + dtype, + message="Fix: TSEL only support same data type between dst, src0, and src1.", + ) + return rows, cols + + +def check_tsels_operands(mask_view, src_view, out_view, *, dtype, shape): + rows, cols = require_static_matrix_shape(shape, context="TSELS") + require_supported_dtype( + dtype, + allowed={"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"}, + message="TSELS: Invalid data type", + ) + for view, label, view_dtype in ( + (mask_view, "mask", dtype), + (src_view, "src", dtype), + (out_view, "dst", dtype), + ): + require_view_shape( + view, + [rows, cols], + message=f"Fix: TSELS {label} valid shape mismatch.", + ) + require_view_dtype( + view, + view_dtype, + message="TileType of dst and src must be the same.", + ) + return rows, cols + + def check_mrgsort_operands(src_view, out_view, *, dtype, shape, block_len): rows, cols = require_static_matrix_shape(shape, context="TMRGSORT") if rows != 1: @@ -747,12 +887,16 @@ def check_sort32_operands(src_view, idx_view, out_view, *, dtype, shape): "check_col_expand_operands", "check_col_reduce_operands", "check_gather_operands", + "check_gatherb_operands", "check_mrgsort_operands", "check_row_expand_operands", "check_row_reduce_operands", + "check_scatter_operands", "check_sort32_operands", "check_tscalar_operands", "check_tbinop_operands", + "check_tsel_operands", + "check_tsels_operands", "const_expr", "const_float", "const_scalar", @@ -783,5 +927,6 @@ def check_sort32_operands(src_view, idx_view, out_view, *, dtype, shape): "store_view", "tail_mask", "uint32_type", + "uint16_type", "vreg_type", ] diff --git a/ptodsl/lib/a5/a5_header_coverage.py b/ptodsl/lib/a5/a5_header_coverage.py index c20570b3..5533e769 100644 --- a/ptodsl/lib/a5/a5_header_coverage.py +++ b/ptodsl/lib/a5/a5_header_coverage.py @@ -135,10 +135,13 @@ "TColExpandSub": "tcol_expand_sub", "TColMax": "tcol_max", "TColMin": "tcol_min", + "TColProd": "tcol_prod", "TColReduceOps": "treduce._tcol_reduce", "TColSum": "tcol_sum", "TDiv": "tdiv", "TDivS": "tdivs", + "TExpandS": "texpands", + "TGatherB": "tgatherb", "TLRelu": "tlrelu", "TMax": "tmax", "TMaxs": "tmaxs", @@ -150,6 +153,7 @@ "TMulS": "tmuls", "TOr": "tor_", "TOrS": "tors", + "TPrelu": "tprelu", "TRowExpand": "trow_expand", "TRowExpandAdd": "trow_expand_add", "TRowExpandDiv": "trow_expand_div", @@ -157,8 +161,12 @@ "TRowExpandMin": "trow_expand_min", "TRowExpandMul": "trow_expand_mul", "TRowExpandSub": "trow_expand_sub", + "TRowProd": "trow_prod", "TRowReduce": "treduce._trow_reduce", "TRsqrt": "trsqrt", + "TScatter": "tscatter", + "TSel": "tsel", + "TSels": "tsels", "TShl": "tshl", "TShlS": "tshls", "TShr": "tshr", @@ -214,13 +222,11 @@ _BLOCKED_HEADERS = { "TAlias": "C++ helper/meta header, not a tile micro-instruction kernel surface.", "TAssign": "C++ helper/meta header, not a tile micro-instruction kernel surface.", - "TColProd": "No column-product micro lowering is wired yet.", "TImg2col": "Hardware layout/state programming path, not a straightforward vector-micro rewrite target.", "TMatmul": "Cube/L0 path is not a pure vector-micro rewrite target.", "TPop": "Runtime buffer stack/state helper, not a direct vector tile rewrite target.", "TPrefetch": "Prefetch/runtime helper, not a direct vector tile rewrite target.", "TPush": "Runtime buffer stack/state helper, not a direct vector tile rewrite target.", - "TRowProd": "No row-product micro lowering is wired yet.", "TSetFmatrix": "Hardware state setup header, not a straightforward vector-micro rewrite target.", "TSetImg2colPadding": "Hardware state setup header, not a straightforward vector-micro rewrite target.", "TSetImg2colRpt": "Hardware state setup header, not a straightforward vector-micro rewrite target.", @@ -243,11 +249,9 @@ "TColReduceIdx": "Indexed column reduction is not implemented yet.", "TCvt": "Tile conversion helper is not implemented in the A5 micro layer yet.", "TDeQuant": "Quantization/dequantization path is not implemented yet.", - "TExpandS": "Scalar expand helper is not implemented yet.", "TFMod": "Fmod lowering is not implemented yet.", "TFModS": "Scalar fmod lowering is not implemented yet.", "TFillPad": "Pad/fill helper is not implemented yet.", - "TGatherB": "GatherB lowering is not implemented yet, even though vgatherb exists in the micro surface.", "TGetScaleAddr": "Scale-address helper is not represented in the PTODSL A5 layer.", "THistogram": "Histogram lowering is not implemented yet.", "TPack": "Pack lowering is not implemented yet.", @@ -256,7 +260,6 @@ "TPartMax": "Part-op lowering is not implemented yet.", "TPartMin": "Part-op lowering is not implemented yet.", "TPartMul": "Part-op lowering is not implemented yet.", - "TPrelu": "PReLU lowering is not implemented yet.", "TQuant": "Quantization path is not implemented yet.", "TRandom": "Random-number helper is not implemented yet.", "TRem": "Remainder lowering is not implemented yet.", @@ -264,9 +267,6 @@ "TRowExpandBinOp": "Generic row-broadcast binary frontend is not exposed yet.", "TRowExpandExpdif": "Specialized exp-diff row-broadcast lowering is not implemented yet.", "TRowReduceIdx": "Indexed row reduction is not implemented yet.", - "TScatter": "Scatter lowering is not implemented yet, even though vscatter exists in the micro surface.", - "TSel": "Packed-mask select lowering is not implemented yet.", - "TSels": "Scalar/mask select lowering is not implemented yet.", "TSync": "Synchronization helper is not represented in the A5 library layer yet.", "TTri": "Triangular helper is not implemented yet.", } diff --git a/ptodsl/lib/a5/generated/tile_ops/TILE_OP_GENERATION_INDEX.md b/ptodsl/lib/a5/generated/tile_ops/TILE_OP_GENERATION_INDEX.md index 263e3a05..932e0b37 100644 --- a/ptodsl/lib/a5/generated/tile_ops/TILE_OP_GENERATION_INDEX.md +++ b/ptodsl/lib/a5/generated/tile_ops/TILE_OP_GENERATION_INDEX.md @@ -23,7 +23,7 @@ | `row_sum` | `generated` | `tile_ops/row_sum.pto` | Static-shape row reduction via vcadd + point-store. | | `row_min` | `generated` | `tile_ops/row_min.pto` | Static-shape row reduction via vcmin + point-store. | | `row_max` | `generated` | `tile_ops/row_max.pto` | Static-shape row reduction via vcmax + point-store. | -| `row_prod` | `blocked` | - | No row-product micro lowering is wired yet. | +| `row_prod` | `generated` | `tile_ops/row_prod.pto` | Static-shape row reduction via vmul + vintlv tree reduction + point-store. | | `row_expand` | `generated` | `tile_ops/row_expand.pto` | Static-shape canonical broadcast via vldas/vldus/vdup/vsts. | | `row_expand_sub` | `generated` | `tile_ops/row_expand_sub.pto` | Static-shape canonical broadcast via vldas/vldus/vdup/vsub/vsts. | | `row_expand_div` | `generated` | `tile_ops/row_expand_div.pto` | Static-shape canonical broadcast via vldas/vldus/vdup/vdiv/vsts. | @@ -31,7 +31,7 @@ | `col_sum` | `generated` | `tile_ops/col_sum.pto` | Static-shape TColReduceOps-style column reduction via vadd. | | `col_min` | `generated` | `tile_ops/col_min.pto` | Static-shape TColReduceOps-style column reduction via vmin. | | `col_max` | `generated` | `tile_ops/col_max.pto` | Static-shape TColReduceOps-style column reduction via vmax. | -| `col_prod` | `blocked` | - | No column-product micro lowering is wired yet. | +| `col_prod` | `generated` | `tile_ops/col_prod.pto` | Static-shape TColReduceOps-style column reduction via vmul. | | `col_expand` | `generated` | `tile_ops/col_expand.pto` | Static-shape canonical broadcast via vlds/vsts replication. | | `mrgsort` | `generated` | `tile_ops/mrgsort.pto` | Single-list row-major merge sort via vmrgsort4. | | `sort32` | `generated` | `tile_ops/sort32.pto` | Static-shape block sort via vbitsort. | diff --git a/ptodsl/lib/a5/generated/tile_ops/abs.pto b/ptodsl/lib/a5/generated/tile_ops/abs.pto index ef021091..aea13f8e 100644 --- a/ptodsl/lib/a5/generated/tile_ops/abs.pto +++ b/ptodsl/lib/a5/generated/tile_ops/abs.pto @@ -73,3 +73,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/add.pto b/ptodsl/lib/a5/generated/tile_ops/add.pto index 5b4d9507..b824e9fa 100644 --- a/ptodsl/lib/a5/generated/tile_ops/add.pto +++ b/ptodsl/lib/a5/generated/tile_ops/add.pto @@ -88,3 +88,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/col_expand.pto b/ptodsl/lib/a5/generated/tile_ops/col_expand.pto index 7b987b15..346efa3c 100644 --- a/ptodsl/lib/a5/generated/tile_ops/col_expand.pto +++ b/ptodsl/lib/a5/generated/tile_ops/col_expand.pto @@ -52,3 +52,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/col_max.pto b/ptodsl/lib/a5/generated/tile_ops/col_max.pto index 55459b11..773749a7 100644 --- a/ptodsl/lib/a5/generated/tile_ops/col_max.pto +++ b/ptodsl/lib/a5/generated/tile_ops/col_max.pto @@ -91,3 +91,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/col_min.pto b/ptodsl/lib/a5/generated/tile_ops/col_min.pto index 665e0a4f..a393f251 100644 --- a/ptodsl/lib/a5/generated/tile_ops/col_min.pto +++ b/ptodsl/lib/a5/generated/tile_ops/col_min.pto @@ -91,3 +91,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/col_prod.pto b/ptodsl/lib/a5/generated/tile_ops/col_prod.pto new file mode 100644 index 00000000..752f6983 --- /dev/null +++ b/ptodsl/lib/a5/generated/tile_ops/col_prod.pto @@ -0,0 +1,94 @@ +module { + func.func @tile_op_col_prod(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c64 = arith.constant 64 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64_0 = arith.constant 64 : index + %0 = pto.make_tensor_view %arg0, shape = [%c8, %c64_0], strides = [%c64, %c1] : !pto.tensor_view + %c64_1 = arith.constant 64 : index + %c1_2 = arith.constant 1 : index + %c1_3 = arith.constant 1 : index + %c64_4 = arith.constant 64 : index + %1 = pto.make_tensor_view %arg1, shape = [%c1_3, %c64_4], strides = [%c64_1, %c1_2] : !pto.tensor_view + pto.section.vector { + %c0 = arith.constant 0 : index + %c0_5 = arith.constant 0 : index + %c8_6 = arith.constant 8 : index + %c64_7 = arith.constant 64 : index + %2 = pto.partition_view %0, offsets = [%c0, %c0_5], sizes = [%c8_6, %c64_7] : !pto.tensor_view -> !pto.partition_tensor_view<8x64xf32> + %c0_8 = arith.constant 0 : index + %c0_9 = arith.constant 0 : index + %c1_10 = arith.constant 1 : index + %c64_11 = arith.constant 64 : index + %3 = pto.partition_view %1, offsets = [%c0_8, %c0_9], sizes = [%c1_10, %c64_11] : !pto.tensor_view -> !pto.partition_tensor_view<1x64xf32> + %c0_i64 = arith.constant 0 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %4 = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + %5 = pto.alloc_tile addr = %c2048_i64 : !pto.tile_buf + pto.tload ins(%2 : !pto.partition_tensor_view<8x64xf32>) outs(%4 : !pto.tile_buf) + %6 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %7 = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %cst = arith.constant 1.000000e+00 : f32 + %8 = pto.vbr %cst : f32 -> !pto.vreg<64xf32> + %9 = pto.pset_b32 "PAT_ALL" : !pto.mask + %c64_12 = arith.constant 64 : index + %10 = pto.pset_b32 "PAT_ALL" : !pto.mask + %c0_13 = arith.constant 0 : index + %11 = pto.addptr %6, %c0_13 : -> + %c64_14 = arith.constant 64 : index + %result, %updated_source = pto.vlds_post %11[%c64_14] : !pto.ptr -> !pto.vreg<64xf32>, !pto.ptr + %12 = pto.pset_b32 "PAT_ALL" : !pto.mask + %13 = pto.vsel %result, %8, %12 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %14 = pto.vsel %13, %8, %10 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %15 = pto.vmul %8, %14, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %c64_15 = arith.constant 64 : index + %result_16, %updated_source_17 = pto.vlds_post %updated_source[%c64_15] : !pto.ptr -> !pto.vreg<64xf32>, !pto.ptr + %16 = pto.pset_b32 "PAT_ALL" : !pto.mask + %17 = pto.vsel %result_16, %8, %16 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %18 = pto.vsel %17, %8, %10 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %19 = pto.vmul %15, %18, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %c64_18 = arith.constant 64 : index + %result_19, %updated_source_20 = pto.vlds_post %updated_source_17[%c64_18] : !pto.ptr -> !pto.vreg<64xf32>, !pto.ptr + %20 = pto.pset_b32 "PAT_ALL" : !pto.mask + %21 = pto.vsel %result_19, %8, %20 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %22 = pto.vsel %21, %8, %10 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %23 = pto.vmul %19, %22, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %c64_21 = arith.constant 64 : index + %result_22, %updated_source_23 = pto.vlds_post %updated_source_20[%c64_21] : !pto.ptr -> !pto.vreg<64xf32>, !pto.ptr + %24 = pto.pset_b32 "PAT_ALL" : !pto.mask + %25 = pto.vsel %result_22, %8, %24 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %26 = pto.vsel %25, %8, %10 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %27 = pto.vmul %23, %26, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %c64_24 = arith.constant 64 : index + %result_25, %updated_source_26 = pto.vlds_post %updated_source_23[%c64_24] : !pto.ptr -> !pto.vreg<64xf32>, !pto.ptr + %28 = pto.pset_b32 "PAT_ALL" : !pto.mask + %29 = pto.vsel %result_25, %8, %28 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %30 = pto.vsel %29, %8, %10 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %31 = pto.vmul %27, %30, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %c64_27 = arith.constant 64 : index + %result_28, %updated_source_29 = pto.vlds_post %updated_source_26[%c64_27] : !pto.ptr -> !pto.vreg<64xf32>, !pto.ptr + %32 = pto.pset_b32 "PAT_ALL" : !pto.mask + %33 = pto.vsel %result_28, %8, %32 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %34 = pto.vsel %33, %8, %10 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %35 = pto.vmul %31, %34, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %c64_30 = arith.constant 64 : index + %result_31, %updated_source_32 = pto.vlds_post %updated_source_29[%c64_30] : !pto.ptr -> !pto.vreg<64xf32>, !pto.ptr + %36 = pto.pset_b32 "PAT_ALL" : !pto.mask + %37 = pto.vsel %result_31, %8, %36 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %38 = pto.vsel %37, %8, %10 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %39 = pto.vmul %35, %38, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %c64_33 = arith.constant 64 : index + %result_34, %updated_source_35 = pto.vlds_post %updated_source_32[%c64_33] : !pto.ptr -> !pto.vreg<64xf32>, !pto.ptr + %40 = pto.pset_b32 "PAT_ALL" : !pto.mask + %41 = pto.vsel %result_34, %8, %40 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %42 = pto.vsel %41, %8, %10 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %43 = pto.vmul %39, %42, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %c0_36 = arith.constant 0 : index + %44 = pto.addptr %7, %c0_36 : -> + %45 = pto.vsts_post %43, %44[%c64_12], %10 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask -> !pto.ptr + pto.tstore ins(%5 : !pto.tile_buf) outs(%3 : !pto.partition_tensor_view<1x64xf32>) + } + return + } +} + diff --git a/ptodsl/lib/a5/generated/tile_ops/col_sum.pto b/ptodsl/lib/a5/generated/tile_ops/col_sum.pto index 78b09c22..4142875f 100644 --- a/ptodsl/lib/a5/generated/tile_ops/col_sum.pto +++ b/ptodsl/lib/a5/generated/tile_ops/col_sum.pto @@ -91,3 +91,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/div.pto b/ptodsl/lib/a5/generated/tile_ops/div.pto index ac2fb1e4..ede78981 100644 --- a/ptodsl/lib/a5/generated/tile_ops/div.pto +++ b/ptodsl/lib/a5/generated/tile_ops/div.pto @@ -88,3 +88,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/exp.pto b/ptodsl/lib/a5/generated/tile_ops/exp.pto index 0f47cc9a..58911188 100644 --- a/ptodsl/lib/a5/generated/tile_ops/exp.pto +++ b/ptodsl/lib/a5/generated/tile_ops/exp.pto @@ -73,3 +73,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/gather.pto b/ptodsl/lib/a5/generated/tile_ops/gather.pto index 01045b3a..2c226fe7 100644 --- a/ptodsl/lib/a5/generated/tile_ops/gather.pto +++ b/ptodsl/lib/a5/generated/tile_ops/gather.pto @@ -53,3 +53,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/log.pto b/ptodsl/lib/a5/generated/tile_ops/log.pto index 5b4f62bd..08d0143c 100644 --- a/ptodsl/lib/a5/generated/tile_ops/log.pto +++ b/ptodsl/lib/a5/generated/tile_ops/log.pto @@ -73,3 +73,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/mov.pto b/ptodsl/lib/a5/generated/tile_ops/mov.pto index ade28826..737af953 100644 --- a/ptodsl/lib/a5/generated/tile_ops/mov.pto +++ b/ptodsl/lib/a5/generated/tile_ops/mov.pto @@ -65,3 +65,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/mrgsort.pto b/ptodsl/lib/a5/generated/tile_ops/mrgsort.pto index 0f0afe8e..bdb0a9d9 100644 --- a/ptodsl/lib/a5/generated/tile_ops/mrgsort.pto +++ b/ptodsl/lib/a5/generated/tile_ops/mrgsort.pto @@ -42,3 +42,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/mul.pto b/ptodsl/lib/a5/generated/tile_ops/mul.pto index 86af874f..926adc5c 100644 --- a/ptodsl/lib/a5/generated/tile_ops/mul.pto +++ b/ptodsl/lib/a5/generated/tile_ops/mul.pto @@ -88,3 +88,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/or_.pto b/ptodsl/lib/a5/generated/tile_ops/or_.pto index 66add0fe..ea25374b 100644 --- a/ptodsl/lib/a5/generated/tile_ops/or_.pto +++ b/ptodsl/lib/a5/generated/tile_ops/or_.pto @@ -88,3 +88,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/reciprocal.pto b/ptodsl/lib/a5/generated/tile_ops/reciprocal.pto index 3ed78ddc..d5ae3f9e 100644 --- a/ptodsl/lib/a5/generated/tile_ops/reciprocal.pto +++ b/ptodsl/lib/a5/generated/tile_ops/reciprocal.pto @@ -73,3 +73,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/relu.pto b/ptodsl/lib/a5/generated/tile_ops/relu.pto index 001aa127..703535ef 100644 --- a/ptodsl/lib/a5/generated/tile_ops/relu.pto +++ b/ptodsl/lib/a5/generated/tile_ops/relu.pto @@ -73,3 +73,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/row_expand.pto b/ptodsl/lib/a5/generated/tile_ops/row_expand.pto index e3fd1be3..6254d70a 100644 --- a/ptodsl/lib/a5/generated/tile_ops/row_expand.pto +++ b/ptodsl/lib/a5/generated/tile_ops/row_expand.pto @@ -89,3 +89,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/row_expand_div.pto b/ptodsl/lib/a5/generated/tile_ops/row_expand_div.pto index 7f1574e2..e1d37373 100644 --- a/ptodsl/lib/a5/generated/tile_ops/row_expand_div.pto +++ b/ptodsl/lib/a5/generated/tile_ops/row_expand_div.pto @@ -119,3 +119,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/row_expand_mul.pto b/ptodsl/lib/a5/generated/tile_ops/row_expand_mul.pto index 8f9631d5..34fc0119 100644 --- a/ptodsl/lib/a5/generated/tile_ops/row_expand_mul.pto +++ b/ptodsl/lib/a5/generated/tile_ops/row_expand_mul.pto @@ -119,3 +119,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/row_expand_sub.pto b/ptodsl/lib/a5/generated/tile_ops/row_expand_sub.pto index 910a0c80..9922961c 100644 --- a/ptodsl/lib/a5/generated/tile_ops/row_expand_sub.pto +++ b/ptodsl/lib/a5/generated/tile_ops/row_expand_sub.pto @@ -119,3 +119,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/row_max.pto b/ptodsl/lib/a5/generated/tile_ops/row_max.pto index 8bc4bafd..c01e53dc 100644 --- a/ptodsl/lib/a5/generated/tile_ops/row_max.pto +++ b/ptodsl/lib/a5/generated/tile_ops/row_max.pto @@ -116,3 +116,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/row_min.pto b/ptodsl/lib/a5/generated/tile_ops/row_min.pto index 394d2a95..bde5c99e 100644 --- a/ptodsl/lib/a5/generated/tile_ops/row_min.pto +++ b/ptodsl/lib/a5/generated/tile_ops/row_min.pto @@ -116,3 +116,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/row_prod.pto b/ptodsl/lib/a5/generated/tile_ops/row_prod.pto new file mode 100644 index 00000000..87a02975 --- /dev/null +++ b/ptodsl/lib/a5/generated/tile_ops/row_prod.pto @@ -0,0 +1,207 @@ +module { + func.func @tile_op_row_prod(%arg0: !pto.ptr, %arg1: !pto.ptr) { + %c64 = arith.constant 64 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64_0 = arith.constant 64 : index + %0 = pto.make_tensor_view %arg0, shape = [%c8, %c64_0], strides = [%c64, %c1] : !pto.tensor_view + %c1_1 = arith.constant 1 : index + %c1_2 = arith.constant 1 : index + %c8_3 = arith.constant 8 : index + %c1_4 = arith.constant 1 : index + %1 = pto.make_tensor_view %arg1, shape = [%c8_3, %c1_4], strides = [%c1_1, %c1_2] : !pto.tensor_view + pto.section.vector { + %c0 = arith.constant 0 : index + %c0_5 = arith.constant 0 : index + %c8_6 = arith.constant 8 : index + %c64_7 = arith.constant 64 : index + %2 = pto.partition_view %0, offsets = [%c0, %c0_5], sizes = [%c8_6, %c64_7] : !pto.tensor_view -> !pto.partition_tensor_view<8x64xf32> + %c0_8 = arith.constant 0 : index + %c0_9 = arith.constant 0 : index + %c8_10 = arith.constant 8 : index + %c1_11 = arith.constant 1 : index + %3 = pto.partition_view %1, offsets = [%c0_8, %c0_9], sizes = [%c8_10, %c1_11] : !pto.tensor_view -> !pto.partition_tensor_view<8x1xf32> + %c0_i64 = arith.constant 0 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %4 = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + %5 = pto.alloc_tile addr = %c2048_i64 : !pto.tile_buf + pto.tload ins(%2 : !pto.partition_tensor_view<8x64xf32>) outs(%4 : !pto.tile_buf) + %6 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %7 = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %8 = pto.pset_b32 "PAT_ALL" : !pto.mask + %cst = arith.constant 1.000000e+00 : f32 + %9 = pto.vbr %cst : f32 -> !pto.vreg<64xf32> + %10 = pto.pset_b32 "PAT_ALL" : !pto.mask + %c0_12 = arith.constant 0 : index + %11 = pto.vlds %6[%c0_12] : !pto.ptr -> !pto.vreg<64xf32> + %12 = pto.vsel %11, %9, %10 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %13 = pto.vmul %9, %12, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low, %high = pto.vintlv %13, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %14 = pto.vmul %low, %high, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_13, %high_14 = pto.vintlv %14, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %15 = pto.vmul %low_13, %high_14, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_15, %high_16 = pto.vintlv %15, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %16 = pto.vmul %low_15, %high_16, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_17, %high_18 = pto.vintlv %16, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %17 = pto.vmul %low_17, %high_18, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_19, %high_20 = pto.vintlv %17, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %18 = pto.vmul %low_19, %high_20, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_21, %high_22 = pto.vintlv %18, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %19 = pto.vmul %low_21, %high_22, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %c0_23 = arith.constant 0 : index + %c1_i32 = arith.constant 1 : i32 + %mask, %scalar_out = pto.plt_b32 %c1_i32 : i32 -> !pto.mask, i32 + pto.vsts %19, %7[%c0_23], %mask {dist = "ONEPT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %20 = pto.pset_b32 "PAT_ALL" : !pto.mask + %c64_24 = arith.constant 64 : index + %21 = pto.vlds %6[%c64_24] : !pto.ptr -> !pto.vreg<64xf32> + %22 = pto.vsel %21, %9, %20 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %23 = pto.vmul %9, %22, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_25, %high_26 = pto.vintlv %23, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %24 = pto.vmul %low_25, %high_26, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_27, %high_28 = pto.vintlv %24, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %25 = pto.vmul %low_27, %high_28, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_29, %high_30 = pto.vintlv %25, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %26 = pto.vmul %low_29, %high_30, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_31, %high_32 = pto.vintlv %26, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %27 = pto.vmul %low_31, %high_32, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_33, %high_34 = pto.vintlv %27, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %28 = pto.vmul %low_33, %high_34, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_35, %high_36 = pto.vintlv %28, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %29 = pto.vmul %low_35, %high_36, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %c64_37 = arith.constant 64 : index + %c1_i32_38 = arith.constant 1 : i32 + %mask_39, %scalar_out_40 = pto.plt_b32 %c1_i32_38 : i32 -> !pto.mask, i32 + pto.vsts %29, %7[%c64_37], %mask_39 {dist = "ONEPT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %30 = pto.pset_b32 "PAT_ALL" : !pto.mask + %c128 = arith.constant 128 : index + %31 = pto.vlds %6[%c128] : !pto.ptr -> !pto.vreg<64xf32> + %32 = pto.vsel %31, %9, %30 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %33 = pto.vmul %9, %32, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_41, %high_42 = pto.vintlv %33, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %34 = pto.vmul %low_41, %high_42, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_43, %high_44 = pto.vintlv %34, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %35 = pto.vmul %low_43, %high_44, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_45, %high_46 = pto.vintlv %35, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %36 = pto.vmul %low_45, %high_46, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_47, %high_48 = pto.vintlv %36, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %37 = pto.vmul %low_47, %high_48, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_49, %high_50 = pto.vintlv %37, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %38 = pto.vmul %low_49, %high_50, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_51, %high_52 = pto.vintlv %38, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %39 = pto.vmul %low_51, %high_52, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %c128_53 = arith.constant 128 : index + %c1_i32_54 = arith.constant 1 : i32 + %mask_55, %scalar_out_56 = pto.plt_b32 %c1_i32_54 : i32 -> !pto.mask, i32 + pto.vsts %39, %7[%c128_53], %mask_55 {dist = "ONEPT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %40 = pto.pset_b32 "PAT_ALL" : !pto.mask + %c192 = arith.constant 192 : index + %41 = pto.vlds %6[%c192] : !pto.ptr -> !pto.vreg<64xf32> + %42 = pto.vsel %41, %9, %40 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %43 = pto.vmul %9, %42, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_57, %high_58 = pto.vintlv %43, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %44 = pto.vmul %low_57, %high_58, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_59, %high_60 = pto.vintlv %44, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %45 = pto.vmul %low_59, %high_60, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_61, %high_62 = pto.vintlv %45, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %46 = pto.vmul %low_61, %high_62, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_63, %high_64 = pto.vintlv %46, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %47 = pto.vmul %low_63, %high_64, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_65, %high_66 = pto.vintlv %47, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %48 = pto.vmul %low_65, %high_66, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_67, %high_68 = pto.vintlv %48, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %49 = pto.vmul %low_67, %high_68, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %c192_69 = arith.constant 192 : index + %c1_i32_70 = arith.constant 1 : i32 + %mask_71, %scalar_out_72 = pto.plt_b32 %c1_i32_70 : i32 -> !pto.mask, i32 + pto.vsts %49, %7[%c192_69], %mask_71 {dist = "ONEPT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %50 = pto.pset_b32 "PAT_ALL" : !pto.mask + %c256 = arith.constant 256 : index + %51 = pto.vlds %6[%c256] : !pto.ptr -> !pto.vreg<64xf32> + %52 = pto.vsel %51, %9, %50 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %53 = pto.vmul %9, %52, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_73, %high_74 = pto.vintlv %53, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %54 = pto.vmul %low_73, %high_74, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_75, %high_76 = pto.vintlv %54, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %55 = pto.vmul %low_75, %high_76, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_77, %high_78 = pto.vintlv %55, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %56 = pto.vmul %low_77, %high_78, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_79, %high_80 = pto.vintlv %56, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %57 = pto.vmul %low_79, %high_80, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_81, %high_82 = pto.vintlv %57, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %58 = pto.vmul %low_81, %high_82, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_83, %high_84 = pto.vintlv %58, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %59 = pto.vmul %low_83, %high_84, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %c256_85 = arith.constant 256 : index + %c1_i32_86 = arith.constant 1 : i32 + %mask_87, %scalar_out_88 = pto.plt_b32 %c1_i32_86 : i32 -> !pto.mask, i32 + pto.vsts %59, %7[%c256_85], %mask_87 {dist = "ONEPT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %60 = pto.pset_b32 "PAT_ALL" : !pto.mask + %c320 = arith.constant 320 : index + %61 = pto.vlds %6[%c320] : !pto.ptr -> !pto.vreg<64xf32> + %62 = pto.vsel %61, %9, %60 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %63 = pto.vmul %9, %62, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_89, %high_90 = pto.vintlv %63, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %64 = pto.vmul %low_89, %high_90, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_91, %high_92 = pto.vintlv %64, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %65 = pto.vmul %low_91, %high_92, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_93, %high_94 = pto.vintlv %65, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %66 = pto.vmul %low_93, %high_94, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_95, %high_96 = pto.vintlv %66, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %67 = pto.vmul %low_95, %high_96, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_97, %high_98 = pto.vintlv %67, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %68 = pto.vmul %low_97, %high_98, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_99, %high_100 = pto.vintlv %68, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %69 = pto.vmul %low_99, %high_100, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %c320_101 = arith.constant 320 : index + %c1_i32_102 = arith.constant 1 : i32 + %mask_103, %scalar_out_104 = pto.plt_b32 %c1_i32_102 : i32 -> !pto.mask, i32 + pto.vsts %69, %7[%c320_101], %mask_103 {dist = "ONEPT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %70 = pto.pset_b32 "PAT_ALL" : !pto.mask + %c384 = arith.constant 384 : index + %71 = pto.vlds %6[%c384] : !pto.ptr -> !pto.vreg<64xf32> + %72 = pto.vsel %71, %9, %70 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %73 = pto.vmul %9, %72, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_105, %high_106 = pto.vintlv %73, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %74 = pto.vmul %low_105, %high_106, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_107, %high_108 = pto.vintlv %74, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %75 = pto.vmul %low_107, %high_108, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_109, %high_110 = pto.vintlv %75, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %76 = pto.vmul %low_109, %high_110, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_111, %high_112 = pto.vintlv %76, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %77 = pto.vmul %low_111, %high_112, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_113, %high_114 = pto.vintlv %77, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %78 = pto.vmul %low_113, %high_114, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_115, %high_116 = pto.vintlv %78, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %79 = pto.vmul %low_115, %high_116, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %c384_117 = arith.constant 384 : index + %c1_i32_118 = arith.constant 1 : i32 + %mask_119, %scalar_out_120 = pto.plt_b32 %c1_i32_118 : i32 -> !pto.mask, i32 + pto.vsts %79, %7[%c384_117], %mask_119 {dist = "ONEPT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + %80 = pto.pset_b32 "PAT_ALL" : !pto.mask + %c448 = arith.constant 448 : index + %81 = pto.vlds %6[%c448] : !pto.ptr -> !pto.vreg<64xf32> + %82 = pto.vsel %81, %9, %80 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %83 = pto.vmul %9, %82, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_121, %high_122 = pto.vintlv %83, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %84 = pto.vmul %low_121, %high_122, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_123, %high_124 = pto.vintlv %84, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %85 = pto.vmul %low_123, %high_124, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_125, %high_126 = pto.vintlv %85, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %86 = pto.vmul %low_125, %high_126, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_127, %high_128 = pto.vintlv %86, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %87 = pto.vmul %low_127, %high_128, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_129, %high_130 = pto.vintlv %87, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %88 = pto.vmul %low_129, %high_130, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %low_131, %high_132 = pto.vintlv %88, %9 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %89 = pto.vmul %low_131, %high_132, %8 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %c448_133 = arith.constant 448 : index + %c1_i32_134 = arith.constant 1 : i32 + %mask_135, %scalar_out_136 = pto.plt_b32 %c1_i32_134 : i32 -> !pto.mask, i32 + pto.vsts %89, %7[%c448_133], %mask_135 {dist = "ONEPT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.tstore ins(%5 : !pto.tile_buf) outs(%3 : !pto.partition_tensor_view<8x1xf32>) + } + return + } +} + diff --git a/ptodsl/lib/a5/generated/tile_ops/row_sum.pto b/ptodsl/lib/a5/generated/tile_ops/row_sum.pto index 61c20c50..c369a883 100644 --- a/ptodsl/lib/a5/generated/tile_ops/row_sum.pto +++ b/ptodsl/lib/a5/generated/tile_ops/row_sum.pto @@ -116,3 +116,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/rsqrt.pto b/ptodsl/lib/a5/generated/tile_ops/rsqrt.pto index 17752171..92da13ae 100644 --- a/ptodsl/lib/a5/generated/tile_ops/rsqrt.pto +++ b/ptodsl/lib/a5/generated/tile_ops/rsqrt.pto @@ -81,3 +81,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/sort32.pto b/ptodsl/lib/a5/generated/tile_ops/sort32.pto index c739860e..48e38209 100644 --- a/ptodsl/lib/a5/generated/tile_ops/sort32.pto +++ b/ptodsl/lib/a5/generated/tile_ops/sort32.pto @@ -55,3 +55,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/sqrt.pto b/ptodsl/lib/a5/generated/tile_ops/sqrt.pto index 4de655e0..ffc2b89b 100644 --- a/ptodsl/lib/a5/generated/tile_ops/sqrt.pto +++ b/ptodsl/lib/a5/generated/tile_ops/sqrt.pto @@ -73,3 +73,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/generated/tile_ops/sub.pto b/ptodsl/lib/a5/generated/tile_ops/sub.pto index e81c72ae..1caefd1a 100644 --- a/ptodsl/lib/a5/generated/tile_ops/sub.pto +++ b/ptodsl/lib/a5/generated/tile_ops/sub.pto @@ -88,3 +88,4 @@ module { return } } + diff --git a/ptodsl/lib/a5/ops.py b/ptodsl/lib/a5/ops.py index 1270b99c..256c3435 100644 --- a/ptodsl/lib/a5/ops.py +++ b/ptodsl/lib/a5/ops.py @@ -8,7 +8,6 @@ VF_IMPL_DEFAULT, ) from .native import ( - col_prod, compare, concat, extract, @@ -22,15 +21,14 @@ matmul_mx_acc, matmul_mx_bias, move_tile, - row_prod, - scatter, - select, store_tile, trans, vector_copy, vload, vstore, ) +from .tindex import tgather, tgatherb, tscatter +from .tselect import tsel, tsels from .tbinary import ( tand, tadd, @@ -40,6 +38,7 @@ tmov, tmul, tor_, + tprelu, tshl, tshr, tsub, @@ -64,13 +63,16 @@ from .treduce import ( tcol_max, tcol_min, + tcol_prod, tcol_sum, trow_max, trow_min, + trow_prod, trow_sum, ) from .tscalar import ( taxpy, + texpands, tadds, tands, tdivs, @@ -84,7 +86,7 @@ tsubs, txors, ) -from .tsort import tgather, tmrgsort, tsort32 +from .tsort import tmrgsort, tsort32 from .tunary import tabs, texp, tlog, trecip, trelu, trsqrt, tsqrt # A5-style aliases. @@ -125,15 +127,26 @@ TRsqrt = trsqrt TRecip = trecip TAxpy = taxpy +TExpandS = texpands TGather = tgather -TScatter = scatter -TSel = select +TGatherB = tgatherb +TScatter = tscatter +gatherb = tgatherb +scatter = tscatter +TSel = tsel +TSelS = tsels +TSels = tsels +select = tsel +selects = tsels +TPrelu = tprelu TConcat = concat TExtract = extract TInsert = insert TRowSum = trow_sum TRowMin = trow_min TRowMax = trow_max +TRowProd = trow_prod +row_prod = trow_prod TRowExpand = trow_expand TRowExpandAdd = trow_expand_add TRowExpandSub = trow_expand_sub @@ -144,6 +157,8 @@ TColSum = tcol_sum TColMin = tcol_min TColMax = tcol_max +TColProd = tcol_prod +col_prod = tcol_prod TColExpand = tcol_expand TColExpandAdd = tcol_expand_add TColExpandSub = tcol_expand_sub @@ -173,6 +188,7 @@ "TAnd", "TAndS", "TAxpy", + "TExpandS", "TColExpand", "TColExpandAdd", "TColExpandDiv", @@ -182,6 +198,7 @@ "TColExpandSub", "TColMax", "TColMin", + "TColProd", "TColSum", "TConcat", "TCmp", @@ -190,6 +207,7 @@ "TExp", "TExtract", "TGather", + "TGatherB", "TInsert", "TLRelu", "TLoad", @@ -212,6 +230,7 @@ "TMulS", "TOr", "TOrS", + "TPrelu", "TRecip", "TRelu", "TRowExpand", @@ -223,10 +242,13 @@ "TRowExpandSub", "TRowMax", "TRowMin", + "TRowProd", "TRowSum", "TRsqrt", "TScatter", "TSel", + "TSelS", + "TSels", "TShl", "TShlS", "TShr", @@ -239,11 +261,12 @@ "TTrans", "TXor", "TXorS", - "col_prod", "compare", + "col_prod", "concat", "extract", "full_mask_b32", + "gatherb", "insert", "load_tile", "matmul", @@ -256,6 +279,7 @@ "row_prod", "scatter", "select", + "selects", "store_tile", "tabs", "tadd", @@ -263,6 +287,7 @@ "tand", "tands", "taxpy", + "texpands", "tcol_expand", "tcol_expand_add", "tcol_expand_div", @@ -272,11 +297,13 @@ "tcol_expand_sub", "tcol_max", "tcol_min", + "tcol_prod", "tcol_sum", "tdiv", "tdivs", "texp", "tgather", + "tgatherb", "tlrelu", "tlog", "tmax", @@ -289,6 +316,7 @@ "tmuls", "tor_", "tors", + "tprelu", "trans", "trecip", "trelu", @@ -301,7 +329,11 @@ "trow_expand_sub", "trow_max", "trow_min", + "trow_prod", "trow_sum", + "tscatter", + "tsel", + "tsels", "trsqrt", "tshl", "tshls", diff --git a/ptodsl/lib/a5/tbinary.py b/ptodsl/lib/a5/tbinary.py index 133ed9e1..57e2ec25 100644 --- a/ptodsl/lib/a5/tbinary.py +++ b/ptodsl/lib/a5/tbinary.py @@ -18,6 +18,7 @@ const_i64, dtype_byte_width, flat_active_lanes, + mask_type, mask_for_chunk, matrix_active_lanes, normalize_vf_impl_kind, @@ -29,6 +30,7 @@ s, store_view, load_view, + const_scalar, vreg_type, ) @@ -467,6 +469,47 @@ def tmov( return out_view +def tprelu( + lhs_view, + rhs_view, + out_view, + *, + dtype, + tile_shape=None, + shape=None, + valid_row=None, + valid_col=None, + valid_shape=None, + lanes=None, + base_addr=0, + impl=VF_IMPL_DEFAULT, +): + zero_scalar = const_scalar(dtype, 0) + + def emit_prelu(vector_type, lhs_vec, rhs_vec, mask): + neg_vec = pto.vmul(vector_type, lhs_vec, rhs_vec, mask) + cmp_mask = pto.vcmps(mask_type(), lhs_vec, zero_scalar, mask, "gt") + return pto.vsel(vector_type, lhs_vec, neg_vec, cmp_mask) + + return _binary_tile_vop( + lhs_view, + rhs_view, + out_view, + dtype=dtype, + tile_shape=tile_shape, + shape=shape, + valid_row=valid_row, + valid_col=valid_col, + valid_shape=valid_shape, + lanes=lanes, + base_addr=base_addr, + context="TPRELU", + micro_op=emit_prelu, + impl=impl, + allowed_dtypes={"f32", "f16"}, + ) + + def _binary_tile_vop( lhs_view, rhs_view, @@ -483,6 +526,7 @@ def _binary_tile_vop( context, micro_op, impl, + allowed_dtypes=None, ): rows, cols, valid_row, valid_col, type_valid_shape = resolve_tile_spec( tile_shape=tile_shape, @@ -499,6 +543,7 @@ def _binary_tile_vop( dtype=dtype, shape=[rows, cols], context=context, + allowed=allowed_dtypes, ) lanes = resolve_lanes(dtype, lanes) element_count = rows * cols @@ -745,6 +790,7 @@ def _binary_2d_post_update( "tmov", "tmul", "tor_", + "tprelu", "tshl", "tshr", "tsub", diff --git a/ptodsl/lib/a5/tile_micro_coverage.py b/ptodsl/lib/a5/tile_micro_coverage.py index 9e249eeb..54d684c1 100644 --- a/ptodsl/lib/a5/tile_micro_coverage.py +++ b/ptodsl/lib/a5/tile_micro_coverage.py @@ -107,9 +107,9 @@ "note": "Static-shape row reduction via vcmax + point-store.", }, "row_prod": { - "status": "blocked", - "helper": None, - "note": "No row-product micro lowering is wired yet.", + "status": "implemented", + "helper": "trow_prod", + "note": "Static-shape row reduction via vmul + vintlv tree reduction + point-store.", }, "row_expand": { "status": "implemented", @@ -147,9 +147,9 @@ "note": "Static-shape TColReduceOps-style column reduction via vmax.", }, "col_prod": { - "status": "blocked", - "helper": None, - "note": "No column-product micro lowering is wired yet.", + "status": "implemented", + "helper": "tcol_prod", + "note": "Static-shape TColReduceOps-style column reduction via vmul.", }, "col_expand": { "status": "implemented", diff --git a/ptodsl/lib/a5/tile_op_kernels.py b/ptodsl/lib/a5/tile_op_kernels.py index d4da63e4..9885643e 100644 --- a/ptodsl/lib/a5/tile_op_kernels.py +++ b/ptodsl/lib/a5/tile_op_kernels.py @@ -410,6 +410,12 @@ def kernel(src: "src_ptr_t", idx: "idx_ptr_t", dst: "dst_ptr_t") -> None: ), "expected_tokens": ["pto.vcmax"], }, + "row_prod": { + "builder": lambda: _row_reduce_kernel( + "tile_op_row_prod", ops.trow_prod, dtype_name="float32" + ), + "expected_tokens": ["pto.vmul", "pto.vintlv"], + }, "row_expand": { "builder": lambda: _row_expand_kernel( "tile_op_row_expand", ops.trow_expand, dtype_name="float32" @@ -452,6 +458,12 @@ def kernel(src: "src_ptr_t", idx: "idx_ptr_t", dst: "dst_ptr_t") -> None: ), "expected_tokens": ["pto.vmax"], }, + "col_prod": { + "builder": lambda: _col_reduce_kernel( + "tile_op_col_prod", ops.tcol_prod, dtype_name="float32" + ), + "expected_tokens": ["pto.vmul"], + }, "col_expand": { "builder": lambda: _col_expand_kernel( "tile_op_col_expand", ops.tcol_expand, dtype_name="float32" diff --git a/ptodsl/lib/a5/tindex.py b/ptodsl/lib/a5/tindex.py new file mode 100644 index 00000000..0fcf0032 --- /dev/null +++ b/ptodsl/lib/a5/tindex.py @@ -0,0 +1,356 @@ +"""Implement indexed tile ops with PTO vector micro instructions. + +This file demonstrates how to rewrite gather/scatter-style tile helpers +directly in terms of PTO micro instructions such as `pto.vgather2`, +`pto.vgatherb`, and `pto.vscatter`. +""" + +from mlir.dialects import arith, pto +from mlir.ir import IndexType + +from ._common import ( + alloc_tile_buffer, + check_gather_operands, + check_gatherb_operands, + check_scatter_operands, + const_expr, + const_i64, + const_scalar, + dtype_byte_width, + full_mask, + mask_for_chunk, + matrix_active_lanes, + micro_lane_count, + ptr, + raw, + range_constexpr, + resolve_tile_spec, + s, + store_view, + load_view, + uint16_type, + uint32_type, + vreg_type, +) + + +def _active_lanes_value(active_lanes): + if isinstance(active_lanes, int): + return raw(s.const(active_lanes)) + return arith.IndexCastOp(IndexType.get(), raw(active_lanes)).result + + +def _zero_tile_buffer(out_ptr, *, dtype, rows, cols): + lanes = micro_lane_count(dtype) + vector_type = vreg_type(lanes, dtype) + zero_vec = pto.vbr(vector_type, const_scalar(dtype, 0)) + for row in range_constexpr(rows): + for col in range_constexpr(0, cols, lanes): + count = min(cols - col, lanes) + offset = s.const(row * cols + col) + mask = mask_for_chunk(dtype, count) + pto.vsts(zero_vec, out_ptr, raw(offset), mask) + + +def tgather( + src_view, + indices_view, + out_view, + *, + dtype, + index_dtype=None, + tile_shape=None, + shape=None, + valid_row=None, + valid_col=None, + valid_shape=None, + base_addr=0, +): + index_dtype = uint32_type() if index_dtype is None else index_dtype + rows, cols, valid_row, valid_col, type_valid_shape = resolve_tile_spec( + tile_shape=tile_shape, + shape=shape, + valid_row=valid_row, + valid_col=valid_col, + valid_shape=valid_shape, + context="TGATHER", + ) + rows, cols = check_gather_operands( + src_view, + indices_view, + out_view, + dtype=dtype, + index_dtype=index_dtype, + shape=[rows, cols], + ) + src_bytes = rows * cols * dtype_byte_width(dtype) + idx_bytes = rows * cols * dtype_byte_width(index_dtype) + + src_addr = const_i64(base_addr) + idx_addr = const_i64(base_addr + src_bytes) + out_addr = const_i64(base_addr + src_bytes + idx_bytes) + + src_tile = alloc_tile_buffer( + dtype, + [rows, cols], + space="VEC", + addr=src_addr, + valid_shape=type_valid_shape, + valid_row=valid_row, + valid_col=valid_col, + ) + idx_tile = alloc_tile_buffer( + index_dtype, + [rows, cols], + space="VEC", + addr=idx_addr, + valid_shape=type_valid_shape, + valid_row=valid_row, + valid_col=valid_col, + ) + out_tile = alloc_tile_buffer( + dtype, + [rows, cols], + space="VEC", + addr=out_addr, + valid_shape=type_valid_shape, + valid_row=valid_row, + valid_col=valid_col, + ) + load_view(src_view, src_tile) + load_view(indices_view, idx_tile) + + src_ptr = pto.castptr(ptr(dtype, space="VEC"), src_addr) + idx_ptr = pto.castptr(ptr(index_dtype, space="VEC"), idx_addr) + out_ptr = pto.castptr(ptr(dtype, space="VEC"), out_addr) + lanes = micro_lane_count(dtype) + vector_type = vreg_type(lanes, dtype) + index_vector_type = vreg_type(micro_lane_count(index_dtype), index_dtype) + + for row in range_constexpr(rows): + row_base = row * cols + for col in range_constexpr(0, cols, lanes): + active = matrix_active_lanes(valid_row, valid_col, row, col, lanes) + offset = s.const(row_base + col) + mask = mask_for_chunk(dtype, active) + idx_vec = pto.vlds(index_vector_type, idx_ptr, raw(offset)) + out_vec = pto.vgather2( + vector_type, + src_ptr, + idx_vec, + _active_lanes_value(active), + ) + pto.vsts(out_vec, out_ptr, raw(offset), mask) + + store_view(out_tile, out_view) + return out_view + + +def tgatherb( + src_view, + indices_view, + out_view, + *, + dtype, + index_dtype=None, + tile_shape=None, + shape=None, + valid_row=None, + valid_col=None, + valid_shape=None, + base_addr=0, +): + index_dtype = uint32_type() if index_dtype is None else index_dtype + rows, cols, valid_row, valid_col, type_valid_shape = resolve_tile_spec( + tile_shape=tile_shape, + shape=shape, + valid_row=valid_row, + valid_col=valid_col, + valid_shape=valid_shape, + context="TGATHERB", + ) + rows, cols = check_gatherb_operands( + src_view, + indices_view, + out_view, + dtype=dtype, + index_dtype=index_dtype, + shape=[rows, cols], + ) + src_bytes = rows * cols * dtype_byte_width(dtype) + idx_bytes = rows * cols * dtype_byte_width(index_dtype) + + src_addr = const_i64(base_addr) + idx_addr = const_i64(base_addr + src_bytes) + out_addr = const_i64(base_addr + src_bytes + idx_bytes) + + src_tile = alloc_tile_buffer( + dtype, + [rows, cols], + space="VEC", + addr=src_addr, + valid_shape=type_valid_shape, + valid_row=valid_row, + valid_col=valid_col, + ) + idx_tile = alloc_tile_buffer( + index_dtype, + [rows, cols], + space="VEC", + addr=idx_addr, + valid_shape=type_valid_shape, + valid_row=valid_row, + valid_col=valid_col, + ) + out_tile = alloc_tile_buffer( + dtype, + [rows, cols], + space="VEC", + addr=out_addr, + valid_shape=type_valid_shape, + valid_row=valid_row, + valid_col=valid_col, + ) + load_view(src_view, src_tile) + load_view(indices_view, idx_tile) + + src_ptr = pto.castptr(ptr(dtype, space="VEC"), src_addr) + idx_ptr = pto.castptr(ptr(index_dtype, space="VEC"), idx_addr) + out_ptr = pto.castptr(ptr(dtype, space="VEC"), out_addr) + lanes = micro_lane_count(dtype) + vector_type = vreg_type(lanes, dtype) + offset_vector_type = vreg_type(micro_lane_count(index_dtype), index_dtype) + static_repeat_times = (cols + lanes - 1) // lanes + + if const_expr(static_repeat_times > rows): + for row in range_constexpr(rows): + row_base = row * cols + for col in range_constexpr(0, cols, lanes): + active = matrix_active_lanes(valid_row, valid_col, row, col, lanes) + offset = s.const(row_base + col) + mask = mask_for_chunk(dtype, active) + idx_vec = pto.vlds(offset_vector_type, idx_ptr, raw(offset)) + out_vec = pto.vgatherb( + vector_type, + src_ptr, + idx_vec, + _active_lanes_value(active), + ) + pto.vsts(out_vec, out_ptr, raw(offset), mask) + else: + for col in range_constexpr(0, cols, lanes): + for row in range_constexpr(rows): + active = matrix_active_lanes(valid_row, valid_col, row, col, lanes) + offset = s.const(row * cols + col) + mask = mask_for_chunk(dtype, active) + idx_vec = pto.vlds(offset_vector_type, idx_ptr, raw(offset)) + out_vec = pto.vgatherb( + vector_type, + src_ptr, + idx_vec, + _active_lanes_value(active), + ) + pto.vsts(out_vec, out_ptr, raw(offset), mask) + + store_view(out_tile, out_view) + return out_view + + +def tscatter( + src_view, + indices_view, + out_view, + *, + dtype, + index_dtype=None, + tile_shape=None, + shape=None, + valid_row=None, + valid_col=None, + valid_shape=None, + base_addr=0, +): + if index_dtype is None: + index_dtype = uint32_type() if dtype_byte_width(dtype) == 4 else uint16_type() + rows, cols, valid_row, valid_col, type_valid_shape = resolve_tile_spec( + tile_shape=tile_shape, + shape=shape, + valid_row=valid_row, + valid_col=valid_col, + valid_shape=valid_shape, + context="TSCATTER", + ) + rows, cols, _, _ = check_scatter_operands( + src_view, + indices_view, + out_view, + dtype=dtype, + index_dtype=index_dtype, + shape=[rows, cols], + ) + src_bytes = rows * cols * dtype_byte_width(dtype) + idx_bytes = rows * cols * dtype_byte_width(index_dtype) + + src_addr = const_i64(base_addr) + idx_addr = const_i64(base_addr + src_bytes) + out_addr = const_i64(base_addr + src_bytes + idx_bytes) + + src_tile = alloc_tile_buffer( + dtype, + [rows, cols], + space="VEC", + addr=src_addr, + valid_shape=type_valid_shape, + valid_row=valid_row, + valid_col=valid_col, + ) + idx_tile = alloc_tile_buffer( + index_dtype, + [rows, cols], + space="VEC", + addr=idx_addr, + valid_shape=type_valid_shape, + valid_row=valid_row, + valid_col=valid_col, + ) + out_tile = alloc_tile_buffer( + dtype, + [rows, cols], + space="VEC", + addr=out_addr, + valid_shape=type_valid_shape, + valid_row=valid_row, + valid_col=valid_col, + ) + load_view(src_view, src_tile) + load_view(indices_view, idx_tile) + + src_ptr = pto.castptr(ptr(dtype, space="VEC"), src_addr) + idx_ptr = pto.castptr(ptr(index_dtype, space="VEC"), idx_addr) + out_ptr = pto.castptr(ptr(dtype, space="VEC"), out_addr) + _zero_tile_buffer(out_ptr, dtype=dtype, rows=rows, cols=cols) + + batch = micro_lane_count(index_dtype) + value_vector_type = vreg_type(batch, dtype) + index_vector_type = vreg_type(batch, index_dtype) + load_dist = "UNPK_B8" if dtype_byte_width(dtype) == 1 else None + + for row in range_constexpr(rows): + row_base = row * cols + for col in range_constexpr(0, cols, batch): + active = matrix_active_lanes(valid_row, valid_col, row, col, batch) + offset = s.const(row_base + col) + idx_vec = pto.vlds(index_vector_type, idx_ptr, raw(offset)) + src_vec = pto.vlds( + value_vector_type, + src_ptr, + raw(offset), + dist=load_dist, + ) + pto.vscatter(src_vec, out_ptr, idx_vec, _active_lanes_value(active)) + + store_view(out_tile, out_view) + return out_view + + +__all__ = ["tgather", "tgatherb", "tscatter"] diff --git a/ptodsl/lib/a5/treduce.py b/ptodsl/lib/a5/treduce.py index b800a71c..607106b2 100644 --- a/ptodsl/lib/a5/treduce.py +++ b/ptodsl/lib/a5/treduce.py @@ -18,8 +18,8 @@ check_col_reduce_operands, check_row_reduce_operands, const_expr, - const_float, const_i64, + const_scalar, dtype_byte_width, full_mask, mask_for_chunk, @@ -125,6 +125,95 @@ def trow_min( ) +def trow_prod( + src_view, + out_view, + *, + dtype, + tile_shape=None, + shape=None, + valid_row=None, + valid_col=None, + valid_shape=None, + base_addr=0, +): + rows, cols, valid_row, valid_col, type_valid_shape = resolve_tile_spec( + tile_shape=tile_shape, + shape=shape, + valid_row=valid_row, + valid_col=valid_col, + valid_shape=valid_shape, + context="TROWPROD", + ) + rows, cols = check_row_reduce_operands( + src_view, + out_view, + dtype=dtype, + shape=[rows, cols], + context="TROWPROD", + allowed={"f32", "f16", "i32", "i16"}, + ) + lanes = micro_lane_count(dtype) + width = dtype_byte_width(dtype) + vector_type = vreg_type(lanes, dtype) + loop_count = 0 + active_lanes = lanes + while active_lanes > 1: + active_lanes //= 2 + loop_count += 1 + + buf_bytes = rows * cols * width + src_addr = const_i64(base_addr) + out_addr = const_i64(base_addr + buf_bytes) + + src_tile = alloc_tile_buffer( + dtype, + [rows, cols], + space="VEC", + addr=src_addr, + valid_shape=type_valid_shape, + valid_row=valid_row, + valid_col=valid_col, + ) + out_tile = alloc_tile_buffer( + dtype, + [rows, cols], + space="VEC", + valid_shape=[type_valid_shape[0], 1], + addr=out_addr, + valid_row=valid_row, + valid_col=1, + ) + load_view(src_view, src_tile) + + src_ptr = pto.castptr(ptr(dtype, space="VEC"), src_addr) + out_ptr = pto.castptr(ptr(dtype, space="VEC"), out_addr) + vector_mask = full_mask(dtype) + one_scalar = const_scalar(dtype, 1) + one_vec = pto.vbr(vector_type, one_scalar) + + for row in range_constexpr(rows): + accum = one_vec + for col in range_constexpr(0, cols, lanes): + active = matrix_active_lanes(valid_row, valid_col, row, col, lanes) + mask = mask_for_chunk(dtype, active) + offset = s.const(row * cols + col) + vec = pto.vlds(vector_type, src_ptr, raw(offset)) + masked_vec = pto.vsel(vector_type, vec, one_vec, mask) + accum = pto.vmul(vector_type, accum, masked_vec, vector_mask) + + for _ in range_constexpr(loop_count): + low, high = pto.vintlv(vector_type, vector_type, accum, one_vec) + accum = pto.vmul(vector_type, low, high, vector_mask) + + out_offset = s.const(row * cols) + store_mask = mask_for_chunk(dtype, matrix_active_lanes(valid_row, 1, row, 0, 1)) + pto.vsts(accum, out_ptr, raw(out_offset), store_mask, dist=onept_dist(dtype)) + + store_view(out_tile, out_view) + return out_view + + def tcol_sum( src_view, out_view, @@ -155,6 +244,37 @@ def tcol_sum( ) +def tcol_prod( + src_view, + out_view, + *, + dtype, + tile_shape=None, + shape=None, + valid_row=None, + valid_col=None, + valid_shape=None, + base_addr=0, + impl=VF_IMPL_DEFAULT, +): + return _tcol_reduce( + src_view, + out_view, + dtype=dtype, + tile_shape=tile_shape, + shape=shape, + valid_row=valid_row, + valid_col=valid_col, + valid_shape=valid_shape, + base_addr=base_addr, + context="TCOLPROD", + reduce_op=pto.vmul, + init_value=1, + impl=impl, + allowed_dtypes={"f32", "f16", "bf16", "i32", "u32", "i16", "u16"}, + ) + + def tcol_max( src_view, out_view, @@ -278,7 +398,7 @@ def _trow_reduce( src_ptr = pto.castptr(ptr(dtype, space="VEC"), src_addr) out_ptr = pto.castptr(ptr(dtype, space="VEC"), out_addr) vector_mask = full_mask(dtype) - init_scalar = const_float(dtype, init_value) + init_scalar = const_scalar(dtype, init_value) neutral_vec = pto.vbr(vector_type, init_scalar) for row in range_constexpr(rows): @@ -314,6 +434,7 @@ def _tcol_reduce( reduce_op, init_value, impl, + allowed_dtypes=None, ): validation_context = "TCOLREDUCE" rows, cols, valid_row, valid_col, type_valid_shape = resolve_tile_spec( @@ -325,7 +446,12 @@ def _tcol_reduce( context=validation_context, ) rows, cols = check_col_reduce_operands( - src_view, out_view, dtype=dtype, shape=[rows, cols], context=validation_context + src_view, + out_view, + dtype=dtype, + shape=[rows, cols], + context=validation_context, + allowed=allowed_dtypes, ) lanes = micro_lane_count(dtype) buf_bytes = rows * cols * dtype_byte_width(dtype) @@ -357,7 +483,7 @@ def _tcol_reduce( src_ptr = pto.castptr(ptr_type, src_addr) out_ptr = pto.castptr(ptr_type, out_addr) impl_kind = normalize_vf_impl_kind(impl) - init_scalar = const_float(dtype, init_value) + init_scalar = const_scalar(dtype, init_value) neutral_vec = pto.vbr(vector_type, init_scalar) vector_mask = full_mask(dtype) if const_expr(impl_kind == VF_IMPL_DEFAULT): @@ -478,8 +604,10 @@ def _tcol_reduce_post_update( "VF_IMPL_2D_POST_UPDATE", "tcol_max", "tcol_min", + "tcol_prod", "tcol_sum", "trow_max", "trow_min", + "trow_prod", "trow_sum", ] diff --git a/ptodsl/lib/a5/tscalar.py b/ptodsl/lib/a5/tscalar.py index b62122b5..d5dfa694 100644 --- a/ptodsl/lib/a5/tscalar.py +++ b/ptodsl/lib/a5/tscalar.py @@ -465,6 +465,62 @@ def taxpy( ) +def texpands( + scalar, + out_view, + *, + dtype, + tile_shape=None, + shape=None, + valid_row=None, + valid_col=None, + valid_shape=None, + lanes=None, + base_addr=0, +): + rows, cols, valid_row, valid_col, type_valid_shape = resolve_tile_spec( + tile_shape=tile_shape, + shape=shape, + valid_row=valid_row, + valid_col=valid_col, + valid_shape=valid_shape, + context="TEXPANDS", + ) + rows, cols = check_tscalar_operands( + out_view, + out_view, + dtype=dtype, + shape=[rows, cols], + context="TEXPANDS", + ) + lanes = resolve_lanes(dtype, lanes) + element_count = rows * cols + out_addr = const_i64(base_addr) + out_tile = alloc_tile_buffer( + dtype, + [rows, cols], + space="VEC", + addr=out_addr, + valid_shape=type_valid_shape, + valid_row=valid_row, + valid_col=valid_col, + ) + out_ptr = pto.castptr(ptr(dtype, space="VEC"), out_addr) + vector_type = vreg_type(lanes, dtype) + scalar_value = raw(scalar) + if not hasattr(scalar_value, "type"): + scalar_value = const_scalar(dtype, scalar) + fill_vec = pto.vbr(vector_type, scalar_value) + + for offset in range_constexpr(0, element_count, lanes): + active = flat_active_lanes(valid_row, valid_col, offset, lanes) + mask = mask_for_chunk(dtype, active) + pto.vsts(fill_vec, out_ptr, raw(s.const(offset)), mask) + + store_view(out_tile, out_view) + return out_view + + def _scalar_tile_vop( src_view, scalar, @@ -783,6 +839,7 @@ def _scalar_2d_no_post_update( "VF_IMPL_2D_NO_POST_UPDATE", "VF_IMPL_2D_POST_UPDATE", "taxpy", + "texpands", "tadds", "tands", "tdivs", diff --git a/ptodsl/lib/a5/tselect.py b/ptodsl/lib/a5/tselect.py new file mode 100644 index 00000000..1d469eaf --- /dev/null +++ b/ptodsl/lib/a5/tselect.py @@ -0,0 +1,262 @@ +"""Implement tile select ops with PTO predicate and vector micro instructions.""" + +from mlir.dialects import pto +from ptodsl import language as dsl + +from ._common import ( + alloc_tile_buffer, + check_tsel_operands, + check_tsels_operands, + const_i64, + const_scalar, + dtype_byte_width, + full_mask, + mask_for_chunk, + mask_type, + ptr, + raw, + range_constexpr, + resolve_tile_spec, + extract_tensor_dtype_token, + s, + store_view, + load_view, + vreg_type, +) + + +def tsel( + mask_view, + lhs_view, + rhs_view, + out_view, + *, + dtype, + tile_shape=None, + shape=None, + valid_row=None, + valid_col=None, + valid_shape=None, + base_addr=0, +): + rows, cols, valid_row, valid_col, type_valid_shape = resolve_tile_spec( + tile_shape=tile_shape, + shape=shape, + valid_row=valid_row, + valid_col=valid_col, + valid_shape=valid_shape, + context="TSEL", + ) + rows, cols = check_tsel_operands( + mask_view, + lhs_view, + rhs_view, + out_view, + dtype=dtype, + shape=[rows, cols], + ) + if ( + not isinstance(valid_row, int) + or not isinstance(valid_col, int) + or type_valid_shape != [valid_row, valid_col] + ): + raise ValueError("TSEL lowering currently requires static valid shape.") + + lanes = 64 + elem_bytes = dtype_byte_width(dtype) + data_bytes = rows * cols * elem_bytes + mask_bytes = rows * cols + + mask_addr = const_i64(base_addr) + lhs_addr = const_i64(base_addr + mask_bytes) + rhs_addr = const_i64(base_addr + mask_bytes + data_bytes) + out_addr = const_i64(base_addr + mask_bytes + data_bytes * 2) + mask_token = extract_tensor_dtype_token(mask_view) + mask_dtype = dsl.uint8 if mask_token == "u8" else dsl.int8 + + mask_tile = alloc_tile_buffer( + mask_dtype, + [rows, cols], + space="VEC", + addr=mask_addr, + valid_shape=[valid_row, valid_col], + ) + lhs_tile = alloc_tile_buffer( + dtype, + [rows, cols], + space="VEC", + addr=lhs_addr, + valid_shape=[valid_row, valid_col], + ) + rhs_tile = alloc_tile_buffer( + dtype, + [rows, cols], + space="VEC", + addr=rhs_addr, + valid_shape=[valid_row, valid_col], + ) + out_tile = alloc_tile_buffer( + dtype, + [rows, cols], + space="VEC", + addr=out_addr, + valid_shape=[valid_row, valid_col], + ) + load_view(mask_view, mask_tile) + load_view(lhs_view, lhs_tile) + load_view(rhs_view, rhs_tile) + + mask_ptr = pto.castptr(ptr(mask_dtype, space="VEC"), mask_addr) + lhs_ptr = pto.castptr(ptr(dtype, space="VEC"), lhs_addr) + rhs_ptr = pto.castptr(ptr(dtype, space="VEC"), rhs_addr) + out_ptr = pto.castptr(ptr(dtype, space="VEC"), out_addr) + vector_type = vreg_type(lanes, dtype) + full_mask16 = pto.pset_b16(mask_type(), "PAT_ALL") + + repeat_times = (valid_col + lanes - 1) // lanes + paired_repeat_times = repeat_times // 2 + remain_repeat = repeat_times % 2 + repeat_idx_base = paired_repeat_times * 2 + + for row in range_constexpr(valid_row): + row_base = row * cols + mask_base = row * cols + for j in range_constexpr(paired_repeat_times): + repeat_idx = j * 2 + col_offset0 = repeat_idx * lanes + col_offset1 = col_offset0 + lanes + mask_offset = s.const(mask_base + repeat_idx * 8) + count0 = min(lanes, valid_col - col_offset0) + count1 = min(lanes, valid_col - col_offset1) + + raw_mask = pto.plds(mask_type(), mask_ptr, raw(mask_offset), dist="US") + low_mask, high_mask = pto.pintlv_b16( + mask_type(), mask_type(), raw_mask, full_mask16 + ) + + data_offset0 = s.const(row_base + col_offset0) + lhs0 = pto.vlds(vector_type, lhs_ptr, raw(data_offset0)) + rhs0 = pto.vlds(vector_type, rhs_ptr, raw(data_offset0)) + out0 = pto.vsel(vector_type, lhs0, rhs0, low_mask) + pto.vsts(out0, out_ptr, raw(data_offset0), mask_for_chunk(dtype, count0)) + + data_offset1 = s.const(row_base + col_offset1) + lhs1 = pto.vlds(vector_type, lhs_ptr, raw(data_offset1)) + rhs1 = pto.vlds(vector_type, rhs_ptr, raw(data_offset1)) + out1 = pto.vsel(vector_type, lhs1, rhs1, high_mask) + pto.vsts(out1, out_ptr, raw(data_offset1), mask_for_chunk(dtype, count1)) + + for j in range_constexpr(remain_repeat): + repeat_idx = repeat_idx_base + j + col_offset = repeat_idx * lanes + count = max(0, valid_col - col_offset) + mask_offset = s.const(mask_base + repeat_idx * 8) + raw_mask = pto.plds(mask_type(), mask_ptr, raw(mask_offset), dist="US") + unpacked_mask = pto.punpack(mask_type(), raw_mask, "LOWER") + data_offset = s.const(row_base + col_offset) + lhs = pto.vlds(vector_type, lhs_ptr, raw(data_offset)) + rhs = pto.vlds(vector_type, rhs_ptr, raw(data_offset)) + out = pto.vsel(vector_type, lhs, rhs, unpacked_mask) + pto.vsts(out, out_ptr, raw(data_offset), mask_for_chunk(dtype, count)) + + store_view(out_tile, out_view) + return out_view + + +def tsels( + mask_view, + src_view, + scalar, + out_view, + *, + dtype, + tile_shape=None, + shape=None, + valid_row=None, + valid_col=None, + valid_shape=None, + base_addr=0, +): + rows, cols, valid_row, valid_col, type_valid_shape = resolve_tile_spec( + tile_shape=tile_shape, + shape=shape, + valid_row=valid_row, + valid_col=valid_col, + valid_shape=valid_shape, + context="TSELS", + ) + rows, cols = check_tsels_operands( + mask_view, + src_view, + out_view, + dtype=dtype, + shape=[rows, cols], + ) + if ( + not isinstance(valid_row, int) + or not isinstance(valid_col, int) + or type_valid_shape != [valid_row, valid_col] + ): + raise ValueError("TSELS lowering currently requires static valid shape.") + + lanes = 256 // dtype_byte_width(dtype) + total_elements = valid_row * valid_col + if total_elements % lanes != 0: + raise ValueError( + "TSELS lowering currently requires total valid elements divisible by vector width." + ) + + elem_bytes = dtype_byte_width(dtype) + buf_bytes = rows * cols * elem_bytes + mask_addr = const_i64(base_addr) + src_addr = const_i64(base_addr + buf_bytes) + out_addr = const_i64(base_addr + buf_bytes * 2) + + mask_tile = alloc_tile_buffer( + dtype, + [rows, cols], + space="VEC", + addr=mask_addr, + valid_shape=[valid_row, valid_col], + ) + src_tile = alloc_tile_buffer( + dtype, + [rows, cols], + space="VEC", + addr=src_addr, + valid_shape=[valid_row, valid_col], + ) + out_tile = alloc_tile_buffer( + dtype, + [rows, cols], + space="VEC", + addr=out_addr, + valid_shape=[valid_row, valid_col], + ) + load_view(mask_view, mask_tile) + load_view(src_view, src_tile) + + mask_ptr = pto.castptr(ptr(dtype, space="VEC"), mask_addr) + src_ptr = pto.castptr(ptr(dtype, space="VEC"), src_addr) + out_ptr = pto.castptr(ptr(dtype, space="VEC"), out_addr) + vector_type = vreg_type(lanes, dtype) + scalar_value = raw(scalar) + if not hasattr(scalar_value, "type"): + scalar_value = const_scalar(dtype, scalar) + scalar_vec = pto.vdup(vector_type, scalar_value, position="POS_LOWEST") + all_pred = full_mask(dtype) + zero = const_scalar(dtype, 0) + + for offset in range_constexpr(0, total_elements, lanes): + index = s.const(offset) + mask_vec = pto.vlds(vector_type, mask_ptr, raw(index)) + src_vec = pto.vlds(vector_type, src_ptr, raw(index)) + select_mask = pto.vcmps(mask_type(), mask_vec, zero, all_pred, "ne") + out_vec = pto.vsel(vector_type, src_vec, scalar_vec, select_mask) + pto.vsts(out_vec, out_ptr, raw(index), all_pred) + + store_view(out_tile, out_view) + return out_view + + +__all__ = ["tsel", "tsels"] diff --git a/ptodsl/lib/a5/tsort.py b/ptodsl/lib/a5/tsort.py index bc972bd8..0415ae44 100644 --- a/ptodsl/lib/a5/tsort.py +++ b/ptodsl/lib/a5/tsort.py @@ -1,18 +1,13 @@ -"""Implement gather/sort tile ops with PTO vector micro instructions.""" +"""Implement sort-style tile ops with PTO vector micro instructions.""" -from mlir.dialects import arith, pto -from mlir.ir import IndexType +from mlir.dialects import pto from ._common import ( alloc_tile_buffer, - check_gather_operands, check_mrgsort_operands, check_sort32_operands, const_i64, dtype_byte_width, - mask_for_chunk, - matrix_active_lanes, - micro_lane_count, ptr, raw, range_constexpr, @@ -21,101 +16,8 @@ store_view, load_view, uint32_type, - vreg_type, ) - - -def tgather( - src_view, - indices_view, - out_view, - *, - dtype, - index_dtype=None, - tile_shape=None, - shape=None, - valid_row=None, - valid_col=None, - valid_shape=None, - base_addr=0, -): - index_dtype = uint32_type() if index_dtype is None else index_dtype - rows, cols, valid_row, valid_col, type_valid_shape = resolve_tile_spec( - tile_shape=tile_shape, - shape=shape, - valid_row=valid_row, - valid_col=valid_col, - valid_shape=valid_shape, - context="TGATHER", - ) - rows, cols = check_gather_operands( - src_view, - indices_view, - out_view, - dtype=dtype, - index_dtype=index_dtype, - shape=[rows, cols], - ) - src_bytes = rows * cols * dtype_byte_width(dtype) - idx_bytes = rows * cols * dtype_byte_width(index_dtype) - - src_addr = const_i64(base_addr) - idx_addr = const_i64(base_addr + src_bytes) - out_addr = const_i64(base_addr + src_bytes + idx_bytes) - - src_tile = alloc_tile_buffer( - dtype, - [rows, cols], - space="VEC", - addr=src_addr, - valid_shape=type_valid_shape, - valid_row=valid_row, - valid_col=valid_col, - ) - idx_tile = alloc_tile_buffer( - index_dtype, - [rows, cols], - space="VEC", - addr=idx_addr, - valid_shape=type_valid_shape, - valid_row=valid_row, - valid_col=valid_col, - ) - out_tile = alloc_tile_buffer( - dtype, - [rows, cols], - space="VEC", - addr=out_addr, - valid_shape=type_valid_shape, - valid_row=valid_row, - valid_col=valid_col, - ) - load_view(src_view, src_tile) - load_view(indices_view, idx_tile) - - src_ptr = pto.castptr(ptr(dtype, space="VEC"), src_addr) - idx_ptr = pto.castptr(ptr(index_dtype, space="VEC"), idx_addr) - out_ptr = pto.castptr(ptr(dtype, space="VEC"), out_addr) - lanes = micro_lane_count(dtype) - vector_type = vreg_type(lanes, dtype) - index_vector_type = vreg_type(micro_lane_count(index_dtype), index_dtype) - - for row in range_constexpr(rows): - row_base = row * cols - for col in range_constexpr(0, cols, lanes): - active = matrix_active_lanes(valid_row, valid_col, row, col, lanes) - offset = s.const(row_base + col) - mask = mask_for_chunk(dtype, active) - idx_vec = pto.vlds(index_vector_type, idx_ptr, raw(offset)) - if isinstance(active, int): - active_lanes = raw(s.const(active)) - else: - active_lanes = arith.IndexCastOp(IndexType.get(), raw(active)).result - out_vec = pto.vgather2(vector_type, src_ptr, idx_vec, active_lanes) - pto.vsts(out_vec, out_ptr, raw(offset), mask) - - store_view(out_tile, out_view) - return out_view +from .tindex import tgather def tmrgsort( diff --git a/tests/regression/test_a5_lib_regression.py b/tests/regression/test_a5_lib_regression.py index 3a65a79e..ca59d6aa 100644 --- a/tests/regression/test_a5_lib_regression.py +++ b/tests/regression/test_a5_lib_regression.py @@ -54,11 +54,17 @@ def _slice_tensor(source, *, offsets, sizes, dtype): def test_a5_split_modules_are_publicly_exposed(): assert a5.tbinary.tadd is a5.tadd + assert a5.tbinary.tprelu is a5.tprelu assert a5.tscalar.tadds is a5.tadds + assert a5.tscalar.texpands is a5.texpands assert a5.tunary.trsqrt is a5.trsqrt assert a5.texpand.trow_expand is a5.trow_expand assert a5.treduce.trow_sum is a5.trow_sum - assert a5.tsort.tgather is a5.tgather + assert a5.tindex.tgather is a5.tgather + assert a5.tindex.tgatherb is a5.tgatherb + assert a5.tindex.tscatter is a5.tscatter + assert a5.tselect.tsel is a5.tsel + assert a5.tselect.tsels is a5.tsels def test_public_pto_ptr_supports_explicit_memory_spaces(): @@ -305,6 +311,79 @@ def a5_tgather_dynamic_valid( assert "arith.index_cast" in text +def test_a5_tgatherb_emits_byte_gather_micro_opcodes(): + def uint32(): + return IntegerType.get_unsigned(32) + + def meta_data(): + return { + "ptr_src": pto.PtrType(pto.float32), + "ptr_idx": pto.PtrType(uint32()), + } + + @to_ir_module(meta_data=meta_data) + def a5_tgatherb(src: "ptr_src", idx: "ptr_idx", dst: "ptr_src") -> None: + src_view = _make_tensor(src, shape=[8, 64], dtype=pto.float32) + idx_view = _make_tensor(idx, shape=[8, 64], dtype=uint32()) + dst_view = _make_tensor(dst, shape=[8, 64], dtype=pto.float32) + with pto.vector_section(): + a5.tgatherb( + _slice_tensor( + src_view, offsets=[0, 0], sizes=[8, 64], dtype=pto.float32 + ), + _slice_tensor(idx_view, offsets=[0, 0], sizes=[8, 64], dtype=uint32()), + _slice_tensor( + dst_view, offsets=[0, 0], sizes=[8, 64], dtype=pto.float32 + ), + dtype=pto.float32, + index_dtype=uint32(), + shape=[8, 64], + ) + + text = str(a5_tgatherb) + + assert "func.func @a5_tgatherb" in text + assert "pto.vgatherb" in text + assert "pto.tgatherb" not in text + + +def test_a5_tscatter_emits_zero_fill_then_vscatter(): + def uint32(): + return IntegerType.get_unsigned(32) + + def meta_data(): + return { + "ptr_src": pto.PtrType(pto.float32), + "ptr_idx": pto.PtrType(uint32()), + } + + @to_ir_module(meta_data=meta_data) + def a5_tscatter(src: "ptr_src", idx: "ptr_idx", dst: "ptr_src") -> None: + src_view = _make_tensor(src, shape=[8, 64], dtype=pto.float32) + idx_view = _make_tensor(idx, shape=[8, 64], dtype=uint32()) + dst_view = _make_tensor(dst, shape=[8, 64], dtype=pto.float32) + with pto.vector_section(): + a5.tscatter( + _slice_tensor( + src_view, offsets=[0, 0], sizes=[8, 64], dtype=pto.float32 + ), + _slice_tensor(idx_view, offsets=[0, 0], sizes=[8, 64], dtype=uint32()), + _slice_tensor( + dst_view, offsets=[0, 0], sizes=[8, 64], dtype=pto.float32 + ), + dtype=pto.float32, + index_dtype=uint32(), + shape=[8, 64], + ) + + text = str(a5_tscatter) + + assert "func.func @a5_tscatter" in text + assert "pto.vbr" in text + assert "pto.vscatter" in text + assert "pto.tscatter" not in text + + def test_a5_trow_expand_emits_broadcast_micro_ops(): def meta_data(): return { @@ -523,6 +602,121 @@ def a5_taxpy(src: "ptr_t", dst: "ptr_t") -> None: assert "pto.taxpy" not in text +def test_a5_texpands_emits_vbr_and_vsts(): + def meta_data(): + return {"ptr_t": pto.PtrType(pto.float32)} + + @to_ir_module(meta_data=meta_data) + def a5_texpands(dst: "ptr_t") -> None: + dst_view = _make_tensor(dst, shape=[8, 64], dtype=pto.float32) + scalar = arith.ConstantOp(pto.float32, 1.5).result + with pto.vector_section(): + a5.texpands( + scalar, + _slice_tensor( + dst_view, offsets=[0, 0], sizes=[8, 64], dtype=pto.float32 + ), + dtype=pto.float32, + shape=[8, 64], + ) + + text = str(a5_texpands) + + assert "func.func @a5_texpands" in text + assert "pto.vbr" in text + assert "pto.vsts" in text + assert "pto.texpands" not in text + + +def test_a5_tprelu_emits_vcmps_vmul_and_vsel(): + def meta_data(): + return {"ptr_t": pto.PtrType(pto.float32)} + + @to_ir_module(meta_data=meta_data) + def a5_tprelu(src0: "ptr_t", src1: "ptr_t", dst: "ptr_t") -> None: + lhs = _make_tensor(src0, shape=[8, 64], dtype=pto.float32) + rhs = _make_tensor(src1, shape=[8, 64], dtype=pto.float32) + out = _make_tensor(dst, shape=[8, 64], dtype=pto.float32) + with pto.vector_section(): + a5.tprelu( + _slice_tensor(lhs, offsets=[0, 0], sizes=[8, 64], dtype=pto.float32), + _slice_tensor(rhs, offsets=[0, 0], sizes=[8, 64], dtype=pto.float32), + _slice_tensor(out, offsets=[0, 0], sizes=[8, 64], dtype=pto.float32), + dtype=pto.float32, + shape=[8, 64], + ) + + text = str(a5_tprelu) + + assert "func.func @a5_tprelu" in text + assert "pto.vcmps" in text + assert "pto.vmul" in text + assert "pto.vsel" in text + assert "pto.tprelu" not in text + + +def test_a5_tsel_emits_plds_pintlv_and_vsel(): + def meta_data(): + return { + "ptr_mask": pto.PtrType(pto.int8), + "ptr_data": pto.PtrType(pto.float32), + } + + @to_ir_module(meta_data=meta_data) + def a5_tsel(mask: "ptr_mask", src0: "ptr_data", src1: "ptr_data", dst: "ptr_data") -> None: + mask_view = _make_tensor(mask, shape=[1, 128], dtype=pto.int8) + lhs = _make_tensor(src0, shape=[1, 128], dtype=pto.float32) + rhs = _make_tensor(src1, shape=[1, 128], dtype=pto.float32) + out = _make_tensor(dst, shape=[1, 128], dtype=pto.float32) + with pto.vector_section(): + a5.tsel( + _slice_tensor(mask_view, offsets=[0, 0], sizes=[1, 128], dtype=pto.int8), + _slice_tensor(lhs, offsets=[0, 0], sizes=[1, 128], dtype=pto.float32), + _slice_tensor(rhs, offsets=[0, 0], sizes=[1, 128], dtype=pto.float32), + _slice_tensor(out, offsets=[0, 0], sizes=[1, 128], dtype=pto.float32), + dtype=pto.float32, + shape=[1, 128], + ) + + text = str(a5_tsel) + + assert "func.func @a5_tsel" in text + assert "pto.pset_b16" in text + assert "pto.plds" in text + assert "pto.pintlv_b16" in text + assert "pto.vsel" in text + assert "pto.tsel" not in text + + +def test_a5_tsels_emits_vcmps_vdup_and_vsel(): + def meta_data(): + return {"ptr_t": pto.PtrType(pto.float32)} + + @to_ir_module(meta_data=meta_data) + def a5_tsels(mask: "ptr_t", src: "ptr_t", dst: "ptr_t") -> None: + mask_view = _make_tensor(mask, shape=[1, 64], dtype=pto.float32) + src_view = _make_tensor(src, shape=[1, 64], dtype=pto.float32) + dst_view = _make_tensor(dst, shape=[1, 64], dtype=pto.float32) + scalar = arith.ConstantOp(pto.float32, 3.0).result + with pto.vector_section(): + a5.tsels( + _slice_tensor(mask_view, offsets=[0, 0], sizes=[1, 64], dtype=pto.float32), + _slice_tensor(src_view, offsets=[0, 0], sizes=[1, 64], dtype=pto.float32), + scalar, + _slice_tensor(dst_view, offsets=[0, 0], sizes=[1, 64], dtype=pto.float32), + dtype=pto.float32, + shape=[1, 64], + ) + + text = str(a5_tsels) + + assert "func.func @a5_tsels" in text + assert "pto.vcmps" in text + assert "pto.vdup" in text + assert "pto.vsel" in text + assert "pto.tsels" not in text + + @pytest.mark.parametrize( ("helper_name", "micro_op", "tile_op"), [ @@ -583,15 +777,16 @@ def a5_expand_helper(src0: "ptr_t", src1: "ptr_t", dst: "ptr_t") -> None: @pytest.mark.parametrize( - ("helper_name", "reduce_op", "combine_op", "tile_op"), + ("helper_name", "reduce_op", "combine_op", "extra_token", "tile_op"), [ - ("trow_sum", "pto.vcadd", "pto.vadd", "pto.trowsum"), - ("trow_max", "pto.vcmax", "pto.vmax", "pto.trowmax"), - ("trow_min", "pto.vcmin", "pto.vmin", "pto.trowmin"), + ("trow_sum", "pto.vcadd", "pto.vadd", None, "pto.trowsum"), + ("trow_max", "pto.vcmax", "pto.vmax", None, "pto.trowmax"), + ("trow_min", "pto.vcmin", "pto.vmin", None, "pto.trowmin"), + ("trow_prod", "pto.vmul", "pto.vmul", "pto.vintlv", "pto.trowprod"), ], ) def test_a5_trow_reduce_emits_reduction_micro_ops( - helper_name, reduce_op, combine_op, tile_op + helper_name, reduce_op, combine_op, extra_token, tile_op ): def meta_data(): return { @@ -621,6 +816,8 @@ def a5_trow_reduce(src: "ptr_t", dst: "ptr_t") -> None: assert reduce_op in text assert combine_op in text + if extra_token is not None: + assert extra_token in text assert 'dist = "ONEPT_B32"' in text assert tile_op not in text @@ -631,6 +828,7 @@ def a5_trow_reduce(src: "ptr_t", dst: "ptr_t") -> None: ("tcol_sum", "pto.vadd", "pto.tcolsum", a5.VF_IMPL_1D_POST_UPDATE), ("tcol_max", "pto.vmax", "pto.tcolmax", a5.VF_IMPL_1D_NO_POST_UPDATE), ("tcol_min", "pto.vmin", "pto.tcolmin", a5.VF_IMPL_1D_POST_UPDATE), + ("tcol_prod", "pto.vmul", "pto.tcolprod", a5.VF_IMPL_1D_POST_UPDATE), ], ) def test_a5_tcol_reduce_emits_template_lowering(helper_name, reduce_op, tile_op, impl): From 439001b9ef26edfcf2489c52356734cb6b671c5f Mon Sep 17 00:00:00 2001 From: RuoyuZhou Date: Wed, 1 Apr 2026 10:45:05 +0800 Subject: [PATCH 2/2] Fix A5 micro coverage formatting --- ptodsl/lib/a5/__init__.py | 13 +++++++++++- ptodsl/lib/a5/_common.py | 3 ++- ptodsl/lib/a5/generated/tile_ops/abs.pto | 1 - ptodsl/lib/a5/generated/tile_ops/add.pto | 1 - .../lib/a5/generated/tile_ops/col_expand.pto | 1 - ptodsl/lib/a5/generated/tile_ops/col_max.pto | 1 - ptodsl/lib/a5/generated/tile_ops/col_min.pto | 1 - ptodsl/lib/a5/generated/tile_ops/col_prod.pto | 1 - ptodsl/lib/a5/generated/tile_ops/col_sum.pto | 1 - ptodsl/lib/a5/generated/tile_ops/div.pto | 1 - ptodsl/lib/a5/generated/tile_ops/exp.pto | 1 - ptodsl/lib/a5/generated/tile_ops/gather.pto | 1 - ptodsl/lib/a5/generated/tile_ops/log.pto | 1 - ptodsl/lib/a5/generated/tile_ops/mov.pto | 1 - ptodsl/lib/a5/generated/tile_ops/mrgsort.pto | 1 - ptodsl/lib/a5/generated/tile_ops/mul.pto | 1 - ptodsl/lib/a5/generated/tile_ops/or_.pto | 1 - .../lib/a5/generated/tile_ops/reciprocal.pto | 1 - ptodsl/lib/a5/generated/tile_ops/relu.pto | 1 - .../lib/a5/generated/tile_ops/row_expand.pto | 1 - .../a5/generated/tile_ops/row_expand_div.pto | 1 - .../a5/generated/tile_ops/row_expand_mul.pto | 1 - .../a5/generated/tile_ops/row_expand_sub.pto | 1 - ptodsl/lib/a5/generated/tile_ops/row_max.pto | 1 - ptodsl/lib/a5/generated/tile_ops/row_min.pto | 1 - ptodsl/lib/a5/generated/tile_ops/row_prod.pto | 1 - ptodsl/lib/a5/generated/tile_ops/row_sum.pto | 1 - ptodsl/lib/a5/generated/tile_ops/rsqrt.pto | 1 - ptodsl/lib/a5/generated/tile_ops/sort32.pto | 1 - ptodsl/lib/a5/generated/tile_ops/sqrt.pto | 1 - ptodsl/lib/a5/generated/tile_ops/sub.pto | 1 - tests/regression/test_a5_lib_regression.py | 20 ++++++++++++++----- 32 files changed, 29 insertions(+), 36 deletions(-) diff --git a/ptodsl/lib/a5/__init__.py b/ptodsl/lib/a5/__init__.py index 31baafd6..b66277d1 100644 --- a/ptodsl/lib/a5/__init__.py +++ b/ptodsl/lib/a5/__init__.py @@ -1,4 +1,15 @@ -from . import native, ops, tbinary, texpand, tindex, tselect, treduce, tscalar, tsort, tunary +from . import ( + native, + ops, + tbinary, + texpand, + tindex, + tselect, + treduce, + tscalar, + tsort, + tunary, +) from .a5_header_coverage import A5_HEADER_COVERAGE, a5_header_coverage_markdown from .kernels import ( HIVM_LLVM_KERNELS, diff --git a/ptodsl/lib/a5/_common.py b/ptodsl/lib/a5/_common.py index f78684fb..ac2b87e3 100644 --- a/ptodsl/lib/a5/_common.py +++ b/ptodsl/lib/a5/_common.py @@ -515,7 +515,8 @@ def check_tbinop_operands( rows, cols = require_static_matrix_shape(shape, context=context) require_supported_dtype( dtype, - allowed=allowed or {"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"}, + allowed=allowed + or {"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"}, message=f"Fix: {context} has invalid data type.", ) for view, label in ((lhs_view, "src0"), (rhs_view, "src1"), (out_view, "dst")): diff --git a/ptodsl/lib/a5/generated/tile_ops/abs.pto b/ptodsl/lib/a5/generated/tile_ops/abs.pto index aea13f8e..ef021091 100644 --- a/ptodsl/lib/a5/generated/tile_ops/abs.pto +++ b/ptodsl/lib/a5/generated/tile_ops/abs.pto @@ -73,4 +73,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/add.pto b/ptodsl/lib/a5/generated/tile_ops/add.pto index b824e9fa..5b4d9507 100644 --- a/ptodsl/lib/a5/generated/tile_ops/add.pto +++ b/ptodsl/lib/a5/generated/tile_ops/add.pto @@ -88,4 +88,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/col_expand.pto b/ptodsl/lib/a5/generated/tile_ops/col_expand.pto index 346efa3c..7b987b15 100644 --- a/ptodsl/lib/a5/generated/tile_ops/col_expand.pto +++ b/ptodsl/lib/a5/generated/tile_ops/col_expand.pto @@ -52,4 +52,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/col_max.pto b/ptodsl/lib/a5/generated/tile_ops/col_max.pto index 773749a7..55459b11 100644 --- a/ptodsl/lib/a5/generated/tile_ops/col_max.pto +++ b/ptodsl/lib/a5/generated/tile_ops/col_max.pto @@ -91,4 +91,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/col_min.pto b/ptodsl/lib/a5/generated/tile_ops/col_min.pto index a393f251..665e0a4f 100644 --- a/ptodsl/lib/a5/generated/tile_ops/col_min.pto +++ b/ptodsl/lib/a5/generated/tile_ops/col_min.pto @@ -91,4 +91,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/col_prod.pto b/ptodsl/lib/a5/generated/tile_ops/col_prod.pto index 752f6983..cc039554 100644 --- a/ptodsl/lib/a5/generated/tile_ops/col_prod.pto +++ b/ptodsl/lib/a5/generated/tile_ops/col_prod.pto @@ -91,4 +91,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/col_sum.pto b/ptodsl/lib/a5/generated/tile_ops/col_sum.pto index 4142875f..78b09c22 100644 --- a/ptodsl/lib/a5/generated/tile_ops/col_sum.pto +++ b/ptodsl/lib/a5/generated/tile_ops/col_sum.pto @@ -91,4 +91,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/div.pto b/ptodsl/lib/a5/generated/tile_ops/div.pto index ede78981..ac2fb1e4 100644 --- a/ptodsl/lib/a5/generated/tile_ops/div.pto +++ b/ptodsl/lib/a5/generated/tile_ops/div.pto @@ -88,4 +88,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/exp.pto b/ptodsl/lib/a5/generated/tile_ops/exp.pto index 58911188..0f47cc9a 100644 --- a/ptodsl/lib/a5/generated/tile_ops/exp.pto +++ b/ptodsl/lib/a5/generated/tile_ops/exp.pto @@ -73,4 +73,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/gather.pto b/ptodsl/lib/a5/generated/tile_ops/gather.pto index 2c226fe7..01045b3a 100644 --- a/ptodsl/lib/a5/generated/tile_ops/gather.pto +++ b/ptodsl/lib/a5/generated/tile_ops/gather.pto @@ -53,4 +53,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/log.pto b/ptodsl/lib/a5/generated/tile_ops/log.pto index 08d0143c..5b4f62bd 100644 --- a/ptodsl/lib/a5/generated/tile_ops/log.pto +++ b/ptodsl/lib/a5/generated/tile_ops/log.pto @@ -73,4 +73,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/mov.pto b/ptodsl/lib/a5/generated/tile_ops/mov.pto index 737af953..ade28826 100644 --- a/ptodsl/lib/a5/generated/tile_ops/mov.pto +++ b/ptodsl/lib/a5/generated/tile_ops/mov.pto @@ -65,4 +65,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/mrgsort.pto b/ptodsl/lib/a5/generated/tile_ops/mrgsort.pto index bdb0a9d9..0f0afe8e 100644 --- a/ptodsl/lib/a5/generated/tile_ops/mrgsort.pto +++ b/ptodsl/lib/a5/generated/tile_ops/mrgsort.pto @@ -42,4 +42,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/mul.pto b/ptodsl/lib/a5/generated/tile_ops/mul.pto index 926adc5c..86af874f 100644 --- a/ptodsl/lib/a5/generated/tile_ops/mul.pto +++ b/ptodsl/lib/a5/generated/tile_ops/mul.pto @@ -88,4 +88,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/or_.pto b/ptodsl/lib/a5/generated/tile_ops/or_.pto index ea25374b..66add0fe 100644 --- a/ptodsl/lib/a5/generated/tile_ops/or_.pto +++ b/ptodsl/lib/a5/generated/tile_ops/or_.pto @@ -88,4 +88,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/reciprocal.pto b/ptodsl/lib/a5/generated/tile_ops/reciprocal.pto index d5ae3f9e..3ed78ddc 100644 --- a/ptodsl/lib/a5/generated/tile_ops/reciprocal.pto +++ b/ptodsl/lib/a5/generated/tile_ops/reciprocal.pto @@ -73,4 +73,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/relu.pto b/ptodsl/lib/a5/generated/tile_ops/relu.pto index 703535ef..001aa127 100644 --- a/ptodsl/lib/a5/generated/tile_ops/relu.pto +++ b/ptodsl/lib/a5/generated/tile_ops/relu.pto @@ -73,4 +73,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/row_expand.pto b/ptodsl/lib/a5/generated/tile_ops/row_expand.pto index 6254d70a..e3fd1be3 100644 --- a/ptodsl/lib/a5/generated/tile_ops/row_expand.pto +++ b/ptodsl/lib/a5/generated/tile_ops/row_expand.pto @@ -89,4 +89,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/row_expand_div.pto b/ptodsl/lib/a5/generated/tile_ops/row_expand_div.pto index e1d37373..7f1574e2 100644 --- a/ptodsl/lib/a5/generated/tile_ops/row_expand_div.pto +++ b/ptodsl/lib/a5/generated/tile_ops/row_expand_div.pto @@ -119,4 +119,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/row_expand_mul.pto b/ptodsl/lib/a5/generated/tile_ops/row_expand_mul.pto index 34fc0119..8f9631d5 100644 --- a/ptodsl/lib/a5/generated/tile_ops/row_expand_mul.pto +++ b/ptodsl/lib/a5/generated/tile_ops/row_expand_mul.pto @@ -119,4 +119,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/row_expand_sub.pto b/ptodsl/lib/a5/generated/tile_ops/row_expand_sub.pto index 9922961c..910a0c80 100644 --- a/ptodsl/lib/a5/generated/tile_ops/row_expand_sub.pto +++ b/ptodsl/lib/a5/generated/tile_ops/row_expand_sub.pto @@ -119,4 +119,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/row_max.pto b/ptodsl/lib/a5/generated/tile_ops/row_max.pto index c01e53dc..8bc4bafd 100644 --- a/ptodsl/lib/a5/generated/tile_ops/row_max.pto +++ b/ptodsl/lib/a5/generated/tile_ops/row_max.pto @@ -116,4 +116,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/row_min.pto b/ptodsl/lib/a5/generated/tile_ops/row_min.pto index bde5c99e..394d2a95 100644 --- a/ptodsl/lib/a5/generated/tile_ops/row_min.pto +++ b/ptodsl/lib/a5/generated/tile_ops/row_min.pto @@ -116,4 +116,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/row_prod.pto b/ptodsl/lib/a5/generated/tile_ops/row_prod.pto index 87a02975..e7b37530 100644 --- a/ptodsl/lib/a5/generated/tile_ops/row_prod.pto +++ b/ptodsl/lib/a5/generated/tile_ops/row_prod.pto @@ -204,4 +204,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/row_sum.pto b/ptodsl/lib/a5/generated/tile_ops/row_sum.pto index c369a883..61c20c50 100644 --- a/ptodsl/lib/a5/generated/tile_ops/row_sum.pto +++ b/ptodsl/lib/a5/generated/tile_ops/row_sum.pto @@ -116,4 +116,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/rsqrt.pto b/ptodsl/lib/a5/generated/tile_ops/rsqrt.pto index 92da13ae..17752171 100644 --- a/ptodsl/lib/a5/generated/tile_ops/rsqrt.pto +++ b/ptodsl/lib/a5/generated/tile_ops/rsqrt.pto @@ -81,4 +81,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/sort32.pto b/ptodsl/lib/a5/generated/tile_ops/sort32.pto index 48e38209..c739860e 100644 --- a/ptodsl/lib/a5/generated/tile_ops/sort32.pto +++ b/ptodsl/lib/a5/generated/tile_ops/sort32.pto @@ -55,4 +55,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/sqrt.pto b/ptodsl/lib/a5/generated/tile_ops/sqrt.pto index ffc2b89b..4de655e0 100644 --- a/ptodsl/lib/a5/generated/tile_ops/sqrt.pto +++ b/ptodsl/lib/a5/generated/tile_ops/sqrt.pto @@ -73,4 +73,3 @@ module { return } } - diff --git a/ptodsl/lib/a5/generated/tile_ops/sub.pto b/ptodsl/lib/a5/generated/tile_ops/sub.pto index 1caefd1a..e81c72ae 100644 --- a/ptodsl/lib/a5/generated/tile_ops/sub.pto +++ b/ptodsl/lib/a5/generated/tile_ops/sub.pto @@ -88,4 +88,3 @@ module { return } } - diff --git a/tests/regression/test_a5_lib_regression.py b/tests/regression/test_a5_lib_regression.py index ca59d6aa..c53a3249 100644 --- a/tests/regression/test_a5_lib_regression.py +++ b/tests/regression/test_a5_lib_regression.py @@ -663,14 +663,18 @@ def meta_data(): } @to_ir_module(meta_data=meta_data) - def a5_tsel(mask: "ptr_mask", src0: "ptr_data", src1: "ptr_data", dst: "ptr_data") -> None: + def a5_tsel( + mask: "ptr_mask", src0: "ptr_data", src1: "ptr_data", dst: "ptr_data" + ) -> None: mask_view = _make_tensor(mask, shape=[1, 128], dtype=pto.int8) lhs = _make_tensor(src0, shape=[1, 128], dtype=pto.float32) rhs = _make_tensor(src1, shape=[1, 128], dtype=pto.float32) out = _make_tensor(dst, shape=[1, 128], dtype=pto.float32) with pto.vector_section(): a5.tsel( - _slice_tensor(mask_view, offsets=[0, 0], sizes=[1, 128], dtype=pto.int8), + _slice_tensor( + mask_view, offsets=[0, 0], sizes=[1, 128], dtype=pto.int8 + ), _slice_tensor(lhs, offsets=[0, 0], sizes=[1, 128], dtype=pto.float32), _slice_tensor(rhs, offsets=[0, 0], sizes=[1, 128], dtype=pto.float32), _slice_tensor(out, offsets=[0, 0], sizes=[1, 128], dtype=pto.float32), @@ -700,10 +704,16 @@ def a5_tsels(mask: "ptr_t", src: "ptr_t", dst: "ptr_t") -> None: scalar = arith.ConstantOp(pto.float32, 3.0).result with pto.vector_section(): a5.tsels( - _slice_tensor(mask_view, offsets=[0, 0], sizes=[1, 64], dtype=pto.float32), - _slice_tensor(src_view, offsets=[0, 0], sizes=[1, 64], dtype=pto.float32), + _slice_tensor( + mask_view, offsets=[0, 0], sizes=[1, 64], dtype=pto.float32 + ), + _slice_tensor( + src_view, offsets=[0, 0], sizes=[1, 64], dtype=pto.float32 + ), scalar, - _slice_tensor(dst_view, offsets=[0, 0], sizes=[1, 64], dtype=pto.float32), + _slice_tensor( + dst_view, offsets=[0, 0], sizes=[1, 64], dtype=pto.float32 + ), dtype=pto.float32, shape=[1, 64], )