diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 491c52ef..8e5b3105 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -1,27 +1,11 @@
-name: CI
+name: Full Regression CI
on:
push:
branches: [main]
- pull_request:
- branches: [main]
workflow_dispatch:
jobs:
- pre-commit:
- name: pre-commit
- runs-on: ubuntu-24.04
- steps:
- - uses: actions/checkout@v4
- - uses: actions/setup-python@v5
- with:
- python-version: "3.11"
- - name: Run pre-commit checks
- run: |
- python -m pip install --upgrade pip
- python -m pip install pre-commit
- pre-commit run --all-files
-
test:
name: test (${{ matrix.arch }}, ${{ matrix.install-mode }})
strategy:
@@ -58,6 +42,12 @@ jobs:
- uses: actions/checkout@v4
+ - uses: actions/checkout@v4
+ with:
+ repository: ${{ env.RELEASE_REPO }}
+ ref: ${{ env.RELEASE_TAG }}
+ path: ptoas-src
+
- name: Install Python packages
run: |
pip install --no-cache-dir torch==2.9.0 --index-url https://download.pytorch.org/whl/cpu
@@ -84,7 +74,7 @@ jobs:
- name: Clone pto-isa headers
run: |
- git clone https://gitcode.com/cann/pto-isa.git /sources/pto-isa
+ git clone https://github.com/PTO-ISA/pto-isa.git /sources/pto-isa
cd /sources/pto-isa && git checkout ${PTOISA_COMMIT}
- name: Install ptodsl (${{ matrix.install-mode }})
@@ -95,8 +85,10 @@ jobs:
pip install -e .
fi
- - name: Run frontend tests
- run: pytest -v ./tests/frontend
+ - name: Run host API and regression tests
+ env:
+ PTOAS_VPTO_MANIFEST: ${{ github.workspace }}/ptoas-src/docs/vpto-manifest.json
+ run: pytest -v ./tests/api ./tests/frontend ./tests/regression
- name: Run NPU build tests
run: |
diff --git a/.github/workflows/fast-ci.yml b/.github/workflows/fast-ci.yml
new file mode 100644
index 00000000..05253d20
--- /dev/null
+++ b/.github/workflows/fast-ci.yml
@@ -0,0 +1,74 @@
+name: Fast CI
+
+on:
+ pull_request:
+ branches: [main]
+ push:
+ branches: [main]
+ workflow_dispatch:
+
+env:
+ RELEASE_REPO: zhangstevenunity/PTOAS
+ RELEASE_VER: 0.9
+ RELEASE_TAG: v0.9
+
+jobs:
+ pre-commit:
+ name: pre-commit
+ runs-on: ubuntu-24.04
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v5
+ with:
+ python-version: "3.11"
+ - name: Run pre-commit checks
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install pre-commit
+ pre-commit run --all-files
+
+ host-tests:
+ name: host tests (${{ matrix.install-mode }})
+ runs-on: ubuntu-24.04
+ strategy:
+ fail-fast: false
+ matrix:
+ install-mode: [standard, editable]
+ steps:
+ - uses: actions/checkout@v4
+
+ - uses: actions/checkout@v4
+ with:
+ repository: ${{ env.RELEASE_REPO }}
+ ref: ${{ env.RELEASE_TAG }}
+ path: ptoas-src
+
+ - uses: actions/setup-python@v5
+ with:
+ python-version: "3.11"
+
+ - name: Install Python packages
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install pytest
+
+ - name: Install ptoas wheel
+ run: |
+ WHEEL_NAME=ptoas-${RELEASE_VER}-cp311-none-manylinux_2_34_x86_64.whl
+ wget https://github.com/${RELEASE_REPO}/releases/download/${RELEASE_TAG}/${WHEEL_NAME}
+ python -m pip install ./${WHEEL_NAME}
+ python -c "import mlir.ir; from mlir.dialects import pto"
+
+ - name: Install ptodsl (${{ matrix.install-mode }})
+ run: |
+ if [ "${{ matrix.install-mode }}" = "standard" ]; then
+ python -m pip install .
+ else
+ python -m pip install -e .
+ fi
+
+ - name: Run host API and regression tests
+ env:
+ PTOAS_VPTO_MANIFEST: ${{ github.workspace }}/ptoas-src/docs/vpto-manifest.json
+ run: |
+ pytest -v ./tests/api ./tests/frontend ./tests/regression
diff --git a/README.md b/README.md
index 84ab41cc..eb474af1 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,7 @@
# PTO-DSL
-Pythonic interface and JIT compiler for [PTO-ISA](https://gitcode.com/cann/pto-isa)
+Pythonic interface and JIT compiler for [PTO-ISA](https://github.com/PTO-ISA/pto-isa)
PTO-DSL provides a programming abstraction similar to [cuTile](https://docs.nvidia.com/cuda/cutile-python/), but native to [NPU](https://www.hiascend.com/).
@@ -37,6 +37,37 @@ pip install -e .
See [examples](./examples) and [tests](./tests)
+Preferred frontend style keeps the existing low-level ops available, but adds a thinner
+object-centric layer for common tensor and tile flows:
+
+```python
+from ptodsl import pto, tile
+
+
+def vec_add(src0: "ptr_t", src1: "ptr_t", dst: "ptr_t", rows: "index_t", cols: "index_t"):
+ x = pto.make_tensor(src0, shape=[rows, cols], dtype=pto.float32)
+ y = pto.make_tensor(src1, shape=[rows, cols], dtype=pto.float32)
+ z = pto.make_tensor(dst, shape=[rows, cols], dtype=pto.float32)
+
+ x_tile = x.slice([0, 0], [32, 32])
+ y_tile = y.slice([0, 0], [32, 32])
+ z_tile = z.slice([0, 0], [32, 32])
+
+ with pto.vector_section():
+ tile_buf = pto.make_tile_buffer(pto.float32, [32, 32], space="VEC")
+ lhs = tile_buf.alloc()
+ rhs = tile_buf.alloc()
+ out = tile_buf.alloc()
+ lhs.load_from(x_tile)
+ rhs.load_from(y_tile)
+ tile.add(lhs, rhs, out)
+ out.store_to(z_tile)
+```
+
+The lower-level `PtrType`, `TensorType`, `SubTensorType`, `TileBufType`, `as_tensor`,
+`slice_view`, and `alloc_tile` APIs remain supported for cases where explicit control is
+preferred.
+
## Contribute
See [contribute_guide.md](./contribute_guide.md)
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 7eda8bc5..7c7dea8a 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -61,3 +61,9 @@ RUN ptoas ./tmatmulk.pto -o ./tmatmulk.cpp
RUN python ./abs.py > ./abs.pto
RUN ptoas --enable-insert-sync ./abs.pto -o ./abs.cpp
+# certain operations need latest isa header, not CANN 8.5.0 default
+# header on 2026/02/14
+ARG PTOISA_COMMIT=672ee54cb8905bb9f9abbe80ec26ed2054b7a0cc
+WORKDIR /sources
+RUN git clone https://github.com/PTO-ISA/pto-isa.git \
+ && cd pto-isa && git checkout $PTOISA_COMMIT
diff --git a/examples/aot/matmul_mxfp8/matmul_mxfp8_builder.py b/examples/aot/matmul_mxfp8/matmul_mxfp8_builder.py
new file mode 100644
index 00000000..5cf713fa
--- /dev/null
+++ b/examples/aot/matmul_mxfp8/matmul_mxfp8_builder.py
@@ -0,0 +1,84 @@
+from ptodsl import to_ir_module
+import ptodsl.language as pto
+
+
+def build(M=16, K=64, N=32, lhs_variant="e5m2", rhs_variant="e5m2"):
+ def meta_data():
+ mx = pto.make_mxfp8(lhs=lhs_variant, rhs=rhs_variant)
+ scale_k = mx.scale_k(K)
+
+ ptr_lhs = pto.PtrType(mx.lhs)
+ ptr_rhs = pto.PtrType(mx.rhs)
+ ptr_scale = pto.PtrType(mx.scale)
+ ptr_bias = pto.PtrType(mx.acc)
+
+ lhs_tensor = pto.TensorType(rank=2, dtype=mx.lhs)
+ rhs_tensor = pto.TensorType(rank=2, dtype=mx.rhs)
+ lhs_scale_tensor = pto.TensorType(rank=2, dtype=mx.scale)
+ rhs_scale_tensor = pto.TensorType(rank=2, dtype=mx.scale)
+ bias_tensor = pto.TensorType(rank=2, dtype=mx.acc)
+
+ lhs_tile_view = pto.SubTensorType(shape=[M, K], dtype=mx.lhs)
+ rhs_tile_view = pto.SubTensorType(shape=[K, N], dtype=mx.rhs)
+ lhs_scale_tile_view = pto.SubTensorType(shape=[M, scale_k], dtype=mx.scale)
+ rhs_scale_tile_view = pto.SubTensorType(shape=[scale_k, N], dtype=mx.scale)
+ bias_tile_view = pto.SubTensorType(shape=[1, N], dtype=mx.acc)
+
+ lhs_tile = pto.TileBufType(shape=[M, K], dtype=mx.lhs, memory_space="LEFT")
+ rhs_tile = pto.TileBufType(shape=[K, N], dtype=mx.rhs, memory_space="RIGHT")
+ lhs_scale_tile = pto.LeftScaleTileBufType(shape=[M, scale_k], dtype=mx.scale)
+ rhs_scale_tile = pto.RightScaleTileBufType(shape=[scale_k, N], dtype=mx.scale)
+ bias_tile = pto.TileBufType(shape=[1, N], dtype=mx.acc, memory_space="BIAS")
+ acc_tile = pto.TileBufType(shape=[M, N], dtype=mx.acc, memory_space="ACC")
+
+ return locals()
+
+ const = pto.const
+
+ @to_ir_module(meta_data=meta_data)
+ def matmul_mxfp8(
+ a_ptr: "ptr_lhs",
+ a_scale_ptr: "ptr_scale",
+ b_ptr: "ptr_rhs",
+ b_scale_ptr: "ptr_scale",
+ bias_ptr: "ptr_bias",
+ ) -> None:
+ c0 = const(0)
+ c1 = const(1)
+ cM = const(M)
+ cK = const(K)
+ cN = const(N)
+ cScaleK = const(scale_k)
+
+ tv_a = pto.as_tensor(lhs_tensor, ptr=a_ptr, shape=[cM, cK], strides=[cK, c1])
+ tv_b = pto.as_tensor(rhs_tensor, ptr=b_ptr, shape=[cK, cN], strides=[cN, c1])
+ tv_scale_a = pto.as_tensor(lhs_scale_tensor, ptr=a_scale_ptr, shape=[cM, cScaleK], strides=[cScaleK, c1])
+ tv_scale_b = pto.as_tensor(rhs_scale_tensor, ptr=b_scale_ptr, shape=[cScaleK, cN], strides=[cN, c1])
+ tv_bias = pto.as_tensor(bias_tensor, ptr=bias_ptr, shape=[c1, cN], strides=[cN, c1])
+
+ sv_a = pto.slice_view(lhs_tile_view, source=tv_a, offsets=[c0, c0], sizes=[cM, cK])
+ sv_b = pto.slice_view(rhs_tile_view, source=tv_b, offsets=[c0, c0], sizes=[cK, cN])
+ sv_scale_a = pto.slice_view(lhs_scale_tile_view, source=tv_scale_a, offsets=[c0, c0], sizes=[cM, cScaleK])
+ sv_scale_b = pto.slice_view(rhs_scale_tile_view, source=tv_scale_b, offsets=[c0, c0], sizes=[cScaleK, cN])
+ sv_bias = pto.slice_view(bias_tile_view, source=tv_bias, offsets=[c0, c0], sizes=[c1, cN])
+
+ with pto.cube_section():
+ a_tile = pto.alloc_tile(lhs_tile)
+ b_tile = pto.alloc_tile(rhs_tile)
+ a_scale_tile = pto.alloc_tile(lhs_scale_tile)
+ b_scale_tile = pto.alloc_tile(rhs_scale_tile)
+ bias_tile_buf = pto.alloc_tile(bias_tile)
+ acc_tile_buf = pto.alloc_tile(acc_tile)
+
+ pto.load(sv_a, a_tile)
+ pto.load(sv_b, b_tile)
+ pto.load(sv_scale_a, a_scale_tile)
+ pto.load(sv_scale_b, b_scale_tile)
+ pto.load(sv_bias, bias_tile_buf)
+ pto.matmul_mx_bias(a_tile, a_scale_tile, b_tile, b_scale_tile, bias_tile_buf, acc_tile_buf)
+
+ return matmul_mxfp8
+
+
+if __name__ == "__main__":
+ print(build())
diff --git a/examples/aot/matmul_mxfp8/mxfp8_ppt_example.py b/examples/aot/matmul_mxfp8/mxfp8_ppt_example.py
new file mode 100644
index 00000000..b7988c9e
--- /dev/null
+++ b/examples/aot/matmul_mxfp8/mxfp8_ppt_example.py
@@ -0,0 +1,80 @@
+from ptodsl import to_ir_module
+import ptodsl.language as pto
+
+
+M, K, N = 16, 64, 32
+
+
+def meta_data():
+ # 1) 选择 MXFP8 组合。默认是 lhs=e5m2, rhs=e5m2, scale=e8m0, acc=f32。
+ mx = pto.make_mxfp8(lhs="e5m2", rhs="e5m2")
+ scale_k = mx.scale_k(K) # MXFP8 的 scale 张量沿 K 维按 32:1 压缩
+
+ # 2) 全局输入指针类型
+ a_ptr = pto.PtrType(mx.lhs)
+ b_ptr = pto.PtrType(mx.rhs)
+ scale_ptr = pto.PtrType(mx.scale)
+
+ # 3) TensorView 类型
+ a_tensor = pto.TensorType(rank=2, dtype=mx.lhs)
+ b_tensor = pto.TensorType(rank=2, dtype=mx.rhs)
+ scale_a_tensor = pto.TensorType(rank=2, dtype=mx.scale)
+ scale_b_tensor = pto.TensorType(rank=2, dtype=mx.scale)
+
+ # 4) TileView / TileBuf 类型
+ a_view = pto.SubTensorType(shape=[M, K], dtype=mx.lhs)
+ b_view = pto.SubTensorType(shape=[K, N], dtype=mx.rhs)
+ scale_a_view = pto.SubTensorType(shape=[M, scale_k], dtype=mx.scale)
+ scale_b_view = pto.SubTensorType(shape=[scale_k, N], dtype=mx.scale)
+
+ a_tile = pto.TileBufType(shape=[M, K], dtype=mx.lhs, memory_space="LEFT")
+ b_tile = pto.TileBufType(shape=[K, N], dtype=mx.rhs, memory_space="RIGHT")
+ scale_a_tile = pto.LeftScaleTileBufType(shape=[M, scale_k], dtype=mx.scale)
+ scale_b_tile = pto.RightScaleTileBufType(shape=[scale_k, N], dtype=mx.scale)
+ acc_tile = pto.TileBufType(shape=[M, N], dtype=mx.acc, memory_space="ACC")
+
+ return locals()
+
+
+@to_ir_module(meta_data=meta_data)
+def matmul_mxfp8_core(
+ a: "a_ptr",
+ scale_a: "scale_ptr",
+ b: "b_ptr",
+ scale_b: "scale_ptr",
+) -> None:
+ c0 = pto.const(0)
+ c1 = pto.const(1)
+ cM = pto.const(M)
+ cK = pto.const(K)
+ cN = pto.const(N)
+ cScaleK = pto.const(scale_k)
+
+ tv_a = pto.as_tensor(a_tensor, ptr=a, shape=[cM, cK], strides=[cK, c1])
+ tv_b = pto.as_tensor(b_tensor, ptr=b, shape=[cK, cN], strides=[cN, c1])
+ tv_scale_a = pto.as_tensor(scale_a_tensor, ptr=scale_a, shape=[cM, cScaleK], strides=[cScaleK, c1])
+ tv_scale_b = pto.as_tensor(scale_b_tensor, ptr=scale_b, shape=[cScaleK, cN], strides=[cN, c1])
+
+ sv_a = pto.slice_view(a_view, source=tv_a, offsets=[c0, c0], sizes=[cM, cK])
+ sv_b = pto.slice_view(b_view, source=tv_b, offsets=[c0, c0], sizes=[cK, cN])
+ sv_scale_a = pto.slice_view(scale_a_view, source=tv_scale_a, offsets=[c0, c0], sizes=[cM, cScaleK])
+ sv_scale_b = pto.slice_view(scale_b_view, source=tv_scale_b, offsets=[c0, c0], sizes=[cScaleK, cN])
+
+ with pto.cube_section():
+ ta = pto.alloc_tile(a_tile)
+ tb = pto.alloc_tile(b_tile)
+ tsa = pto.alloc_tile(scale_a_tile)
+ tsb = pto.alloc_tile(scale_b_tile)
+ tc = pto.alloc_tile(acc_tile)
+
+ pto.load(sv_a, ta)
+ pto.load(sv_b, tb)
+ pto.load(sv_scale_a, tsa)
+ pto.load(sv_scale_b, tsb)
+
+ # 核心调用:MXFP8 data tile + scale tile -> Acc tile
+ pto.matmul_mx(ta, tsa, tb, tsb, tc)
+
+
+if __name__ == "__main__":
+ print(matmul_mxfp8_core)
diff --git a/examples/aot/template_arithmetic/constexpr_tile_builder.py b/examples/aot/template_arithmetic/constexpr_tile_builder.py
new file mode 100644
index 00000000..a6d1d167
--- /dev/null
+++ b/examples/aot/template_arithmetic/constexpr_tile_builder.py
@@ -0,0 +1,36 @@
+from ptodsl import Constexpr, const_expr, pto, range_constexpr, to_ir_module
+from ptodsl import scalar as s
+
+
+const = s.const
+
+
+def meta_data(TILE_K, UNROLL=2):
+ dtype = pto.float32
+ return {
+ "index_dtype": pto.int32,
+ "tile_type": pto.TileBufType(
+ shape=[1, TILE_K // 2],
+ valid_shape=[1, TILE_K // 2],
+ dtype=dtype,
+ memory_space="VEC",
+ ),
+ }
+
+
+@to_ir_module(meta_data=meta_data)
+def constexpr_tile_kernel(
+ n: "index_dtype",
+ TILE_K: Constexpr[int],
+ UNROLL: Constexpr[int] = 2,
+) -> None:
+ with pto.vector_section():
+ if const_expr(TILE_K % 128 == 0):
+ for _ in range_constexpr(UNROLL):
+ pto.alloc_tile(tile_type)
+ else:
+ pto.alloc_tile(tile_type)
+
+
+if __name__ == "__main__":
+ print(constexpr_tile_kernel(TILE_K=128, UNROLL=3))
diff --git a/examples/ppt/mixed_pto_vector_slide.md b/examples/ppt/mixed_pto_vector_slide.md
new file mode 100644
index 00000000..803180dd
--- /dev/null
+++ b/examples/ppt/mixed_pto_vector_slide.md
@@ -0,0 +1,77 @@
+# PTO `t*` + `v*` 混合示例
+
+## 一页版表达
+
+```text
+Outer PTO tile flow:
+ make_tensor_view -> partition_view -> tload -> [vector inner loop] -> tstore
+
+Inner vector loop:
+ vlds -> vlds -> vadd -> vsts
+```
+
+## PPT 版伪 IR
+
+```mlir
+module {
+ func.func @vec_add_mixed(
+ %a: !pto.ptr,
+ %b: !pto.ptr,
+ %c: !pto.ptr) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c32 = arith.constant 32 : index
+ %c64 = arith.constant 64 : index
+ %c1024 = arith.constant 1024 : index
+
+ // 1) 先用 PTO tile op 选出一个 32x32 工作块
+ %A = pto.make_tensor_view %a, shape = [%c32, %c32], strides = [%c32, %c1]
+ : !pto.tensor_view
+ %B = pto.make_tensor_view %b, shape = [%c32, %c32], strides = [%c32, %c1]
+ : !pto.tensor_view
+ %C = pto.make_tensor_view %c, shape = [%c32, %c32], strides = [%c32, %c1]
+ : !pto.tensor_view
+
+ %tileA = pto.partition_view %A, offsets = [%c0, %c0], sizes = [%c32, %c32]
+ : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32>
+ %tileB = pto.partition_view %B, offsets = [%c0, %c0], sizes = [%c32, %c32]
+ : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32>
+ %tileC = pto.partition_view %C, offsets = [%c0, %c0], sizes = [%c32, %c32]
+ : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32>
+
+ // 统一记号:!tile 表示 vec-local 32x32 f32 tile_buf
+ %bufA = pto.alloc_tile : !pto.tile_buf
+ %bufB = pto.alloc_tile : !pto.tile_buf
+ %bufC = pto.alloc_tile : !pto.tile_buf
+
+ // 2) tile 级搬运:GM -> local tile
+ pto.tload ins(%tileA : !pto.partition_tensor_view<32x32xf32>)
+ outs(%bufA : !pto.tile_buf)
+ pto.tload ins(%tileB : !pto.partition_tensor_view<32x32xf32>)
+ outs(%bufB : !pto.tile_buf)
+
+ // 3) vector 级计算:在 local tile 内部按 64-lane 分块
+ %ptrA = pto.tile_buf_addr %bufA : !pto.tile_buf<...> -> !llvm.ptr<6>
+ %ptrB = pto.tile_buf_addr %bufB : !pto.tile_buf<...> -> !llvm.ptr<6>
+ %ptrC = pto.tile_buf_addr %bufC : !pto.tile_buf<...> -> !llvm.ptr<6>
+
+ scf.for %i = %c0 to %c1024 step %c64 {
+ %va = pto.vlds %ptrA[%i] : !llvm.ptr<6> -> !pto.vreg<64xf32>
+ %vb = pto.vlds %ptrB[%i] : !llvm.ptr<6> -> !pto.vreg<64xf32>
+ %vc = pto.vadd %va, %vb
+ : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>
+ pto.vsts %vc, %ptrC[%i] : !pto.vreg<64xf32>, !llvm.ptr<6>
+ }
+
+ // 4) tile 级写回:local tile -> GM
+ pto.tstore ins(%bufC : !pto.tile_buf)
+ outs(%tileC : !pto.partition_tensor_view<32x32xf32>)
+ return
+ }
+}
+```
+
+## 讲解时只强调这两层
+
+- `pto.t*` 负责选 tile 和搬 tile:`make_tensor_view -> partition_view -> tload -> tstore`
+- `pto.v*` 负责在 tile 内做向量计算:`vlds -> vadd -> vsts`
diff --git a/ptodsl/__init__.py b/ptodsl/__init__.py
index 55333e65..c9fff529 100644
--- a/ptodsl/__init__.py
+++ b/ptodsl/__init__.py
@@ -1,6 +1,19 @@
-from . import pto, scalar, tile
+from . import micro, pto, scalar, tile
+from ._constexpr import Constexpr, const_expr, range_constexpr
from .bench import do_bench
from .compiler.ir import to_ir_module
from .compiler.jit import JitWrapper, jit
-__all__ = ["JitWrapper", "do_bench", "jit", "pto", "scalar", "tile", "to_ir_module"]
+__all__ = [
+ "Constexpr",
+ "JitWrapper",
+ "const_expr",
+ "do_bench",
+ "jit",
+ "micro",
+ "pto",
+ "range_constexpr",
+ "scalar",
+ "tile",
+ "to_ir_module",
+]
diff --git a/ptodsl/_constexpr.py b/ptodsl/_constexpr.py
new file mode 100644
index 00000000..295e5c75
--- /dev/null
+++ b/ptodsl/_constexpr.py
@@ -0,0 +1,284 @@
+import hashlib
+import inspect
+from dataclasses import dataclass
+from typing import Generic, TypeVar, get_origin
+
+
+T = TypeVar("T")
+_MISSING = object()
+
+
+class Constexpr(Generic[T]):
+ """Marker annotation for compile-time-only parameters."""
+
+
+@dataclass(frozen=True)
+class SignatureAnalysis:
+ signature: inspect.Signature
+ constexpr_params: tuple[inspect.Parameter, ...]
+ runtime_params: tuple[inspect.Parameter, ...]
+
+ @property
+ def has_constexpr_params(self):
+ return bool(self.constexpr_params)
+
+
+@dataclass(frozen=True)
+class BoundArguments:
+ all_arguments: dict[str, object]
+ constexpr_arguments: dict[str, object]
+ runtime_arguments: tuple[object, ...]
+ missing_runtime: tuple[str, ...]
+
+
+def analyze_signature(fn_or_signature):
+ signature = (
+ fn_or_signature
+ if isinstance(fn_or_signature, inspect.Signature)
+ else inspect.signature(fn_or_signature)
+ )
+ constexpr_params = []
+ runtime_params = []
+ for param in signature.parameters.values():
+ if is_constexpr_annotation(param.annotation):
+ constexpr_params.append(param)
+ else:
+ runtime_params.append(param)
+ return SignatureAnalysis(
+ signature=signature,
+ constexpr_params=tuple(constexpr_params),
+ runtime_params=tuple(runtime_params),
+ )
+
+
+def is_constexpr_annotation(annotation):
+ if isinstance(annotation, str):
+ return annotation.startswith("Constexpr[") and annotation.endswith("]")
+ return get_origin(annotation) is Constexpr
+
+
+def unwrap_constexpr_annotation(annotation):
+ if isinstance(annotation, str):
+ if is_constexpr_annotation(annotation):
+ return annotation[len("Constexpr[") : -1]
+ return annotation
+ if get_origin(annotation) is Constexpr:
+ args = getattr(annotation, "__args__", ())
+ if args:
+ return args[0]
+ return inspect._empty
+ return annotation
+
+
+def bind_constexpr_arguments(analysis, *args, **kwargs):
+ bound = analysis.signature.bind_partial(*args, **kwargs)
+ provided_runtime = [
+ name for name in bound.arguments if name in {p.name for p in analysis.runtime_params}
+ ]
+ if provided_runtime:
+ joined = ", ".join(provided_runtime)
+ raise TypeError(
+ "Specialization only accepts constexpr arguments; "
+ f"got runtime arguments: {joined}."
+ )
+
+ values = {}
+ missing = []
+ for param in analysis.constexpr_params:
+ if param.name in bound.arguments:
+ values[param.name] = bound.arguments[param.name]
+ elif param.default is not inspect._empty:
+ values[param.name] = param.default
+ else:
+ missing.append(param.name)
+
+ if missing:
+ joined = ", ".join(missing)
+ raise TypeError(f"Missing required constexpr arguments: {joined}.")
+
+ return values
+
+
+def bind_kernel_arguments(analysis, *args, **kwargs):
+ bound = analysis.signature.bind_partial(*args, **kwargs)
+ values = dict(bound.arguments)
+
+ missing_runtime = []
+ runtime_arguments = []
+ constexpr_arguments = {}
+
+ for param in analysis.signature.parameters.values():
+ if param.name not in values and param.default is not inspect._empty:
+ values[param.name] = param.default
+
+ for param in analysis.constexpr_params:
+ if param.name not in values:
+ raise TypeError(f"Missing required constexpr argument '{param.name}'.")
+ constexpr_arguments[param.name] = values[param.name]
+
+ for param in analysis.runtime_params:
+ if param.name in values:
+ runtime_arguments.append(values[param.name])
+ else:
+ runtime_arguments.append(_MISSING)
+ missing_runtime.append(param.name)
+
+ return BoundArguments(
+ all_arguments=values,
+ constexpr_arguments=constexpr_arguments,
+ runtime_arguments=tuple(runtime_arguments),
+ missing_runtime=tuple(missing_runtime),
+ )
+
+
+def normalize_constexpr_bindings(bindings):
+ return tuple((name, _normalize_constexpr_value(value)) for name, value in bindings.items())
+
+
+def specialization_suffix(bindings):
+ if not bindings:
+ return "default"
+ digest = hashlib.sha256(repr(normalize_constexpr_bindings(bindings)).encode("utf-8"))
+ return digest.hexdigest()[:16]
+
+
+def meta_kwargs_for(meta_fn, constexpr_bindings):
+ signature = inspect.signature(meta_fn)
+ if not signature.parameters:
+ return None
+
+ kwargs = {}
+ missing = []
+ for param in signature.parameters.values():
+ if param.kind in (
+ inspect.Parameter.VAR_POSITIONAL,
+ inspect.Parameter.VAR_KEYWORD,
+ ):
+ continue
+ if param.kind is inspect.Parameter.POSITIONAL_ONLY:
+ raise TypeError("`meta_data` does not support positional-only parameters.")
+ if param.name in constexpr_bindings:
+ kwargs[param.name] = constexpr_bindings[param.name]
+ elif param.default is inspect._empty:
+ missing.append(param.name)
+
+ if missing:
+ joined = ", ".join(missing)
+ raise TypeError(
+ "`meta_data` requires unresolved constexpr parameters: "
+ f"{joined}."
+ )
+ return kwargs
+
+
+def is_dynamic_value(value):
+ if isinstance(value, (list, tuple)):
+ return any(is_dynamic_value(item) for item in value)
+ if isinstance(value, dict):
+ return any(
+ is_dynamic_value(key) or is_dynamic_value(item)
+ for key, item in value.items()
+ )
+ cls = value.__class__
+ module = getattr(cls, "__module__", "")
+ name = getattr(cls, "__name__", "")
+ if module.startswith("ptodsl") and name == "Value":
+ return True
+ if module.startswith("mlir.") and (
+ name in {"Value", "OpResult", "BlockArgument"} or "Value" in name
+ ):
+ return True
+ return False
+
+
+def require_constexpr_value(value, *, context):
+ if is_dynamic_value(value):
+ raise TypeError(f"`{context}` requires compile-time values, got dynamic PTODSL/MLIR values.")
+ return value
+
+
+def require_static_int(value, *, context):
+ require_constexpr_value(value, context=context)
+ if isinstance(value, bool) or not isinstance(value, int):
+ raise TypeError(f"`{context}` requires a Python int, got {type(value)!r}.")
+ return value
+
+
+def require_static_int_sequence(values, *, context):
+ return [require_static_int(value, context=context) for value in values]
+
+
+def const_expr(value):
+ require_constexpr_value(value, context="const_expr")
+ return bool(value)
+
+
+def range_constexpr(start, stop=None, step=1):
+ if stop is None:
+ return range(
+ require_static_int(start, context="range_constexpr"),
+ )
+ return range(
+ require_static_int(start, context="range_constexpr"),
+ require_static_int(stop, context="range_constexpr"),
+ require_static_int(step, context="range_constexpr"),
+ )
+
+
+def _normalize_constexpr_value(value):
+ require_constexpr_value(value, context="constexpr specialization")
+ if value is None:
+ return ("none", None)
+ if isinstance(value, bool):
+ return ("bool", value)
+ if isinstance(value, int):
+ return ("int", value)
+ if isinstance(value, float):
+ return ("float", value)
+ if isinstance(value, str):
+ return ("str", value)
+ if isinstance(value, type):
+ return ("type", value.__module__, value.__qualname__)
+ if isinstance(value, tuple):
+ return ("tuple", tuple(_normalize_constexpr_value(item) for item in value))
+ if isinstance(value, list):
+ return ("list", tuple(_normalize_constexpr_value(item) for item in value))
+ if isinstance(value, dict):
+ items = [
+ (_normalize_constexpr_value(key), _normalize_constexpr_value(item))
+ for key, item in value.items()
+ ]
+ items.sort(key=repr)
+ return ("dict", tuple(items))
+
+ cls = value.__class__
+ module = getattr(cls, "__module__", "")
+ if module.startswith("mlir.") or module.startswith("ptodsl."):
+ return ("object", module, cls.__qualname__, str(value))
+
+ raise TypeError(
+ "Unsupported constexpr value type for specialization key: "
+ f"{type(value)!r}."
+ )
+
+
+__all__ = [
+ "BoundArguments",
+ "Constexpr",
+ "SignatureAnalysis",
+ "_MISSING",
+ "analyze_signature",
+ "bind_constexpr_arguments",
+ "bind_kernel_arguments",
+ "const_expr",
+ "is_constexpr_annotation",
+ "is_dynamic_value",
+ "meta_kwargs_for",
+ "normalize_constexpr_bindings",
+ "range_constexpr",
+ "require_constexpr_value",
+ "require_static_int",
+ "require_static_int_sequence",
+ "specialization_suffix",
+ "unwrap_constexpr_annotation",
+]
diff --git a/ptodsl/api/__init__.py b/ptodsl/api/__init__.py
index ca7e01f9..c2a06941 100644
--- a/ptodsl/api/__init__.py
+++ b/ptodsl/api/__init__.py
@@ -1,3 +1,3 @@
-from . import pto, scalar, tile
+from . import micro, pto, scalar, tile
-__all__ = ["pto", "scalar", "tile"]
+__all__ = ["micro", "pto", "scalar", "tile"]
diff --git a/ptodsl/api/_micro_registry.py b/ptodsl/api/_micro_registry.py
new file mode 100644
index 00000000..c9c5c7cc
--- /dev/null
+++ b/ptodsl/api/_micro_registry.py
@@ -0,0 +1,115 @@
+MICRO_OPS = [
+ "castptr",
+ "addptr",
+ "set_flag",
+ "wait_flag",
+ "barrier",
+ "get_buf",
+ "rls_buf",
+ "set_loop2_stride_outtoub",
+ "set_loop1_stride_outtoub",
+ "set_loop_size_outtoub",
+ "set_loop2_stride_ubtoout",
+ "set_loop1_stride_ubtoout",
+ "set_loop_size_ubtoout",
+ "copy_gm_to_ubuf",
+ "copy_ubuf_to_ubuf",
+ "copy_ubuf_to_gm",
+ "vlds",
+ "vldas",
+ "vldus",
+ "plds",
+ "pld",
+ "pldi",
+ "vldx2",
+ "vgather2",
+ "vgatherb",
+ "vgather2_bc",
+ "vsld",
+ "vsldb",
+ "vbr",
+ "vdup",
+ "pset_b8",
+ "pset_b16",
+ "pset_b32",
+ "pge_b8",
+ "pge_b16",
+ "pge_b32",
+ "ppack",
+ "punpack",
+ "vabs",
+ "vexp",
+ "vln",
+ "vsqrt",
+ "vrec",
+ "vrelu",
+ "vnot",
+ "vcadd",
+ "vcmax",
+ "vcmin",
+ "vbcnt",
+ "vcls",
+ "vadd",
+ "vsub",
+ "vmul",
+ "vdiv",
+ "vmax",
+ "vmin",
+ "vand",
+ "vor",
+ "vxor",
+ "vshl",
+ "vshr",
+ "vmuls",
+ "vadds",
+ "vmaxs",
+ "vmins",
+ "vlrelu",
+ "vshls",
+ "vshrs",
+ "vaddc",
+ "vsubc",
+ "vaddcs",
+ "vsubcs",
+ "vsel",
+ "vselr",
+ "vselrv2",
+ "vcmp",
+ "vcmps",
+ "pnot",
+ "psel",
+ "pdintlv_b8",
+ "pintlv_b16",
+ "vintlv",
+ "vdintlv",
+ "vintlvv2",
+ "vdintlvv2",
+ "vtrc",
+ "vcvt",
+ "vci",
+ "vbitsort",
+ "vmrgsort4",
+ "vmull",
+ "vmula",
+ "vsts",
+ "vscatter",
+ "psts",
+ "pst",
+ "psti",
+ "vsst",
+ "vstx2",
+ "vsstb",
+ "vsta",
+ "vstas",
+ "vstar",
+ "pstu",
+ "vstu",
+ "vstus",
+ "vstur",
+ "vlds_post",
+ "uvld",
+ "plt_b8",
+ "plt_b16",
+ "plt_b32",
+ "vsts_post",
+]
diff --git a/ptodsl/api/control_flow.py b/ptodsl/api/control_flow.py
index 457fade8..7d239a33 100644
--- a/ptodsl/api/control_flow.py
+++ b/ptodsl/api/control_flow.py
@@ -3,6 +3,7 @@
from mlir.dialects import scf
from mlir.ir import InsertionPoint
+from .._constexpr import const_expr, range_constexpr
from .scalar import Value, _unwrap
@@ -48,5 +49,4 @@ def cond(condition, then_builder, else_builder):
scf.YieldOp([])
return op
-
-__all__ = ["cond", "range", "if_context"]
+__all__ = ["cond", "const_expr", "range", "range_constexpr", "if_context"]
diff --git a/ptodsl/api/micro.py b/ptodsl/api/micro.py
new file mode 100644
index 00000000..d6d48d2f
--- /dev/null
+++ b/ptodsl/api/micro.py
@@ -0,0 +1,50 @@
+from mlir.dialects import pto as _pto
+
+from ._micro_registry import MICRO_OPS
+from .scalar import Value
+
+
+def _unwrap(value):
+ if isinstance(value, Value):
+ return _unwrap(value.raw)
+ if hasattr(value, "raw"):
+ return _unwrap(value.raw)
+ if isinstance(value, list):
+ return [_unwrap(v) for v in value]
+ if isinstance(value, tuple):
+ return tuple(_unwrap(v) for v in value)
+ if isinstance(value, dict):
+ return {k: _unwrap(v) for k, v in value.items()}
+ return value
+
+
+def _micro_barrier(op, *, loc=None, ip=None):
+ if isinstance(op, str):
+ normalized = op.strip().upper()
+ if normalized.startswith("PIPE"):
+ op = _pto.PipeAttr.get(getattr(_pto.PIPE, normalized))
+ return _pto.barrier(_unwrap(op), loc=loc, ip=ip)
+
+
+def _make_wrapper(name):
+ if name == "barrier":
+ return _micro_barrier
+
+ op = getattr(_pto, name, None)
+ if op is None:
+ raise AttributeError(f"mlir.dialects.pto has no builder for '{name}'")
+
+ def _wrapper(*args, **kwargs):
+ return op(*(_unwrap(arg) for arg in args), **{k: _unwrap(v) for k, v in kwargs.items()})
+
+ _wrapper.__name__ = name
+ _wrapper.__qualname__ = name
+ _wrapper.__doc__ = f"Emit `pto.{name}`."
+ return _wrapper
+
+
+for _mnemonic in MICRO_OPS:
+ globals()[_mnemonic] = _make_wrapper(_mnemonic)
+
+
+__all__ = ["MICRO_OPS", *MICRO_OPS]
diff --git a/ptodsl/api/pto.py b/ptodsl/api/pto.py
index f2e2d0ac..ece7ae29 100644
--- a/ptodsl/api/pto.py
+++ b/ptodsl/api/pto.py
@@ -1,6 +1,12 @@
-from .control_flow import cond, range, if_context
+from .._constexpr import Constexpr
+from ._micro_registry import MICRO_OPS
+from .control_flow import cond, const_expr, range, range_constexpr, if_context
+from . import micro as _micro
from .scalar import Value, wrap_value
from .pto_general import (
+ TensorView,
+ TileBuffer,
+ TileBufferSpec,
alloc_tile,
as_tensor,
cube_section,
@@ -9,31 +15,45 @@
get_subblock_idx,
get_subblock_num,
load,
+ make_tensor,
+ make_tile_buffer,
+ ptr,
slice_view,
store,
vector_section,
print,
)
-from .synchronization import barrier, record_event, record_wait_pair, wait_event
+from .synchronization import barrier, barrier_sync, record_event, record_wait_pair, wait_event
from .type_def import (
+ AlignType,
+ MaskType,
PtrType,
SubTensorType,
TensorType,
TileBufConfig,
TileBufType,
+ VRegType,
__getattr__,
)
__all__ = [
"Value",
+ "Constexpr",
+ "TensorView",
+ "TileBuffer",
+ "TileBufferSpec",
"wrap_value",
"bool",
"float16",
"float32",
"int16",
"int32",
+ "ptr",
"PtrType",
+ "VRegType",
+ "MaskType",
+ "AlignType",
"TensorType",
"SubTensorType",
"TileBufConfig",
@@ -42,19 +62,28 @@
"get_subblock_idx",
"get_subblock_num",
"get_block_num",
+ "make_tensor",
+ "make_tile_buffer",
"as_tensor",
"slice_view",
"vector_section",
"cube_section",
"range",
+ "const_expr",
+ "range_constexpr",
"if_context",
"cond",
"alloc_tile",
"load",
"store",
"print",
+ "barrier_sync",
"record_event",
"wait_event",
"record_wait_pair",
- "barrier",
+ *MICRO_OPS,
]
+
+
+for _name in MICRO_OPS:
+ globals()[_name] = getattr(_micro, _name)
diff --git a/ptodsl/api/pto_general.py b/ptodsl/api/pto_general.py
index c8f649ea..e67d2f8e 100644
--- a/ptodsl/api/pto_general.py
+++ b/ptodsl/api/pto_general.py
@@ -3,6 +3,8 @@
from mlir.dialects import pto as _pto
from mlir.ir import InsertionPoint
+from .._constexpr import require_static_int_sequence
+from . import scalar, type_def
from .scalar import Value, _unwrap
@@ -30,6 +32,181 @@ def _resolve_layout_attr(layout):
return layout
+def _is_static_int(value):
+ return isinstance(value, int) and not isinstance(value, bool)
+
+
+def _shape_factor(value):
+ if _is_static_int(value):
+ return value
+ if isinstance(value, Value):
+ return value
+ if hasattr(value, "raw"):
+ return value.raw
+ return value
+
+
+def _mul_shape_values(lhs, rhs):
+ lhs = _shape_factor(lhs)
+ rhs = _shape_factor(rhs)
+ if _is_static_int(lhs) and _is_static_int(rhs):
+ return lhs * rhs
+ if _is_static_int(lhs) and lhs == 1:
+ return rhs
+ if _is_static_int(rhs) and rhs == 1:
+ return lhs
+ lhs_value = scalar.const(lhs) if _is_static_int(lhs) else scalar.wrap_value(lhs)
+ rhs_value = scalar.const(rhs) if _is_static_int(rhs) else scalar.wrap_value(rhs)
+ return lhs_value * rhs_value
+
+
+def _as_index_operand(value):
+ if _is_static_int(value):
+ return scalar.const(value)
+ return value
+
+
+def _normalize_index_operands(values):
+ return [_as_index_operand(value) for value in values]
+
+
+def _infer_compact_row_major_strides(shape):
+ shape = list(shape)
+ if not shape:
+ raise ValueError("`make_tensor` requires a non-empty shape.")
+ strides = [1] * len(shape)
+ running = 1
+ for idx in range(len(shape) - 1, -1, -1):
+ strides[idx] = running
+ if idx > 0:
+ running = _mul_shape_values(running, shape[idx])
+ return strides
+
+
+def _unwrap_tile_type(tile_type):
+ if hasattr(tile_type, "raw_type"):
+ return tile_type.raw_type
+ return tile_type
+
+
+def _resolve_tensor_element_type(tensor_type, dtype):
+ if dtype is not None:
+ return scalar.resolve_type(dtype)
+ if hasattr(tensor_type, "element_type"):
+ return tensor_type.element_type
+ raise TypeError(
+ "`make_tensor` could not infer the element type from `type=`; pass `dtype=` explicitly."
+ )
+
+
+class TensorView:
+ def __init__(self, raw, *, element_type):
+ self.raw = _unwrap(raw)
+ self._element_type = element_type
+
+ def __getattr__(self, item):
+ return getattr(self.raw, item)
+
+ def __repr__(self):
+ return str(self.raw)
+
+ def slice(self, offsets, sizes, *, static_shape=None):
+ raw_sizes = list(sizes)
+ offsets = _normalize_index_operands(offsets)
+ sizes = _normalize_index_operands(raw_sizes)
+ if static_shape is None:
+ if all(_is_static_int(size) for size in raw_sizes):
+ inferred_shape = require_static_int_sequence(
+ raw_sizes, context="TensorView.slice.sizes"
+ )
+ else:
+ raise TypeError(
+ "`TensorView.slice(..., sizes=...)` requires `static_shape=` when "
+ "sizes include dynamic PTODSL/MLIR values."
+ )
+ else:
+ inferred_shape = require_static_int_sequence(
+ static_shape, context="TensorView.slice.static_shape"
+ )
+ subtensor_type = type_def.SubTensorType(
+ shape=inferred_shape,
+ dtype=self._element_type,
+ )
+ return TensorView(
+ slice_view(subtensor_type, source=self.raw, offsets=offsets, sizes=sizes),
+ element_type=self._element_type,
+ )
+
+
+class TileBufferSpec:
+ def __init__(self, raw_type):
+ self.raw_type = raw_type
+
+ def __repr__(self):
+ return str(self.raw_type)
+
+ def alloc(self, *, addr=None, valid_row=None, valid_col=None):
+ return TileBuffer(
+ alloc_tile(
+ self.raw_type,
+ addr=addr,
+ valid_row=valid_row,
+ valid_col=valid_col,
+ )
+ )
+
+
+class TileBuffer:
+ def __init__(self, raw):
+ self.raw = _unwrap(raw)
+
+ def __getattr__(self, item):
+ return getattr(self.raw, item)
+
+ def __repr__(self):
+ return str(self.raw)
+
+ def load_from(self, view):
+ load(view, self)
+ return self
+
+ def store_to(self, view):
+ store(self, view)
+ return self
+
+
+def ptr(dtype, space=None):
+ return type_def.PtrType(dtype, memory_space=space)
+
+
+def make_tensor(ptr, *, shape, strides=None, dtype=None, type=None, layout=None):
+ if type is None:
+ if dtype is None:
+ raise TypeError("`make_tensor` requires `dtype=` when `type=` is omitted.")
+ type = type_def.TensorType(rank=len(shape), dtype=dtype)
+ element_type = _resolve_tensor_element_type(type, dtype)
+ shape = _normalize_index_operands(shape)
+ if strides is None:
+ strides = _infer_compact_row_major_strides(shape)
+ strides = _normalize_index_operands(strides)
+ return TensorView(
+ as_tensor(type, ptr=ptr, shape=shape, strides=strides, layout=layout),
+ element_type=element_type,
+ )
+
+
+def make_tile_buffer(dtype, shape, *, space, valid_shape=None, config=None):
+ return TileBufferSpec(
+ type_def.TileBufType(
+ shape=shape,
+ dtype=dtype,
+ memory_space=space,
+ valid_shape=valid_shape,
+ config=config,
+ )
+ )
+
+
def as_tensor(tensor_type, *, ptr, shape, strides, layout=None):
shape_vals = [_unwrap(v) for v in shape]
stride_vals = [_unwrap(v) for v in strides]
@@ -46,7 +223,7 @@ def slice_view(subtensor_type, *, source, offsets, sizes):
offset_vals = [_unwrap(v) for v in offsets]
size_vals = [_unwrap(v) for v in sizes]
return _pto.PartitionViewOp(
- subtensor_type, source, offsets=offset_vals, sizes=size_vals
+ subtensor_type, _unwrap(source), offsets=offset_vals, sizes=size_vals
).result
@@ -74,15 +251,15 @@ def alloc_tile(tile_type, *, addr=None, valid_row=None, valid_col=None):
kwargs["valid_row"] = _unwrap(valid_row)
if valid_col is not None:
kwargs["valid_col"] = _unwrap(valid_col)
- return _pto.AllocTileOp(tile_type, **kwargs).result
+ return _pto.AllocTileOp(_unwrap_tile_type(tile_type), **kwargs).result
def load(source, dest):
- _pto.TLoadOp(None, source, dest)
+ _pto.TLoadOp(None, _unwrap(source), _unwrap(dest))
def store(source, dest):
- _pto.TStoreOp(None, source, dest)
+ _pto.TStoreOp(None, _unwrap(source), _unwrap(dest))
def print(format, scalar):
@@ -102,10 +279,16 @@ def print(format, scalar):
__all__ = [
+ "TensorView",
+ "TileBuffer",
+ "TileBufferSpec",
"get_block_idx",
"get_subblock_idx",
"get_subblock_num",
"get_block_num",
+ "ptr",
+ "make_tensor",
+ "make_tile_buffer",
"as_tensor",
"slice_view",
"vector_section",
diff --git a/ptodsl/api/scalar.py b/ptodsl/api/scalar.py
index 7f4e9d54..d4e125c1 100644
--- a/ptodsl/api/scalar.py
+++ b/ptodsl/api/scalar.py
@@ -2,9 +2,31 @@
from mlir.ir import F16Type, F32Type, IndexType, IntegerType
+class LazyTypeAlias:
+ def __init__(self, name, resolver):
+ self._name = name
+ self._resolver = resolver
+
+ def resolve(self):
+ return self._resolver()
+
+ def __repr__(self):
+ return self._name
+
+ __str__ = __repr__
+
+
+def resolve_type(value):
+ if isinstance(value, LazyTypeAlias):
+ return value.resolve()
+ return value
+
+
def _unwrap(value):
if isinstance(value, Value):
- return value.raw
+ return _unwrap(value.raw)
+ if hasattr(value, "raw"):
+ return _unwrap(value.raw)
return value
@@ -86,17 +108,17 @@ def __getattr__(name):
# TODO: add more builtin dtype aliases (for example float16/bfloat16/int8/int64)
# when they are validated against PTO type support.
if name == "bool":
- return IntegerType.get_signless(1)
+ return LazyTypeAlias(name, lambda: IntegerType.get_signless(1))
if name == "float32":
- return F32Type.get()
+ return LazyTypeAlias(name, lambda: F32Type.get())
if name == "float16":
- return F16Type.get()
+ return LazyTypeAlias(name, lambda: F16Type.get())
if name == "int32":
- return IntegerType.get_signless(32)
+ return LazyTypeAlias(name, lambda: IntegerType.get_signless(32))
if name == "int16":
- return IntegerType.get_signless(16)
+ return LazyTypeAlias(name, lambda: IntegerType.get_signless(16))
if name == "uint32":
- return IntegerType.get_unsigned(32)
+ return LazyTypeAlias(name, lambda: IntegerType.get_unsigned(32))
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
@@ -151,9 +173,9 @@ def select(cond, true_val, false_val):
__all__ = [
+ "LazyTypeAlias",
"Value",
"_unwrap",
- "wrap_value",
"const",
"index_cast",
"ceil_div",
@@ -164,5 +186,7 @@ def select(cond, true_val, false_val):
"lt",
"gt",
"ge",
+ "resolve_type",
"select",
+ "wrap_value",
]
diff --git a/ptodsl/api/synchronization.py b/ptodsl/api/synchronization.py
index 1a0801ee..7b5bef1f 100644
--- a/ptodsl/api/synchronization.py
+++ b/ptodsl/api/synchronization.py
@@ -15,6 +15,13 @@ def _resolve_sync_op(sync_op):
return sync_op
+def _resolve_sync_op_attr(sync_op):
+ resolved = _resolve_sync_op(sync_op)
+ if isinstance(resolved, _pto.SyncOpType):
+ return _pto.SyncOpTypeAttr.get(resolved)
+ return resolved
+
+
def _resolve_event_id(event_id):
if isinstance(event_id, int):
if event_id < 0 or event_id > 7:
@@ -67,4 +74,8 @@ def barrier(sync_op):
_pto.barrier(_resolve_sync_op(sync_op))
-__all__ = ["record_event", "wait_event", "record_wait_pair", "barrier"]
+def barrier_sync(sync_op):
+ _pto.barrier_sync(_resolve_sync_op_attr(sync_op))
+
+
+__all__ = ["record_event", "wait_event", "record_wait_pair", "barrier", "barrier_sync"]
diff --git a/ptodsl/api/tile.py b/ptodsl/api/tile.py
index 2cffe513..dbeb70e9 100644
--- a/ptodsl/api/tile.py
+++ b/ptodsl/api/tile.py
@@ -5,148 +5,158 @@
from .scalar import _unwrap
+def _call(op, *args, **kwargs):
+ return op(
+ *(_unwrap(arg) for arg in args),
+ **{name: _unwrap(value) for name, value in kwargs.items()},
+ )
+
+
def mov(source, dest):
- _pto.TMovOp(None, source, dest)
+ _call(_pto.TMovOp, None, source, dest)
def add(lhs, rhs, out):
- _pto.TAddOp(lhs, rhs, out)
+ _call(_pto.TAddOp, lhs, rhs, out)
def sub(lhs, rhs, out):
- _pto.TSubOp(lhs, rhs, out)
+ _call(_pto.TSubOp, lhs, rhs, out)
def div(lhs, rhs, out):
- _pto.TDivOp(lhs, rhs, out)
+ _call(_pto.TDivOp, lhs, rhs, out)
def mul(lhs, rhs, out):
- _pto.TMulOp(lhs, rhs, out)
+ _call(_pto.TMulOp, lhs, rhs, out)
def or_(lhs, rhs, out):
- _pto.TOrOp(lhs, rhs, out)
+ _call(_pto.TOrOp, lhs, rhs, out)
def min(lhs, rhs, out):
- _pto.TMinOp(lhs, rhs, out)
+ _call(_pto.TMinOp, lhs, rhs, out)
def max(lhs, rhs, out):
- _pto.TMaxOp(lhs, rhs, out)
+ _call(_pto.TMaxOp, lhs, rhs, out)
def gather(src, out, indices=None, *, mask_pattern=None):
if mask_pattern is not None:
mask = _pto.MaskPatternAttr.get(getattr(_pto.MaskPattern, mask_pattern))
- _pto.TGatherOp(src, out, maskPattern=mask)
+ _call(_pto.TGatherOp, src, out, maskPattern=mask)
else:
- _pto.TGatherOp(src, out, indices=indices)
+ _call(_pto.TGatherOp, src, out, indices=indices)
def exp(inp, out):
- _pto.TExpOp(inp, out)
+ _call(_pto.TExpOp, inp, out)
def log(inp, out):
- _pto.TLogOp(inp, out)
+ _call(_pto.TLogOp, inp, out)
def relu(inp, out):
- _pto.TReluOp(inp, out)
+ _call(_pto.TReluOp, inp, out)
def abs(inp, out):
- _pto.TAbsOp(inp, out)
+ _call(_pto.TAbsOp, inp, out)
def sqrt(inp, out):
- _pto.TSqrtOp(inp, out)
+ _call(_pto.TSqrtOp, inp, out)
def rsqrt(inp, out):
- _pto.TRsqrtOp(inp, out)
+ _call(_pto.TRsqrtOp, inp, out)
def reciprocal(inp, out):
- _pto.TRecipOp(inp, out)
+ _call(_pto.TRecipOp, inp, out)
def matmul(lhs, rhs, out):
- _pto.TMatmulOp(None, lhs, rhs, out)
+ _call(_pto.TMatmulOp, None, lhs, rhs, out)
def matmul_bias(lhs, rhs, bias, out):
- _pto.TMatmulBiasOp(None, lhs, rhs, bias, out)
+ _call(_pto.TMatmulBiasOp, None, lhs, rhs, bias, out)
def matmul_acc(acc, lhs, rhs, out):
- _pto.TMatmulAccOp(None, acc, lhs, rhs, out)
+ _call(_pto.TMatmulAccOp, None, acc, lhs, rhs, out)
def extract(source, index_row, index_col, out):
_pto.TExtractOp(
- src=source, indexRow=_unwrap(index_row), indexCol=_unwrap(index_col), dst=out
+ src=_unwrap(source),
+ indexRow=_unwrap(index_row),
+ indexCol=_unwrap(index_col),
+ dst=_unwrap(out),
)
def row_sum(src, tmp, dst):
- _pto.TRowSumOp(src=src, tmp=tmp, dst=dst)
+ _call(_pto.TRowSumOp, src=src, tmp=tmp, dst=dst)
def row_min(src, tmp, dst):
- _pto.TRowMinOp(src=src, tmp=tmp, dst=dst)
+ _call(_pto.TRowMinOp, src=src, tmp=tmp, dst=dst)
def row_max(src, tmp, dst):
- _pto.TRowMaxOp(src=src, tmp=tmp, dst=dst)
+ _call(_pto.TRowMaxOp, src=src, tmp=tmp, dst=dst)
def row_prod(src, tmp, dst):
- _pto.TRowProdOp(src=src, tmp=tmp, dst=dst)
+ _call(_pto.TRowProdOp, src=src, tmp=tmp, dst=dst)
def row_expand(src, dst):
- _pto.TRowExpandOp(src=src, dst=dst)
+ _call(_pto.TRowExpandOp, src=src, dst=dst)
def row_expand_sub(src0, src1, dst):
- _pto.TRowExpandSubOp(src0=src0, src1=src1, dst=dst)
+ _call(_pto.TRowExpandSubOp, src0=src0, src1=src1, dst=dst)
def row_expand_div(src0, src1, dst):
- _pto.TRowExpandDivOp(src0=src0, src1=src1, dst=dst)
+ _call(_pto.TRowExpandDivOp, src0=src0, src1=src1, dst=dst)
def row_expand_mul(src0, src1, dst):
- _pto.TRowExpandMulOp(src0=src0, src1=src1, dst=dst)
+ _call(_pto.TRowExpandMulOp, src0=src0, src1=src1, dst=dst)
def col_sum(src, tmp, dst, is_binary=True):
- _pto.TColSumOp(src=src, dst=dst, tmp=tmp, isBinary=BoolAttr.get(is_binary))
+ _call(_pto.TColSumOp, src=src, dst=dst, tmp=tmp, isBinary=BoolAttr.get(is_binary))
def col_min(src, dst):
- _pto.TColMinOp(src=src, dst=dst)
+ _call(_pto.TColMinOp, src=src, dst=dst)
def col_max(src, dst):
- _pto.TColMaxOp(src=src, dst=dst)
+ _call(_pto.TColMaxOp, src=src, dst=dst)
def col_prod(src, tmp, dst, is_binary=True):
- _pto.TColProdOp(src=src, dst=dst, tmp=tmp, isBinary=BoolAttr.get(is_binary))
+ _call(_pto.TColProdOp, src=src, dst=dst, tmp=tmp, isBinary=BoolAttr.get(is_binary))
def col_expand(src, dst):
- _pto.TColExpandOp(src=src, dst=dst)
+ _call(_pto.TColExpandOp, src=src, dst=dst)
def mrgsort(src, dst, block_len):
i32 = IntegerType.get_signless(32)
block_len_i32 = _arith.IndexCastOp(i32, _unwrap(block_len)).result
- _pto.TMrgSortOp(srcs=[src], dsts=[dst], blockLen=block_len_i32)
+ _pto.TMrgSortOp(srcs=[_unwrap(src)], dsts=[_unwrap(dst)], blockLen=block_len_i32)
def sort32(src, dst, idx):
@@ -154,16 +164,16 @@ def sort32(src, dst, idx):
(score, index) pairs to dst. idx is an input tile of uint32 indices
attached to each src element. For float16 src, dst must have 4x the
columns of src (each element expands to 4 float16 words)."""
- _pto.TSort32Op(src, dst, idx)
+ _call(_pto.TSort32Op, src, dst, idx)
def subset(source, offsets, sizes):
offset_vals = [_unwrap(v) for v in offsets]
- return _pto.subset(source, offset_vals, sizes)
+ return _pto.subset(_unwrap(source), offset_vals, sizes)
def print(source):
- _pto.tprint(source)
+ _pto.tprint(_unwrap(source))
__all__ = [
@@ -188,7 +198,6 @@ def print(source):
"row_sum",
"row_min",
"row_max",
- "row_prod",
"row_expand",
"row_expand_sub",
"row_expand_div",
@@ -196,7 +205,6 @@ def print(source):
"col_sum",
"col_min",
"col_max",
- "col_prod",
"col_expand",
"mrgsort",
"sort32",
diff --git a/ptodsl/api/type_def.py b/ptodsl/api/type_def.py
index 251303f6..b544ffee 100644
--- a/ptodsl/api/type_def.py
+++ b/ptodsl/api/type_def.py
@@ -1,5 +1,6 @@
from mlir.dialects import pto as _pto
+from .._constexpr import require_static_int, require_static_int_sequence
from . import scalar
@@ -11,16 +12,48 @@ def __getattr__(name):
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
-def PtrType(dtype):
- return _pto.PtrType.get(dtype)
+def _resolve_memory_space(memory_space):
+ if memory_space is None:
+ return None
+ if isinstance(memory_space, str):
+ return getattr(_pto.AddressSpace, memory_space)
+ if hasattr(memory_space, "value") and not isinstance(memory_space, int):
+ return int(memory_space.value)
+ return memory_space
+
+
+def PtrType(dtype, memory_space=None):
+ resolved = scalar.resolve_type(dtype)
+ memory_space = _resolve_memory_space(memory_space)
+ if memory_space is None:
+ return _pto.PtrType.get(resolved)
+ return _pto.PtrType.get(resolved, memory_space)
+
+
+def VRegType(lanes, dtype):
+ return _pto.VRegType.get(require_static_int(lanes, context="VRegType.lanes"), scalar.resolve_type(dtype))
+
+
+def MaskType():
+ return _pto.MaskType.get()
+
+
+def AlignType():
+ return _pto.AlignType.get()
def TensorType(*, rank, dtype):
- return _pto.TensorViewType.get(rank, dtype)
+ return _pto.TensorViewType.get(
+ require_static_int(rank, context="TensorType.rank"),
+ scalar.resolve_type(dtype),
+ )
def SubTensorType(*, shape, dtype):
- return _pto.PartitionTensorViewType.get(shape, dtype)
+ return _pto.PartitionTensorViewType.get(
+ require_static_int_sequence(shape, context="SubTensorType.shape"),
+ scalar.resolve_type(dtype),
+ )
class TileBufConfig:
@@ -32,7 +65,9 @@ def __init__(
self._bl = _pto.BLayoutAttr.get(getattr(_pto.BLayout, blayout))
self._sl = _pto.SLayoutAttr.get(getattr(_pto.SLayout, slayout))
self._pd = _pto.PadValueAttr.get(getattr(_pto.PadValue, pad))
- self._s_fractal_size = s_fractal_size
+ self._s_fractal_size = require_static_int(
+ s_fractal_size, context="TileBufConfig.s_fractal_size"
+ )
@property
def attr(self):
@@ -88,17 +123,27 @@ def _default_tile_config(memory_space, shape):
def TileBufType(*, shape, dtype, memory_space, valid_shape=None, config=None):
+ shape = require_static_int_sequence(shape, context="TileBufType.shape")
space = _pto.AddressSpaceAttr.get(getattr(_pto.AddressSpace, memory_space))
if valid_shape is None:
valid_shape = shape
+ else:
+ valid_shape = require_static_int_sequence(
+ valid_shape, context="TileBufType.valid_shape"
+ )
if config is None:
config = _default_tile_config(memory_space, shape)
cfg = config.attr if isinstance(config, TileBufConfig) else config
- return _pto.TileBufType.get(shape, dtype, space, valid_shape, cfg)
+ return _pto.TileBufType.get(
+ shape, scalar.resolve_type(dtype), space, valid_shape, cfg
+ )
__all__ = [
"PtrType",
+ "VRegType",
+ "MaskType",
+ "AlignType",
"TensorType",
"SubTensorType",
"TileBufConfig",
diff --git a/ptodsl/compiler/ir.py b/ptodsl/compiler/ir.py
index b32730ef..4ce19543 100644
--- a/ptodsl/compiler/ir.py
+++ b/ptodsl/compiler/ir.py
@@ -1,13 +1,22 @@
import inspect
+from functools import update_wrapper
from mlir.dialects import func, pto as _pto
from mlir.ir import Context, InsertionPoint, Location, Module
-from ..api.scalar import wrap_value
+from .._constexpr import (
+ analyze_signature,
+ bind_constexpr_arguments,
+ meta_kwargs_for,
+ unwrap_constexpr_annotation,
+)
+from ..api.scalar import resolve_type, wrap_value
-def _resolve_meta(meta_fn):
- values = meta_fn()
+def _resolve_meta(meta_fn, constexpr_bindings=None):
+ constexpr_bindings = constexpr_bindings or {}
+ kwargs = meta_kwargs_for(meta_fn, constexpr_bindings)
+ values = meta_fn() if kwargs is None else meta_fn(**kwargs)
if not isinstance(values, dict):
raise ValueError(
"`meta_data()` must return a dict of named symbols to MLIR/PTO types."
@@ -15,18 +24,18 @@ def _resolve_meta(meta_fn):
return dict(values)
-def _resolve_arg_types(signature, meta_map):
+def _resolve_arg_types(parameters, meta_map):
arg_types = []
- for param in signature.parameters.values():
- annot = param.annotation
+ for param in parameters:
+ annot = unwrap_constexpr_annotation(param.annotation)
if isinstance(annot, str):
if annot not in meta_map:
raise ValueError(f"Unknown annotation '{annot}'.")
- arg_types.append(meta_map[annot])
+ arg_types.append(resolve_type(meta_map[annot]))
elif annot is inspect._empty:
raise ValueError(f"Missing annotation for argument '{param.name}'.")
else:
- arg_types.append(annot)
+ arg_types.append(resolve_type(annot))
return arg_types
@@ -37,16 +46,16 @@ def _resolve_ret_types(signature, meta_map):
if isinstance(ret_annot, str):
if ret_annot not in meta_map:
raise ValueError(f"Unknown return annotation '{ret_annot}'.")
- return [meta_map[ret_annot]]
+ return [resolve_type(meta_map[ret_annot])]
if isinstance(ret_annot, (list, tuple)):
out = []
for elem in ret_annot:
if isinstance(elem, str):
- out.append(meta_map[elem])
+ out.append(resolve_type(meta_map[elem]))
else:
- out.append(elem)
+ out.append(resolve_type(elem))
return out
- return [ret_annot]
+ return [resolve_type(ret_annot)]
def _has_func_return(block):
@@ -72,36 +81,55 @@ def _restore_globals(fn, old, injected_names):
fn.__globals__[name] = old[name]
+def _build_ir_module(fn, analysis, meta_data, constexpr_bindings):
+ constexpr_bindings = constexpr_bindings or {}
+
+ with Context() as ctx, Location.unknown():
+ _pto.register_dialect(ctx, load=True)
+ meta_map = _resolve_meta(meta_data, constexpr_bindings)
+ arg_types = _resolve_arg_types(analysis.runtime_params, meta_map)
+ ret_types = _resolve_ret_types(analysis.signature, meta_map)
+ module = Module.create()
+ fn_ty = func.FunctionType.get(arg_types, ret_types)
+
+ with InsertionPoint(module.body):
+ ir_func = func.FuncOp(fn.__name__, fn_ty)
+ entry = ir_func.add_entry_block()
+
+ with InsertionPoint(entry):
+ wrapped_runtime_args = iter(wrap_value(arg) for arg in entry.arguments)
+ call_args = []
+ for param in analysis.signature.parameters.values():
+ if param.name in constexpr_bindings:
+ call_args.append(constexpr_bindings[param.name])
+ else:
+ call_args.append(next(wrapped_runtime_args))
+
+ injected = set(meta_map.keys())
+ old_globals = _inject_globals(fn, meta_map)
+ try:
+ fn(*call_args)
+ finally:
+ _restore_globals(fn, old_globals, injected)
+
+ if not ret_types and not _has_func_return(entry):
+ func.ReturnOp([])
+
+ module.operation.verify()
+ return module
+
+
def to_ir_module(*, meta_data):
def decorator(fn):
- sig = inspect.signature(fn)
-
- with Context() as ctx, Location.unknown():
- _pto.register_dialect(ctx, load=True)
- meta_map = _resolve_meta(meta_data)
- arg_types = _resolve_arg_types(sig, meta_map)
- ret_types = _resolve_ret_types(sig, meta_map)
- module = Module.create()
- fn_ty = func.FunctionType.get(arg_types, ret_types)
-
- with InsertionPoint(module.body):
- ir_func = func.FuncOp(fn.__name__, fn_ty)
- entry = ir_func.add_entry_block()
-
- with InsertionPoint(entry):
- wrapped_args = [wrap_value(arg) for arg in entry.arguments]
- injected = set(meta_map.keys())
- old_globals = _inject_globals(fn, meta_map)
- try:
- fn(*wrapped_args)
- finally:
- _restore_globals(fn, old_globals, injected)
-
- if not ret_types and not _has_func_return(entry):
- func.ReturnOp([])
-
- module.operation.verify()
- return module
+ analysis = analyze_signature(fn)
+ if not analysis.has_constexpr_params:
+ return _build_ir_module(fn, analysis, meta_data, {})
+
+ def specialize(*args, **kwargs):
+ constexpr_bindings = bind_constexpr_arguments(analysis, *args, **kwargs)
+ return _build_ir_module(fn, analysis, meta_data, constexpr_bindings)
+
+ return update_wrapper(specialize, fn)
return decorator
diff --git a/ptodsl/compiler/jit.py b/ptodsl/compiler/jit.py
index 820fc00b..5f1cfa8d 100644
--- a/ptodsl/compiler/jit.py
+++ b/ptodsl/compiler/jit.py
@@ -1,14 +1,21 @@
import ctypes
-import inspect
import os
import pathlib
import subprocess
+from dataclasses import dataclass
from functools import update_wrapper
from mlir.dialects import pto as _pto
from mlir.ir import Context, Location
-from .ir import to_ir_module
+from .._constexpr import (
+ _MISSING,
+ analyze_signature,
+ bind_kernel_arguments,
+ normalize_constexpr_bindings,
+ specialization_suffix,
+)
+from .ir import _build_ir_module
def _type_repr(type_obj):
@@ -21,6 +28,12 @@ def _is_ptr_type(type_obj):
def _ptr_elem_cpp_type(type_obj):
type_repr = _type_repr(type_obj)
+ if "e8m0" in type_repr:
+ return "float8_e8m0_t"
+ if "e4m3" in type_repr:
+ return "float8_e4m3_t"
+ if "e5m2" in type_repr:
+ return "float8_e5m2_t"
if "f32" in type_repr:
return "float"
if "f16" in type_repr:
@@ -52,6 +65,8 @@ def _scalar_cpp_type(type_obj):
return "int32_t"
if "i64" in type_repr or "index" in type_repr:
return "int64_t"
+ if "e8m0" in type_repr or "e4m3" in type_repr or "e5m2" in type_repr:
+ return "uint8_t"
if "f32" in type_repr:
return "float"
if "f16" in type_repr:
@@ -63,6 +78,8 @@ def _scalar_ctype(type_obj):
type_repr = _type_repr(type_obj)
if "i64" in type_repr or "index" in type_repr:
return ctypes.c_int64
+ if "e8m0" in type_repr or "e4m3" in type_repr or "e5m2" in type_repr:
+ return ctypes.c_uint8
if "f32" in type_repr:
return ctypes.c_float
if "f16" in type_repr:
@@ -80,6 +97,14 @@ def _normalize_stream_ptr(stream_ptr):
return stream_ptr
+@dataclass
+class _CompiledSpecialization:
+ arg_types: list[object]
+ lib: object
+ lib_path: pathlib.Path
+ output_dir: pathlib.Path
+
+
class JitWrapper:
def __init__(
self,
@@ -93,7 +118,9 @@ def __init__(
):
self._fn = fn
self._meta_data = meta_data
- self._sig = inspect.signature(fn)
+ self._analysis = analyze_signature(fn)
+ self._sig = self._analysis.signature
+ self._runtime_params = list(self._analysis.runtime_params)
self._arg_types = None
self._output_dir = (
pathlib.Path(output_dir)
@@ -105,17 +132,21 @@ def __init__(
self._npu_arch = npu_arch
self._compiled = False
self._lib = None
+ self._compiled_specializations = {}
self._lib_path = self._output_dir / "kernel.so"
update_wrapper(self, fn)
- def _artifact_paths(self):
- pto_path = self._output_dir / "kernel.pto"
- cpp_path = self._output_dir / "kernel.cpp"
- caller_path = self._output_dir / "caller.cpp"
- return pto_path, cpp_path, caller_path, self._lib_path
+ def _artifact_paths(self, output_dir):
+ pto_path = output_dir / "kernel.pto"
+ cpp_path = output_dir / "kernel.cpp"
+ caller_path = output_dir / "caller.cpp"
+ lib_path = output_dir / "kernel.so"
+ return pto_path, cpp_path, caller_path, lib_path
- def _generate_caller_cpp(self, kernel_cpp_name):
- params = list(self._sig.parameters.values())
+ def _generate_caller_cpp(self, kernel_cpp_name, runtime_params=None):
+ params = list(
+ self._runtime_params if runtime_params is None else runtime_params
+ )
cpp_args = []
launch_args = []
for param, arg_type in zip(params, self._arg_types):
@@ -138,7 +169,7 @@ def _generate_caller_cpp(self, kernel_cpp_name):
"}\n"
)
- def _compile_shared_library(self, caller_cpp_path, lib_path):
+ def _compile_shared_library(self, caller_cpp_path, lib_path, *, cwd):
toolkit_home = os.environ.get("ASCEND_TOOLKIT_HOME")
if not toolkit_home:
raise RuntimeError(
@@ -175,34 +206,42 @@ def _compile_shared_library(self, caller_cpp_path, lib_path):
"-o",
str(lib_path),
]
- subprocess.run(cmd, check=True, cwd=str(self._output_dir))
+ subprocess.run(cmd, check=True, cwd=str(cwd))
- def _resolve_runtime_arg_types(self):
+ def _resolve_runtime_arg_types(self, constexpr_bindings):
from .ir import _resolve_arg_types, _resolve_meta
with Context() as ctx, Location.unknown():
_pto.register_dialect(ctx, load=True)
- meta_map = _resolve_meta(self._meta_data)
- return _resolve_arg_types(self._sig, meta_map)
-
- def _build(self):
- self._output_dir.mkdir(parents=True, exist_ok=True)
- pto_path, cpp_path, caller_path, lib_path = self._artifact_paths()
- self._arg_types = self._resolve_runtime_arg_types()
-
- ir_module = to_ir_module(meta_data=self._meta_data)(self._fn)
+ meta_map = _resolve_meta(self._meta_data, constexpr_bindings)
+ return _resolve_arg_types(self._runtime_params, meta_map)
+
+ def _specialization_output_dir(self, constexpr_bindings):
+ if not self._analysis.has_constexpr_params:
+ return self._output_dir
+ return self._output_dir / f"spec_{specialization_suffix(constexpr_bindings)}"
+
+ def _build(self, constexpr_bindings):
+ output_dir = self._specialization_output_dir(constexpr_bindings)
+ output_dir.mkdir(parents=True, exist_ok=True)
+ pto_path, cpp_path, caller_path, lib_path = self._artifact_paths(output_dir)
+ self._arg_types = self._resolve_runtime_arg_types(constexpr_bindings)
+
+ ir_module = _build_ir_module(
+ self._fn, self._analysis, self._meta_data, constexpr_bindings
+ )
pto_path.write_text(f"{ir_module}\n", encoding="utf-8")
ptoas_cmd = ["ptoas"]
if self._enable_insert_sync:
ptoas_cmd.append("--enable-insert-sync")
ptoas_cmd += [str(pto_path), "-o", str(cpp_path)]
- subprocess.run(ptoas_cmd, check=True, cwd=str(self._output_dir))
+ subprocess.run(ptoas_cmd, check=True, cwd=str(output_dir))
caller_path.write_text(
self._generate_caller_cpp(cpp_path.name), encoding="utf-8"
)
- self._compile_shared_library(caller_path, lib_path)
+ self._compile_shared_library(caller_path, lib_path, cwd=output_dir)
self._lib = ctypes.CDLL(str(lib_path))
self._lib.call_kernel.argtypes = [ctypes.c_uint32, ctypes.c_void_p] + [
@@ -210,6 +249,13 @@ def _build(self):
for arg_type in self._arg_types
]
self._compiled = True
+ self._lib_path = lib_path
+ return _CompiledSpecialization(
+ arg_types=list(self._arg_types),
+ lib=self._lib,
+ lib_path=lib_path,
+ output_dir=output_dir,
+ )
def _convert_ptr(self, value):
if isinstance(value, ctypes.c_void_p):
@@ -220,43 +266,37 @@ def _convert_ptr(self, value):
return ctypes.c_void_p(value)
raise TypeError(f"Pointer-like argument expected, got {type(value)!r}.")
- def _prepare_call_args(self, args):
- params = list(self._sig.parameters.values())
- if len(args) > len(params):
- raise TypeError(
- f"Expected at most {len(params)} arguments, got {len(args)}."
- )
-
- filled_args = list(args)
- for idx in range(len(args), len(params)):
- param = params[idx]
- if param.default is not inspect._empty:
- filled_args.append(param.default)
- continue
- arg_type = self._arg_types[idx]
- if _is_ptr_type(arg_type):
- raise TypeError(f"Missing required pointer argument '{param.name}'.")
-
+ def _prepare_call_args(self, runtime_values, arg_types):
converted = []
- for value, arg_type in zip(filled_args, self._arg_types):
+ for param, value, arg_type in zip(
+ self._runtime_params, runtime_values, arg_types
+ ):
+ if value is _MISSING:
+ raise TypeError(f"Missing required argument '{param.name}'.")
if _is_ptr_type(arg_type):
converted.append(self._convert_ptr(value))
else:
converted.append(value)
return converted
- # TODO: also allow taking named `kwargs`
- def __call__(self, *args, stream_ptr=None):
- if not self._compiled:
- self._build()
+ def __call__(self, *args, stream_ptr=None, **kwargs):
+ bound = bind_kernel_arguments(self._analysis, *args, **kwargs)
+ constexpr_key = normalize_constexpr_bindings(bound.constexpr_arguments)
+
+ specialization = self._compiled_specializations.get(constexpr_key)
+ if specialization is None:
+ specialization = self._build(bound.constexpr_arguments)
+ self._compiled_specializations[constexpr_key] = specialization
if stream_ptr is None:
import torch
stream_ptr = torch.npu.current_stream()._as_parameter_
- call_args = self._prepare_call_args(args)
- self._lib.call_kernel(
+ call_args = self._prepare_call_args(
+ bound.runtime_arguments, specialization.arg_types
+ )
+ specialization.lib.call_kernel(
ctypes.c_uint32(self._block_dim),
_normalize_stream_ptr(stream_ptr),
*call_args,
diff --git a/ptodsl/language.py b/ptodsl/language.py
new file mode 100644
index 00000000..f178ba73
--- /dev/null
+++ b/ptodsl/language.py
@@ -0,0 +1,600 @@
+from contextlib import contextmanager
+from dataclasses import dataclass
+from typing import Sequence
+
+from mlir import ir as mlir_ir
+from mlir.dialects import arith, pto, scf
+from mlir.ir import F16Type, F32Type, IndexType, InsertionPoint, IntegerType
+
+from ._constexpr import (
+ Constexpr,
+ const_expr,
+ range_constexpr,
+ require_static_int,
+ require_static_int_sequence,
+)
+from .api import micro as _micro_api
+from .api._micro_registry import MICRO_OPS
+from .api.scalar import LazyTypeAlias, resolve_type
+
+
+def _unwrap(value):
+ if isinstance(value, Value):
+ return value.raw
+ return value
+
+
+class Value:
+ # TODO: generalize to more comprehensive wrappers like https://github.com/makslevental/mlir-python-extras/blob/0.0.8.2/mlir/extras/dialects/ext/arith.py
+ def __init__(self, raw):
+ self.raw = raw
+
+ def __mul__(self, other):
+ return Value(arith.MulIOp(_unwrap(self), _unwrap(other)).result)
+
+ def __rmul__(self, other):
+ return Value(arith.MulIOp(_unwrap(other), _unwrap(self)).result)
+
+ def __add__(self, other):
+ return Value(arith.AddIOp(_unwrap(self), _unwrap(other)).result)
+
+ def __radd__(self, other):
+ return Value(arith.AddIOp(_unwrap(other), _unwrap(self)).result)
+
+ def __sub__(self, other):
+ return Value(arith.SubIOp(_unwrap(self), _unwrap(other)).result)
+
+ def __rsub__(self, other):
+ return Value(arith.SubIOp(_unwrap(other), _unwrap(self)).result)
+
+ def __floordiv__(self, other):
+ return Value(arith.DivSIOp(_unwrap(self), _unwrap(other)).result)
+
+ def __rfloordiv__(self, other):
+ return Value(arith.DivSIOp(_unwrap(other), _unwrap(self)).result)
+
+ def __truediv__(self, other):
+ return Value(arith.DivFOp(_unwrap(self), _unwrap(other)).result)
+
+ def __rtruediv__(self, other):
+ return Value(arith.DivFOp(_unwrap(other), _unwrap(self)).result)
+
+ def __mod__(self, other):
+ return Value(arith.RemSIOp(_unwrap(self), _unwrap(other)).result)
+
+ def __rmod__(self, other):
+ return Value(arith.RemSIOp(_unwrap(other), _unwrap(self)).result)
+
+ @staticmethod
+ def _cmp(lhs, rhs, predicate):
+ return Value(arith.CmpIOp(predicate, _unwrap(lhs), _unwrap(rhs)).result)
+
+ def __lt__(self, other):
+ return Value._cmp(self, other, arith.CmpIPredicate.slt)
+
+ def __gt__(self, other):
+ return Value._cmp(self, other, arith.CmpIPredicate.sgt)
+
+ def __le__(self, other):
+ return Value._cmp(self, other, arith.CmpIPredicate.sle)
+
+ def __ge__(self, other):
+ return Value._cmp(self, other, arith.CmpIPredicate.sge)
+
+ def __eq__(self, other):
+ return Value._cmp(self, other, arith.CmpIPredicate.eq)
+
+ def __ne__(self, other):
+ return Value._cmp(self, other, arith.CmpIPredicate.ne)
+
+ def __getattr__(self, item):
+ return getattr(self.raw, item)
+
+
+def wrap_value(value):
+ if isinstance(value, Value):
+ return value
+ return Value(value)
+
+
+@dataclass(frozen=True)
+class MXFP8DType:
+ lhs: object
+ rhs: object
+ scale: object
+ acc: object
+ scale_factor: int = 32
+
+ @property
+ def data(self):
+ return self.lhs
+
+ def scale_k(self, k):
+ if k % self.scale_factor != 0:
+ raise ValueError(f"k={k} must be divisible by scale_factor={self.scale_factor} for MXFP8.")
+ return k // self.scale_factor
+
+
+def _get_mlir_float_type(alias_name, *type_names):
+ def _resolve():
+ for type_name in type_names:
+ type_ctor = getattr(mlir_ir, type_name, None)
+ if type_ctor is not None:
+ return type_ctor.get()
+ supported = ", ".join(type_names)
+ raise AttributeError(
+ f"module '{__name__}' has no attribute '{alias_name}' because the active MLIR "
+ f"Python bindings do not expose any of: {supported}"
+ )
+
+ return LazyTypeAlias(alias_name, _resolve)
+
+
+def make_mxfp8(*, lhs="e5m2", rhs="e5m2", acc=None, scale_factor=32):
+ variants = {
+ "e4m3": __getattr__("fp8_e4m3"),
+ "e5m2": __getattr__("fp8_e5m2"),
+ }
+ if lhs not in variants:
+ raise ValueError(f"Unsupported lhs variant '{lhs}'. Expected one of: {', '.join(sorted(variants))}.")
+ if rhs not in variants:
+ raise ValueError(f"Unsupported rhs variant '{rhs}'. Expected one of: {', '.join(sorted(variants))}.")
+ return MXFP8DType(
+ lhs=variants[lhs],
+ rhs=variants[rhs],
+ scale=__getattr__("fp8_e8m0"),
+ acc=__getattr__("float32") if acc is None else acc,
+ scale_factor=scale_factor,
+ )
+
+
+def __getattr__(name):
+ # Keep aliases conservative and only expose types that map cleanly to MLIR/PTO.
+ if name == "bool":
+ return LazyTypeAlias(name, lambda: IntegerType.get_signless(1))
+ if name == "float32":
+ return LazyTypeAlias(name, lambda: F32Type.get())
+ if name == "float16":
+ return LazyTypeAlias(name, lambda: F16Type.get())
+ if name == "bfloat16":
+ return _get_mlir_float_type(name, "BF16Type")
+ if name in ("fp8_e4m3", "float8_e4m3"):
+ return _get_mlir_float_type(name, "Float8E4M3FNType", "Float8E4M3FNUZType")
+ if name in ("fp8_e5m2", "float8_e5m2"):
+ return _get_mlir_float_type(name, "Float8E5M2Type", "Float8E5M2FNUZType")
+ if name in ("fp8_e8m0", "float8_e8m0"):
+ return _get_mlir_float_type(name, "Float8E8M0FNUType", "Float8E8M0FNType")
+ if name == "mxfp8":
+ return make_mxfp8(lhs="e5m2", rhs="e5m2")
+ if name == "mxfp8_e4m3":
+ return make_mxfp8(lhs="e4m3", rhs="e4m3")
+ if name == "mxfp8_e5m2":
+ return make_mxfp8(lhs="e5m2", rhs="e5m2")
+ if name == "int32":
+ return LazyTypeAlias(name, lambda: IntegerType.get_signless(32))
+ if name == "int16":
+ return LazyTypeAlias(name, lambda: IntegerType.get_signless(16))
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
+
+
+def _resolve_memory_space(memory_space):
+ if memory_space is None:
+ return None
+ if isinstance(memory_space, str):
+ return getattr(pto.AddressSpace, memory_space)
+ if hasattr(memory_space, "value") and not isinstance(memory_space, int):
+ return int(memory_space.value)
+ return memory_space
+
+
+def PtrType(dtype, memory_space=None):
+ resolved = resolve_type(dtype)
+ memory_space = _resolve_memory_space(memory_space)
+ if memory_space is None:
+ return pto.PtrType.get(resolved)
+ return pto.PtrType.get(resolved, memory_space)
+
+
+def VRegType(lanes, dtype):
+ return pto.VRegType.get(
+ require_static_int(lanes, context="VRegType.lanes"),
+ resolve_type(dtype),
+ )
+
+
+def MaskType():
+ return pto.MaskType.get()
+
+
+def AlignType():
+ return pto.AlignType.get()
+
+
+def TensorType(*, rank, dtype):
+ return pto.TensorViewType.get(
+ require_static_int(rank, context="TensorType.rank"),
+ resolve_type(dtype),
+ )
+
+
+def SubTensorType(*, shape, dtype):
+ return pto.PartitionTensorViewType.get(
+ require_static_int_sequence(shape, context="SubTensorType.shape"),
+ resolve_type(dtype),
+ )
+
+
+class TileBufConfig:
+ def __init__(self, blayout="RowMajor", slayout="NoneBox", s_fractal_size=512, pad="Null"):
+ # TODO: expose and validate a broader set of tile buffer knobs if PTO adds
+ # more layout/padding/fractal settings that should be configurable here.
+ self._bl = pto.BLayoutAttr.get(getattr(pto.BLayout, blayout))
+ self._sl = pto.SLayoutAttr.get(getattr(pto.SLayout, slayout))
+ self._pd = pto.PadValueAttr.get(getattr(pto.PadValue, pad))
+ self._s_fractal_size = require_static_int(
+ s_fractal_size, context="TileBufConfig.s_fractal_size"
+ )
+
+ @property
+ def attr(self):
+ return pto.TileBufConfigAttr.get(self._bl, self._sl, self._s_fractal_size, self._pd)
+
+
+def _default_tile_config(memory_space, shape):
+ space = memory_space.upper()
+ # Defaults mirror the explicit configs used by the verbose matmul builder.
+ if space == "MAT":
+ if len(shape) >= 1 and shape[0] == 1:
+ return TileBufConfig(blayout="RowMajor", slayout="NoneBox", s_fractal_size=pto.TileConfig.fractalABSize)
+ return TileBufConfig(blayout="ColMajor", slayout="RowMajor", s_fractal_size=pto.TileConfig.fractalABSize)
+ if space == "LEFT":
+ return TileBufConfig(blayout="RowMajor", slayout="RowMajor", s_fractal_size=pto.TileConfig.fractalABSize)
+ if space == "RIGHT":
+ return TileBufConfig(blayout="RowMajor", slayout="ColMajor", s_fractal_size=pto.TileConfig.fractalABSize)
+ if space == "ACC":
+ return TileBufConfig(blayout="ColMajor", slayout="RowMajor", s_fractal_size=pto.TileConfig.fractalCSize)
+ if space == "BIAS":
+ return TileBufConfig(blayout="RowMajor", slayout="NoneBox", s_fractal_size=pto.TileConfig.fractalABSize)
+ if space == "SCALING":
+ return TileBufConfig(blayout="RowMajor", slayout="NoneBox", s_fractal_size=pto.TileConfig.fractalABSize)
+ if space == "VEC":
+ return TileBufConfig()
+ raise ValueError(f"Unsupported memory_space '{memory_space}' for default tile config.")
+
+
+def TileBufType(*, shape, dtype, memory_space, valid_shape=None, config=None):
+ shape = require_static_int_sequence(shape, context="TileBufType.shape")
+ space = pto.AddressSpaceAttr.get(getattr(pto.AddressSpace, memory_space))
+ if valid_shape is None:
+ valid_shape = shape
+ else:
+ valid_shape = require_static_int_sequence(
+ valid_shape, context="TileBufType.valid_shape"
+ )
+ if config is None:
+ config = _default_tile_config(memory_space, shape)
+ cfg = config.attr if isinstance(config, TileBufConfig) else config
+ return pto.TileBufType.get(shape, resolve_type(dtype), space, valid_shape, cfg)
+
+
+def LeftScaleTileBufType(*, shape, dtype, valid_shape=None, config=None):
+ if config is None:
+ config = TileBufConfig(
+ blayout="RowMajor",
+ slayout="RowMajor",
+ s_fractal_size=pto.TileConfig.fractalMxSize,
+ )
+ return TileBufType(shape=shape, dtype=dtype, memory_space="SCALING", valid_shape=valid_shape, config=config)
+
+
+def RightScaleTileBufType(*, shape, dtype, valid_shape=None, config=None):
+ if config is None:
+ config = TileBufConfig(
+ blayout="ColMajor",
+ slayout="ColMajor",
+ s_fractal_size=pto.TileConfig.fractalMxSize,
+ )
+ return TileBufType(shape=shape, dtype=dtype, memory_space="SCALING", valid_shape=valid_shape, config=config)
+
+
+def const(value):
+ return Value(arith.ConstantOp(IndexType.get(), value).result)
+
+
+def get_block_idx():
+ return Value(pto.GetBlockIdxOp().result)
+
+
+def get_subblock_idx():
+ return Value(pto.GetSubBlockIdxOp().result)
+
+
+def get_subblock_num():
+ return Value(pto.GetSubBlockNumOp().result)
+
+
+def get_block_num():
+ return Value(pto.GetBlockNumOp().result)
+
+
+def index_cast(value, index_type=IndexType):
+ if hasattr(index_type, "get"):
+ dst = index_type.get()
+ else:
+ dst = index_type
+ return Value(arith.IndexCastOp(dst, _unwrap(value)).result)
+
+
+def as_tensor(tensor_type, *, ptr, shape, strides):
+ shape_vals = [_unwrap(v) for v in shape]
+ stride_vals = [_unwrap(v) for v in strides]
+ return pto.MakeTensorViewOp(tensor_type, _unwrap(ptr), shape_vals, stride_vals).result
+
+
+def slice_view(subtensor_type, *, source, offsets, sizes):
+ offset_vals = [_unwrap(v) for v in offsets]
+ size_vals = [_unwrap(v) for v in sizes]
+ return pto.PartitionViewOp(subtensor_type, source, offsets=offset_vals, sizes=size_vals).result
+
+
+@contextmanager
+def vector_section():
+ section = pto.SectionVectorOp()
+ block = section.body.blocks.append()
+ with InsertionPoint(block):
+ yield
+
+
+@contextmanager
+def cube_section():
+ section = pto.SectionCubeOp()
+ block = section.body.blocks.append()
+ with InsertionPoint(block):
+ yield
+
+
+def for_range(start, stop, step):
+ loop = scf.ForOp(_unwrap(start), _unwrap(stop), _unwrap(step))
+ with InsertionPoint(loop.body):
+ yield Value(loop.induction_variable)
+ scf.YieldOp([])
+
+
+def alloc_tile(tile_type, *, valid_row=None, valid_col=None):
+ kwargs = {}
+ if valid_row is not None:
+ kwargs["valid_row"] = _unwrap(valid_row)
+ if valid_col is not None:
+ kwargs["valid_col"] = _unwrap(valid_col)
+ return pto.AllocTileOp(tile_type, **kwargs).result
+
+
+def subset(source, offsets, sizes):
+ offset_vals = [_unwrap(v) for v in offsets]
+ return pto.subset(source, offset_vals, sizes)
+
+
+def load(source, dest):
+ pto.TLoadOp(None, source, dest)
+
+
+def mov(source, dest):
+ pto.TMovOp(None, source, dest)
+
+
+def add(lhs, rhs, out):
+ pto.TAddOp(lhs, rhs, out)
+
+
+def sub(lhs, rhs, out):
+ pto.TSubOp(lhs, rhs, out)
+
+
+def div(lhs, rhs, out):
+ pto.TDivOp(lhs, rhs, out)
+
+
+def mul(lhs, rhs, out):
+ pto.TMulOp(lhs, rhs, out)
+
+
+def or_(lhs, rhs, out):
+ pto.TOrOp(lhs, rhs, out)
+
+
+def gather(src, out, indices=None, *, mask_pattern=None):
+ if mask_pattern is not None:
+ mp = pto.MaskPatternAttr.get(getattr(pto.MaskPattern, mask_pattern))
+ pto.TGatherOp(src, out, maskPattern=mp)
+ else:
+ pto.TGatherOp(src, out, indices=indices)
+
+
+def exp(inp, out):
+ pto.TExpOp(inp, out)
+
+
+def log(inp, out):
+ pto.TLogOp(inp, out)
+
+
+def relu(inp, out):
+ pto.TReluOp(inp, out)
+
+
+def abs(inp, out):
+ pto.TAbsOp(inp, out)
+
+
+def sqrt(inp, out):
+ pto.TSqrtOp(inp, out)
+
+
+def store(source, dest):
+ pto.TStoreOp(None, source, dest)
+
+
+def matmul(lhs, rhs, out):
+ pto.TMatmulOp(None, lhs, rhs, out)
+
+
+def matmul_bias(lhs, rhs, bias, out):
+ pto.TMatmulBiasOp(None, lhs, rhs, bias, out)
+
+
+def matmul_acc(acc, lhs, rhs, out):
+ pto.TMatmulAccOp(None, acc, lhs, rhs, out)
+
+
+def _emit_dps_op(op_name, *operands):
+ op_ctor = getattr(pto, op_name, None)
+ if op_ctor is not None:
+ return op_ctor(None, *operands)
+ generic_name = {
+ "TMatmulMxOp": "pto.tmatmul.mx",
+ "TMatmulMxAccOp": "pto.tmatmul.mx.acc",
+ "TMatmulMxBiasOp": "pto.tmatmul.mx.bias",
+ }[op_name]
+ return mlir_ir.Operation.create(generic_name, operands=list(operands))
+
+
+def matmul_mx(lhs, lhs_scale, rhs, rhs_scale, out):
+ _emit_dps_op("TMatmulMxOp", lhs, lhs_scale, rhs, rhs_scale, out)
+
+
+def matmul_mx_acc(acc, lhs, lhs_scale, rhs, rhs_scale, out):
+ _emit_dps_op("TMatmulMxAccOp", acc, lhs, lhs_scale, rhs, rhs_scale, out)
+
+
+def matmul_mx_bias(lhs, lhs_scale, rhs, rhs_scale, bias, out):
+ _emit_dps_op("TMatmulMxBiasOp", lhs, lhs_scale, rhs, rhs_scale, bias, out)
+
+
+def ceil_div(a, b):
+ return Value(arith.CeilDivSIOp(_unwrap(a), _unwrap(b)).result)
+
+
+def div_s(a, b):
+ return Value(arith.DivSIOp(_unwrap(a), _unwrap(b)).result)
+
+
+def rem_s(a, b):
+ return Value(arith.RemSIOp(_unwrap(a), _unwrap(b)).result)
+
+
+def min_u(a, b):
+ return Value(arith.MinUIOp(_unwrap(a), _unwrap(b)).result)
+
+
+def eq(a, b):
+ return Value(arith.CmpIOp(arith.CmpIPredicate.eq, _unwrap(a), _unwrap(b)).result)
+
+
+def lt(a, b):
+ return Value(arith.CmpIOp(arith.CmpIPredicate.slt, _unwrap(a), _unwrap(b)).result)
+
+
+def gt(a, b):
+ return Value(arith.CmpIOp(arith.CmpIPredicate.sgt, _unwrap(a), _unwrap(b)).result)
+
+
+def ge(a, b):
+ return Value(arith.CmpIOp(arith.CmpIPredicate.sge, _unwrap(a), _unwrap(b)).result)
+
+
+def select(cond, true_val, false_val):
+ return Value(arith.SelectOp(_unwrap(cond), _unwrap(true_val), _unwrap(false_val)).result)
+
+
+class _IfElseBranch:
+ def __init__(self, if_op):
+ self._if_op = if_op
+ @contextmanager
+ def else_context(self):
+ with InsertionPoint(self._if_op.else_block):
+ yield
+ scf.YieldOp([])
+
+@contextmanager
+def if_context(condition, has_else=False):
+ if has_else:
+ op = scf.IfOp(_unwrap(condition), [], hasElse=True)
+ branch = _IfElseBranch(op)
+ else:
+ op = scf.IfOp(_unwrap(condition))
+ branch = None
+
+ with InsertionPoint(op.then_block):
+ yield branch
+ scf.YieldOp([])
+
+
+def cond(condition, then_builder, else_builder):
+ op = scf.IfOp(_unwrap(condition), [], hasElse=True)
+ with InsertionPoint(op.then_block):
+ then_builder()
+ scf.YieldOp([])
+ with InsertionPoint(op.else_block):
+ else_builder()
+ scf.YieldOp([])
+ return op
+
+def _resolve_sync_op(sync_op):
+ if isinstance(sync_op, str):
+ normalized = sync_op.strip().upper()
+ if not normalized.startswith("T"):
+ normalized = f"T{normalized}"
+ try:
+ return getattr(pto, normalized)
+ except AttributeError as exc:
+ raise ValueError(f"Unsupported sync op type '{sync_op}'.") from exc
+ return sync_op
+
+
+def _resolve_event_id(event_id):
+ if isinstance(event_id, int):
+ if event_id < 0 or event_id > 7:
+ raise ValueError(f"event_id must be in range [0, 7], got {event_id}.")
+ return getattr(pto, f"EVENT_ID{event_id}")
+ return event_id
+
+
+def record_event(record_op, wait_op, event_id: int|Sequence[int]=0):
+ if not isinstance(event_id, int):
+ for eid in event_id:
+ pto.record_event(_resolve_sync_op(record_op), _resolve_sync_op(wait_op), _resolve_event_id(eid))
+ else:
+ pto.record_event(_resolve_sync_op(record_op), _resolve_sync_op(wait_op), _resolve_event_id(event_id))
+
+
+
+def wait_event(record_op, wait_op, event_id: int|Sequence[int]=0):
+ if not isinstance(event_id, int):
+ for eid in event_id:
+ pto.wait_event(_resolve_sync_op(record_op), _resolve_sync_op(wait_op), _resolve_event_id(eid))
+ else:
+ pto.wait_event(_resolve_sync_op(record_op), _resolve_sync_op(wait_op), _resolve_event_id(event_id))
+
+
+def record_wait_pair(record_op, wait_op, event_id: int|Sequence[int]=0):
+ rec = _resolve_sync_op(record_op)
+ w = _resolve_sync_op(wait_op)
+ ev = _resolve_event_id(event_id)
+ pto.record_event(rec, w, ev)
+ pto.wait_event(rec, w, ev)
+
+
+def barrier(sync_op):
+ pto.barrier(_resolve_sync_op(sync_op))
+
+
+def barrier_sync(sync_op):
+ pto.barrier_sync(pto.SyncOpTypeAttr.get(_resolve_sync_op(sync_op)))
+
+
+def row_sum(src, tmp, dst):
+ pto.TRowSumOp(src = src, tmp = tmp, dst = dst)
+
+
+for _name in MICRO_OPS:
+ globals()[_name] = getattr(_micro_api, _name)
diff --git a/ptodsl/micro.py b/ptodsl/micro.py
new file mode 100644
index 00000000..01b500a3
--- /dev/null
+++ b/ptodsl/micro.py
@@ -0,0 +1,6 @@
+from .api import micro as _micro
+from .api.micro import __all__
+
+
+def __getattr__(name):
+ return getattr(_micro, name)
diff --git a/scripts/update_public_api_snapshot.py b/scripts/update_public_api_snapshot.py
new file mode 100644
index 00000000..41f3ba85
--- /dev/null
+++ b/scripts/update_public_api_snapshot.py
@@ -0,0 +1,21 @@
+#!/usr/bin/env python3
+import sys
+from pathlib import Path
+
+
+def main():
+ repo_root = Path(__file__).resolve().parents[1]
+ sys.path.insert(0, str(repo_root))
+ sys.path.insert(0, str(repo_root / "tests" / "api"))
+
+ from _contract import SNAPSHOT_PATH, collect_public_api_snapshot, snapshot_json
+
+ SNAPSHOT_PATH.write_text(
+ snapshot_json(collect_public_api_snapshot()),
+ encoding="utf-8",
+ )
+ print(f"Updated {SNAPSHOT_PATH.relative_to(repo_root)}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 00000000..c47b532c
--- /dev/null
+++ b/tests/__init__.py
@@ -0,0 +1 @@
+# Test helpers are imported across subpackages.
diff --git a/tests/_vpto_manifest.py b/tests/_vpto_manifest.py
new file mode 100644
index 00000000..520d57b4
--- /dev/null
+++ b/tests/_vpto_manifest.py
@@ -0,0 +1,37 @@
+import json
+import os
+from pathlib import Path
+
+
+ENV_VAR = "PTOAS_VPTO_MANIFEST"
+
+
+def vpto_manifest_path():
+ repo_root = Path(__file__).resolve().parents[1]
+ configured = os.environ.get(ENV_VAR)
+
+ candidates = []
+ if configured:
+ candidates.append(Path(configured).expanduser())
+
+ candidates.extend(
+ [
+ repo_root.parent / "PTOAS" / "docs" / "vpto-manifest.json",
+ repo_root / "PTOAS" / "docs" / "vpto-manifest.json",
+ ]
+ )
+
+ for candidate in candidates:
+ if candidate.is_file():
+ return candidate
+
+ searched = "\n".join(f"- {candidate}" for candidate in candidates)
+ raise FileNotFoundError(
+ "Could not locate PTOAS vPTO manifest. Set "
+ f"`{ENV_VAR}` or provide a sibling PTOAS checkout.\n"
+ f"Searched:\n{searched}"
+ )
+
+
+def load_vpto_manifest():
+ return json.loads(vpto_manifest_path().read_text(encoding="utf-8"))
diff --git a/tests/api/__init__.py b/tests/api/__init__.py
new file mode 100644
index 00000000..b237d370
--- /dev/null
+++ b/tests/api/__init__.py
@@ -0,0 +1 @@
+# API contract and module-focused test package.
diff --git a/tests/api/_contract.py b/tests/api/_contract.py
new file mode 100644
index 00000000..d3a19cac
--- /dev/null
+++ b/tests/api/_contract.py
@@ -0,0 +1,50 @@
+import importlib
+import json
+from pathlib import Path
+
+
+PUBLIC_MODULES = [
+ "ptodsl",
+ "ptodsl.pto",
+ "ptodsl.tile",
+ "ptodsl.scalar",
+ "ptodsl.micro",
+ "ptodsl.compiler",
+ "ptodsl.api.pto",
+ "ptodsl.api.tile",
+ "ptodsl.api.scalar",
+ "ptodsl.api.micro",
+ "ptodsl.api.control_flow",
+ "ptodsl.api.synchronization",
+ "ptodsl.api.type_def",
+ "ptodsl.bench",
+]
+
+MIRROR_MODULES = [
+ ("ptodsl.pto", "ptodsl.api.pto"),
+ ("ptodsl.tile", "ptodsl.api.tile"),
+ ("ptodsl.scalar", "ptodsl.api.scalar"),
+ ("ptodsl.micro", "ptodsl.api.micro"),
+]
+
+SNAPSHOT_PATH = Path(__file__).with_name("public_api_snapshot.json")
+
+
+def exported_names(module_name):
+ module = importlib.import_module(module_name)
+ exports = getattr(module, "__all__", None)
+ if exports is None:
+ raise AssertionError(f"{module_name} is missing __all__.")
+ return sorted(dict.fromkeys(exports))
+
+
+def collect_public_api_snapshot():
+ return {module_name: exported_names(module_name) for module_name in PUBLIC_MODULES}
+
+
+def load_snapshot():
+ return json.loads(SNAPSHOT_PATH.read_text(encoding="utf-8"))
+
+
+def snapshot_json(snapshot):
+ return json.dumps(snapshot, indent=2, sort_keys=True) + "\n"
diff --git a/tests/api/conftest.py b/tests/api/conftest.py
new file mode 100644
index 00000000..d5652e85
--- /dev/null
+++ b/tests/api/conftest.py
@@ -0,0 +1,10 @@
+import pytest
+from mlir.dialects import pto as mlir_pto
+from mlir.ir import Context, Location
+
+
+@pytest.fixture
+def mlir_ctx():
+ with Context() as ctx, Location.unknown():
+ mlir_pto.register_dialect(ctx, load=True)
+ yield ctx
diff --git a/tests/api/public_api_snapshot.json b/tests/api/public_api_snapshot.json
new file mode 100644
index 00000000..50d87c9e
--- /dev/null
+++ b/tests/api/public_api_snapshot.json
@@ -0,0 +1,705 @@
+{
+ "ptodsl": [
+ "Constexpr",
+ "JitWrapper",
+ "const_expr",
+ "do_bench",
+ "jit",
+ "micro",
+ "pto",
+ "range_constexpr",
+ "scalar",
+ "tile",
+ "to_ir_module"
+ ],
+ "ptodsl.api.control_flow": [
+ "cond",
+ "const_expr",
+ "if_context",
+ "range",
+ "range_constexpr"
+ ],
+ "ptodsl.api.micro": [
+ "MICRO_OPS",
+ "addptr",
+ "barrier",
+ "castptr",
+ "copy_gm_to_ubuf",
+ "copy_ubuf_to_gm",
+ "copy_ubuf_to_ubuf",
+ "get_buf",
+ "pdintlv_b8",
+ "pge_b16",
+ "pge_b32",
+ "pge_b8",
+ "pintlv_b16",
+ "pld",
+ "pldi",
+ "plds",
+ "plt_b16",
+ "plt_b32",
+ "plt_b8",
+ "pnot",
+ "ppack",
+ "psel",
+ "pset_b16",
+ "pset_b32",
+ "pset_b8",
+ "pst",
+ "psti",
+ "psts",
+ "pstu",
+ "punpack",
+ "rls_buf",
+ "set_flag",
+ "set_loop1_stride_outtoub",
+ "set_loop1_stride_ubtoout",
+ "set_loop2_stride_outtoub",
+ "set_loop2_stride_ubtoout",
+ "set_loop_size_outtoub",
+ "set_loop_size_ubtoout",
+ "uvld",
+ "vabs",
+ "vadd",
+ "vaddc",
+ "vaddcs",
+ "vadds",
+ "vand",
+ "vbcnt",
+ "vbitsort",
+ "vbr",
+ "vcadd",
+ "vci",
+ "vcls",
+ "vcmax",
+ "vcmin",
+ "vcmp",
+ "vcmps",
+ "vcvt",
+ "vdintlv",
+ "vdintlvv2",
+ "vdiv",
+ "vdup",
+ "vexp",
+ "vgather2",
+ "vgather2_bc",
+ "vgatherb",
+ "vintlv",
+ "vintlvv2",
+ "vldas",
+ "vlds",
+ "vlds_post",
+ "vldus",
+ "vldx2",
+ "vln",
+ "vlrelu",
+ "vmax",
+ "vmaxs",
+ "vmin",
+ "vmins",
+ "vmrgsort4",
+ "vmul",
+ "vmula",
+ "vmull",
+ "vmuls",
+ "vnot",
+ "vor",
+ "vrec",
+ "vrelu",
+ "vscatter",
+ "vsel",
+ "vselr",
+ "vselrv2",
+ "vshl",
+ "vshls",
+ "vshr",
+ "vshrs",
+ "vsld",
+ "vsldb",
+ "vsqrt",
+ "vsst",
+ "vsstb",
+ "vsta",
+ "vstar",
+ "vstas",
+ "vsts",
+ "vsts_post",
+ "vstu",
+ "vstur",
+ "vstus",
+ "vstx2",
+ "vsub",
+ "vsubc",
+ "vsubcs",
+ "vtrc",
+ "vxor",
+ "wait_flag"
+ ],
+ "ptodsl.api.pto": [
+ "AlignType",
+ "Constexpr",
+ "MaskType",
+ "PtrType",
+ "SubTensorType",
+ "TensorType",
+ "TensorView",
+ "TileBufConfig",
+ "TileBufType",
+ "TileBuffer",
+ "TileBufferSpec",
+ "VRegType",
+ "Value",
+ "addptr",
+ "alloc_tile",
+ "as_tensor",
+ "barrier",
+ "barrier_sync",
+ "bool",
+ "castptr",
+ "cond",
+ "const_expr",
+ "copy_gm_to_ubuf",
+ "copy_ubuf_to_gm",
+ "copy_ubuf_to_ubuf",
+ "cube_section",
+ "float16",
+ "float32",
+ "get_block_idx",
+ "get_block_num",
+ "get_buf",
+ "get_subblock_idx",
+ "get_subblock_num",
+ "if_context",
+ "int16",
+ "int32",
+ "load",
+ "make_tensor",
+ "make_tile_buffer",
+ "pdintlv_b8",
+ "pge_b16",
+ "pge_b32",
+ "pge_b8",
+ "pintlv_b16",
+ "pld",
+ "pldi",
+ "plds",
+ "plt_b16",
+ "plt_b32",
+ "plt_b8",
+ "pnot",
+ "ppack",
+ "print",
+ "psel",
+ "pset_b16",
+ "pset_b32",
+ "pset_b8",
+ "pst",
+ "psti",
+ "psts",
+ "pstu",
+ "ptr",
+ "punpack",
+ "range",
+ "range_constexpr",
+ "record_event",
+ "record_wait_pair",
+ "rls_buf",
+ "set_flag",
+ "set_loop1_stride_outtoub",
+ "set_loop1_stride_ubtoout",
+ "set_loop2_stride_outtoub",
+ "set_loop2_stride_ubtoout",
+ "set_loop_size_outtoub",
+ "set_loop_size_ubtoout",
+ "slice_view",
+ "store",
+ "uvld",
+ "vabs",
+ "vadd",
+ "vaddc",
+ "vaddcs",
+ "vadds",
+ "vand",
+ "vbcnt",
+ "vbitsort",
+ "vbr",
+ "vcadd",
+ "vci",
+ "vcls",
+ "vcmax",
+ "vcmin",
+ "vcmp",
+ "vcmps",
+ "vcvt",
+ "vdintlv",
+ "vdintlvv2",
+ "vdiv",
+ "vdup",
+ "vector_section",
+ "vexp",
+ "vgather2",
+ "vgather2_bc",
+ "vgatherb",
+ "vintlv",
+ "vintlvv2",
+ "vldas",
+ "vlds",
+ "vlds_post",
+ "vldus",
+ "vldx2",
+ "vln",
+ "vlrelu",
+ "vmax",
+ "vmaxs",
+ "vmin",
+ "vmins",
+ "vmrgsort4",
+ "vmul",
+ "vmula",
+ "vmull",
+ "vmuls",
+ "vnot",
+ "vor",
+ "vrec",
+ "vrelu",
+ "vscatter",
+ "vsel",
+ "vselr",
+ "vselrv2",
+ "vshl",
+ "vshls",
+ "vshr",
+ "vshrs",
+ "vsld",
+ "vsldb",
+ "vsqrt",
+ "vsst",
+ "vsstb",
+ "vsta",
+ "vstar",
+ "vstas",
+ "vsts",
+ "vsts_post",
+ "vstu",
+ "vstur",
+ "vstus",
+ "vstx2",
+ "vsub",
+ "vsubc",
+ "vsubcs",
+ "vtrc",
+ "vxor",
+ "wait_event",
+ "wait_flag",
+ "wrap_value"
+ ],
+ "ptodsl.api.scalar": [
+ "LazyTypeAlias",
+ "Value",
+ "_unwrap",
+ "ceil_div",
+ "const",
+ "div_s",
+ "eq",
+ "ge",
+ "gt",
+ "index_cast",
+ "lt",
+ "min_u",
+ "rem_s",
+ "resolve_type",
+ "select",
+ "wrap_value"
+ ],
+ "ptodsl.api.synchronization": [
+ "barrier",
+ "barrier_sync",
+ "record_event",
+ "record_wait_pair",
+ "wait_event"
+ ],
+ "ptodsl.api.tile": [
+ "abs",
+ "add",
+ "col_expand",
+ "col_max",
+ "col_min",
+ "col_sum",
+ "div",
+ "exp",
+ "extract",
+ "gather",
+ "log",
+ "matmul",
+ "matmul_acc",
+ "matmul_bias",
+ "mov",
+ "mrgsort",
+ "mul",
+ "or_",
+ "reciprocal",
+ "relu",
+ "row_expand",
+ "row_expand_div",
+ "row_expand_mul",
+ "row_expand_sub",
+ "row_max",
+ "row_min",
+ "row_sum",
+ "rsqrt",
+ "sort32",
+ "sqrt",
+ "sub",
+ "subset"
+ ],
+ "ptodsl.api.type_def": [
+ "AlignType",
+ "MaskType",
+ "PtrType",
+ "SubTensorType",
+ "TensorType",
+ "TileBufConfig",
+ "TileBufType",
+ "VRegType",
+ "bool",
+ "float16",
+ "float32",
+ "int16",
+ "int32",
+ "uint32"
+ ],
+ "ptodsl.bench": [
+ "do_bench"
+ ],
+ "ptodsl.compiler": [
+ "JitWrapper",
+ "jit",
+ "to_ir_module"
+ ],
+ "ptodsl.micro": [
+ "MICRO_OPS",
+ "addptr",
+ "barrier",
+ "castptr",
+ "copy_gm_to_ubuf",
+ "copy_ubuf_to_gm",
+ "copy_ubuf_to_ubuf",
+ "get_buf",
+ "pdintlv_b8",
+ "pge_b16",
+ "pge_b32",
+ "pge_b8",
+ "pintlv_b16",
+ "pld",
+ "pldi",
+ "plds",
+ "plt_b16",
+ "plt_b32",
+ "plt_b8",
+ "pnot",
+ "ppack",
+ "psel",
+ "pset_b16",
+ "pset_b32",
+ "pset_b8",
+ "pst",
+ "psti",
+ "psts",
+ "pstu",
+ "punpack",
+ "rls_buf",
+ "set_flag",
+ "set_loop1_stride_outtoub",
+ "set_loop1_stride_ubtoout",
+ "set_loop2_stride_outtoub",
+ "set_loop2_stride_ubtoout",
+ "set_loop_size_outtoub",
+ "set_loop_size_ubtoout",
+ "uvld",
+ "vabs",
+ "vadd",
+ "vaddc",
+ "vaddcs",
+ "vadds",
+ "vand",
+ "vbcnt",
+ "vbitsort",
+ "vbr",
+ "vcadd",
+ "vci",
+ "vcls",
+ "vcmax",
+ "vcmin",
+ "vcmp",
+ "vcmps",
+ "vcvt",
+ "vdintlv",
+ "vdintlvv2",
+ "vdiv",
+ "vdup",
+ "vexp",
+ "vgather2",
+ "vgather2_bc",
+ "vgatherb",
+ "vintlv",
+ "vintlvv2",
+ "vldas",
+ "vlds",
+ "vlds_post",
+ "vldus",
+ "vldx2",
+ "vln",
+ "vlrelu",
+ "vmax",
+ "vmaxs",
+ "vmin",
+ "vmins",
+ "vmrgsort4",
+ "vmul",
+ "vmula",
+ "vmull",
+ "vmuls",
+ "vnot",
+ "vor",
+ "vrec",
+ "vrelu",
+ "vscatter",
+ "vsel",
+ "vselr",
+ "vselrv2",
+ "vshl",
+ "vshls",
+ "vshr",
+ "vshrs",
+ "vsld",
+ "vsldb",
+ "vsqrt",
+ "vsst",
+ "vsstb",
+ "vsta",
+ "vstar",
+ "vstas",
+ "vsts",
+ "vsts_post",
+ "vstu",
+ "vstur",
+ "vstus",
+ "vstx2",
+ "vsub",
+ "vsubc",
+ "vsubcs",
+ "vtrc",
+ "vxor",
+ "wait_flag"
+ ],
+ "ptodsl.pto": [
+ "AlignType",
+ "Constexpr",
+ "MaskType",
+ "PtrType",
+ "SubTensorType",
+ "TensorType",
+ "TensorView",
+ "TileBufConfig",
+ "TileBufType",
+ "TileBuffer",
+ "TileBufferSpec",
+ "VRegType",
+ "Value",
+ "addptr",
+ "alloc_tile",
+ "as_tensor",
+ "barrier",
+ "barrier_sync",
+ "bool",
+ "castptr",
+ "cond",
+ "const_expr",
+ "copy_gm_to_ubuf",
+ "copy_ubuf_to_gm",
+ "copy_ubuf_to_ubuf",
+ "cube_section",
+ "float16",
+ "float32",
+ "get_block_idx",
+ "get_block_num",
+ "get_buf",
+ "get_subblock_idx",
+ "get_subblock_num",
+ "if_context",
+ "int16",
+ "int32",
+ "load",
+ "make_tensor",
+ "make_tile_buffer",
+ "pdintlv_b8",
+ "pge_b16",
+ "pge_b32",
+ "pge_b8",
+ "pintlv_b16",
+ "pld",
+ "pldi",
+ "plds",
+ "plt_b16",
+ "plt_b32",
+ "plt_b8",
+ "pnot",
+ "ppack",
+ "print",
+ "psel",
+ "pset_b16",
+ "pset_b32",
+ "pset_b8",
+ "pst",
+ "psti",
+ "psts",
+ "pstu",
+ "ptr",
+ "punpack",
+ "range",
+ "range_constexpr",
+ "record_event",
+ "record_wait_pair",
+ "rls_buf",
+ "set_flag",
+ "set_loop1_stride_outtoub",
+ "set_loop1_stride_ubtoout",
+ "set_loop2_stride_outtoub",
+ "set_loop2_stride_ubtoout",
+ "set_loop_size_outtoub",
+ "set_loop_size_ubtoout",
+ "slice_view",
+ "store",
+ "uvld",
+ "vabs",
+ "vadd",
+ "vaddc",
+ "vaddcs",
+ "vadds",
+ "vand",
+ "vbcnt",
+ "vbitsort",
+ "vbr",
+ "vcadd",
+ "vci",
+ "vcls",
+ "vcmax",
+ "vcmin",
+ "vcmp",
+ "vcmps",
+ "vcvt",
+ "vdintlv",
+ "vdintlvv2",
+ "vdiv",
+ "vdup",
+ "vector_section",
+ "vexp",
+ "vgather2",
+ "vgather2_bc",
+ "vgatherb",
+ "vintlv",
+ "vintlvv2",
+ "vldas",
+ "vlds",
+ "vlds_post",
+ "vldus",
+ "vldx2",
+ "vln",
+ "vlrelu",
+ "vmax",
+ "vmaxs",
+ "vmin",
+ "vmins",
+ "vmrgsort4",
+ "vmul",
+ "vmula",
+ "vmull",
+ "vmuls",
+ "vnot",
+ "vor",
+ "vrec",
+ "vrelu",
+ "vscatter",
+ "vsel",
+ "vselr",
+ "vselrv2",
+ "vshl",
+ "vshls",
+ "vshr",
+ "vshrs",
+ "vsld",
+ "vsldb",
+ "vsqrt",
+ "vsst",
+ "vsstb",
+ "vsta",
+ "vstar",
+ "vstas",
+ "vsts",
+ "vsts_post",
+ "vstu",
+ "vstur",
+ "vstus",
+ "vstx2",
+ "vsub",
+ "vsubc",
+ "vsubcs",
+ "vtrc",
+ "vxor",
+ "wait_event",
+ "wait_flag",
+ "wrap_value"
+ ],
+ "ptodsl.scalar": [
+ "LazyTypeAlias",
+ "Value",
+ "_unwrap",
+ "ceil_div",
+ "const",
+ "div_s",
+ "eq",
+ "ge",
+ "gt",
+ "index_cast",
+ "lt",
+ "min_u",
+ "rem_s",
+ "resolve_type",
+ "select",
+ "wrap_value"
+ ],
+ "ptodsl.tile": [
+ "abs",
+ "add",
+ "col_expand",
+ "col_max",
+ "col_min",
+ "col_sum",
+ "div",
+ "exp",
+ "extract",
+ "gather",
+ "log",
+ "matmul",
+ "matmul_acc",
+ "matmul_bias",
+ "mov",
+ "mrgsort",
+ "mul",
+ "or_",
+ "reciprocal",
+ "relu",
+ "row_expand",
+ "row_expand_div",
+ "row_expand_mul",
+ "row_expand_sub",
+ "row_max",
+ "row_min",
+ "row_sum",
+ "rsqrt",
+ "sort32",
+ "sqrt",
+ "sub",
+ "subset"
+ ]
+}
diff --git a/tests/api/test_control_flow_api.py b/tests/api/test_control_flow_api.py
new file mode 100644
index 00000000..6995400f
--- /dev/null
+++ b/tests/api/test_control_flow_api.py
@@ -0,0 +1,56 @@
+import pytest
+from mlir.ir import IndexType
+
+from ptodsl import pto, scalar as s, to_ir_module
+from ptodsl.api import control_flow
+
+
+const = s.const
+
+
+def test_constexpr_helpers_accept_static_inputs_and_reject_dynamic_values():
+ assert control_flow.const_expr(3) is True
+ assert list(control_flow.range_constexpr(3)) == [0, 1, 2]
+ assert list(control_flow.range_constexpr(1, 5, 2)) == [1, 3]
+
+ dynamic = s.Value(object())
+ with pytest.raises(TypeError, match="const_expr"):
+ control_flow.const_expr(dynamic)
+
+ with pytest.raises(TypeError, match="range_constexpr"):
+ list(control_flow.range_constexpr(dynamic))
+
+
+def test_range_if_context_and_cond_emit_scf_ops():
+ def meta_data():
+ return {
+ "index_t": IndexType.get(),
+ "tile_t": pto.TileBufType(
+ shape=[1, 64],
+ valid_shape=[1, 64],
+ dtype=pto.float32,
+ memory_space="VEC",
+ ),
+ }
+
+ @to_ir_module(meta_data=meta_data)
+ def kernel(n: "index_t") -> None:
+ c0 = const(0)
+ c1 = const(1)
+ c2 = const(2)
+
+ with pto.vector_section():
+ for i in control_flow.range(c0, c2, c1):
+ with control_flow.if_context(s.gt(i, c0)):
+ pto.alloc_tile(tile_t)
+ control_flow.cond(
+ s.gt(n, c0),
+ lambda: pto.alloc_tile(tile_t),
+ lambda: pto.alloc_tile(tile_t),
+ )
+
+ text = str(kernel)
+
+ assert "scf.for" in text
+ assert text.count("scf.if") == 2
+ assert text.count("pto.alloc_tile") == 3
diff --git a/tests/api/test_micro_api.py b/tests/api/test_micro_api.py
new file mode 100644
index 00000000..b723e633
--- /dev/null
+++ b/tests/api/test_micro_api.py
@@ -0,0 +1,86 @@
+from types import SimpleNamespace
+
+from mlir.dialects import pto as mlir_pto
+from mlir.ir import IndexType
+
+from ptodsl import micro, pto, to_ir_module
+from ptodsl.api import micro as micro_api
+from tests._vpto_manifest import load_vpto_manifest
+
+
+IMPLEMENTED_MICRO_OPS = sorted(
+ op["mnemonic"]
+ for op in load_vpto_manifest()["ops"]
+ if op.get("status") == "implemented"
+)
+
+
+class _Box:
+ def __init__(self, raw):
+ self.raw = raw
+
+
+def test_manifest_driven_micro_inventory_matches_public_exports():
+ assert set(IMPLEMENTED_MICRO_OPS) == set(micro.MICRO_OPS)
+ assert set(IMPLEMENTED_MICRO_OPS).issubset(set(micro.__all__))
+ assert set(IMPLEMENTED_MICRO_OPS).issubset(set(pto.__all__))
+
+
+def test_every_manifest_micro_op_has_a_callable_wrapper():
+ for name in IMPLEMENTED_MICRO_OPS:
+ wrapper = getattr(micro, name)
+
+ assert callable(wrapper)
+ assert getattr(pto, name) is wrapper
+ assert hasattr(mlir_pto, name)
+
+ if name == "barrier":
+ assert wrapper is micro_api._micro_barrier
+ else:
+ assert wrapper.__name__ == name
+ assert f"`pto.{name}`" in (wrapper.__doc__ or "")
+
+
+def test_barrier_normalizes_pipe_names(monkeypatch):
+ seen = {}
+
+ monkeypatch.setattr(
+ micro_api._pto,
+ "PIPE",
+ SimpleNamespace(PIPE_VECTOR="PIPE_VECTOR"),
+ raising=False,
+ )
+ monkeypatch.setattr(
+ micro_api._pto,
+ "PipeAttr",
+ SimpleNamespace(get=lambda value: f"pipe:{value}"),
+ )
+ monkeypatch.setattr(
+ micro_api._pto,
+ "barrier",
+ lambda op, loc=None, ip=None: seen.setdefault("call", (op, loc, ip)),
+ )
+
+ micro.barrier("pipe_vector")
+
+ assert seen["call"] == ("pipe:PIPE_VECTOR", None, None)
+
+
+def test_selected_micro_wrappers_emit_ir_and_accept_objects_with_raw_values():
+ def meta_data():
+ return {
+ "ptr_t": pto.ptr(pto.float32, space="VEC"),
+ "index_t": IndexType.get(),
+ }
+
+ @to_ir_module(meta_data=meta_data)
+ def kernel(src: "ptr_t", dst: "ptr_t", offset: "index_t") -> None:
+ mask = micro.pset_b32(pto.MaskType(), "PAT_ALL")
+ vec = micro.vlds(pto.VRegType(64, pto.float32), _Box(src), offset)
+ micro.vsts(vec, _Box(dst), offset, _Box(mask))
+
+ text = str(kernel)
+
+ assert "pto.pset_b32" in text
+ assert "pto.vlds" in text
+ assert "pto.vsts" in text
diff --git a/tests/api/test_pto_api.py b/tests/api/test_pto_api.py
new file mode 100644
index 00000000..fbff5090
--- /dev/null
+++ b/tests/api/test_pto_api.py
@@ -0,0 +1,87 @@
+import pytest
+from mlir.ir import IndexType
+
+from ptodsl import pto, scalar as s, tile, to_ir_module
+
+
+const = s.const
+
+
+def test_pto_namespace_exports_pythonic_helpers():
+ exports = set(pto.__all__)
+ assert {
+ "ptr",
+ "make_tensor",
+ "make_tile_buffer",
+ "TensorView",
+ "TileBuffer",
+ "barrier_sync",
+ }.issubset(exports)
+
+
+def test_ptr_and_make_tile_buffer_match_low_level_type_builders(mlir_ctx):
+ assert str(pto.ptr(pto.float32)) == str(pto.PtrType(pto.float32))
+ assert str(pto.ptr(pto.float32, space="VEC")) == str(
+ pto.PtrType(pto.float32, memory_space="VEC")
+ )
+
+ tile_spec = pto.make_tile_buffer(pto.float32, [32, 32], space="VEC")
+ assert str(tile_spec.raw_type) == str(
+ pto.TileBufType(
+ shape=[32, 32],
+ valid_shape=[32, 32],
+ dtype=pto.float32,
+ memory_space="VEC",
+ )
+ )
+
+
+def test_make_tensor_requires_dtype_when_type_is_omitted():
+ with pytest.raises(TypeError, match="requires `dtype=`"):
+ pto.make_tensor("ptr", shape=[32, 32])
+
+
+def test_tensor_view_slice_requires_static_shape_for_dynamic_sizes():
+ def meta_data():
+ return {
+ "ptr_t": pto.ptr(pto.float32),
+ "index_t": IndexType.get(),
+ }
+
+ def kernel(src: "ptr_t", total_elements: "index_t") -> None:
+ view = pto.make_tensor(src, shape=[total_elements], dtype=pto.float32)
+ view.slice([0], [total_elements])
+
+ with pytest.raises(TypeError, match="requires `static_shape=`"):
+ to_ir_module(meta_data=meta_data)(kernel)
+
+
+def test_pythonic_pto_wrappers_emit_tensor_and_tile_ops():
+ def meta_data():
+ return {
+ "ptr_t": pto.ptr(pto.float32),
+ "index_t": IndexType.get(),
+ }
+
+ @to_ir_module(meta_data=meta_data)
+ def kernel(src: "ptr_t", dst: "ptr_t", rows: "index_t", cols: "index_t") -> None:
+ src_view = pto.make_tensor(src, shape=[rows, cols], dtype=pto.float32)
+ dst_view = pto.make_tensor(dst, shape=[rows, cols], dtype=pto.float32)
+ src_tile = src_view.slice([0, 0], [32, 32])
+ dst_tile = dst_view.slice([0, 0], [32, 32])
+
+ with pto.vector_section():
+ tile_spec = pto.make_tile_buffer(pto.float32, [32, 32], space="VEC")
+ tmp = tile_spec.alloc()
+ tmp.load_from(src_tile)
+ tile.add(tmp, tmp, tmp)
+ tmp.store_to(dst_tile)
+
+ text = str(kernel)
+
+ assert "pto.make_tensor_view" in text
+ assert "pto.partition_view" in text
+ assert "pto.alloc_tile" in text
+ assert "pto.tload" in text
+ assert "pto.tadd" in text
+ assert "pto.tstore" in text
diff --git a/tests/api/test_public_api_contract.py b/tests/api/test_public_api_contract.py
new file mode 100644
index 00000000..0b779b76
--- /dev/null
+++ b/tests/api/test_public_api_contract.py
@@ -0,0 +1,52 @@
+import difflib
+import importlib
+
+import ptodsl
+import pytest
+
+from ._contract import (
+ MIRROR_MODULES,
+ PUBLIC_MODULES,
+ collect_public_api_snapshot,
+ exported_names,
+ load_snapshot,
+ snapshot_json,
+)
+
+
+def test_committed_public_api_snapshot_matches_current_exports():
+ expected = load_snapshot()
+ actual = collect_public_api_snapshot()
+
+ if actual != expected:
+ diff = "".join(
+ difflib.unified_diff(
+ snapshot_json(expected).splitlines(keepends=True),
+ snapshot_json(actual).splitlines(keepends=True),
+ fromfile="tests/api/public_api_snapshot.json",
+ tofile="current-public-api",
+ )
+ )
+ pytest.fail(
+ "Public API snapshot drifted.\n"
+ "Run `python scripts/update_public_api_snapshot.py` and commit the result.\n\n"
+ f"{diff}"
+ )
+
+
+@pytest.mark.parametrize("module_name", PUBLIC_MODULES)
+def test_all_entries_in___all___resolve(module_name):
+ module = importlib.import_module(module_name)
+
+ for name in getattr(module, "__all__", []):
+ assert getattr(module, name) is not None, f"{module_name}.{name} did not resolve"
+
+
+@pytest.mark.parametrize(("mirror_module", "source_module"), MIRROR_MODULES)
+def test_mirror_modules_match_backing_api_exports(mirror_module, source_module):
+ assert exported_names(mirror_module) == exported_names(source_module)
+
+
+def test_top_level_ptodsl_exports_import_cleanly():
+ for name in ptodsl.__all__:
+ assert getattr(ptodsl, name) is not None
diff --git a/tests/api/test_scalar_api.py b/tests/api/test_scalar_api.py
new file mode 100644
index 00000000..26667f14
--- /dev/null
+++ b/tests/api/test_scalar_api.py
@@ -0,0 +1,49 @@
+from mlir.dialects import arith
+from mlir.ir import IntegerType
+
+from ptodsl import scalar
+
+
+class _Box:
+ def __init__(self, raw):
+ self.raw = raw
+
+
+def _i32_const(value):
+ return scalar.Value(arith.ConstantOp(IntegerType.get_signless(32), value).result)
+
+
+def test_scalar_dtype_aliases_resolve_inside_mlir_context(mlir_ctx):
+ assert repr(scalar.float32) == "float32"
+ assert str(scalar.resolve_type(scalar.bool)) == "i1"
+ assert str(scalar.resolve_type(scalar.float16)) == "f16"
+ assert str(scalar.resolve_type(scalar.float32)) == "f32"
+ assert str(scalar.resolve_type(scalar.int16)) == "i16"
+ assert str(scalar.resolve_type(scalar.int32)) == "i32"
+ assert str(scalar.resolve_type(scalar.uint32)) == "ui32"
+
+
+def test_wrap_value_and_unwrap_handle_nested_raw_wrappers():
+ value = scalar.Value("inner")
+
+ assert scalar.wrap_value(value) is value
+ assert isinstance(scalar.wrap_value("raw"), scalar.Value)
+ assert scalar._unwrap(_Box(_Box(value))) == "inner"
+
+
+def test_scalar_arithmetic_and_helper_builders_emit_expected_arith_ops(mlir_ctx):
+ lhs = _i32_const(8)
+ rhs = _i32_const(2)
+
+ assert "arith.addi" in str((lhs + rhs).raw.owner)
+ assert "arith.subi" in str((lhs - rhs).raw.owner)
+ assert "arith.muli" in str((lhs * rhs).raw.owner)
+ assert "arith.divsi" in str((lhs // rhs).raw.owner)
+ assert "arith.remsi" in str((lhs % rhs).raw.owner)
+ assert "arith.cmpi" in str((lhs < rhs).raw.owner)
+ assert "arith.index_cast" in str(scalar.index_cast(lhs).raw.owner)
+ assert "arith.ceildivsi" in str(scalar.ceil_div(lhs, rhs).raw.owner)
+ assert "arith.minui" in str(scalar.min_u(lhs, rhs).raw.owner)
+ assert "arith.select" in str(
+ scalar.select(scalar.eq(lhs, rhs), lhs, rhs).raw.owner
+ )
diff --git a/tests/api/test_synchronization_api.py b/tests/api/test_synchronization_api.py
new file mode 100644
index 00000000..45c11e62
--- /dev/null
+++ b/tests/api/test_synchronization_api.py
@@ -0,0 +1,82 @@
+from types import SimpleNamespace
+
+from ptodsl.api import synchronization as sync
+
+
+def test_barrier_normalizes_string_sync_ops(monkeypatch):
+ seen = {}
+
+ monkeypatch.setattr(sync._pto, "TDMA", "TDMA", raising=False)
+ monkeypatch.setattr(sync._pto, "barrier", lambda op: seen.setdefault("op", op))
+
+ sync.barrier("dma")
+
+ assert seen["op"] == "TDMA"
+
+
+def test_barrier_sync_normalizes_string_sync_ops(monkeypatch):
+ seen = {}
+
+ monkeypatch.setattr(sync._pto, "TVEC", sync._pto.SyncOpType.TVEC, raising=False)
+ monkeypatch.setattr(
+ sync._pto,
+ "SyncOpTypeAttr",
+ SimpleNamespace(get=lambda value: f"sync:{value}"),
+ )
+ monkeypatch.setattr(
+ sync._pto, "barrier_sync", lambda op: seen.setdefault("op", op)
+ )
+
+ sync.barrier_sync("vec")
+
+ assert seen["op"] == f"sync:{sync._pto.SyncOpType.TVEC}"
+
+
+def test_record_event_expands_sequence_event_ids(monkeypatch):
+ calls = []
+
+ monkeypatch.setattr(sync._pto, "TDMA", "TDMA", raising=False)
+ monkeypatch.setattr(sync._pto, "TVEC", "TVEC", raising=False)
+ monkeypatch.setattr(sync._pto, "EVENT_ID1", "EVENT_ID1", raising=False)
+ monkeypatch.setattr(sync._pto, "EVENT_ID2", "EVENT_ID2", raising=False)
+ monkeypatch.setattr(sync._pto, "record_event", lambda *args: calls.append(args))
+
+ sync.record_event("dma", "vec", [1, 2])
+
+ assert calls == [
+ ("TDMA", "TVEC", "EVENT_ID1"),
+ ("TDMA", "TVEC", "EVENT_ID2"),
+ ]
+
+
+def test_wait_event_rejects_invalid_event_ids(monkeypatch):
+ monkeypatch.setattr(sync._pto, "TDMA", "TDMA", raising=False)
+ monkeypatch.setattr(sync._pto, "TVEC", "TVEC", raising=False)
+
+ try:
+ sync.wait_event("dma", "vec", 8)
+ except ValueError as exc:
+ assert "event_id must be in range [0, 7]" in str(exc)
+ else:
+ raise AssertionError("wait_event accepted an out-of-range event_id")
+
+
+def test_record_wait_pair_calls_record_and_wait_once(monkeypatch):
+ calls = []
+
+ monkeypatch.setattr(sync._pto, "TDMA", "TDMA", raising=False)
+ monkeypatch.setattr(sync._pto, "TVEC", "TVEC", raising=False)
+ monkeypatch.setattr(sync._pto, "EVENT_ID0", "EVENT_ID0", raising=False)
+ monkeypatch.setattr(
+ sync._pto, "record_event", lambda *args: calls.append(("record", args))
+ )
+ monkeypatch.setattr(
+ sync._pto, "wait_event", lambda *args: calls.append(("wait", args))
+ )
+
+ sync.record_wait_pair("dma", "vec")
+
+ assert calls == [
+ ("record", ("TDMA", "TVEC", "EVENT_ID0")),
+ ("wait", ("TDMA", "TVEC", "EVENT_ID0")),
+ ]
diff --git a/tests/api/test_tile_api.py b/tests/api/test_tile_api.py
new file mode 100644
index 00000000..a561e75b
--- /dev/null
+++ b/tests/api/test_tile_api.py
@@ -0,0 +1,162 @@
+from types import SimpleNamespace
+
+from ptodsl.api import tile as tile_api
+
+
+class _Box:
+ def __init__(self, raw):
+ self.raw = raw
+
+
+CALL_CASES = [
+ ("mov", lambda: tile_api.mov(_Box("src"), _Box("dst")), "TMovOp", (None, "src", "dst"), {}),
+ ("add", lambda: tile_api.add(_Box("lhs"), _Box("rhs"), _Box("out")), "TAddOp", ("lhs", "rhs", "out"), {}),
+ ("sub", lambda: tile_api.sub(_Box("lhs"), _Box("rhs"), _Box("out")), "TSubOp", ("lhs", "rhs", "out"), {}),
+ ("div", lambda: tile_api.div(_Box("lhs"), _Box("rhs"), _Box("out")), "TDivOp", ("lhs", "rhs", "out"), {}),
+ ("mul", lambda: tile_api.mul(_Box("lhs"), _Box("rhs"), _Box("out")), "TMulOp", ("lhs", "rhs", "out"), {}),
+ ("or_", lambda: tile_api.or_(_Box("lhs"), _Box("rhs"), _Box("out")), "TOrOp", ("lhs", "rhs", "out"), {}),
+ ("exp", lambda: tile_api.exp(_Box("inp"), _Box("out")), "TExpOp", ("inp", "out"), {}),
+ ("log", lambda: tile_api.log(_Box("inp"), _Box("out")), "TLogOp", ("inp", "out"), {}),
+ ("relu", lambda: tile_api.relu(_Box("inp"), _Box("out")), "TReluOp", ("inp", "out"), {}),
+ ("abs", lambda: tile_api.abs(_Box("inp"), _Box("out")), "TAbsOp", ("inp", "out"), {}),
+ ("sqrt", lambda: tile_api.sqrt(_Box("inp"), _Box("out")), "TSqrtOp", ("inp", "out"), {}),
+ ("rsqrt", lambda: tile_api.rsqrt(_Box("inp"), _Box("out")), "TRsqrtOp", ("inp", "out"), {}),
+ ("reciprocal", lambda: tile_api.reciprocal(_Box("inp"), _Box("out")), "TRecipOp", ("inp", "out"), {}),
+ ("matmul", lambda: tile_api.matmul(_Box("lhs"), _Box("rhs"), _Box("out")), "TMatmulOp", (None, "lhs", "rhs", "out"), {}),
+ ("matmul_bias", lambda: tile_api.matmul_bias(_Box("lhs"), _Box("rhs"), _Box("bias"), _Box("out")), "TMatmulBiasOp", (None, "lhs", "rhs", "bias", "out"), {}),
+ ("matmul_acc", lambda: tile_api.matmul_acc(_Box("acc"), _Box("lhs"), _Box("rhs"), _Box("out")), "TMatmulAccOp", (None, "acc", "lhs", "rhs", "out"), {}),
+ ("row_sum", lambda: tile_api.row_sum(_Box("src"), _Box("tmp"), _Box("dst")), "TRowSumOp", (), {"src": "src", "tmp": "tmp", "dst": "dst"}),
+ ("row_min", lambda: tile_api.row_min(_Box("src"), _Box("tmp"), _Box("dst")), "TRowMinOp", (), {"src": "src", "tmp": "tmp", "dst": "dst"}),
+ ("row_max", lambda: tile_api.row_max(_Box("src"), _Box("tmp"), _Box("dst")), "TRowMaxOp", (), {"src": "src", "tmp": "tmp", "dst": "dst"}),
+ ("row_expand", lambda: tile_api.row_expand(_Box("src"), _Box("dst")), "TRowExpandOp", (), {"src": "src", "dst": "dst"}),
+ ("row_expand_sub", lambda: tile_api.row_expand_sub(_Box("src0"), _Box("src1"), _Box("dst")), "TRowExpandSubOp", (), {"src0": "src0", "src1": "src1", "dst": "dst"}),
+ ("row_expand_div", lambda: tile_api.row_expand_div(_Box("src0"), _Box("src1"), _Box("dst")), "TRowExpandDivOp", (), {"src0": "src0", "src1": "src1", "dst": "dst"}),
+ ("row_expand_mul", lambda: tile_api.row_expand_mul(_Box("src0"), _Box("src1"), _Box("dst")), "TRowExpandMulOp", (), {"src0": "src0", "src1": "src1", "dst": "dst"}),
+ ("col_sum", lambda: tile_api.col_sum(_Box("src"), _Box("tmp"), _Box("dst"), is_binary=False), "TColSumOp", (), {"src": "src", "tmp": "tmp", "dst": "dst", "isBinary": "false"}),
+ ("col_min", lambda: tile_api.col_min(_Box("src"), _Box("dst")), "TColMinOp", (), {"src": "src", "dst": "dst"}),
+ ("col_max", lambda: tile_api.col_max(_Box("src"), _Box("dst")), "TColMaxOp", (), {"src": "src", "dst": "dst"}),
+ ("col_expand", lambda: tile_api.col_expand(_Box("src"), _Box("dst")), "TColExpandOp", (), {"src": "src", "dst": "dst"}),
+ ("sort32", lambda: tile_api.sort32(_Box("src"), _Box("dst"), _Box("idx")), "TSort32Op", ("src", "dst", "idx"), {}),
+]
+
+SPECIAL_CASES = {"gather", "extract", "mrgsort", "subset"}
+
+
+def test_tile_export_coverage_is_complete():
+ covered = {name for name, *_ in CALL_CASES} | SPECIAL_CASES
+ assert covered == set(tile_api.__all__)
+
+
+def test_call_based_tile_ops_dispatch_to_the_expected_underlying_builders(
+ monkeypatch, mlir_ctx
+):
+ seen = []
+
+ def fake_call(op, *args, **kwargs):
+ seen.append(
+ (
+ op,
+ tuple(tile_api._unwrap(arg) for arg in args),
+ {name: tile_api._unwrap(value) for name, value in kwargs.items()},
+ )
+ )
+
+ monkeypatch.setattr(tile_api, "_call", fake_call)
+
+ for _, invoker, expected_op_name, expected_args, expected_kwargs in CALL_CASES:
+ invoker()
+ op, args, kwargs = seen.pop(0)
+ assert op is getattr(tile_api._pto, expected_op_name)
+ assert args == expected_args
+ assert {key: str(value) for key, value in kwargs.items()} == expected_kwargs
+
+
+def test_gather_supports_indices_and_mask_pattern(monkeypatch):
+ seen = []
+
+ def fake_call(op, *args, **kwargs):
+ seen.append(
+ (
+ op,
+ tuple(tile_api._unwrap(arg) for arg in args),
+ {name: tile_api._unwrap(value) for name, value in kwargs.items()},
+ )
+ )
+
+ monkeypatch.setattr(tile_api, "_call", fake_call)
+ monkeypatch.setattr(tile_api._pto, "MaskPattern", SimpleNamespace(PAT_ALL="PAT_ALL"))
+ monkeypatch.setattr(
+ tile_api._pto,
+ "MaskPatternAttr",
+ SimpleNamespace(get=lambda value: f"mask:{value}"),
+ )
+
+ tile_api.gather(_Box("src"), _Box("dst"), indices=_Box("idx"))
+ tile_api.gather(_Box("src"), _Box("dst"), mask_pattern="PAT_ALL")
+
+ assert seen[0] == (
+ tile_api._pto.TGatherOp,
+ ("src", "dst"),
+ {"indices": "idx"},
+ )
+ assert seen[1] == (
+ tile_api._pto.TGatherOp,
+ ("src", "dst"),
+ {"maskPattern": "mask:PAT_ALL"},
+ )
+
+
+def test_extract_unwraps_source_and_indices(monkeypatch):
+ seen = {}
+
+ def fake_extract(**kwargs):
+ seen.update({name: tile_api._unwrap(value) for name, value in kwargs.items()})
+
+ monkeypatch.setattr(tile_api._pto, "TExtractOp", fake_extract)
+
+ tile_api.extract(_Box("src"), _Box("row"), _Box("col"), _Box("dst"))
+
+ assert seen == {
+ "src": "src",
+ "indexRow": "row",
+ "indexCol": "col",
+ "dst": "dst",
+ }
+
+
+def test_mrgsort_casts_block_length_and_wraps_src_dst_lists(monkeypatch, mlir_ctx):
+ seen = {}
+
+ def fake_index_cast(dtype, value):
+ seen["index_cast"] = (str(dtype), value)
+ return SimpleNamespace(result=f"cast:{value}")
+
+ def fake_mrgsort(**kwargs):
+ seen["mrgsort"] = kwargs
+
+ monkeypatch.setattr(tile_api._arith, "IndexCastOp", fake_index_cast)
+ monkeypatch.setattr(tile_api._pto, "TMrgSortOp", fake_mrgsort)
+
+ tile_api.mrgsort(_Box("src"), _Box("dst"), _Box("block"))
+
+ assert seen["index_cast"][1] == "block"
+ assert seen["mrgsort"] == {
+ "srcs": ["src"],
+ "dsts": ["dst"],
+ "blockLen": "cast:block",
+ }
+
+
+def test_subset_unwraps_offsets_and_returns_underlying_result(monkeypatch):
+ seen = {}
+
+ def fake_subset(source, offsets, sizes):
+ seen["call"] = (source, offsets, sizes)
+ return "subset-result"
+
+ monkeypatch.setattr(tile_api._pto, "subset", fake_subset)
+
+ result = tile_api.subset(_Box("src"), [_Box("r0"), _Box("r1")], [32, 64])
+
+ assert result == "subset-result"
+ assert seen["call"] == ("src", ["r0", "r1"], [32, 64])
diff --git a/tests/api/test_type_def_api.py b/tests/api/test_type_def_api.py
new file mode 100644
index 00000000..dca05848
--- /dev/null
+++ b/tests/api/test_type_def_api.py
@@ -0,0 +1,76 @@
+import pytest
+from mlir.dialects import pto as mlir_pto
+
+from ptodsl import scalar as s
+from ptodsl.api import type_def
+
+
+def test_basic_type_builders_render_expected_types(mlir_ctx):
+ assert str(type_def.PtrType(type_def.float32)) == "!pto.ptr"
+ assert str(type_def.PtrType(type_def.float32, memory_space="VEC")) == "!pto.ptr"
+ assert str(type_def.VRegType(64, type_def.float32)) == "!pto.vreg<64xf32>"
+ assert str(type_def.MaskType()) == "!pto.mask"
+ assert str(type_def.AlignType()) == "!pto.align"
+ assert str(type_def.TensorType(rank=2, dtype=type_def.float32)) == "!pto.tensor_view"
+ assert (
+ str(type_def.SubTensorType(shape=[32, 32], dtype=type_def.float32))
+ == "!pto.partition_tensor_view<32x32xf32>"
+ )
+
+
+def test_tile_buffer_defaults_match_explicit_configs(mlir_ctx):
+ vec_default = type_def.TileBufType(
+ shape=[32, 32],
+ dtype=type_def.float32,
+ memory_space="VEC",
+ )
+ vec_explicit = type_def.TileBufType(
+ shape=[32, 32],
+ dtype=type_def.float32,
+ memory_space="VEC",
+ config=type_def.TileBufConfig(),
+ )
+
+ mat_default = type_def.TileBufType(
+ shape=[1, 64],
+ dtype=type_def.float32,
+ memory_space="MAT",
+ )
+ mat_explicit = type_def.TileBufType(
+ shape=[1, 64],
+ dtype=type_def.float32,
+ memory_space="MAT",
+ config=type_def.TileBufConfig(
+ blayout="RowMajor",
+ slayout="NoneBox",
+ s_fractal_size=mlir_pto.TileConfig.fractalABSize,
+ ),
+ )
+
+ assert str(vec_default) == str(vec_explicit)
+ assert str(mat_default) == str(mat_explicit)
+
+
+def test_static_type_builders_reject_dynamic_values(mlir_ctx):
+ dynamic = s.Value(object())
+
+ with pytest.raises(TypeError, match="TensorType.rank"):
+ type_def.TensorType(rank=dynamic, dtype=type_def.float32)
+
+ with pytest.raises(TypeError, match="SubTensorType.shape"):
+ type_def.SubTensorType(shape=[32, dynamic], dtype=type_def.float32)
+
+ with pytest.raises(TypeError, match="TileBufType.shape"):
+ type_def.TileBufType(
+ shape=[1, dynamic],
+ dtype=type_def.float32,
+ memory_space="VEC",
+ )
+
+ with pytest.raises(TypeError, match="TileBufType.valid_shape"):
+ type_def.TileBufType(
+ shape=[1, 64],
+ valid_shape=[1, dynamic],
+ dtype=type_def.float32,
+ memory_space="VEC",
+ )
diff --git a/tests/frontend/test_caller_gen.py b/tests/frontend/test_caller_gen.py
index 47e01a02..5a2e50ea 100644
--- a/tests/frontend/test_caller_gen.py
+++ b/tests/frontend/test_caller_gen.py
@@ -65,6 +65,33 @@ def mixed_kernel(data: "ptr_i8", count: "i64_type", idx: "index_dtype") -> None:
)
+def test_generate_caller_cpp_maps_mxfp8_pointer_and_scalar_types():
+ def mixed_mxfp8_kernel(
+ lhs: "ptr_e5m2",
+ lhs_scale: "ptr_e8m0",
+ alpha: "e4m3_type",
+ ) -> None:
+ return None
+
+ wrapper = JitWrapper(mixed_mxfp8_kernel, meta_data=lambda: {}, block_dim=4)
+ wrapper._arg_types = [
+ _FakeType("!pto.ptr"),
+ _FakeType("!pto.ptr"),
+ _FakeType("f8E4M3FN"),
+ ]
+
+ caller_cpp = wrapper._generate_caller_cpp("generated.cpp")
+
+ assert (
+ 'extern "C" void call_kernel(uint32_t blockDim, void *stream, uint8_t *lhs, '
+ "uint8_t *lhs_scale, uint8_t alpha)"
+ ) in caller_cpp
+ assert (
+ "mixed_mxfp8_kernel<<>>((float8_e5m2_t *)lhs, "
+ "(float8_e8m0_t *)lhs_scale, alpha);"
+ ) in caller_cpp
+
+
def test_generate_caller_cpp_for_dynamic_1d_add_signature():
def vec_add_1d_dynamic(
arg0: "ptr_type",
diff --git a/tests/frontend/test_constexpr_frontend.py b/tests/frontend/test_constexpr_frontend.py
new file mode 100644
index 00000000..84531d32
--- /dev/null
+++ b/tests/frontend/test_constexpr_frontend.py
@@ -0,0 +1,157 @@
+import pathlib
+from types import SimpleNamespace
+
+import pytest
+
+from ptodsl import Constexpr, const_expr, pto, range_constexpr, to_ir_module
+from ptodsl import scalar as s
+from ptodsl.compiler.jit import JitWrapper
+
+
+const = s.const
+
+
+class _FakeType:
+ def __init__(self, text):
+ self._text = text
+
+ def __str__(self):
+ return self._text
+
+
+def test_to_ir_module_returns_specializer_and_prunes_constexpr_signature():
+ seen = {}
+
+ def meta_data(TILE, UNROLL=2):
+ seen["tile"] = TILE
+ seen["unroll"] = UNROLL
+ dtype = pto.float32
+ return {
+ "index_dtype": pto.int32,
+ "tile_type": pto.TileBufType(
+ shape=[1, TILE // 2],
+ valid_shape=[1, TILE // 2],
+ dtype=dtype,
+ memory_space="VEC",
+ ),
+ }
+
+ @to_ir_module(meta_data=meta_data)
+ def templated_kernel(
+ n: "index_dtype",
+ TILE: Constexpr[int],
+ UNROLL: Constexpr[int] = 2,
+ ) -> None:
+ with pto.vector_section():
+ if const_expr(TILE % 128 == 0):
+ for _ in range_constexpr(UNROLL):
+ pto.alloc_tile(tile_type)
+ else:
+ pto.alloc_tile(tile_type)
+
+ assert callable(templated_kernel)
+
+ module = templated_kernel(TILE=128, UNROLL=3)
+ text = str(module)
+
+ assert seen == {"tile": 128, "unroll": 3}
+ assert "func.func @templated_kernel(%arg0: i32)" in text
+ assert "scf.if" not in text
+ assert "scf.for" not in text
+ assert text.count("pto.alloc_tile") == 3
+
+
+def test_to_ir_module_rejects_missing_constexpr_arguments():
+ def meta_data(TILE):
+ return {"index_dtype": pto.int32}
+
+ @to_ir_module(meta_data=meta_data)
+ def templated_kernel(n: "index_dtype", TILE: Constexpr[int]) -> None:
+ return None
+
+ with pytest.raises(TypeError, match="Missing required constexpr arguments: TILE"):
+ templated_kernel()
+
+
+def test_constexpr_helpers_reject_dynamic_values():
+ dynamic = s.Value(object())
+
+ with pytest.raises(TypeError, match="const_expr"):
+ const_expr(dynamic)
+
+ with pytest.raises(TypeError, match="range_constexpr"):
+ list(range_constexpr(dynamic))
+
+
+def test_type_builders_reject_dynamic_static_dimensions():
+ with pytest.raises(TypeError, match="TensorType.rank"):
+ pto.TensorType(rank=s.Value(object()), dtype=pto.float32)
+
+ with pytest.raises(TypeError, match="TileBufType.shape"):
+ pto.TileBufType(
+ shape=[1, s.Value(object())],
+ valid_shape=[1, 1],
+ dtype=pto.float32,
+ memory_space="VEC",
+ )
+
+
+def test_jit_prunes_constexpr_parameters_from_generated_caller_cpp():
+ def mixed_kernel(
+ data: "ptr_i8",
+ count: "i64_type",
+ TILE: Constexpr[int],
+ ) -> None:
+ return None
+
+ wrapper = JitWrapper(mixed_kernel, meta_data=lambda TILE: {}, block_dim=7)
+ wrapper._arg_types = [
+ _FakeType("!pto.ptr"),
+ _FakeType("i64"),
+ ]
+
+ caller_cpp = wrapper._generate_caller_cpp("generated.cpp")
+
+ assert "TILE" not in caller_cpp
+ assert (
+ 'extern "C" void call_kernel(uint32_t blockDim, void *stream, uint8_t *data, '
+ "int64_t count)"
+ ) in caller_cpp
+ assert "mixed_kernel<<>>((int8_t *)data, count);" in caller_cpp
+
+
+def test_jit_reuses_specialization_for_identical_constexpr_bindings(monkeypatch):
+ class _FakeLib:
+ def __init__(self):
+ self.calls = []
+
+ def call_kernel(self, *args):
+ self.calls.append(args)
+
+ def templated_kernel(n: "index_dtype", TILE: Constexpr[int]) -> None:
+ return None
+
+ wrapper = JitWrapper(templated_kernel, meta_data=lambda TILE: {}, block_dim=4)
+ build_calls = []
+ built_libs = []
+
+ def fake_build(constexpr_bindings):
+ build_calls.append(dict(constexpr_bindings))
+ lib = _FakeLib()
+ built_libs.append(lib)
+ return SimpleNamespace(
+ arg_types=[_FakeType("i32")],
+ lib=lib,
+ lib_path=pathlib.Path("/tmp/fake.so"),
+ output_dir=pathlib.Path("/tmp"),
+ )
+
+ monkeypatch.setattr(wrapper, "_build", fake_build)
+
+ wrapper(7, TILE=32, stream_ptr=0)
+ wrapper(9, TILE=32, stream_ptr=0)
+ wrapper(11, TILE=64, stream_ptr=0)
+
+ assert build_calls == [{"TILE": 32}, {"TILE": 64}]
+ assert len(built_libs[0].calls) == 2
+ assert len(built_libs[1].calls) == 1
diff --git a/tests/frontend/test_micro_frontend.py b/tests/frontend/test_micro_frontend.py
new file mode 100644
index 00000000..b1e3bf08
--- /dev/null
+++ b/tests/frontend/test_micro_frontend.py
@@ -0,0 +1,77 @@
+from mlir.dialects import pto as mlir_pto
+from mlir.ir import Context, IndexType, Location
+
+from ptodsl import micro, pto, to_ir_module
+from ptodsl.api._micro_registry import MICRO_OPS
+from tests._vpto_manifest import load_vpto_manifest
+
+
+def test_micro_type_builders_support_memory_space_strings():
+ with Context() as ctx, Location.unknown():
+ mlir_pto.register_dialect(ctx, load=True)
+
+ assert str(pto.PtrType(pto.float32, memory_space="VEC")) == "!pto.ptr"
+ assert str(pto.VRegType(64, pto.float32)) == "!pto.vreg<64xf32>"
+ assert str(pto.MaskType()) == "!pto.mask"
+ assert str(pto.AlignType()) == "!pto.align"
+
+
+def test_pure_micro_kernel_emits_vpto_ops():
+ def meta_data():
+ return {
+ "ptr_t": pto.PtrType(pto.float32, memory_space="VEC"),
+ "index_t": IndexType.get(),
+ }
+
+ @to_ir_module(meta_data=meta_data)
+ def micro_kernel(src: "ptr_t", dst: "ptr_t", offset: "index_t") -> None:
+ mask = micro.pset_b32(pto.MaskType(), "PAT_ALL")
+ vec = micro.vlds(pto.VRegType(64, pto.float32), src, offset)
+ micro.vsts(vec, dst, offset, mask)
+
+ text = str(micro_kernel)
+
+ assert "func.func @micro_kernel" in text
+ assert "pto.pset_b32" in text
+ assert "pto.vlds" in text
+ assert "pto.vsts" in text
+ assert "!pto.vreg<64xf32>" in text
+
+
+def test_mixed_tile_and_micro_kernel_emit_together():
+ def meta_data():
+ dtype = pto.float32
+ return {
+ "ptr_t": pto.PtrType(dtype, memory_space="VEC"),
+ "index_t": IndexType.get(),
+ "tile_type": pto.TileBufType(
+ shape=[1, 64],
+ valid_shape=[1, 64],
+ dtype=dtype,
+ memory_space="VEC",
+ ),
+ }
+
+ @to_ir_module(meta_data=meta_data)
+ def mixed_kernel(src: "ptr_t", dst: "ptr_t", offset: "index_t") -> None:
+ pto.alloc_tile(tile_type)
+ mask = pto.pset_b32(pto.MaskType(), "PAT_ALL")
+ vec = pto.vlds(pto.VRegType(64, pto.float32), src, offset)
+ pto.vsts(vec, dst, offset, mask)
+
+ text = str(mixed_kernel)
+
+ assert "pto.alloc_tile" in text
+ assert "pto.vlds" in text
+ assert "pto.vsts" in text
+
+
+def test_micro_registry_matches_manifest_inventory():
+ manifest = load_vpto_manifest()
+ implemented = {
+ op["mnemonic"] for op in manifest["ops"] if op.get("status") == "implemented"
+ }
+
+ assert set(MICRO_OPS) == implemented
+ assert all(hasattr(micro, name) for name in implemented)
+ assert all(hasattr(pto, name) for name in implemented)
diff --git a/tests/frontend/test_mxfp8_frontend.py b/tests/frontend/test_mxfp8_frontend.py
new file mode 100644
index 00000000..03b0a70e
--- /dev/null
+++ b/tests/frontend/test_mxfp8_frontend.py
@@ -0,0 +1,55 @@
+import types
+
+import ptodsl.language as pto
+
+
+class _StubType:
+ @staticmethod
+ def get():
+ return object()
+
+
+def test_mxfp8_family_uses_e5m2_data_and_e8m0_scale(monkeypatch):
+ stub_ir = types.SimpleNamespace(
+ Float8E5M2Type=_StubType,
+ Float8E8M0FNUType=_StubType,
+ Float8E4M3FNType=_StubType,
+ )
+ monkeypatch.setattr(pto, "mlir_ir", stub_ir)
+
+ mx = pto.mxfp8
+
+ assert mx.lhs is not None
+ assert mx.rhs is not None
+ assert mx.data is not None
+ assert mx.scale is not None
+ assert mx.acc is not None
+ assert mx.scale_k(64) == 2
+
+
+def test_float8_aliases_accept_common_mlir_ctor_names(monkeypatch):
+ stub_ir = types.SimpleNamespace(
+ Float8E4M3FNType=_StubType,
+ Float8E5M2Type=_StubType,
+ Float8E8M0FNUType=_StubType,
+ )
+ monkeypatch.setattr(pto, "mlir_ir", stub_ir)
+
+ assert pto.fp8_e4m3 is not None
+ assert pto.fp8_e5m2 is not None
+ assert pto.fp8_e8m0 is not None
+
+
+def test_make_mxfp8_accepts_mixed_lhs_rhs_variants(monkeypatch):
+ stub_ir = types.SimpleNamespace(
+ Float8E4M3FNType=_StubType,
+ Float8E5M2Type=_StubType,
+ Float8E8M0FNUType=_StubType,
+ )
+ monkeypatch.setattr(pto, "mlir_ir", stub_ir)
+
+ mx = pto.make_mxfp8(lhs="e4m3", rhs="e5m2")
+
+ assert mx.lhs is not None
+ assert mx.rhs is not None
+ assert mx.scale is not None
diff --git a/tests/frontend/test_pythonic_frontend.py b/tests/frontend/test_pythonic_frontend.py
new file mode 100644
index 00000000..283f5ea0
--- /dev/null
+++ b/tests/frontend/test_pythonic_frontend.py
@@ -0,0 +1,196 @@
+import re
+
+import pytest
+
+from mlir.dialects import pto as mlir_pto
+from mlir.ir import Context, IndexType, Location
+
+from ptodsl import micro, pto, tile, to_ir_module
+from ptodsl import scalar as s
+
+
+const = s.const
+
+
+class _Box:
+ def __init__(self, raw):
+ self.raw = raw
+
+
+def _normalized_pto_lines(text):
+ return [
+ re.sub(r"%[A-Za-z0-9._]+", "%v", line.strip())
+ for line in text.splitlines()
+ if "pto." in line
+ ]
+
+
+def test_pythonic_type_factories_match_low_level_builders():
+ with Context() as ctx, Location.unknown():
+ mlir_pto.register_dialect(ctx, load=True)
+
+ assert str(pto.ptr(pto.float32)) == str(pto.PtrType(pto.float32))
+ assert str(pto.ptr(pto.float32, space="VEC")) == str(
+ pto.PtrType(pto.float32, memory_space="VEC")
+ )
+
+ tile_spec = pto.make_tile_buffer(pto.float32, [32, 32], space="VEC")
+ assert str(tile_spec.raw_type) == str(
+ pto.TileBufType(
+ shape=[32, 32],
+ valid_shape=[32, 32],
+ dtype=pto.float32,
+ memory_space="VEC",
+ )
+ )
+
+
+def _meta_data():
+ dtype = pto.float32
+ return {
+ "ptr_t": pto.ptr(dtype),
+ "index_t": IndexType.get(),
+ "tensor_t": pto.TensorType(rank=2, dtype=dtype),
+ "subtensor_t": pto.SubTensorType(shape=[32, 32], dtype=dtype),
+ "tile_t": pto.TileBufType(
+ shape=[32, 32],
+ valid_shape=[32, 32],
+ dtype=dtype,
+ memory_space="VEC",
+ ),
+ }
+
+
+def _build_pythonic_module():
+ def kernel(
+ src0: "ptr_t",
+ src1: "ptr_t",
+ dst: "ptr_t",
+ rows: "index_t",
+ cols: "index_t",
+ ) -> None:
+ c0 = const(0)
+
+ tv0 = pto.make_tensor(src0, shape=[rows, cols], dtype=pto.float32)
+ tv1 = pto.make_tensor(src1, shape=[rows, cols], dtype=pto.float32)
+ tv2 = pto.make_tensor(dst, shape=[rows, cols], dtype=pto.float32)
+
+ sv0 = tv0.slice([c0, c0], [32, 32])
+ sv1 = tv1.slice([c0, c0], [32, 32])
+ sv2 = tv2.slice([c0, c0], [32, 32])
+
+ tile_spec = pto.make_tile_buffer(pto.float32, [32, 32], space="VEC")
+ with pto.vector_section():
+ tb0 = tile_spec.alloc()
+ tb1 = tile_spec.alloc()
+ tb2 = tile_spec.alloc()
+ tb0.load_from(sv0)
+ tb1.load_from(sv1)
+ tile.add(tb0, tb1, tb2)
+ tb2.store_to(sv2)
+
+ kernel.__name__ = "vector_add_frontend"
+ return to_ir_module(meta_data=_meta_data)(kernel)
+
+
+def _build_low_level_module():
+ def kernel(
+ src0: "ptr_t",
+ src1: "ptr_t",
+ dst: "ptr_t",
+ rows: "index_t",
+ cols: "index_t",
+ ) -> None:
+ c0 = const(0)
+ c1 = const(1)
+ c32 = const(32)
+
+ tv0 = pto.as_tensor(tensor_t, ptr=src0, shape=[rows, cols], strides=[cols, c1])
+ tv1 = pto.as_tensor(tensor_t, ptr=src1, shape=[rows, cols], strides=[cols, c1])
+ tv2 = pto.as_tensor(tensor_t, ptr=dst, shape=[rows, cols], strides=[cols, c1])
+
+ sv0 = pto.slice_view(subtensor_t, source=tv0, offsets=[c0, c0], sizes=[c32, c32])
+ sv1 = pto.slice_view(subtensor_t, source=tv1, offsets=[c0, c0], sizes=[c32, c32])
+ sv2 = pto.slice_view(subtensor_t, source=tv2, offsets=[c0, c0], sizes=[c32, c32])
+
+ with pto.vector_section():
+ tb0 = pto.alloc_tile(tile_t)
+ tb1 = pto.alloc_tile(tile_t)
+ tb2 = pto.alloc_tile(tile_t)
+ pto.load(sv0, tb0)
+ pto.load(sv1, tb1)
+ tile.add(tb0, tb1, tb2)
+ pto.store(tb2, sv2)
+
+ kernel.__name__ = "vector_add_frontend"
+ return to_ir_module(meta_data=_meta_data)(kernel)
+
+
+def test_pythonic_tensor_and_tile_flow_matches_low_level_ir():
+ pythonic_module = _build_pythonic_module()
+ low_level_module = _build_low_level_module()
+
+ pythonic_text = str(pythonic_module)
+ low_level_text = str(low_level_module)
+
+ assert _normalized_pto_lines(pythonic_text) == _normalized_pto_lines(low_level_text)
+ text = pythonic_text
+ assert "pto.make_tensor_view" in text
+ assert "pto.partition_view" in text
+ assert "strides = [%arg4, %c1]" in text
+ assert text.count("pto.alloc_tile") == 3
+ assert "pto.tload" in text
+ assert "pto.tstore" in text
+
+
+def test_tensor_view_slice_requires_static_shape_for_dynamic_sizes():
+ def meta_data():
+ return {
+ "ptr_t": pto.ptr(pto.float32),
+ "index_t": IndexType.get(),
+ }
+
+ def kernel(src: "ptr_t", total_elements: "index_t") -> None:
+ tv = pto.make_tensor(src, shape=[total_elements], dtype=pto.float32)
+ tv.slice([0], [total_elements])
+
+ with pytest.raises(TypeError, match="requires `static_shape=`"):
+ to_ir_module(meta_data=meta_data)(kernel)
+
+
+def test_tensor_view_slice_accepts_dynamic_sizes_with_static_shape():
+ def meta_data():
+ return {
+ "ptr_t": pto.ptr(pto.float32),
+ "index_t": IndexType.get(),
+ }
+
+ def kernel(src: "ptr_t", total_elements: "index_t", tile_extent: "index_t") -> None:
+ tv = pto.make_tensor(src, shape=[total_elements], dtype=pto.float32)
+ tv.slice([0], [tile_extent], static_shape=[128])
+
+ module = to_ir_module(meta_data=meta_data)(kernel)
+ text = str(module)
+
+ assert "pto.partition_view" in text
+ assert "!pto.partition_tensor_view<128xf32>" in text
+
+
+def test_micro_ops_accept_objects_with_raw_values():
+ def meta_data():
+ return {
+ "ptr_t": pto.ptr(pto.float32, space="VEC"),
+ "index_t": IndexType.get(),
+ }
+
+ @to_ir_module(meta_data=meta_data)
+ def kernel(src: "ptr_t", dst: "ptr_t", offset: "index_t") -> None:
+ mask = micro.pset_b32(pto.MaskType(), "PAT_ALL")
+ vec = micro.vlds(pto.VRegType(64, pto.float32), _Box(src), offset)
+ micro.vsts(vec, _Box(dst), offset, _Box(mask))
+
+ text = str(kernel)
+
+ assert "pto.pset_b32" in text
+ assert "pto.vlds" in text
+ assert "pto.vsts" in text
diff --git a/tests/frontend/test_sync_frontend.py b/tests/frontend/test_sync_frontend.py
new file mode 100644
index 00000000..9e2501df
--- /dev/null
+++ b/tests/frontend/test_sync_frontend.py
@@ -0,0 +1,24 @@
+from mlir.ir import Context, Location
+from mlir.dialects import pto as mlir_pto
+
+from ptodsl import pto, to_ir_module
+
+
+def test_barrier_sync_emits_pto_barrier_sync_op():
+ @to_ir_module(meta_data=lambda: {})
+ def sync_kernel() -> None:
+ pto.barrier_sync("vec")
+ pto.barrier_sync(mlir_pto.SyncOpTypeAttr.get(mlir_pto.SyncOpType.TMATMUL))
+
+ text = str(sync_kernel)
+
+ assert "pto.barrier_sync[]" in text
+ assert "pto.barrier_sync[]" in text
+
+
+def test_barrier_sync_helper_accepts_string_and_attr():
+ with Context() as ctx, Location.unknown():
+ mlir_pto.register_dialect(ctx, load=True)
+
+ attr = mlir_pto.SyncOpTypeAttr.get(mlir_pto.SyncOpType.TVEC)
+ assert str(attr) == "#pto.sync_op_type"
diff --git a/tests/regression/test_api_regression.py b/tests/regression/test_api_regression.py
new file mode 100644
index 00000000..cf24859e
--- /dev/null
+++ b/tests/regression/test_api_regression.py
@@ -0,0 +1,147 @@
+from types import SimpleNamespace
+
+from mlir.ir import IndexType
+
+from ptodsl import Constexpr, JitWrapper, const_expr, pto, tile, to_ir_module
+from ptodsl import scalar as s
+
+
+const = s.const
+
+
+class _FakeType:
+ def __init__(self, text):
+ self._text = text
+
+ def __str__(self):
+ return self._text
+
+
+def test_low_level_tensor_tile_flow_regression():
+ def meta_data():
+ dtype = pto.float32
+ return {
+ "ptr_t": pto.PtrType(dtype),
+ "tensor_t": pto.TensorType(rank=2, dtype=dtype),
+ "subtensor_t": pto.SubTensorType(shape=[32, 32], dtype=dtype),
+ "tile_t": pto.TileBufType(
+ shape=[32, 32],
+ valid_shape=[32, 32],
+ dtype=dtype,
+ memory_space="VEC",
+ ),
+ "index_t": IndexType.get(),
+ }
+
+ @to_ir_module(meta_data=meta_data)
+ def kernel(src: "ptr_t", dst: "ptr_t", rows: "index_t", cols: "index_t") -> None:
+ c0 = const(0)
+ c1 = const(1)
+ c32 = const(32)
+
+ src_view = pto.as_tensor(tensor_t, ptr=src, shape=[rows, cols], strides=[cols, c1])
+ dst_view = pto.as_tensor(tensor_t, ptr=dst, shape=[rows, cols], strides=[cols, c1])
+ src_tile = pto.slice_view(subtensor_t, source=src_view, offsets=[c0, c0], sizes=[c32, c32])
+ dst_tile = pto.slice_view(subtensor_t, source=dst_view, offsets=[c0, c0], sizes=[c32, c32])
+
+ with pto.vector_section():
+ tb = pto.alloc_tile(tile_t)
+ pto.load(src_tile, tb)
+ tile.add(tb, tb, tb)
+ pto.store(tb, dst_tile)
+
+ text = str(kernel)
+
+ assert "pto.make_tensor_view" in text
+ assert "pto.partition_view" in text
+ assert "pto.tadd" in text
+ assert "pto.tstore" in text
+
+
+def test_pythonic_tensor_tile_flow_regression():
+ def meta_data():
+ return {
+ "ptr_t": pto.ptr(pto.float32),
+ "index_t": IndexType.get(),
+ }
+
+ @to_ir_module(meta_data=meta_data)
+ def kernel(src: "ptr_t", dst: "ptr_t", rows: "index_t", cols: "index_t") -> None:
+ src_view = pto.make_tensor(src, shape=[rows, cols], dtype=pto.float32)
+ dst_view = pto.make_tensor(dst, shape=[rows, cols], dtype=pto.float32)
+
+ with pto.vector_section():
+ buf = pto.make_tile_buffer(pto.float32, [32, 32], space="VEC").alloc()
+ buf.load_from(src_view.slice([0, 0], [32, 32]))
+ buf.store_to(dst_view.slice([0, 0], [32, 32]))
+
+ text = str(kernel)
+
+ assert "pto.make_tensor_view" in text
+ assert "strides = [%arg3, %c1]" in text
+ assert "pto.partition_view" in text
+ assert "pto.tload" in text
+ assert "pto.tstore" in text
+
+
+def test_mixed_tile_and_micro_regression():
+ def meta_data():
+ dtype = pto.float32
+ return {
+ "ptr_t": pto.ptr(dtype, space="VEC"),
+ "index_t": IndexType.get(),
+ "tile_t": pto.TileBufType(
+ shape=[1, 64],
+ valid_shape=[1, 64],
+ dtype=dtype,
+ memory_space="VEC",
+ ),
+ }
+
+ @to_ir_module(meta_data=meta_data)
+ def kernel(src: "ptr_t", dst: "ptr_t", offset: "index_t") -> None:
+ with pto.vector_section():
+ pto.alloc_tile(tile_t)
+ mask = pto.pset_b32(pto.MaskType(), "PAT_ALL")
+ vec = pto.vlds(pto.VRegType(64, pto.float32), src, offset)
+ pto.vsts(vec, dst, offset, mask)
+
+ text = str(kernel)
+
+ assert "pto.alloc_tile" in text
+ assert "pto.pset_b32" in text
+ assert "pto.vlds" in text
+ assert "pto.vsts" in text
+
+
+def test_constexpr_specialization_and_jit_caller_regression():
+ def meta_data(TILE):
+ return {
+ "tile_t": pto.TileBufType(
+ shape=[1, TILE],
+ valid_shape=[1, TILE],
+ dtype=pto.float32,
+ memory_space="VEC",
+ ),
+ }
+
+ @to_ir_module(meta_data=meta_data)
+ def kernel(TILE: Constexpr[int]) -> None:
+ if const_expr(TILE >= 64):
+ with pto.vector_section():
+ pto.alloc_tile(tile_t)
+
+ module = kernel(TILE=64)
+ text = str(module)
+
+ assert "func.func @kernel()" in text
+ assert "scf.if" not in text
+ assert text.count("pto.alloc_tile") == 1
+
+ wrapper = JitWrapper(kernel.__wrapped__, meta_data=meta_data, block_dim=4)
+ wrapper._arg_types = []
+ wrapper._runtime_params = []
+ caller_cpp = wrapper._generate_caller_cpp("generated.cpp")
+
+ assert "TILE" not in caller_cpp
+ assert 'extern "C" void call_kernel(uint32_t blockDim, void *stream)' in caller_cpp