diff --git a/README.md b/README.md
index bfffce44..245eb973 100644
--- a/README.md
+++ b/README.md
@@ -1,50 +1,194 @@
-
-
# PTO-DSL
-Pythonic interface and JIT compiler for [PTO-ISA](https://github.com/PTO-ISA/pto-isa)
-
-PTO-DSL provides a programming abstraction similar to [cuTile](https://docs.nvidia.com/cuda/cutile-python/), but native to [NPU](https://www.hiascend.com/).
+Python DSL for PTO-ISA kernels, with a public `pto` surface for tensor/tile
+authoring, a raw `micro` surface for direct PTO micro instructions, and an A5
+library layer that rewrites tile-style helpers in terms of those micro ops.
-**Key features:**
-- Automatic software pipelining without [manual synchronization](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/850/API/ascendcopapi/atlasascendc_api_07_0179.html)
-- Easily interface with [torch-npu](https://gitcode.com/ascend/pytorch)
-- Lightweight, open-source compiler stack using [PTO Assembler](https://github.com/zhangstevenunity/PTOAS)
+The current repo targets three authoring levels:
-## Installation
+- `ptodsl.pto`: ergonomic tensor, view, tile, sync, and control-flow helpers
+- `ptodsl.micro`: raw PTO micro instruction access such as `vlds`, `vadd`,
+ `vsts`, `pset_b32`, and vector register types
+- `ptodsl.lib.a5`: readable A5 helper implementations that show how tile-style
+ operations are written with PTO micro instructions
-See [docker/README.md](./docker/README.md) for full reproducible dependencies on NPU.
+## Recent Upgrade
-Then, install this lightweight DSL package itself:
+The recent PTODSL upgrade changed the repo in four important ways:
-```bash
-# install latest commit
-pip install git+https://github.com/huawei-csl/pto-dsl.git
+1. `pto.ptr(dtype, space=...)` is now the preferred pointer constructor for
+ explicit memory spaces such as `GM`, `VEC`, `LEFT`, `RIGHT`, and `ACC`.
+2. The public `pto` namespace now includes pythonic builders:
+ `make_tensor(...)`, `TensorView.slice(...)`, `make_tile_buffer(...)`,
+ `TileBufferSpec.alloc()`, and `TileBuffer.load_from()/store_to()`.
+3. The package root now exposes `ptodsl.micro` as the raw micro-op surface.
+4. The A5 library under [`ptodsl/lib/a5`](/Users/zhoubot/github/pto-org/pto-dsl/ptodsl/lib/a5)
+ is organized around tile helpers implemented with PTO micro instructions,
+ and selected pure-micro kernels are validated through PTOAS into
+ `llvm.hivm.*` intrinsics.
-# or stable tag
-pip install git+https://github.com/huawei-csl/pto-dsl.git@0.1.0
-```
+Detailed API notes are in
+[`docs/latest_api.md`](/Users/zhoubot/github/pto-org/pto-dsl/docs/latest_api.md).
+
+## Install
-For in-place development:
+PTODSL depends on the PTO dialect Python bindings from PTOAS and an MLIR Python
+environment. For a reproducible setup, start with
+[`docker/README.md`](/Users/zhoubot/github/pto-org/pto-dsl/docker/README.md).
+
+For local development:
```bash
-git clone https://github.com/huawei-csl/pto-dsl.git
+git clone https://github.com/PTO-ISA/pto-dsl.git
cd pto-dsl
pip install -e .
```
-## Usage
+Typical local testing in this repo also needs PTOAS and MLIR on `PYTHONPATH`,
+for example:
+
+```bash
+PYTHONPATH=/path/to/mlir_core:/path/to/PTOAS/install:/path/to/PTOAS/build/python \
+python -m pytest -q tests/frontend tests/regression
+```
+
+## Public API
+
+### 1. Pythonic `pto`
+
+Use `ptodsl.pto` for tensor/view/tile construction:
+
+```python
+from mlir.ir import IndexType
+from ptodsl import pto, tile, to_ir_module
+
+
+def meta_data():
+ return {
+ "ptr_t": pto.ptr(pto.float32),
+ "index_t": IndexType.get(),
+ }
+
+
+@to_ir_module(meta_data=meta_data)
+def add_tile(src0: "ptr_t", src1: "ptr_t", dst: "ptr_t", valid_row: "index_t", valid_col: "index_t") -> None:
+ lhs = pto.make_tensor(src0, shape=[8, 64], dtype=pto.float32)
+ rhs = pto.make_tensor(src1, shape=[8, 64], dtype=pto.float32)
+ out = pto.make_tensor(dst, shape=[8, 64], dtype=pto.float32)
+
+ lhs_tile = lhs.slice([0, 0], [8, 64])
+ rhs_tile = rhs.slice([0, 0], [8, 64])
+ out_tile = out.slice([0, 0], [8, 64])
+
+ with pto.vector_section():
+ lhs_buf = pto.make_tile_buffer(
+ pto.float32,
+ [8, 64],
+ space="VEC",
+ valid_shape=[-1, -1],
+ ).alloc(valid_row=valid_row, valid_col=valid_col)
+ rhs_buf = pto.make_tile_buffer(
+ pto.float32,
+ [8, 64],
+ space="VEC",
+ valid_shape=[-1, -1],
+ ).alloc(valid_row=valid_row, valid_col=valid_col)
+ out_buf = pto.make_tile_buffer(
+ pto.float32,
+ [8, 64],
+ space="VEC",
+ valid_shape=[-1, -1],
+ ).alloc(valid_row=valid_row, valid_col=valid_col)
+
+ lhs_buf.load_from(lhs_tile)
+ rhs_buf.load_from(rhs_tile)
+ tile.add(lhs_buf, rhs_buf, out_buf)
+ out_buf.store_to(out_tile)
+```
+
+This still emits native PTO tensor/tile IR such as `pto.make_tensor_view`,
+`pto.partition_view`, `pto.alloc_tile`, `pto.tload`, `pto.tadd`, and
+`pto.tstore`.
+
+### 2. Raw `micro`
+
+Use `ptodsl.micro` when you want to write the micro instruction sequence
+directly:
+
+```python
+from mlir.ir import IndexType
+from ptodsl import micro, pto, to_ir_module
+from ptodsl.api.scalar import _unwrap
+
+
+def meta_data():
+ return {
+ "ptr_t": pto.ptr(pto.float32, space="VEC"),
+ "index_t": IndexType.get(),
+ }
+
+
+@to_ir_module(meta_data=meta_data)
+def vadd_demo(src0: "ptr_t", src1: "ptr_t", dst: "ptr_t", offset: "index_t") -> None:
+ v64f32 = micro.VRegType.get(64, pto.float32)
+ mask = micro.pset_b32(micro.MaskType.get(), "PAT_ALL")
+ lhs = micro.vlds(v64f32, _unwrap(src0), _unwrap(offset))
+ rhs = micro.vlds(v64f32, _unwrap(src1), _unwrap(offset))
+ out = micro.vadd(v64f32, lhs, rhs, mask)
+ micro.vsts(out, _unwrap(dst), _unwrap(offset), mask)
+```
+
+This is the most direct PTODSL surface for VPTO/PTOAS lowering.
+
+### 3. A5 Library
+
+The A5 layer under [`ptodsl/lib/a5`](/Users/zhoubot/github/pto-org/pto-dsl/ptodsl/lib/a5)
+shows how tile-style helpers map to micro instructions:
+
+- `tadd` is written with `pto.vlds`, `pto.vadd`, and `pto.vsts`
+- `trow_sum` is written with `pto.vcadd` plus vector combine/store logic
+- `tcol_expand`, `tgather`, `tmrgsort`, and `tsort32` are expressed directly in
+ terms of PTO micro opcodes where supported
+
+See [`ptodsl/lib/a5/README.md`](/Users/zhoubot/github/pto-org/pto-dsl/ptodsl/lib/a5/README.md)
+for the file layout and generation flow.
-See [examples](./examples) and [tests](./tests)
+## End-to-End Flow
-## Contribute
+The repo currently tracks two useful flows:
-See [contribute_guide.md](./contribute_guide.md)
+- PTODSL frontend coverage:
+ tensor/view/tile and A5 examples emit correct `.pto`
+- PTODSL -> PTOAS -> HIVM proof path:
+ pure micro kernels such as
+ [`a5_hivm_vadd_demo.pto`](/Users/zhoubot/github/pto-org/pto-dsl/ptodsl/lib/a5/generated/a5_hivm_vadd_demo.pto)
+ lower through PTOAS into
+ [`a5_hivm_vadd_demo.ll`](/Users/zhoubot/github/pto-org/pto-dsl/ptodsl/lib/a5/generated/a5_hivm_vadd_demo.ll)
+
+Generated examples live in
+[`ptodsl/lib/a5/generated`](/Users/zhoubot/github/pto-org/pto-dsl/ptodsl/lib/a5/generated).
+
+## Tests
+
+The repo currently uses:
+
+- [`tests/frontend`](/Users/zhoubot/github/pto-org/pto-dsl/tests/frontend) for
+ frontend IR construction
+- [`tests/regression`](/Users/zhoubot/github/pto-org/pto-dsl/tests/regression)
+ for A5 library coverage, generated artifact expectations, and public-surface
+ regressions
+
+Run them with:
+
+```bash
+PYTHONPATH=/path/to/mlir_core:/path/to/PTOAS/install:/path/to/PTOAS/build/python \
+python -m pytest -q tests/frontend tests/regression
+```
-## Compare to other frameworks
+## Related Files
-PTO-DSL aims for **low-level, explicit, NPU-native primitives** that can match the performance of **programming in [hardware intrinsics](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/850/API/cceintrinsicapi/cceapi_0001.html)**. Compared to other (also very good) kernel programming frameworks, it has a bit different scope by design:
-- vs [tilelang-ascend](https://github.com/tile-ai/tilelang-ascend): tilelang can also [use PTO-ISA as codegen backend](https://github.com/tile-ai/tilelang-ascend/blob/76553755da078479a7f60cce9c5f0e9a24d0008b/src/target/codegen_ascend_pto.cc). PTO-DSL intentionally exposes lower-level control, for example L2 swizzling is one-liner `T.use_swizzle` in tilelang, but is a user-defined custom function in PTO-DSL -- see this [matmul optimization example](examples/aot/matmul_optimization_guide/matmul_optim_guide.md). Once PTO-DSL is more stabilized, it might serve as a component like the [CuteDSL backend for tilelang](https://github.com/tile-ai/tilelang/blob/v0.1.8/src/target/codegen_cutedsl.cc).
-- vs [triton-ascend](https://gitcode.com/Ascend/triton-ascend): Both frameworks automate software pipelining based on some MLIR dialects for NPU. PTO-DSL exposes more NPU-native memory hierarchy such as `L0`/`L1`/`UB`. Also, `pto.load`/`pto.store` always maps to native efficient DMA instructions, while `tl.load`/`tl.store` tries to do GPU-style memory coalescing.
-- vs [Catlass](https://gitcode.com/cann/catlass): Catlass provides expert-optimized template collections, while PTO-DSL is more like the [CuteDSL](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/overview.html) layer of Cutlass, offering explicit low-level primitives.
-- vs [PyPTO](https://gitcode.com/cann/pypto): PyPTO is a full [MPMD](https://en.wikipedia.org/wiki/Flynn%27s_taxonomy#Multiple_programs,_multiple_data_streams_(MPMD)) dynamic runtime stack, which also [uses PTO-ISA as lowest-level primitive](https://gitcode.com/cann/pypto/tree/r0.1.1/framework/src/interface/tileop). PyPTO's Tensor API abstraction is closer to PyTorch/JAX level, while a PTO-DSL kernel is still [SPMD](https://en.wikipedia.org/wiki/Single_program,_multiple_data) and is closer to CuTile/CuteDSL level.
+- [`docs/latest_api.md`](/Users/zhoubot/github/pto-org/pto-dsl/docs/latest_api.md)
+- [`ptodsl/api/pto.py`](/Users/zhoubot/github/pto-org/pto-dsl/ptodsl/api/pto.py)
+- [`ptodsl/api/micro.py`](/Users/zhoubot/github/pto-org/pto-dsl/ptodsl/api/micro.py)
+- [`ptodsl/lib/a5/README.md`](/Users/zhoubot/github/pto-org/pto-dsl/ptodsl/lib/a5/README.md)
+- [`contribute_guide.md`](/Users/zhoubot/github/pto-org/pto-dsl/contribute_guide.md)
diff --git a/docs/latest_api.md b/docs/latest_api.md
new file mode 100644
index 00000000..fc7664c4
--- /dev/null
+++ b/docs/latest_api.md
@@ -0,0 +1,120 @@
+# PTODSL Latest API
+
+This document summarizes the public PTODSL surface after the recent A5/micro
+upgrade and explains which layer to use for each kind of kernel.
+
+## Public Layers
+
+### `ptodsl.pto`
+
+Use this layer for:
+
+- pointer/type construction
+- tensor and partitioned-view authoring
+- tile buffer allocation and `tload`/`tstore`
+- control flow and synchronization
+
+Key entry points:
+
+- `pto.ptr(dtype, space=None)`
+- `pto.TensorType(rank=..., dtype=...)`
+- `pto.SubTensorType(shape=..., dtype=...)`
+- `pto.TileBufType(shape=..., dtype=..., memory_space=..., valid_shape=..., config=...)`
+- `pto.make_tensor(ptr, shape=..., strides=None, dtype=..., type=None, layout=None)`
+- `TensorView.slice(offsets, sizes, static_shape=None)`
+- `pto.make_tile_buffer(dtype, shape, space=..., valid_shape=None, config=None)`
+- `TileBufferSpec.alloc(addr=None, valid_row=None, valid_col=None)`
+- `TileBuffer.load_from(view)` / `TileBuffer.store_to(view)`
+
+Type aliases currently exposed through `pto`:
+
+- `bool`
+- `float16`
+- `float32`
+- `bfloat16`
+- `int8`
+- `int16`
+- `int32`
+- `uint8`
+- `uint16`
+- `uint32`
+
+## `ptodsl.micro`
+
+Use this layer when you want raw PTO micro instructions without going through
+tile helpers.
+
+Examples:
+
+- `micro.vlds`
+- `micro.vadd`
+- `micro.vsts`
+- `micro.vcadd`
+- `micro.vgather2`
+- `micro.vmrgsort4`
+- `micro.vbitsort`
+- `micro.pset_b32`
+- `micro.VRegType`
+- `micro.MaskType`
+
+This layer is a thin pass-through over the PTO dialect Python bindings, filtered
+to the public micro-op surface.
+
+## `ptodsl.lib.a5`
+
+Use the A5 library when you want readable, opcode-focused examples of how an
+existing A5 tile helper is expressed with PTO micro instructions.
+
+Examples:
+
+- `a5.tadd`
+- `a5.tadds`
+- `a5.trow_sum`
+- `a5.tcol_expand`
+- `a5.tgather`
+- `a5.tsort32`
+
+The split modules in [`ptodsl/lib/a5`](/Users/zhoubot/github/pto-org/pto-dsl/ptodsl/lib/a5)
+are organized by tile helper family:
+
+- [`tbinary.py`](/Users/zhoubot/github/pto-org/pto-dsl/ptodsl/lib/a5/tbinary.py)
+- [`tscalar.py`](/Users/zhoubot/github/pto-org/pto-dsl/ptodsl/lib/a5/tscalar.py)
+- [`tunary.py`](/Users/zhoubot/github/pto-org/pto-dsl/ptodsl/lib/a5/tunary.py)
+- [`texpand.py`](/Users/zhoubot/github/pto-org/pto-dsl/ptodsl/lib/a5/texpand.py)
+- [`treduce.py`](/Users/zhoubot/github/pto-org/pto-dsl/ptodsl/lib/a5/treduce.py)
+- [`tsort.py`](/Users/zhoubot/github/pto-org/pto-dsl/ptodsl/lib/a5/tsort.py)
+
+## Compile-Time vs Runtime Values
+
+PTODSL now follows the same staging model as the PTO C++ tile headers:
+
+- compile-time constants:
+ dtype, memory space, tile capacity, tile layout/config, specialization knobs
+- runtime values:
+ pointers, offsets, valid row/column bounds, problem sizes
+
+In practice:
+
+- `tile_shape=[ROWS, COLS]` describes the fixed tile envelope
+- `valid_row` and `valid_col` describe the runtime active region when the valid
+ box is dynamic
+- `Constexpr[...]` is used in template-style builders such as
+ `build_templated_elementwise_add`
+
+## End-to-End Lowering
+
+The strongest validated path today is:
+
+1. write a pure micro kernel in PTODSL
+2. emit `.pto`
+3. lower with PTOAS VPTO
+4. inspect emitted `llvm.hivm.*` intrinsics
+
+Reference artifacts:
+
+- [`a5_hivm_vadd_demo.pto`](/Users/zhoubot/github/pto-org/pto-dsl/ptodsl/lib/a5/generated/a5_hivm_vadd_demo.pto)
+- [`a5_hivm_vadd_demo.ll`](/Users/zhoubot/github/pto-org/pto-dsl/ptodsl/lib/a5/generated/a5_hivm_vadd_demo.ll)
+
+The higher-level tensor/tile frontend remains fully useful for PTODSL authoring
+and regression coverage, but the pure micro path is the clearest proof route
+for PTOAS HIVM lowering.
diff --git a/ptodsl/__init__.py b/ptodsl/__init__.py
index 5ed02b28..e0994323 100644
--- a/ptodsl/__init__.py
+++ b/ptodsl/__init__.py
@@ -1,4 +1,4 @@
-from . import pto, scalar, tile
+from . import micro, pto, scalar, tile
from .bench import do_bench
from .compiler.ir import to_ir_module
from .compiler.jit import JitWrapper, jit
@@ -10,6 +10,7 @@
"const_expr",
"do_bench",
"jit",
+ "micro",
"pto",
"range_constexpr",
"scalar",
diff --git a/ptodsl/api/__init__.py b/ptodsl/api/__init__.py
index ca7e01f9..c2a06941 100644
--- a/ptodsl/api/__init__.py
+++ b/ptodsl/api/__init__.py
@@ -1,3 +1,3 @@
-from . import pto, scalar, tile
+from . import micro, pto, scalar, tile
-__all__ = ["pto", "scalar", "tile"]
+__all__ = ["micro", "pto", "scalar", "tile"]
diff --git a/ptodsl/api/micro.py b/ptodsl/api/micro.py
new file mode 100644
index 00000000..f095e47a
--- /dev/null
+++ b/ptodsl/api/micro.py
@@ -0,0 +1,34 @@
+from mlir.dialects import pto as _pto
+
+
+def _is_public_micro_symbol(name):
+ if name.startswith("_"):
+ return False
+ if name in {
+ "AddressSpace",
+ "AddressSpaceAttr",
+ "AlignType",
+ "MaskType",
+ "VRegType",
+ }:
+ return True
+ if name.startswith("v") or name.startswith("p"):
+ return True
+ return False
+
+
+def __getattr__(name):
+ try:
+ value = getattr(_pto, name)
+ except AttributeError as exc:
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'") from exc
+ if not _is_public_micro_symbol(name):
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
+ return value
+
+
+def __dir__():
+ return sorted(__all__)
+
+
+__all__ = sorted(name for name in dir(_pto) if _is_public_micro_symbol(name))
diff --git a/ptodsl/api/pto.py b/ptodsl/api/pto.py
index 158bca41..32c8a9ad 100644
--- a/ptodsl/api/pto.py
+++ b/ptodsl/api/pto.py
@@ -1,6 +1,9 @@
from .control_flow import cond, range, if_context
from .scalar import Value, wrap_value
from .pto_general import (
+ TensorView,
+ TileBuffer,
+ TileBufferSpec,
alloc_tile,
as_tensor,
cube_section,
@@ -9,6 +12,8 @@
get_subblock_idx,
get_subblock_num,
load,
+ make_tensor,
+ make_tile_buffer,
slice_view,
store,
vector_section,
@@ -32,19 +37,28 @@
"bool",
"float16",
"float32",
+ "bfloat16",
"int16",
"int32",
+ "int8",
+ "uint32",
+ "uint16",
+ "uint8",
"PtrType",
"ptr",
"TensorType",
"SubTensorType",
"TileBufConfig",
"TileBufType",
+ "TensorView",
+ "TileBuffer",
+ "TileBufferSpec",
"get_block_idx",
"get_subblock_idx",
"get_subblock_num",
"get_block_num",
"as_tensor",
+ "make_tensor",
"slice_view",
"vector_section",
"cube_section",
@@ -52,6 +66,7 @@
"if_context",
"cond",
"alloc_tile",
+ "make_tile_buffer",
"load",
"store",
"print",
diff --git a/ptodsl/api/pto_general.py b/ptodsl/api/pto_general.py
index c8f649ea..146113fe 100644
--- a/ptodsl/api/pto_general.py
+++ b/ptodsl/api/pto_general.py
@@ -1,9 +1,11 @@
from contextlib import contextmanager
+from mlir.dialects import arith
from mlir.dialects import pto as _pto
-from mlir.ir import InsertionPoint
+from mlir.ir import IndexType, InsertionPoint
from .scalar import Value, _unwrap
+from .type_def import SubTensorType, TensorType, TileBufType
def get_block_idx():
@@ -30,6 +32,101 @@ def _resolve_layout_attr(layout):
return layout
+def _index_value(value):
+ if isinstance(value, int):
+ return arith.ConstantOp(IndexType.get(), value).result
+ return _unwrap(value)
+
+
+def _mul_index(lhs, rhs):
+ if isinstance(lhs, int) and isinstance(rhs, int):
+ return lhs * rhs
+ return Value(arith.MulIOp(_index_value(lhs), _index_value(rhs)).result)
+
+
+def _resolve_tensor_dtype(type_or_value):
+ candidate = type_or_value.type if hasattr(type_or_value, "type") else type_or_value
+ return getattr(candidate, "element_type", None)
+
+
+def _row_major_strides(shape):
+ strides = [None] * len(shape)
+ stride = 1
+ for index in range(len(shape) - 1, -1, -1):
+ strides[index] = stride
+ dim = shape[index]
+ stride = _mul_index(dim, stride)
+ return [_index_value(stride) for stride in strides]
+
+
+class TensorView:
+ def __init__(self, raw, *, dtype=None):
+ self.raw = raw
+ self.dtype = dtype if dtype is not None else _resolve_tensor_dtype(raw)
+
+ def slice(self, offsets, sizes, *, static_shape=None):
+ if static_shape is None:
+ if not all(isinstance(size, int) for size in sizes):
+ raise ValueError(
+ "TensorView.slice(...) requires static_shape when any size is dynamic."
+ )
+ static_shape = list(sizes)
+ return TensorView(
+ slice_view(
+ SubTensorType(shape=static_shape, dtype=self.dtype),
+ source=self.raw,
+ offsets=offsets,
+ sizes=sizes,
+ ),
+ dtype=self.dtype,
+ )
+
+
+class TileBufferSpec:
+ def __init__(self, *, dtype, shape, space, valid_shape=None, config=None):
+ self.dtype = dtype
+ self.shape = list(shape)
+ self.space = space
+ self.valid_shape = list(shape) if valid_shape is None else list(valid_shape)
+ self.config = config
+ self._raw_type = TileBufType(
+ shape=self.shape,
+ dtype=dtype,
+ memory_space=space,
+ valid_shape=self.valid_shape,
+ config=config,
+ )
+
+ @property
+ def raw_type(self):
+ return self._raw_type
+
+ def alloc(self, *, addr=None, valid_row=None, valid_col=None):
+ return TileBuffer(
+ alloc_tile(
+ self.raw_type,
+ addr=addr,
+ valid_row=valid_row,
+ valid_col=valid_col,
+ ),
+ spec=self,
+ )
+
+
+class TileBuffer:
+ def __init__(self, raw, *, spec=None):
+ self.raw = raw
+ self.spec = spec
+
+ def load_from(self, view):
+ load(view, self)
+ return self
+
+ def store_to(self, view):
+ store(self, view)
+ return view
+
+
def as_tensor(tensor_type, *, ptr, shape, strides, layout=None):
shape_vals = [_unwrap(v) for v in shape]
stride_vals = [_unwrap(v) for v in strides]
@@ -43,10 +140,10 @@ def as_tensor(tensor_type, *, ptr, shape, strides, layout=None):
def slice_view(subtensor_type, *, source, offsets, sizes):
- offset_vals = [_unwrap(v) for v in offsets]
- size_vals = [_unwrap(v) for v in sizes]
+ offset_vals = [_index_value(v) for v in offsets]
+ size_vals = [_index_value(v) for v in sizes]
return _pto.PartitionViewOp(
- subtensor_type, source, offsets=offset_vals, sizes=size_vals
+ subtensor_type, _unwrap(source), offsets=offset_vals, sizes=size_vals
).result
@@ -71,18 +168,48 @@ def alloc_tile(tile_type, *, addr=None, valid_row=None, valid_col=None):
if addr is not None:
kwargs["addr"] = _unwrap(addr)
if valid_row is not None:
- kwargs["valid_row"] = _unwrap(valid_row)
+ kwargs["valid_row"] = _index_value(valid_row)
if valid_col is not None:
- kwargs["valid_col"] = _unwrap(valid_col)
+ kwargs["valid_col"] = _index_value(valid_col)
return _pto.AllocTileOp(tile_type, **kwargs).result
def load(source, dest):
- _pto.TLoadOp(None, source, dest)
+ _pto.TLoadOp(None, _unwrap(source), _unwrap(dest))
def store(source, dest):
- _pto.TStoreOp(None, source, dest)
+ _pto.TStoreOp(None, _unwrap(source), _unwrap(dest))
+
+
+def make_tensor(ptr, *, shape, strides=None, dtype=None, type=None, layout=None):
+ if type is None:
+ if dtype is None:
+ raise ValueError("make_tensor(...) requires dtype when type is omitted.")
+ type = TensorType(rank=len(shape), dtype=dtype)
+ resolved_dtype = dtype if dtype is not None else _resolve_tensor_dtype(type)
+ if strides is None:
+ strides = _row_major_strides(shape)
+ return TensorView(
+ as_tensor(
+ type,
+ ptr=ptr,
+ shape=[_index_value(v) for v in shape],
+ strides=[_index_value(v) for v in strides],
+ layout=layout,
+ ),
+ dtype=resolved_dtype,
+ )
+
+
+def make_tile_buffer(dtype, shape, *, space, valid_shape=None, config=None):
+ return TileBufferSpec(
+ dtype=dtype,
+ shape=shape,
+ space=space,
+ valid_shape=valid_shape,
+ config=config,
+ )
def print(format, scalar):
@@ -106,11 +233,16 @@ def print(format, scalar):
"get_subblock_idx",
"get_subblock_num",
"get_block_num",
+ "TensorView",
+ "TileBuffer",
+ "TileBufferSpec",
"as_tensor",
+ "make_tensor",
"slice_view",
"vector_section",
"cube_section",
"alloc_tile",
+ "make_tile_buffer",
"load",
"store",
"print",
diff --git a/ptodsl/api/scalar.py b/ptodsl/api/scalar.py
index 93b2770a..b228bc03 100644
--- a/ptodsl/api/scalar.py
+++ b/ptodsl/api/scalar.py
@@ -1,5 +1,6 @@
+from mlir import ir as mlir_ir
from mlir.dialects import arith
-from mlir.ir import F16Type, F32Type, IndexType, IntegerType
+from mlir.ir import IndexType, IntegerType
def _unwrap(value):
@@ -84,21 +85,39 @@ def wrap_value(value):
return Value(value)
+def _get_mlir_float_type(alias_name, *type_names):
+ for type_name in type_names:
+ type_ctor = getattr(mlir_ir, type_name, None)
+ if type_ctor is not None:
+ return type_ctor.get()
+ supported = ", ".join(type_names)
+ raise AttributeError(
+ f"module '{__name__}' has no attribute '{alias_name}' because the active MLIR "
+ f"Python bindings do not expose any of: {supported}"
+ )
+
+
def __getattr__(name):
- # TODO: add more builtin dtype aliases (for example float16/bfloat16/int8/int64)
- # when they are validated against PTO type support.
if name == "bool":
return IntegerType.get_signless(1)
if name == "float32":
- return F32Type.get()
+ return _get_mlir_float_type(name, "F32Type", "Float32Type")
if name == "float16":
- return F16Type.get()
+ return _get_mlir_float_type(name, "F16Type", "Float16Type")
+ if name == "bfloat16":
+ return _get_mlir_float_type(name, "BF16Type")
if name == "int32":
return IntegerType.get_signless(32)
if name == "int16":
return IntegerType.get_signless(16)
+ if name == "int8":
+ return IntegerType.get_signless(8)
if name == "uint32":
return IntegerType.get_unsigned(32)
+ if name == "uint16":
+ return IntegerType.get_unsigned(16)
+ if name == "uint8":
+ return IntegerType.get_unsigned(8)
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
@@ -156,6 +175,16 @@ def select(cond, true_val, false_val):
"Value",
"_unwrap",
"wrap_value",
+ "bool",
+ "float16",
+ "float32",
+ "bfloat16",
+ "int32",
+ "int16",
+ "int8",
+ "uint32",
+ "uint16",
+ "uint8",
"const",
"index_cast",
"ceil_div",
diff --git a/ptodsl/api/tile.py b/ptodsl/api/tile.py
index 2cffe513..ccc7c031 100644
--- a/ptodsl/api/tile.py
+++ b/ptodsl/api/tile.py
@@ -6,147 +6,160 @@
def mov(source, dest):
- _pto.TMovOp(None, source, dest)
+ _pto.TMovOp(None, _unwrap(source), _unwrap(dest))
def add(lhs, rhs, out):
- _pto.TAddOp(lhs, rhs, out)
+ _pto.TAddOp(_unwrap(lhs), _unwrap(rhs), _unwrap(out))
def sub(lhs, rhs, out):
- _pto.TSubOp(lhs, rhs, out)
+ _pto.TSubOp(_unwrap(lhs), _unwrap(rhs), _unwrap(out))
def div(lhs, rhs, out):
- _pto.TDivOp(lhs, rhs, out)
+ _pto.TDivOp(_unwrap(lhs), _unwrap(rhs), _unwrap(out))
def mul(lhs, rhs, out):
- _pto.TMulOp(lhs, rhs, out)
+ _pto.TMulOp(_unwrap(lhs), _unwrap(rhs), _unwrap(out))
def or_(lhs, rhs, out):
- _pto.TOrOp(lhs, rhs, out)
+ _pto.TOrOp(_unwrap(lhs), _unwrap(rhs), _unwrap(out))
def min(lhs, rhs, out):
- _pto.TMinOp(lhs, rhs, out)
+ _pto.TMinOp(_unwrap(lhs), _unwrap(rhs), _unwrap(out))
def max(lhs, rhs, out):
- _pto.TMaxOp(lhs, rhs, out)
+ _pto.TMaxOp(_unwrap(lhs), _unwrap(rhs), _unwrap(out))
def gather(src, out, indices=None, *, mask_pattern=None):
if mask_pattern is not None:
mask = _pto.MaskPatternAttr.get(getattr(_pto.MaskPattern, mask_pattern))
- _pto.TGatherOp(src, out, maskPattern=mask)
+ _pto.TGatherOp(_unwrap(src), _unwrap(out), maskPattern=mask)
else:
- _pto.TGatherOp(src, out, indices=indices)
+ _pto.TGatherOp(_unwrap(src), _unwrap(out), indices=_unwrap(indices))
def exp(inp, out):
- _pto.TExpOp(inp, out)
+ _pto.TExpOp(_unwrap(inp), _unwrap(out))
def log(inp, out):
- _pto.TLogOp(inp, out)
+ _pto.TLogOp(_unwrap(inp), _unwrap(out))
def relu(inp, out):
- _pto.TReluOp(inp, out)
+ _pto.TReluOp(_unwrap(inp), _unwrap(out))
def abs(inp, out):
- _pto.TAbsOp(inp, out)
+ _pto.TAbsOp(_unwrap(inp), _unwrap(out))
def sqrt(inp, out):
- _pto.TSqrtOp(inp, out)
+ _pto.TSqrtOp(_unwrap(inp), _unwrap(out))
def rsqrt(inp, out):
- _pto.TRsqrtOp(inp, out)
+ _pto.TRsqrtOp(_unwrap(inp), _unwrap(out))
def reciprocal(inp, out):
- _pto.TRecipOp(inp, out)
+ _pto.TRecipOp(_unwrap(inp), _unwrap(out))
def matmul(lhs, rhs, out):
- _pto.TMatmulOp(None, lhs, rhs, out)
+ _pto.TMatmulOp(None, _unwrap(lhs), _unwrap(rhs), _unwrap(out))
def matmul_bias(lhs, rhs, bias, out):
- _pto.TMatmulBiasOp(None, lhs, rhs, bias, out)
+ _pto.TMatmulBiasOp(None, _unwrap(lhs), _unwrap(rhs), _unwrap(bias), _unwrap(out))
def matmul_acc(acc, lhs, rhs, out):
- _pto.TMatmulAccOp(None, acc, lhs, rhs, out)
+ _pto.TMatmulAccOp(None, _unwrap(acc), _unwrap(lhs), _unwrap(rhs), _unwrap(out))
def extract(source, index_row, index_col, out):
_pto.TExtractOp(
- src=source, indexRow=_unwrap(index_row), indexCol=_unwrap(index_col), dst=out
+ src=_unwrap(source),
+ indexRow=_unwrap(index_row),
+ indexCol=_unwrap(index_col),
+ dst=_unwrap(out),
)
def row_sum(src, tmp, dst):
- _pto.TRowSumOp(src=src, tmp=tmp, dst=dst)
+ _pto.TRowSumOp(src=_unwrap(src), tmp=_unwrap(tmp), dst=_unwrap(dst))
def row_min(src, tmp, dst):
- _pto.TRowMinOp(src=src, tmp=tmp, dst=dst)
+ _pto.TRowMinOp(src=_unwrap(src), tmp=_unwrap(tmp), dst=_unwrap(dst))
def row_max(src, tmp, dst):
- _pto.TRowMaxOp(src=src, tmp=tmp, dst=dst)
+ _pto.TRowMaxOp(src=_unwrap(src), tmp=_unwrap(tmp), dst=_unwrap(dst))
def row_prod(src, tmp, dst):
- _pto.TRowProdOp(src=src, tmp=tmp, dst=dst)
+ _pto.TRowProdOp(src=_unwrap(src), tmp=_unwrap(tmp), dst=_unwrap(dst))
def row_expand(src, dst):
- _pto.TRowExpandOp(src=src, dst=dst)
+ _pto.TRowExpandOp(src=_unwrap(src), dst=_unwrap(dst))
def row_expand_sub(src0, src1, dst):
- _pto.TRowExpandSubOp(src0=src0, src1=src1, dst=dst)
+ _pto.TRowExpandSubOp(src0=_unwrap(src0), src1=_unwrap(src1), dst=_unwrap(dst))
def row_expand_div(src0, src1, dst):
- _pto.TRowExpandDivOp(src0=src0, src1=src1, dst=dst)
+ _pto.TRowExpandDivOp(src0=_unwrap(src0), src1=_unwrap(src1), dst=_unwrap(dst))
def row_expand_mul(src0, src1, dst):
- _pto.TRowExpandMulOp(src0=src0, src1=src1, dst=dst)
+ _pto.TRowExpandMulOp(src0=_unwrap(src0), src1=_unwrap(src1), dst=_unwrap(dst))
def col_sum(src, tmp, dst, is_binary=True):
- _pto.TColSumOp(src=src, dst=dst, tmp=tmp, isBinary=BoolAttr.get(is_binary))
+ _pto.TColSumOp(
+ src=_unwrap(src),
+ dst=_unwrap(dst),
+ tmp=_unwrap(tmp),
+ isBinary=BoolAttr.get(is_binary),
+ )
def col_min(src, dst):
- _pto.TColMinOp(src=src, dst=dst)
+ _pto.TColMinOp(src=_unwrap(src), dst=_unwrap(dst))
def col_max(src, dst):
- _pto.TColMaxOp(src=src, dst=dst)
+ _pto.TColMaxOp(src=_unwrap(src), dst=_unwrap(dst))
def col_prod(src, tmp, dst, is_binary=True):
- _pto.TColProdOp(src=src, dst=dst, tmp=tmp, isBinary=BoolAttr.get(is_binary))
+ _pto.TColProdOp(
+ src=_unwrap(src),
+ dst=_unwrap(dst),
+ tmp=_unwrap(tmp),
+ isBinary=BoolAttr.get(is_binary),
+ )
def col_expand(src, dst):
- _pto.TColExpandOp(src=src, dst=dst)
+ _pto.TColExpandOp(src=_unwrap(src), dst=_unwrap(dst))
def mrgsort(src, dst, block_len):
i32 = IntegerType.get_signless(32)
block_len_i32 = _arith.IndexCastOp(i32, _unwrap(block_len)).result
- _pto.TMrgSortOp(srcs=[src], dsts=[dst], blockLen=block_len_i32)
+ _pto.TMrgSortOp(srcs=[_unwrap(src)], dsts=[_unwrap(dst)], blockLen=block_len_i32)
def sort32(src, dst, idx):
@@ -154,16 +167,16 @@ def sort32(src, dst, idx):
(score, index) pairs to dst. idx is an input tile of uint32 indices
attached to each src element. For float16 src, dst must have 4x the
columns of src (each element expands to 4 float16 words)."""
- _pto.TSort32Op(src, dst, idx)
+ _pto.TSort32Op(_unwrap(src), _unwrap(dst), _unwrap(idx))
def subset(source, offsets, sizes):
offset_vals = [_unwrap(v) for v in offsets]
- return _pto.subset(source, offset_vals, sizes)
+ return _pto.subset(_unwrap(source), offset_vals, sizes)
def print(source):
- _pto.tprint(source)
+ _pto.tprint(_unwrap(source))
__all__ = [
diff --git a/ptodsl/api/type_def.py b/ptodsl/api/type_def.py
index f0ef6872..b4272e28 100644
--- a/ptodsl/api/type_def.py
+++ b/ptodsl/api/type_def.py
@@ -6,7 +6,18 @@
def __getattr__(name):
# MLIR type factories require an active context, so keep dtype aliases lazy
# and resolve them only when user code accesses them inside PTO/MLIR setup.
- if name in {"bool", "float16", "float32", "int16", "int32", "uint32"}:
+ if name in {
+ "bool",
+ "float16",
+ "float32",
+ "bfloat16",
+ "int16",
+ "int32",
+ "int8",
+ "uint32",
+ "uint16",
+ "uint8",
+ }:
return getattr(scalar, name)
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
@@ -125,7 +136,11 @@ def TileBufType(*, shape, dtype, memory_space, valid_shape=None, config=None):
"bool",
"float16",
"float32",
+ "bfloat16",
"int16",
"int32",
+ "int8",
"uint32",
+ "uint16",
+ "uint8",
]
diff --git a/ptodsl/language.py b/ptodsl/language.py
index 5ee76d5e..a7ec3403 100644
--- a/ptodsl/language.py
+++ b/ptodsl/language.py
@@ -175,6 +175,14 @@ def __getattr__(name):
return IntegerType.get_signless(32)
if name == "int16":
return IntegerType.get_signless(16)
+ if name == "int8":
+ return IntegerType.get_signless(8)
+ if name == "uint32":
+ return IntegerType.get_unsigned(32)
+ if name == "uint16":
+ return IntegerType.get_unsigned(16)
+ if name == "uint8":
+ return IntegerType.get_unsigned(8)
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
diff --git a/ptodsl/lib/a5/README.md b/ptodsl/lib/a5/README.md
index 1a824e3b..ed5b600a 100644
--- a/ptodsl/lib/a5/README.md
+++ b/ptodsl/lib/a5/README.md
@@ -13,6 +13,8 @@ The scope of this layout is:
- Example builder kernels that emit `.pto` through PTODSL
- A checked-in generation flow for reproducible `.pto` artifacts and HIVM LLVM
sidecars for pure micro kernels
+- A public PTODSL import surface that now includes `ptodsl.pto` for pythonic
+ tensor/tile construction and `ptodsl.micro` for raw PTO micro instructions
Entry points:
@@ -33,6 +35,13 @@ Entry points:
`llvm.hivm.*` intrinsics
- [`generated`](./generated): emitted `.pto` artifacts from `scripts/generate_a5_pto.py`
+Recommended usage:
+
+- use `ptodsl.pto` to build tensors, views, and tile buffers ergonomically
+- use `ptodsl.micro` when you want raw PTO micro instructions directly
+- use this A5 layer when you want readable examples of how a tile helper such
+ as `tadd` or `trow_sum` is rewritten in terms of those micro instructions
+
Regenerate the current artifacts with:
```bash
diff --git a/ptodsl/micro.py b/ptodsl/micro.py
new file mode 100644
index 00000000..01b500a3
--- /dev/null
+++ b/ptodsl/micro.py
@@ -0,0 +1,6 @@
+from .api import micro as _micro
+from .api.micro import __all__
+
+
+def __getattr__(name):
+ return getattr(_micro, name)
diff --git a/tests/regression/test_public_api_surface.py b/tests/regression/test_public_api_surface.py
new file mode 100644
index 00000000..d26405e4
--- /dev/null
+++ b/tests/regression/test_public_api_surface.py
@@ -0,0 +1,97 @@
+from mlir.ir import IndexType
+
+from ptodsl import micro, pto, tile, to_ir_module
+from ptodsl.api.scalar import _unwrap
+
+
+def test_pythonic_pto_tensor_and_tile_buffer_surface_emits_expected_ir():
+ def meta_data():
+ return {
+ "ptr_t": pto.ptr(pto.float32),
+ "index_t": IndexType.get(),
+ }
+
+ @to_ir_module(meta_data=meta_data)
+ def wrapper_demo(
+ src: "ptr_t", dst: "ptr_t", valid_row: "index_t", valid_col: "index_t"
+ ) -> None:
+ src_view = pto.make_tensor(src, shape=[8, 64], dtype=pto.float32)
+ dst_view = pto.make_tensor(dst, shape=[8, 64], dtype=pto.float32)
+
+ src_tile = src_view.slice([0, 0], [8, 64])
+ dst_tile = dst_view.slice([0, 0], [8, 64])
+
+ with pto.vector_section():
+ buf = pto.make_tile_buffer(
+ pto.float32,
+ [8, 64],
+ space="VEC",
+ valid_shape=[-1, -1],
+ ).alloc(valid_row=valid_row, valid_col=valid_col)
+ buf.load_from(src_tile)
+ buf.store_to(dst_tile)
+
+ text = str(wrapper_demo)
+
+ assert "pto.make_tensor_view" in text
+ assert "pto.partition_view" in text
+ assert "!pto.tile_buf" in text
+ assert "valid_row = %arg2 valid_col = %arg3" in text
+ assert "pto.tload" in text
+ assert "pto.tstore" in text
+
+
+def test_tile_ops_accept_pythonic_tile_buffer_wrappers():
+ def meta_data():
+ return {"ptr_t": pto.ptr(pto.float32)}
+
+ @to_ir_module(meta_data=meta_data)
+ def wrapper_add(src0: "ptr_t", src1: "ptr_t", dst: "ptr_t") -> None:
+ lhs = pto.make_tensor(src0, shape=[8, 64], dtype=pto.float32)
+ rhs = pto.make_tensor(src1, shape=[8, 64], dtype=pto.float32)
+ out = pto.make_tensor(dst, shape=[8, 64], dtype=pto.float32)
+
+ with pto.vector_section():
+ lhs_buf = pto.make_tile_buffer(pto.float32, [8, 64], space="VEC").alloc()
+ rhs_buf = pto.make_tile_buffer(pto.float32, [8, 64], space="VEC").alloc()
+ out_buf = pto.make_tile_buffer(pto.float32, [8, 64], space="VEC").alloc()
+
+ lhs_buf.load_from(lhs.slice([0, 0], [8, 64]))
+ rhs_buf.load_from(rhs.slice([0, 0], [8, 64]))
+ tile.add(lhs_buf, rhs_buf, out_buf)
+ out_buf.store_to(out.slice([0, 0], [8, 64]))
+
+ text = str(wrapper_add)
+
+ assert "pto.tload" in text
+ assert "pto.tadd" in text
+ assert "pto.tstore" in text
+
+
+def test_public_micro_module_exposes_raw_micro_surface():
+ assert hasattr(micro, "vadd")
+ assert hasattr(micro, "pset_b32")
+
+ def meta_data():
+ return {
+ "ptr_t": pto.ptr(pto.float32, space="VEC"),
+ "index_t": IndexType.get(),
+ }
+
+ @to_ir_module(meta_data=meta_data)
+ def vadd_demo(
+ src0: "ptr_t", src1: "ptr_t", dst: "ptr_t", offset: "index_t"
+ ) -> None:
+ v64f32 = micro.VRegType.get(64, pto.float32)
+ mask = micro.pset_b32(micro.MaskType.get(), "PAT_ALL")
+ lhs = micro.vlds(v64f32, _unwrap(src0), _unwrap(offset))
+ rhs = micro.vlds(v64f32, _unwrap(src1), _unwrap(offset))
+ out = micro.vadd(v64f32, lhs, rhs, mask)
+ micro.vsts(out, _unwrap(dst), _unwrap(offset), mask)
+
+ text = str(vadd_demo)
+
+ assert "pto.pset_b32" in text
+ assert "pto.vlds" in text
+ assert "pto.vadd" in text
+ assert "pto.vsts" in text