From 2ae850117e29c31854b2cf7307ec062ebd74783f Mon Sep 17 00:00:00 2001 From: jimmychou <47636600+jimmychou0@users.noreply.github.com> Date: Fri, 5 Jun 2026 10:56:24 +0800 Subject: [PATCH 1/8] feat(ptodsl): add simt launch query APIs --- .../ptodsl-simt-micro-op-api-design.md | 423 ++++++++++++++++++ .../03-kernel-entry-and-subkernels.md | 38 ++ ptodsl/ptodsl/_ops.py | 114 ++++- ptodsl/ptodsl/_tracing/session.py | 59 ++- ptodsl/ptodsl/pto.py | 9 +- ptodsl/tests/test_jit_compile.py | 65 +++ 6 files changed, 699 insertions(+), 9 deletions(-) create mode 100644 docs/designs/ptodsl-simt-micro-op-api-design.md diff --git a/docs/designs/ptodsl-simt-micro-op-api-design.md b/docs/designs/ptodsl-simt-micro-op-api-design.md new file mode 100644 index 000000000..7bb6806c5 --- /dev/null +++ b/docs/designs/ptodsl-simt-micro-op-api-design.md @@ -0,0 +1,423 @@ +# PTO-DSL SIMT Micro-op API Design + +## 1. Scope + +This document records the PTO-DSL frontend design plan for the SIMT micro-op +surface that is already supported by VPTO on `main`. + +The design is intentionally frontend-first: + +- expose Python PTO-DSL wrappers for existing VPTO SIMT operations; +- keep wrapper names and parameters close to VPTO IR; +- avoid backend changes unless the frontend generates valid IR that the + backend incorrectly rejects; +- document open questions before changing lowering, verifiers, or backend + passes. + +The first implementation batch focuses on SIMT launch and query operations. +Later batches are listed for context so the API direction stays consistent. + +## 2. References + +- SIMT ISA documentation: `docs/isa/micro-isa/17-simt.md` +- VPTO operation definitions: `include/PTO/IR/VPTOOps.td` +- VPTO verifier behavior: `lib/PTO/IR/VPTO.cpp` +- Existing PTO-DSL operation wrappers: `ptodsl/ptodsl/_ops.py` +- Existing PTO-DSL subkernel lowering: `ptodsl/ptodsl/_subkernels.py` +- Existing PTO-DSL tracing session: `ptodsl/ptodsl/_tracing/session.py` +- Existing PTO-DSL SIMT docs: `ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md` +- Existing scalar docs: `ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md` +- Existing SIMT VPTO lit tests: `test/lit/vpto/simt_*` +- Existing SIMT runtime samples: `test/vpto/cases/micro-op/simt/*` + +## 3. Current PTO-DSL State + +Current PTO-DSL already has a narrow SIMT surface: + +- `@pto.simt` decorator and `with pto.simt():` inline scope. +- `pto.store_vfsimt_info(dim_z, dim_y, dim_x)`. +- `pto.get_tid_x()`, `pto.get_tid_y()`, `pto.get_tid_z()`. +- `scalar.load(...)` and `scalar.store(...)` for plain scalar element access. + +Current `@pto.simt` helper calls lower to: + +```mlir +%dim_z = arith.constant 1 : i32 +%dim_y = arith.constant 1 : i32 +%dim_x = arith.constant 1 : i32 +pto.store_vfsimt_info %dim_z, %dim_y, %dim_x : i32, i32, i32 +func.call @simt_body(...) +``` + +That path emits a reusable helper function marked with `pto.simt_entry`, but it +does not yet expose user-controlled launch dimensions and does not use +`pto.simt_launch`. + +## 4. Full Migration Plan + +The full SIMT micro-op PTO-DSL surface can be migrated in staged batches. + +### Batch 1: Launch and Query Ops + +Expose launch configuration and nullary thread/lane query wrappers: + +- `pto.simt_launch(...)` +- `pto.store_vfsimt_info(...)` +- `pto.get_tid_x/y/z()` +- `pto.get_block_dim_x/y/z()` +- `pto.get_grid_dim_x/y/z()` +- `pto.get_block_idx_x/y/z()` +- `pto.get_veccoreid()` +- `pto.get_clock32()` +- `pto.get_clock64()` +- `pto.get_laneid()` +- `pto.get_lanemask_eq/le/lt/ge/gt()` + +### Batch 2: Lane Collective Ops + +Expose direct wrappers for: + +- `pto.vote_all/any/uni/ballot(pred)` +- `pto.shuffle_idx/up/down/bfly(value, control, *, width=32)` +- `pto.redux_add/max/min(value, *, signedness=None)` + +### Batch 3: SIMT Scalar Memory and Atomics + +Expose direct wrappers for: + +- `pto.ldg(ptr, offset=0, *, l1cache="cache", l2cache="nmfv")` +- `pto.stg(value, ptr, offset=0, *, l1cache="cache", l2cache="nmfv")` +- `pto.atomic_exch/add/sub/min/max/and/or/xor(ptr, value, *, l2cache="nmfv", signedness=None)` +- `pto.atomic_cas(ptr, compare, value, *, l2cache="nmfv", signedness=None)` + +Plain scalar memory remains available through `scalar.load(...)` and +`scalar.store(...)`. + +### Batch 4: SIMT Scalar Math, Convert, Sync, and State + +Expose direct wrappers for: + +- `pto.prmt(...)` +- `pto.mulhi(...)` +- `pto.mul_i32toi64(...)` +- `pto.absf(...)`, `pto.sqrt(...)`, `pto.exp(...)`, `pto.log(...)`, + `pto.pow(...)`, `pto.ceil(...)`, `pto.floor(...)`, `pto.rint(...)`, + `pto.round(...)`, `pto.fmin(...)`, `pto.fmax(...)`, `pto.fma(...)` +- `pto.convert(...)` +- `pto.syncthreads()`, `pto.threadfence()`, `pto.threadfence_block()` +- `pto.keep(...)`, `pto.resume(...)` + +`pto.sqrt/exp/log` are VPTO SIMT micro-ops. They are not the same API layer as +the existing `scalar.sqrt/exp/log` helpers, which currently emit generic +`math.*` operations. + +## 5. Batch 1 Detailed Design + +### 5.1 Goals + +Batch 1 should make SIMT launch dimensions and all nullary SIMT runtime queries +authorable from PTO-DSL. + +The implementation should: + +- keep micro-op names aligned with VPTO op names; +- preserve the low-level `store_vfsimt_info(dim_z, dim_y, dim_x)` order; +- add an ergonomic launch wrapper that uses the launch-site `x, y, z` order; +- preserve current `@pto.simt` helper behavior for existing code; +- avoid backend changes. + +### 5.2 Non-goals + +Batch 1 should not implement lane collectives, atomics, GM scalar cache policy, +scalar math, conversion, keep/resume, or runtime/ST coverage. + +Batch 1 should not change the semantics of `scalar.load/store`. + +### 5.3 Operation Mapping + +| PTO-DSL API | VPTO IR op | Return | +|---|---|---| +| `pto.store_vfsimt_info(dim_z, dim_y, dim_x)` | `pto.store_vfsimt_info` | `None` | +| `pto.simt_launch(body, *args, dims=(dim_x, dim_y, dim_z))` | `pto.simt_launch` or equivalent `store_vfsimt_info + func.call` | `None` | +| `pto.get_tid_x()` | `pto.get_tid_x` | `i32` | +| `pto.get_tid_y()` | `pto.get_tid_y` | `i32` | +| `pto.get_tid_z()` | `pto.get_tid_z` | `i32` | +| `pto.get_block_dim_x()` | `pto.get_block_dim_x` | `i32` | +| `pto.get_block_dim_y()` | `pto.get_block_dim_y` | `i32` | +| `pto.get_block_dim_z()` | `pto.get_block_dim_z` | `i32` | +| `pto.get_grid_dim_x()` | `pto.get_grid_dim_x` | `i32` | +| `pto.get_grid_dim_y()` | `pto.get_grid_dim_y` | `i32` | +| `pto.get_grid_dim_z()` | `pto.get_grid_dim_z` | `i32` | +| `pto.get_block_idx_x()` | `pto.get_block_idx_x` | `i32` | +| `pto.get_block_idx_y()` | `pto.get_block_idx_y` | `i32` | +| `pto.get_block_idx_z()` | `pto.get_block_idx_z` | `i32` | +| `pto.get_veccoreid()` | `pto.get_veccoreid` | `i32` | +| `pto.get_clock32()` | `pto.get_clock32` | `i32` | +| `pto.get_clock64()` | `pto.get_clock64` | `i64` | +| `pto.get_laneid()` | `pto.get_laneid` | `i32` | +| `pto.get_lanemask_eq()` | `pto.get_lanemask_eq` | `i32` | +| `pto.get_lanemask_le()` | `pto.get_lanemask_le` | `i32` | +| `pto.get_lanemask_lt()` | `pto.get_lanemask_lt` | `i32` | +| `pto.get_lanemask_ge()` | `pto.get_lanemask_ge` | `i32` | +| `pto.get_lanemask_gt()` | `pto.get_lanemask_gt` | `i32` | + +### 5.4 Launch API + +#### Signature + +```python +pto.simt_launch( + body: pto.SubkernelTemplate, + *args, + dims: tuple[int | Scalar, int | Scalar, int | Scalar] = (1, 1, 1), +) -> None +``` + +`dims` uses `(dim_x, dim_y, dim_z)` order. This matches the textual +`pto.simt_launch @body<<>>(...)` order and the common launch-site +mental model. + +The existing low-level API keeps its backend order: + +```python +pto.store_vfsimt_info(dim_z, dim_y, dim_x) -> None +``` + +This asymmetry is intentional: + +- `store_vfsimt_info` is a direct wrapper over the backend operation and should + not rename or reorder operands. +- `simt_launch` is launch-site sugar and should match the IR sugar order + `x, y, z`. + +#### Example + +```python +from ptodsl import pto, scalar + + +@pto.simt +def write_tid(dst: pto.ptr(pto.i32, pto.MemorySpace.UB)): + tid = pto.get_tid_x() + idx = scalar.index_cast(tid) + scalar.store(tid, dst, idx) + + +@pto.jit(target="a5") +def kernel(dst: pto.ptr(pto.i32, pto.MemorySpace.UB)): + pto.simt_launch(write_tid, dst, dims=(32, 1, 1)) +``` + +Expected source-level IR shape for Batch 1: + +```mlir +%dim_x = arith.constant 32 : i32 +%dim_y = arith.constant 1 : i32 +%dim_z = arith.constant 1 : i32 +pto.simt_launch @write_tid<<<%dim_x, %dim_y, %dim_z>>>(%dst) + : (!pto.ptr) -> () +``` + +Batch 1 emits VPTO `pto.simt_launch` directly. The existing backend +`vpto-expand-wrapper-ops` pass expands it to `pto.store_vfsimt_info + func.call`. + +### 5.5 `@pto.simt` Decorator Attributes + +SIMT entry functions may carry optional VPTO attributes: + +- `pto.simt_max_threads` +- `pto.simt_max_regs` + +Proposed PTO-DSL decorator extension: + +```python +@pto.simt(max_threads=256, max_regs=48) +def body(...): + ... +``` + +Lowering: + +```mlir +func.func @body(...) attributes { + pto.simt_entry, + pto.simt_max_threads = 256 : i32, + pto.simt_max_regs = 48 : i32 +} +``` + +Both decorator arguments should be optional. When omitted, PTO-DSL should emit +no explicit attributes and let backend defaults apply. + +Validation: + +- values must be Python integers known at trace time; +- values must be positive; +- these attributes must only be attached to functions that are already marked + `pto.simt_entry`. + +This extension is useful for launch-envelope documentation and resource +control, but it is not required to expose query ops. It can be implemented in +the same batch or as a small follow-up. + +### 5.6 Query API Behavior + +All query APIs are nullary wrappers and return a wrapped MLIR SSA value. + +Implementation pattern: + +```python +def get_laneid(): + return wrap_surface_value(_pto.GetLaneIdOp().result) +``` + +No Python-side context check is required for the first version. The backend +already knows which operations are legal in `pto.simt_entry` when applicable. +Adding a frontend context check can be considered later if it improves error +messages without hiding backend semantics. + +### 5.7 Type Handling for Launch Dimensions + +Launch dimensions are VPTO `i32` operands. PTO-DSL should accept: + +- Python integer literals; +- PTO scalar values that are already `i32`; +- index-like runtime values when they can be explicitly cast to `i32`. + +Proposed normalization rule: + +- Python `int` is materialized as signless `i32` constant. +- A runtime scalar with type `i32` is accepted unchanged. +- A runtime scalar with type `index` may be cast to `i32` if existing PTO-DSL + scalar casting helpers provide a clear path. +- Other types should raise a clear Python `TypeError`. + +The implementation should not silently accept `i64` or arbitrary integers by +truncation. + +### 5.8 Interaction With Existing `@pto.simt` Calls + +Current code can call a SIMT subkernel directly: + +```python +write_tid(dst) +``` + +Today that direct call lowers to launch dimensions `(1, 1, 1)`. + +To preserve compatibility, direct `SubkernelTemplate.__call__` behavior should +remain valid and keep its current default launch dimensions. `pto.simt_launch` +is the explicit launch-dimension surface for new code. + +Future ergonomic options: + +```python +write_tid.launch(dst, dims=(32, 1, 1)) +``` + +This method is not required for Batch 1. If added, it should call the same +lowering path as `pto.simt_launch(...)` and should not create a second semantic +route. + +### 5.9 Implementation Sketch + +Frontend files likely touched: + +- `ptodsl/ptodsl/_ops.py` + - add nullary query wrappers; + - add `_coerce_i32_dim(...)` helper if existing helpers are not sufficient; + - add `simt_launch(...)` wrapper or delegate to tracing runtime. +- `ptodsl/ptodsl/pto.py` + - export new wrappers. +- `ptodsl/ptodsl/_subkernels.py` + - optionally extend `simt(..., max_threads=None, max_regs=None)`. +- `ptodsl/ptodsl/_tracing/session.py` + - add a reusable lowering method for explicit SIMT launches; + - optionally attach `pto.simt_max_threads` and `pto.simt_max_regs` attrs when + creating helper functions. +- `ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md` + - document explicit `pto.simt_launch(...)` and optional decorator attrs. +- `ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md` or a new SIMT section + - document query ops if we want user-guide coverage in the same PR. +- `ptodsl/tests/support/docs_fragment_fixtures.py` + - update only if new docs snippets are executable docs-as-tests. +- `ptodsl/tests/test_jit_compile.py` + - add compile smoke tests for query wrappers and explicit launch dims. + +Backend files should not be touched for Batch 1 unless frontend-generated IR is +valid but rejected by existing VPTO code. + +### 5.10 Test Plan + +Minimum Python/frontend tests: + +1. Existing direct `@pto.simt` call still emits `pto.store_vfsimt_info` and a + single reusable `pto.simt_entry` function. +2. `pto.simt_launch(body, dst, dims=(32, 1, 1))` emits either: + - `pto.simt_launch @body<<<...>>>`, or + - an equivalent `pto.store_vfsimt_info` with dimensions reordered to + `z, y, x` followed by `func.call @body`. +3. All query wrappers compile inside a SIMT body and emit the expected op names. +4. `get_clock64()` returns an `i64` value; all other query wrappers in Batch 1 + return `i32`. +5. Invalid launch dimensions raise Python errors before backend verification + when the type is clearly unsupported. + +Suggested lit/frontend assertions: + +- `func.func @body(...) attributes {pto.simt_entry}` +- `pto.get_tid_x` +- `pto.get_block_dim_x` +- `pto.get_grid_dim_x` +- `pto.get_block_idx_x` +- `pto.get_veccoreid` +- `pto.get_clock32` +- `pto.get_clock64` +- `pto.get_laneid` +- `pto.get_lanemask_lt` +- explicit launch dimensions are present in the generated IR. + +Runtime/ST validation is not required for the first frontend API PR unless a +later implementation changes runtime behavior. + +### 5.11 Open Questions + +1. Should `pto.simt_launch(...)` directly emit VPTO `SimtLaunchOp`, or should + it lower immediately to `store_vfsimt_info + func.call` in PTO-DSL tracing? + + Batch 1 uses direct `SimtLaunchOp` emission. This matches the ISA and keeps + the frontend surface one-to-one with VPTO. Expansion remains owned by the + existing backend wrapper-expansion pass. + +2. Should direct `@pto.simt` calls remain fixed at `(1, 1, 1)` forever, or + should they accept launch dims later through a method such as + `body.launch(..., dims=(...))`? + + Batch 1 preserves current direct-call behavior. A method can be added later + as pure sugar over `pto.simt_launch(...)`. + +3. Should PTO-DSL enforce "query ops only inside `pto.simt_entry`" at Python + tracing time? + + Batch 1 relies on backend verification. A frontend context check may improve + diagnostics later, but it should not invent semantics different from VPTO. + +4. Should `@pto.simt(max_threads=..., max_regs=...)` be included in Batch 1? + + These attributes are part of the SIMT entry contract and are cheap to expose, + but they are not necessary for query wrappers. Batch 1 leaves them for a + follow-up. + +## 6. Backend Change Guardrail + +Before changing `include/PTO/IR/*`, `lib/PTO/IR/*`, or +`lib/PTO/Transforms/*` for this work, answer: + +- Is PTO-DSL generating IR that matches `docs/isa/micro-isa/17-simt.md` and + `include/PTO/IR/VPTOOps.td`? +- Does the existing backend reject that valid IR? +- Did existing VPTO lit tests already cover the intended backend behavior? +- Can the issue be fixed by wrapper normalization, tracing, docs, or tests? +- If a backend change is still needed, can it be covered by a narrow lit test? + +The default answer for Batch 1 should be no backend changes. diff --git a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md index 67b3d51c8..639a77df9 100644 --- a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md +++ b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md @@ -580,6 +580,44 @@ it efficient for per-element operations. **Invocation modes**: can be called from `@pto.jit` in either mode, or used inline with `with pto.simt():` (Section 3.4). +#### Explicit SIMT launch dimensions + +Calling a decorated SIMT helper directly uses the default launch descriptor +emitted by the tracer. Use `pto.simt_launch` when the launch dimensions must be +authored explicitly. + +```python +pto.simt_launch(body, *args, dims=(dim_x, dim_y, dim_z)) +``` + +| Parameter | Type | Description | +|-----------|------|-------------| +| `body` | `@pto.simt` function | SIMT entry body to launch | +| `*args` | PTO values | Arguments passed to the SIMT body; types must match the body signature | +| `dims` | tuple of 3 `i32`-compatible values | Launch dimensions in `x, y, z` order | + +`pto.simt_launch` follows the source-level `x, y, z` launch order. The lower +level `pto.store_vfsimt_info(dim_z, dim_y, dim_x)` wrapper is also available +for direct VPTO authoring, but its operand order follows the backend launch +descriptor order. + +#### SIMT query ops + +SIMT query ops are nullary micro-op wrappers. They return PTO scalar values +visible to the current SIMT work-item. + +| API | Return | Description | +|-----|--------|-------------| +| `pto.get_tid_x()` / `pto.get_tid_y()` / `pto.get_tid_z()` | `i32` | Current work-item coordinate | +| `pto.get_block_dim_x()` / `pto.get_block_dim_y()` / `pto.get_block_dim_z()` | `i32` | Block dimension in the selected axis | +| `pto.get_grid_dim_x()` / `pto.get_grid_dim_y()` / `pto.get_grid_dim_z()` | `i32` | Grid dimension in the selected axis | +| `pto.get_block_idx_x()` / `pto.get_block_idx_y()` / `pto.get_block_idx_z()` | `i32` | Block index in the selected axis | +| `pto.get_veccoreid()` | `i32` | Vector-core id visible to the work-item | +| `pto.get_clock32()` | `i32` | 32-bit clock sample | +| `pto.get_clock64()` | `i64` | 64-bit clock sample | +| `pto.get_laneid()` | `i32` | Physical SIMT lane id | +| `pto.get_lanemask_eq()` / `pto.get_lanemask_le()` / `pto.get_lanemask_lt()` / `pto.get_lanemask_ge()` / `pto.get_lanemask_gt()` | `i32` | Lane masks derived from the current lane id | + ## 3.4 Inline context manager syntax In addition to the decorator form, each sub-kernel unit provides a context diff --git a/ptodsl/ptodsl/_ops.py b/ptodsl/ptodsl/_ops.py index 7a7713dd7..86f23ae8c 100644 --- a/ptodsl/ptodsl/_ops.py +++ b/ptodsl/ptodsl/_ops.py @@ -4402,6 +4402,21 @@ def store_vfsimt_info(dim_z, dim_y, dim_x): ) +def simt_launch(body, *args, dims=(1, 1, 1), **kwargs): + """``pto.simt_launch`` – launch a ``@pto.simt`` helper with ``(x, y, z)`` dimensions.""" + spec = getattr(body, "spec", None) + role = getattr(spec, "role", None) + role_value = getattr(role, "value", role) + if role_value != "simt": + raise TypeError("pto.simt_launch(body, ...) expects body to be a @pto.simt-decorated function") + + body._validate_invocation(*args, **kwargs) + + from ._tracing.active import require_active_session + session = require_active_session("pto.simt_launch") + session.lower_simt_launch_subkernel(body, *args, dims=dims, **kwargs) + + def get_tid_x(): """``pto.get_tid_x`` → i32 SIMT lane X coordinate.""" return wrap_surface_value(_pto.GetTidXOp().result) @@ -4417,6 +4432,96 @@ def get_tid_z(): return wrap_surface_value(_pto.GetTidZOp().result) +def get_block_dim_x(): + """``pto.get_block_dim_x`` → i32 SIMT block X dimension.""" + return wrap_surface_value(_pto.GetBlockDimXOp().result) + + +def get_block_dim_y(): + """``pto.get_block_dim_y`` → i32 SIMT block Y dimension.""" + return wrap_surface_value(_pto.GetBlockDimYOp().result) + + +def get_block_dim_z(): + """``pto.get_block_dim_z`` → i32 SIMT block Z dimension.""" + return wrap_surface_value(_pto.GetBlockDimZOp().result) + + +def get_grid_dim_x(): + """``pto.get_grid_dim_x`` → i32 SIMT grid X dimension.""" + return wrap_surface_value(_pto.GetGridDimXOp().result) + + +def get_grid_dim_y(): + """``pto.get_grid_dim_y`` → i32 SIMT grid Y dimension.""" + return wrap_surface_value(_pto.GetGridDimYOp().result) + + +def get_grid_dim_z(): + """``pto.get_grid_dim_z`` → i32 SIMT grid Z dimension.""" + return wrap_surface_value(_pto.GetGridDimZOp().result) + + +def get_block_idx_x(): + """``pto.get_block_idx_x`` → i32 SIMT block X index.""" + return wrap_surface_value(_pto.GetBlockIdxXOp().result) + + +def get_block_idx_y(): + """``pto.get_block_idx_y`` → i32 SIMT block Y index.""" + return wrap_surface_value(_pto.GetBlockIdxYOp().result) + + +def get_block_idx_z(): + """``pto.get_block_idx_z`` → i32 SIMT block Z index.""" + return wrap_surface_value(_pto.GetBlockIdxZOp().result) + + +def get_veccoreid(): + """``pto.get_veccoreid`` → i32 SIMT vector-core id.""" + return wrap_surface_value(_pto.GetVecCoreIdOp().result) + + +def get_clock32(): + """``pto.get_clock32`` → i32 SIMT clock sample.""" + return wrap_surface_value(_pto.GetClock32Op().result) + + +def get_clock64(): + """``pto.get_clock64`` → i64 SIMT clock sample.""" + return wrap_surface_value(_pto.GetClock64Op().result) + + +def get_laneid(): + """``pto.get_laneid`` → i32 SIMT lane id.""" + return wrap_surface_value(_pto.GetLaneIdOp().result) + + +def get_lanemask_eq(): + """``pto.get_lanemask_eq`` → i32 SIMT lane equality mask.""" + return wrap_surface_value(_pto.GetLaneMaskEqOp().result) + + +def get_lanemask_le(): + """``pto.get_lanemask_le`` → i32 SIMT lane less-or-equal mask.""" + return wrap_surface_value(_pto.GetLaneMaskLeOp().result) + + +def get_lanemask_lt(): + """``pto.get_lanemask_lt`` → i32 SIMT lane less-than mask.""" + return wrap_surface_value(_pto.GetLaneMaskLtOp().result) + + +def get_lanemask_ge(): + """``pto.get_lanemask_ge`` → i32 SIMT lane greater-or-equal mask.""" + return wrap_surface_value(_pto.GetLaneMaskGeOp().result) + + +def get_lanemask_gt(): + """``pto.get_lanemask_gt`` → i32 SIMT lane greater-than mask.""" + return wrap_surface_value(_pto.GetLaneMaskGtOp().result) + + def pipe_barrier(pipe): """``pto.pipe_barrier(pipe)`` – drain the specified hardware pipeline.""" _pto.BarrierOp(_pipe_attr(pipe)) @@ -4590,7 +4695,14 @@ def import_reserved_buffer(name, *, peer_func): "mte_l0c_l1", "mte_l0c_gm", "mte_l0c_ub", "mad", "mad_acc", "mad_bias", "mad_mx", "mad_mx_acc", "mad_mx_bias", "get_block_idx", "get_block_num", "get_subblock_idx", "get_subblock_num", - "store_vfsimt_info", "get_tid_x", "get_tid_y", "get_tid_z", + "store_vfsimt_info", "simt_launch", + "get_tid_x", "get_tid_y", "get_tid_z", + "get_block_dim_x", "get_block_dim_y", "get_block_dim_z", + "get_grid_dim_x", "get_grid_dim_y", "get_grid_dim_z", + "get_block_idx_x", "get_block_idx_y", "get_block_idx_z", + "get_veccoreid", "get_clock32", "get_clock64", + "get_laneid", "get_lanemask_eq", "get_lanemask_le", "get_lanemask_lt", + "get_lanemask_ge", "get_lanemask_gt", "pipe_barrier", "get_buf", "rls_buf", "set_cross_flag", "wait_cross_flag", "set_intra_flag", "wait_intra_flag", "set_flag", "wait_flag", diff --git a/ptodsl/ptodsl/_tracing/session.py b/ptodsl/ptodsl/_tracing/session.py index 272dd1003..d00e2d907 100644 --- a/ptodsl/ptodsl/_tracing/session.py +++ b/ptodsl/ptodsl/_tracing/session.py @@ -18,10 +18,11 @@ yield_carry_loop_state, ) from .._surface_values import unwrap_surface_value, wrap_like_surface_value +from .._types import _strip_integer_signedness from mlir.dialects import arith, func from mlir.dialects import pto as _pto -from mlir.ir import InsertionPoint, IntegerType, UnitAttr +from mlir.ir import FlatSymbolRefAttr, IndexType, InsertionPoint, IntegerType, Operation, UnitAttr @dataclass(frozen=True) @@ -157,6 +158,27 @@ def finish_carry_loop(self, frame, exc_type, exc, tb): def lower_simt_helper_subkernel(self, subkernel, *args, **kwargs): """Lower one ``@pto.simt`` call through a dedicated helper function.""" + helper_fn, arg_templates = self._get_or_create_simt_helper_function(subkernel, *args, **kwargs) + + i32 = IntegerType.get_signless(32) + dim_z = arith.ConstantOp(i32, 1).result + dim_y = arith.ConstantOp(i32, 1).result + dim_x = arith.ConstantOp(i32, 1).result + _pto.StoreVfSimtInfoOp(dim_z, dim_y, dim_x) + func.CallOp(helper_fn, [unwrap_surface_value(arg) for arg in arg_templates]) + + def lower_simt_launch_subkernel(self, subkernel, *args, dims, **kwargs): + """Lower one explicit ``pto.simt_launch`` call through a SIMT helper.""" + helper_fn, arg_templates = self._get_or_create_simt_helper_function(subkernel, *args, **kwargs) + dim_x, dim_y, dim_z = _coerce_simt_launch_dims(dims) + Operation.create( + "pto.simt_launch", + attributes={"callee": FlatSymbolRefAttr.get(subkernel.spec.symbol_name)}, + operands=[dim_x, dim_y, dim_z, *[unwrap_surface_value(arg) for arg in arg_templates]], + ) + + def _get_or_create_simt_helper_function(self, subkernel, *args, **kwargs): + """Return the reusable ``pto.simt_entry`` helper for *subkernel*.""" outer_frame = self.current_subkernel if outer_frame is not None and outer_frame.role == "simt": raise RuntimeError("@pto.simt helper lowering does not support nested SIMT helper calls") @@ -180,12 +202,7 @@ def lower_simt_helper_subkernel(self, subkernel, *args, **kwargs): subkernel.emit_body(*wrapped_args, **kwargs) func.ReturnOp([]) - i32 = IntegerType.get_signless(32) - dim_z = arith.ConstantOp(i32, 1).result - dim_y = arith.ConstantOp(i32, 1).result - dim_x = arith.ConstantOp(i32, 1).result - _pto.StoreVfSimtInfoOp(dim_z, dim_y, dim_x) - func.CallOp(helper_fn, [unwrap_surface_value(arg) for arg in arg_templates]) + return helper_fn, arg_templates def lookup_helper(self, symbol_name: str): """Return a previously declared helper function, or ``None``.""" @@ -218,6 +235,34 @@ def validate_final_state(self) -> None: raise RuntimeError("PTODSL trace-session exited with an open loop-carry lowering frame") +def _coerce_simt_launch_dims(dims): + if not isinstance(dims, (tuple, list)) or len(dims) != 3: + raise TypeError("pto.simt_launch(..., dims=...) expects a 3-item (dim_x, dim_y, dim_z) tuple") + return tuple( + _coerce_i32_dim(dim, context=f"pto.simt_launch(..., dims[{index}])") + for index, dim in enumerate(dims) + ) + + +def _coerce_i32_dim(value, *, context: str): + raw_value = unwrap_surface_value(value) + i32 = IntegerType.get_signless(32) + if isinstance(raw_value, bool): + raise TypeError(f"{context} does not accept bool values") + if isinstance(raw_value, int): + if raw_value < 0: + raise ValueError(f"{context} expects a non-negative i32 launch dimension, got {raw_value}") + return arith.ConstantOp(i32, raw_value).result + if IndexType.isinstance(raw_value.type): + return arith.IndexCastOp(i32, raw_value).result + if IntegerType.isinstance(raw_value.type): + width = IntegerType(raw_value.type).width + if width != 32: + raise TypeError(f"{context} expects i32 launch dimension, got {raw_value.type}") + return _strip_integer_signedness(raw_value) + raise TypeError(f"{context} expects i32 launch dimension, got {raw_value.type}") + + __all__ = [ "HelperFunctionSpec", "SubkernelTraceFrame", diff --git a/ptodsl/ptodsl/pto.py b/ptodsl/ptodsl/pto.py index 66e693a90..67d1a3f6b 100644 --- a/ptodsl/ptodsl/pto.py +++ b/ptodsl/ptodsl/pto.py @@ -111,7 +111,14 @@ mte_l0c_l1, mte_l0c_gm, mte_l0c_ub, mad, mad_acc, mad_bias, mad_mx, mad_mx_acc, mad_mx_bias, get_block_idx, get_block_num, get_subblock_idx, get_subblock_num, - store_vfsimt_info, get_tid_x, get_tid_y, get_tid_z, + store_vfsimt_info, simt_launch, + get_tid_x, get_tid_y, get_tid_z, + get_block_dim_x, get_block_dim_y, get_block_dim_z, + get_grid_dim_x, get_grid_dim_y, get_grid_dim_z, + get_block_idx_x, get_block_idx_y, get_block_idx_z, + get_veccoreid, get_clock32, get_clock64, + get_laneid, get_lanemask_eq, get_lanemask_le, get_lanemask_lt, + get_lanemask_ge, get_lanemask_gt, pipe_barrier, get_buf, rls_buf, set_cross_flag, wait_cross_flag, set_intra_flag, wait_intra_flag, diff --git a/ptodsl/tests/test_jit_compile.py b/ptodsl/tests/test_jit_compile.py index 4ba5a7ff3..b9ed63dab 100644 --- a/ptodsl/tests/test_jit_compile.py +++ b/ptodsl/tests/test_jit_compile.py @@ -400,6 +400,31 @@ def simt_tid_probe(): pto.get_tid_z() +@pto.simt +def simt_query_probe(): + pto.get_tid_x() + pto.get_tid_y() + pto.get_tid_z() + pto.get_block_dim_x() + pto.get_block_dim_y() + pto.get_block_dim_z() + pto.get_grid_dim_x() + pto.get_grid_dim_y() + pto.get_grid_dim_z() + pto.get_block_idx_x() + pto.get_block_idx_y() + pto.get_block_idx_z() + pto.get_veccoreid() + pto.get_clock32() + pto.get_clock64() + pto.get_laneid() + pto.get_lanemask_eq() + pto.get_lanemask_le() + pto.get_lanemask_lt() + pto.get_lanemask_ge() + pto.get_lanemask_gt() + + @pto.simd def ast_subkernel_runtime_for_helper(rows: pto.i32): for row in range(0, rows, 1): @@ -413,6 +438,11 @@ def simt_helper_lowering_probe(*, TRACE_TOKEN: pto.const_expr = 0): simt_tid_probe() +@pto.jit(target="a5") +def simt_explicit_launch_probe(*, TRACE_TOKEN: pto.constexpr = 0): + pto.simt_launch(simt_query_probe, dims=(32, 2, 1)) + + @pto.jit(target="a5") def ast_subkernel_runtime_for_probe(rows: pto.i32): ast_subkernel_runtime_for_helper(rows) @@ -2469,6 +2499,41 @@ def main() -> None: expect("pto.get_tid_y" in simt_text, "SIMT helper body should contain pto.get_tid_y") expect("pto.get_tid_z" in simt_text, "SIMT helper body should contain pto.get_tid_z") + simt_launch_text = simt_explicit_launch_probe.compile(TRACE_TOKEN=1).mlir_text() + expect_parse_roundtrip_and_verify(simt_launch_text, "explicit simt launch specialization") + expect( + "pto.simt_launch @simt_query_probe<<<" in simt_launch_text, + "pto.simt_launch(...) should emit VPTO simt_launch sugar", + ) + expect( + "func.func @simt_query_probe() attributes {pto.simt_entry}" in simt_launch_text, + "explicit pto.simt_launch should materialize a reusable pto.simt_entry helper", + ) + for op_name in ( + "pto.get_tid_x", + "pto.get_tid_y", + "pto.get_tid_z", + "pto.get_block_dim_x", + "pto.get_block_dim_y", + "pto.get_block_dim_z", + "pto.get_grid_dim_x", + "pto.get_grid_dim_y", + "pto.get_grid_dim_z", + "pto.get_block_idx_x", + "pto.get_block_idx_y", + "pto.get_block_idx_z", + "pto.get_veccoreid", + "pto.get_clock32", + "pto.get_clock64", + "pto.get_laneid", + "pto.get_lanemask_eq", + "pto.get_lanemask_le", + "pto.get_lanemask_lt", + "pto.get_lanemask_ge", + "pto.get_lanemask_gt", + ): + expect(op_name in simt_launch_text, f"SIMT query body should contain {op_name}") + ast_subkernel_runtime_for_text = ast_subkernel_runtime_for_probe.compile().mlir_text() expect_parse_roundtrip_and_verify( ast_subkernel_runtime_for_text, From 0a802ef31d240191823d0645c4565e7e88a03561 Mon Sep 17 00:00:00 2001 From: jimmychou <47636600+jimmychou0@users.noreply.github.com> Date: Wed, 10 Jun 2026 12:25:06 +0800 Subject: [PATCH 2/8] feat(ptodsl): add simt micro-op wrappers --- .../ptodsl-simt-micro-op-api-design.md | 15 +- .../03-kernel-entry-and-subkernels.md | 95 ++++ ptodsl/ptodsl/_ops.py | 450 ++++++++++++++++++ ptodsl/ptodsl/pto.py | 10 + ptodsl/tests/test_jit_compile.py | 186 ++++++++ 5 files changed, 755 insertions(+), 1 deletion(-) diff --git a/docs/designs/ptodsl-simt-micro-op-api-design.md b/docs/designs/ptodsl-simt-micro-op-api-design.md index 7bb6806c5..eb74d97ba 100644 --- a/docs/designs/ptodsl-simt-micro-op-api-design.md +++ b/docs/designs/ptodsl-simt-micro-op-api-design.md @@ -55,10 +55,13 @@ does not yet expose user-controlled launch dimensions and does not use ## 4. Full Migration Plan -The full SIMT micro-op PTO-DSL surface can be migrated in staged batches. +The full SIMT micro-op PTO-DSL surface is migrated in staged batches. ### Batch 1: Launch and Query Ops +Status: implemented in `ptodsl/ptodsl/_ops.py`, exported from +`ptodsl/ptodsl/pto.py`, and covered by `ptodsl/tests/test_jit_compile.py`. + Expose launch configuration and nullary thread/lane query wrappers: - `pto.simt_launch(...)` @@ -75,6 +78,9 @@ Expose launch configuration and nullary thread/lane query wrappers: ### Batch 2: Lane Collective Ops +Status: implemented as direct VPTO wrappers and covered by the full SIMT +surface compile test. + Expose direct wrappers for: - `pto.vote_all/any/uni/ballot(pred)` @@ -83,6 +89,10 @@ Expose direct wrappers for: ### Batch 3: SIMT Scalar Memory and Atomics +Status: implemented as direct VPTO wrappers. `pto.ldg`/`pto.stg` reuse the +same address-access normalization as `scalar.load`/`scalar.store`; atomics +operate on explicit pointer operands. + Expose direct wrappers for: - `pto.ldg(ptr, offset=0, *, l1cache="cache", l2cache="nmfv")` @@ -95,6 +105,9 @@ Plain scalar memory remains available through `scalar.load(...)` and ### Batch 4: SIMT Scalar Math, Convert, Sync, and State +Status: implemented as direct VPTO wrappers. `pto.keep`/`pto.resume` expose +explicit slot attributes and leave placement validation to VPTO. + Expose direct wrappers for: - `pto.prmt(...)` diff --git a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md index 639a77df9..42e3b26c4 100644 --- a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md +++ b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md @@ -618,6 +618,101 @@ visible to the current SIMT work-item. | `pto.get_laneid()` | `i32` | Physical SIMT lane id | | `pto.get_lanemask_eq()` / `pto.get_lanemask_le()` / `pto.get_lanemask_lt()` / `pto.get_lanemask_ge()` / `pto.get_lanemask_gt()` | `i32` | Lane masks derived from the current lane id | +#### SIMT lane collective ops + +These wrappers map directly to VPTO SIMT lane collective micro-ops. + +```python +pto.vote_all(pred) +pto.vote_any(pred) +pto.vote_uni(pred) +pto.vote_ballot(pred) + +pto.shuffle_idx(value, index, *, width=32) +pto.shuffle_up(value, offset, *, width=32) +pto.shuffle_down(value, offset, *, width=32) +pto.shuffle_bfly(value, mask, *, width=32) + +pto.redux_add(value, *, signedness=None) +pto.redux_max(value, *, signedness=None) +pto.redux_min(value, *, signedness=None) +``` + +`pred` must be an `i1` predicate. Shuffle control operands are coerced to +`i32`; `width` must be `16` or `32`. Integer `redux_max` and `redux_min` +require `signedness="signed"` or `signedness="unsigned"`; floating-point redux +does not accept signedness. + +#### SIMT scalar GM memory and atomic ops + +```python +pto.ldg(ptr, offset=0, *, l1cache="cache", l2cache="nmfv") +pto.stg(value, ptr, offset=0, *, l1cache="cache", l2cache="nmfv") + +pto.atomic_exch(ptr, value, *, l2cache="nmfv", signedness=None) +pto.atomic_add(ptr, value, *, l2cache="nmfv", signedness=None) +pto.atomic_sub(ptr, value, *, l2cache="nmfv", signedness=None) +pto.atomic_min(ptr, value, *, l2cache="nmfv", signedness=None) +pto.atomic_max(ptr, value, *, l2cache="nmfv", signedness=None) +pto.atomic_and(ptr, value, *, l2cache="nmfv", signedness=None) +pto.atomic_or(ptr, value, *, l2cache="nmfv", signedness=None) +pto.atomic_xor(ptr, value, *, l2cache="nmfv", signedness=None) +pto.atomic_cas(ptr, compare, value, *, l2cache="nmfv", signedness=None) +``` + +`pto.ldg` and `pto.stg` are GM scalar memory micro-ops with cache-control +clauses. Plain scalar memory remains available through `scalar.load(...)` and +`scalar.store(...)`. + +`l1cache` accepts `"cache"` or `"uncache"`. Load `l2cache` accepts the VPTO +load L2 cache tokens; store and atomic `l2cache` accept the VPTO store/atomic +L2 cache tokens. Atomic pointers must point to GM or UB scalar storage accepted +by the VPTO verifier. Integer atomics may pass `signedness`; floating-point and +packed atomics must omit it. + +#### SIMT scalar math, conversion, sync, and state ops + +```python +pto.prmt(lhs, rhs, selector) +pto.mulhi(lhs, rhs, *, signedness) +pto.mul_i32toi64(lhs, rhs, *, signedness) + +pto.absf(value) +pto.sqrt(value) +pto.exp(value) +pto.log(value) +pto.pow(lhs, rhs) +pto.ceil(value) +pto.floor(value) +pto.rint(value) +pto.round(value) +pto.fmin(lhs, rhs) +pto.fmax(lhs, rhs) +pto.fma(lhs, rhs, acc) + +pto.convert(src, dst_type, *, rounding, saturation, signedness=None) + +pto.syncthreads() +pto.threadfence() +pto.threadfence_block() +pto.keep(payload, *, slot) +pto.resume(result_type, *, slot) +``` + +`pto.sqrt`, `pto.exp`, `pto.log`, and related functions are VPTO SIMT +micro-ops. They are distinct from the generic `scalar.sqrt`, `scalar.exp`, and +`scalar.log` helpers in Chapter 6. + +`pto.convert` requires an explicit destination type plus VPTO conversion +controls. `rounding` accepts `"r"`, `"a"`, `"f"`, `"c"`, `"z"`, `"o"`, or +`"h"`. `saturation` accepts `"sat"`/`"nosat"` or `"on"`/`"off"`. +`signedness` is required when converting to or from integer types and omitted +for floating-to-floating or packed floating conversion. Integer-to-integer +conversion is not supported by `pto.convert`. + +`pto.keep` and `pto.resume` use explicit non-negative Python integer slots. +Keep/resume placement constraints are enforced by the VPTO verifier. + ## 3.4 Inline context manager syntax In addition to the decorator form, each sub-kernel unit provides a context diff --git a/ptodsl/ptodsl/_ops.py b/ptodsl/ptodsl/_ops.py index 86f23ae8c..90ae0e76d 100644 --- a/ptodsl/ptodsl/_ops.py +++ b/ptodsl/ptodsl/_ops.py @@ -46,6 +46,7 @@ emit_as_ptr, infer_tile_element_type, parse_tile_type_metadata, + resolve_address_access, unwrap_surface_value, wrap_surface_value, ) @@ -4522,6 +4523,445 @@ def get_lanemask_gt(): return wrap_surface_value(_pto.GetLaneMaskGtOp().result) +_SIGNEDNESS_TOKENS = {"signed", "unsigned"} +_L1_CACHE_TOKENS = {"cache", "uncache"} +_LD_L2_CACHE_TOKENS = { + "nmfv", "nmlv", "nmprs", "nmpref", + "nakeep", "naclean", "nadrop", + "idsfv", "idslv", "idsprs", "idspref", + "exfv", "exlv", "exprs", "expref", +} +_ST_L2_CACHE_TOKENS = { + "nmfv", "nmlv", "nmprs", "nmred", + "naci", "napw", "napi", "nared", + "wbhfv", "wbhlv", "wbhprs", "wbhred", + "wtsfv", "wtslv", "wtsprs", "wtsred", +} +_ROUNDING_TOKENS = {"r", "a", "f", "c", "z", "o", "h"} +_SATURATION_TOKENS = {"sat", "nosat"} + + +def _optional_signedness_attr(signedness, *, context: str): + if signedness is None: + return None + return _simt_enum_attr("signedness", signedness, supported=_SIGNEDNESS_TOKENS, context=context) + + +def _required_signedness_attr(signedness, *, context: str): + if signedness is None: + raise TypeError(f"{context} requires signedness='signed' or 'unsigned'") + return _optional_signedness_attr(signedness, context=context) + + +def _l1_cache_attr(value, *, context: str): + return _simt_enum_attr("l1cache", value, supported=_L1_CACHE_TOKENS, context=context) + + +def _ld_l2_cache_attr(value, *, context: str): + return _simt_enum_attr("ld_l2cache", value, supported=_LD_L2_CACHE_TOKENS, context=context) + + +def _st_l2_cache_attr(value, *, context: str): + return _simt_enum_attr("st_l2cache", value, supported=_ST_L2_CACHE_TOKENS, context=context) + + +def _rounding_attr(value, *, context: str): + return _simt_enum_attr("rounding", value, supported=_ROUNDING_TOKENS, context=context) + + +def _saturation_attr(value, *, context: str): + normalized = _normalize_token(value, context=context) + aliases = {"on": "sat", "off": "nosat", "sat": "sat", "nosat": "nosat"} + token = aliases.get(normalized) + if token is None: + expected = ", ".join(sorted((*_SATURATION_TOKENS, "on", "off"))) + raise ValueError(f"{context} does not support {value!r}; expected one of {expected}") + return _simt_enum_attr("saturation", token, supported=_SATURATION_TOKENS, context=context) + + +def _simt_enum_attr(kind, value, *, supported: set[str], context: str): + normalized = _normalize_token(value, context=context) + if normalized not in supported: + expected = ", ".join(sorted(supported)) + raise ValueError(f"{context} does not support {value!r}; expected one of {expected}") + return Attribute.parse(f"#pto.{kind}<{normalized}>") + + +def _coerce_i32_operand(value, *, context: str): + return coerce_scalar_to_type(value, IntegerType.get_signless(32), context=context) + + +def _same_type_unary(op_cls, value): + return wrap_surface_value(op_cls(unwrap_surface_value(value)).result) + + +def _same_type_binary(op_cls, lhs, rhs, *, context: str): + raw_lhs = unwrap_surface_value(lhs) + raw_rhs = coerce_scalar_to_type(rhs, raw_lhs.type, context=context) + return wrap_surface_value(op_cls(raw_lhs, raw_rhs).result) + + +def _same_type_ternary(op_cls, lhs, rhs, acc, *, context: str): + raw_lhs = unwrap_surface_value(lhs) + raw_rhs = coerce_scalar_to_type(rhs, raw_lhs.type, context=context) + raw_acc = coerce_scalar_to_type(acc, raw_lhs.type, context=context) + return wrap_surface_value(op_cls(raw_lhs, raw_rhs, raw_acc).result) + + +def _validate_redux_signedness(value_type, signedness, *, require_for_integer: bool, context: str): + if IntegerType.isinstance(value_type): + if require_for_integer and signedness is None: + raise TypeError(f"{context} requires signedness='signed' or 'unsigned' for integer values") + return + if signedness is not None: + raise TypeError(f"{context} does not accept signedness for floating-point values") + + +def _validate_integer_signedness_only(value_type, signedness, *, context: str): + if signedness is not None and not IntegerType.isinstance(value_type): + raise TypeError(f"{context} does not accept signedness for non-integer values") + + +def _validate_convert_signedness(src_type, dst_type, signedness, *, context: str): + src_int = IntegerType.isinstance(src_type) + dst_int = IntegerType.isinstance(dst_type) + if src_int and dst_int: + raise TypeError(f"{context} does not support integer-to-integer conversion") + if src_int or dst_int: + if signedness is None: + raise TypeError(f"{context} requires signedness='signed' or 'unsigned' when converting to or from integer types") + return + if signedness is not None: + raise TypeError(f"{context} does not accept signedness for floating-point or packed conversion") + + +def vote_all(pred): + """``pto.vote_all`` – SIMT all-lane predicate vote.""" + return wrap_surface_value(_pto.VoteAllOp(unwrap_surface_value(pred)).result) + + +def vote_any(pred): + """``pto.vote_any`` – SIMT any-lane predicate vote.""" + return wrap_surface_value(_pto.VoteAnyOp(unwrap_surface_value(pred)).result) + + +def vote_uni(pred): + """``pto.vote_uni`` – SIMT uniform-predicate vote.""" + return wrap_surface_value(_pto.VoteUniOp(unwrap_surface_value(pred)).result) + + +def vote_ballot(pred): + """``pto.vote_ballot`` – SIMT ballot predicate vote.""" + return wrap_surface_value(_pto.VoteBallotOp(unwrap_surface_value(pred)).result) + + +def _validate_shuffle_width(width, *, context: str): + if width not in (16, 32): + raise ValueError(f"{context} expects width to be 16 or 32, got {width}") + return width + + +def shuffle_idx(value, index, *, width=32): + """``pto.shuffle_idx`` – read a payload from an absolute SIMT lane index.""" + return wrap_surface_value(_pto.ShuffleIdxOp( + unwrap_surface_value(value), + _coerce_i32_operand(index, context="shuffle_idx(..., index)"), + width=_validate_shuffle_width(width, context="shuffle_idx(..., width)"), + ).result) + + +def shuffle_up(value, offset, *, width=32): + """``pto.shuffle_up`` – read a payload from a lower-index SIMT lane.""" + return wrap_surface_value(_pto.ShuffleUpOp( + unwrap_surface_value(value), + _coerce_i32_operand(offset, context="shuffle_up(..., offset)"), + width=_validate_shuffle_width(width, context="shuffle_up(..., width)"), + ).result) + + +def shuffle_down(value, offset, *, width=32): + """``pto.shuffle_down`` – read a payload from a higher-index SIMT lane.""" + return wrap_surface_value(_pto.ShuffleDownOp( + unwrap_surface_value(value), + _coerce_i32_operand(offset, context="shuffle_down(..., offset)"), + width=_validate_shuffle_width(width, context="shuffle_down(..., width)"), + ).result) + + +def shuffle_bfly(value, mask, *, width=32): + """``pto.shuffle_bfly`` – read a payload from a butterfly-selected SIMT lane.""" + return wrap_surface_value(_pto.ShuffleBflyOp( + unwrap_surface_value(value), + _coerce_i32_operand(mask, context="shuffle_bfly(..., mask)"), + width=_validate_shuffle_width(width, context="shuffle_bfly(..., width)"), + ).result) + + +def redux_add(value, *, signedness=None): + """``pto.redux_add`` – SIMT lane sum reduction.""" + raw_value = unwrap_surface_value(value) + _validate_redux_signedness(raw_value.type, signedness, require_for_integer=False, context="redux_add(value)") + return wrap_surface_value(_pto.ReduxAddOp( + raw_value, + signedness=_optional_signedness_attr(signedness, context="redux_add(..., signedness)"), + ).result) + + +def redux_max(value, *, signedness=None): + """``pto.redux_max`` – SIMT lane max reduction.""" + raw_value = unwrap_surface_value(value) + _validate_redux_signedness(raw_value.type, signedness, require_for_integer=True, context="redux_max(value)") + return wrap_surface_value(_pto.ReduxMaxOp( + raw_value, + signedness=_optional_signedness_attr(signedness, context="redux_max(..., signedness)"), + ).result) + + +def redux_min(value, *, signedness=None): + """``pto.redux_min`` – SIMT lane min reduction.""" + raw_value = unwrap_surface_value(value) + _validate_redux_signedness(raw_value.type, signedness, require_for_integer=True, context="redux_min(value)") + return wrap_surface_value(_pto.ReduxMinOp( + raw_value, + signedness=_optional_signedness_attr(signedness, context="redux_min(..., signedness)"), + ).result) + + +def ldg(ptr_or_ref, offset=None, *, l1cache="cache", l2cache="nmfv"): + """``pto.ldg`` – scalar GM load with cache controls.""" + buffer_value, index_value = resolve_address_access(ptr_or_ref, offset) + result_type = _pointer_element_type(buffer_value, context="ldg(ptr, offset)") + return wrap_surface_value(_pto.PTOLdgOp( + result_type, + buffer_value, + index_value, + l1cache=_l1_cache_attr(l1cache, context="ldg(..., l1cache)"), + l2cache=_ld_l2_cache_attr(l2cache, context="ldg(..., l2cache)"), + ).value) + + +def stg(value, ptr_or_ref, offset=None, *, l1cache="cache", l2cache="nmfv"): + """``pto.stg`` – scalar GM store with cache controls.""" + buffer_value, index_value = resolve_address_access(ptr_or_ref, offset) + elem_type = _pointer_element_type(buffer_value, context="stg(value, ptr, offset)") + _pto.PTOStgOp( + buffer_value, + index_value, + coerce_scalar_to_type(value, elem_type, context="stg(value, ...)"), + l1cache=_l1_cache_attr(l1cache, context="stg(..., l1cache)"), + l2cache=_st_l2_cache_attr(l2cache, context="stg(..., l2cache)"), + ) + + +def _atomic_binary(op_cls, ptr, value, *, l2cache, signedness, context: str): + raw_ptr = unwrap_surface_value(ptr) + elem_type = _pointer_element_type(raw_ptr, context=context) + _validate_integer_signedness_only(elem_type, signedness, context=context) + raw_value = coerce_scalar_to_type(value, elem_type, context=context) + return wrap_surface_value(op_cls( + raw_value.type, + raw_ptr, + raw_value, + l2cache=_st_l2_cache_attr(l2cache, context=f"{context} l2cache"), + signedness=_optional_signedness_attr(signedness, context=f"{context} signedness"), + ).old) + + +def atomic_exch(ptr, value, *, l2cache="nmfv", signedness=None): + """``pto.atomic_exch`` – SIMT scalar atomic exchange.""" + return _atomic_binary(_pto.AtomicExchOp, ptr, value, l2cache=l2cache, signedness=signedness, context="atomic_exch(ptr, value)") + + +def atomic_add(ptr, value, *, l2cache="nmfv", signedness=None): + """``pto.atomic_add`` – SIMT scalar atomic add.""" + return _atomic_binary(_pto.AtomicAddOp, ptr, value, l2cache=l2cache, signedness=signedness, context="atomic_add(ptr, value)") + + +def atomic_sub(ptr, value, *, l2cache="nmfv", signedness=None): + """``pto.atomic_sub`` – SIMT scalar atomic subtract.""" + return _atomic_binary(_pto.AtomicSubOp, ptr, value, l2cache=l2cache, signedness=signedness, context="atomic_sub(ptr, value)") + + +def atomic_min(ptr, value, *, l2cache="nmfv", signedness=None): + """``pto.atomic_min`` – SIMT scalar atomic min.""" + return _atomic_binary(_pto.AtomicMinOp, ptr, value, l2cache=l2cache, signedness=signedness, context="atomic_min(ptr, value)") + + +def atomic_max(ptr, value, *, l2cache="nmfv", signedness=None): + """``pto.atomic_max`` – SIMT scalar atomic max.""" + return _atomic_binary(_pto.AtomicMaxOp, ptr, value, l2cache=l2cache, signedness=signedness, context="atomic_max(ptr, value)") + + +def atomic_and(ptr, value, *, l2cache="nmfv", signedness=None): + """``pto.atomic_and`` – SIMT scalar atomic bitwise and.""" + return _atomic_binary(_pto.AtomicAndOp, ptr, value, l2cache=l2cache, signedness=signedness, context="atomic_and(ptr, value)") + + +def atomic_or(ptr, value, *, l2cache="nmfv", signedness=None): + """``pto.atomic_or`` – SIMT scalar atomic bitwise or.""" + return _atomic_binary(_pto.AtomicOrOp, ptr, value, l2cache=l2cache, signedness=signedness, context="atomic_or(ptr, value)") + + +def atomic_xor(ptr, value, *, l2cache="nmfv", signedness=None): + """``pto.atomic_xor`` – SIMT scalar atomic bitwise xor.""" + return _atomic_binary(_pto.AtomicXorOp, ptr, value, l2cache=l2cache, signedness=signedness, context="atomic_xor(ptr, value)") + + +def atomic_cas(ptr, compare, value, *, l2cache="nmfv", signedness=None): + """``pto.atomic_cas`` – SIMT scalar atomic compare-and-swap.""" + raw_ptr = unwrap_surface_value(ptr) + elem_type = _pointer_element_type(raw_ptr, context="atomic_cas(ptr, compare, value)") + _validate_integer_signedness_only(elem_type, signedness, context="atomic_cas(ptr, compare, value)") + raw_compare = coerce_scalar_to_type(compare, elem_type, context="atomic_cas(compare)") + raw_value = coerce_scalar_to_type(value, elem_type, context="atomic_cas(value)") + return wrap_surface_value(_pto.AtomicCasOp( + raw_ptr, + raw_compare, + raw_value, + l2cache=_st_l2_cache_attr(l2cache, context="atomic_cas(..., l2cache)"), + signedness=_optional_signedness_attr(signedness, context="atomic_cas(..., signedness)"), + ).old) + + +def prmt(lhs, rhs, selector): + """``pto.prmt`` – SIMT scalar byte permutation.""" + return wrap_surface_value(_pto.PrmtOp( + _coerce_i32_operand(lhs, context="prmt(lhs, ...)"), + _coerce_i32_operand(rhs, context="prmt(..., rhs, ...)"), + _coerce_i32_operand(selector, context="prmt(..., selector)"), + ).result) + + +def mulhi(lhs, rhs, *, signedness): + """``pto.mulhi`` – high half of an integer product.""" + raw_lhs = unwrap_surface_value(lhs) + raw_rhs = coerce_scalar_to_type(rhs, raw_lhs.type, context="mulhi(lhs, rhs)") + return wrap_surface_value(_pto.MulhiOp( + raw_lhs, + raw_rhs, + _required_signedness_attr(signedness, context="mulhi(..., signedness)"), + ).result) + + +def mul_i32toi64(lhs, rhs, *, signedness): + """``pto.mul_i32toi64`` – widened i32 product.""" + return wrap_surface_value(_pto.MulI32ToI64Op( + _coerce_i32_operand(lhs, context="mul_i32toi64(lhs, ...)"), + _coerce_i32_operand(rhs, context="mul_i32toi64(..., rhs)"), + _required_signedness_attr(signedness, context="mul_i32toi64(..., signedness)"), + ).result) + + +def absf(value): + """``pto.absf`` – SIMT floating absolute value.""" + return _same_type_unary(_pto.AbsFOp, value) + + +def sqrt(value): + """``pto.sqrt`` – SIMT floating square root.""" + return _same_type_unary(_pto.SqrtOp, value) + + +def exp(value): + """``pto.exp`` – SIMT floating exponential.""" + return _same_type_unary(_pto.ExpOp, value) + + +def log(value): + """``pto.log`` – SIMT floating natural logarithm.""" + return _same_type_unary(_pto.LogOp, value) + + +def pow(lhs, rhs): + """``pto.pow`` – SIMT floating power.""" + return _same_type_binary(_pto.PowOp, lhs, rhs, context="pow(lhs, rhs)") + + +def ceil(value): + """``pto.ceil`` – SIMT floating ceil.""" + return _same_type_unary(_pto.CeilOp, value) + + +def floor(value): + """``pto.floor`` – SIMT floating floor.""" + return _same_type_unary(_pto.FloorOp, value) + + +def rint(value): + """``pto.rint`` – SIMT floating rint.""" + return _same_type_unary(_pto.RintOp, value) + + +def round(value): + """``pto.round`` – SIMT floating round.""" + return _same_type_unary(_pto.RoundOp, value) + + +def fmin(lhs, rhs): + """``pto.fmin`` – SIMT floating minimum.""" + return _same_type_binary(_pto.FMinOp, lhs, rhs, context="fmin(lhs, rhs)") + + +def fmax(lhs, rhs): + """``pto.fmax`` – SIMT floating maximum.""" + return _same_type_binary(_pto.FMaxOp, lhs, rhs, context="fmax(lhs, rhs)") + + +def fma(lhs, rhs, acc): + """``pto.fma`` – SIMT floating fused multiply-add.""" + return _same_type_ternary(_pto.FmaOp, lhs, rhs, acc, context="fma(lhs, rhs, acc)") + + +def convert(src, dst_type, *, rounding, saturation, signedness=None): + """``pto.convert`` – SIMT scalar or packed conversion.""" + raw_src = unwrap_surface_value(src) + raw_dst_type = _resolve(dst_type) + _validate_convert_signedness(raw_src.type, raw_dst_type, signedness, context="convert(src, dst_type)") + return wrap_surface_value(_pto.ConvertOp( + raw_dst_type, + raw_src, + _rounding_attr(rounding, context="convert(..., rounding)"), + _saturation_attr(saturation, context="convert(..., saturation)"), + signedness=_optional_signedness_attr(signedness, context="convert(..., signedness)"), + ).dst) + + +def syncthreads(): + """``pto.syncthreads`` – synchronize SIMT workitems.""" + _pto.SyncthreadsOp() + + +def threadfence(): + """``pto.threadfence`` – issue a SIMT workitem memory fence.""" + _pto.ThreadfenceOp() + + +def threadfence_block(): + """``pto.threadfence_block`` – issue a SIMT block-scoped memory fence.""" + _pto.ThreadfenceBlockOp() + + +def _slot_attr_value(slot, *, context: str): + if not isinstance(slot, int) or isinstance(slot, bool): + raise TypeError(f"{context} expects a non-negative Python int slot") + if slot < 0: + raise ValueError(f"{context} expects a non-negative slot, got {slot}") + return slot + + +def keep(payload, *, slot): + """``pto.keep`` – preserve a SIMT scalar payload in an explicit slot.""" + _pto.KeepOp(unwrap_surface_value(payload), _slot_attr_value(slot, context="keep(..., slot)")) + + +def resume(result_type, *, slot): + """``pto.resume`` – restore a SIMT scalar payload from an explicit slot.""" + return wrap_surface_value(_pto.ResumeOp( + _resolve(result_type), + _slot_attr_value(slot, context="resume(..., slot)"), + ).result) + + def pipe_barrier(pipe): """``pto.pipe_barrier(pipe)`` – drain the specified hardware pipeline.""" _pto.BarrierOp(_pipe_attr(pipe)) @@ -4703,6 +5143,16 @@ def import_reserved_buffer(name, *, peer_func): "get_veccoreid", "get_clock32", "get_clock64", "get_laneid", "get_lanemask_eq", "get_lanemask_le", "get_lanemask_lt", "get_lanemask_ge", "get_lanemask_gt", + "vote_all", "vote_any", "vote_uni", "vote_ballot", + "shuffle_idx", "shuffle_up", "shuffle_down", "shuffle_bfly", + "redux_add", "redux_max", "redux_min", + "ldg", "stg", + "atomic_exch", "atomic_add", "atomic_sub", "atomic_min", "atomic_max", + "atomic_and", "atomic_or", "atomic_xor", "atomic_cas", + "prmt", "mulhi", "mul_i32toi64", + "absf", "sqrt", "exp", "log", "pow", "ceil", "floor", "rint", "round", + "fmin", "fmax", "fma", "convert", + "syncthreads", "threadfence", "threadfence_block", "keep", "resume", "pipe_barrier", "get_buf", "rls_buf", "set_cross_flag", "wait_cross_flag", "set_intra_flag", "wait_intra_flag", "set_flag", "wait_flag", diff --git a/ptodsl/ptodsl/pto.py b/ptodsl/ptodsl/pto.py index 67d1a3f6b..d6753e7f9 100644 --- a/ptodsl/ptodsl/pto.py +++ b/ptodsl/ptodsl/pto.py @@ -119,6 +119,16 @@ get_veccoreid, get_clock32, get_clock64, get_laneid, get_lanemask_eq, get_lanemask_le, get_lanemask_lt, get_lanemask_ge, get_lanemask_gt, + vote_all, vote_any, vote_uni, vote_ballot, + shuffle_idx, shuffle_up, shuffle_down, shuffle_bfly, + redux_add, redux_max, redux_min, + ldg, stg, + atomic_exch, atomic_add, atomic_sub, atomic_min, atomic_max, + atomic_and, atomic_or, atomic_xor, atomic_cas, + prmt, mulhi, mul_i32toi64, + absf, sqrt, exp, log, pow, ceil, floor, rint, round, + fmin, fmax, fma, convert, + syncthreads, threadfence, threadfence_block, keep, resume, pipe_barrier, get_buf, rls_buf, set_cross_flag, wait_cross_flag, set_intra_flag, wait_intra_flag, diff --git a/ptodsl/tests/test_jit_compile.py b/ptodsl/tests/test_jit_compile.py index b9ed63dab..192c846e2 100644 --- a/ptodsl/tests/test_jit_compile.py +++ b/ptodsl/tests/test_jit_compile.py @@ -425,6 +425,96 @@ def simt_query_probe(): pto.get_lanemask_gt() +@pto.simt +def simt_collective_math_probe(): + lane = pto.get_laneid() + pred = pto.const(1, dtype=pto.i1) + + pto.vote_all(pred) + pto.vote_any(pred) + pto.vote_uni(pred) + pto.vote_ballot(pred) + + pto.shuffle_idx(lane, lane, width=32) + pto.shuffle_up(lane, 1, width=32) + pto.shuffle_down(lane, 1, width=32) + pto.shuffle_bfly(lane, 1, width=32) + + pto.redux_add(lane, signedness="signed") + pto.redux_max(lane, signedness="signed") + pto.redux_min(lane, signedness="signed") + + pto.prmt(lane, lane, lane) + pto.mulhi(lane, lane, signedness="signed") + pto.mul_i32toi64(lane, lane, signedness="unsigned") + + as_f32 = pto.convert(lane, pto.f32, rounding="r", saturation="nosat", signedness="signed") + pto.convert(as_f32, pto.i32, rounding="z", saturation="sat", signedness="signed") + pto.absf(as_f32) + pto.sqrt(as_f32) + pto.exp(as_f32) + pto.log(as_f32) + pto.pow(as_f32, as_f32) + pto.ceil(as_f32) + pto.floor(as_f32) + pto.rint(as_f32) + pto.round(as_f32) + pto.fmin(as_f32, as_f32) + pto.fmax(as_f32, as_f32) + pto.fma(as_f32, as_f32, as_f32) + + +@pto.simt +def simt_memory_atomic_probe( + gm: pto.ptr(pto.i32, "gm"), +): + idx = scalar.index_cast(pto.get_tid_x()) + value = pto.ldg(gm, idx, l1cache="cache", l2cache="nmfv") + pto.stg(value, gm, idx, l1cache="uncache", l2cache="wtsred") + + old = pto.atomic_add(gm, value, l2cache="nmfv", signedness="signed") + pto.atomic_exch(gm, value, signedness="signed") + pto.atomic_sub(gm, value, signedness="signed") + pto.atomic_min(gm, value, signedness="signed") + pto.atomic_max(gm, value, signedness="signed") + pto.atomic_and(gm, value, signedness="unsigned") + pto.atomic_or(gm, value, signedness="unsigned") + pto.atomic_xor(gm, value, signedness="unsigned") + pto.atomic_cas(gm, old, value, signedness="signed") + + pto.syncthreads() + pto.threadfence() + pto.threadfence_block() + + +@pto.simt +def simt_keep_stage(): + pto.keep(pto.get_tid_x(), slot=0) + + +@pto.simt +def simt_resume_stage(gm: pto.ptr(pto.i32, "gm")): + resumed = pto.resume(pto.i32, slot=0) + idx = scalar.index_cast(pto.get_tid_x()) + scalar.store(resumed, gm, idx) + + +@pto.simt +def simt_invalid_redux_signedness_probe(): + pto.redux_max(pto.get_laneid()) + + +@pto.simt +def simt_invalid_convert_signedness_probe(): + pto.convert(pto.get_laneid(), pto.f32, rounding="r", saturation="nosat") + + +@pto.simt +def simt_invalid_atomic_signedness_probe(gm: pto.ptr(pto.f32, "gm")): + value = pto.ldg(gm, 0) + pto.atomic_add(gm, value, signedness="signed") + + @pto.simd def ast_subkernel_runtime_for_helper(rows: pto.i32): for row in range(0, rows, 1): @@ -443,6 +533,37 @@ def simt_explicit_launch_probe(*, TRACE_TOKEN: pto.constexpr = 0): pto.simt_launch(simt_query_probe, dims=(32, 2, 1)) +@pto.jit(target="a5") +def simt_full_surface_probe( + gm: pto.ptr(pto.i32, "gm"), + *, + TRACE_TOKEN: pto.constexpr = 0, +): + pto.simt_launch(simt_collective_math_probe, dims=(32, 1, 1)) + pto.simt_launch(simt_memory_atomic_probe, gm, dims=(32, 1, 1)) + pto.simt_launch(simt_keep_stage, dims=(32, 1, 1)) + pto.simt_launch(simt_resume_stage, gm, dims=(32, 1, 1)) + + +@pto.jit(target="a5") +def simt_invalid_redux_signedness_launch(*, TRACE_TOKEN: pto.constexpr = 0): + pto.simt_launch(simt_invalid_redux_signedness_probe, dims=(32, 1, 1)) + + +@pto.jit(target="a5") +def simt_invalid_convert_signedness_launch(*, TRACE_TOKEN: pto.constexpr = 0): + pto.simt_launch(simt_invalid_convert_signedness_probe, dims=(32, 1, 1)) + + +@pto.jit(target="a5") +def simt_invalid_atomic_signedness_launch( + gm: pto.ptr(pto.f32, "gm"), + *, + TRACE_TOKEN: pto.constexpr = 0, +): + pto.simt_launch(simt_invalid_atomic_signedness_probe, gm, dims=(32, 1, 1)) + + @pto.jit(target="a5") def ast_subkernel_runtime_for_probe(rows: pto.i32): ast_subkernel_runtime_for_helper(rows) @@ -2534,6 +2655,71 @@ def main() -> None: ): expect(op_name in simt_launch_text, f"SIMT query body should contain {op_name}") + simt_full_text = simt_full_surface_probe.compile(TRACE_TOKEN=1).mlir_text() + expect_parse_roundtrip_and_verify(simt_full_text, "full simt surface specialization") + for op_name in ( + "pto.vote_all", + "pto.vote_any", + "pto.vote_uni", + "pto.vote_ballot", + "pto.shuffle_idx", + "pto.shuffle_up", + "pto.shuffle_down", + "pto.shuffle_bfly", + "pto.redux_add", + "pto.redux_max", + "pto.redux_min", + "pto.ldg", + "pto.stg", + "pto.atomic_exch", + "pto.atomic_add", + "pto.atomic_sub", + "pto.atomic_min", + "pto.atomic_max", + "pto.atomic_and", + "pto.atomic_or", + "pto.atomic_xor", + "pto.atomic_cas", + "pto.prmt", + "pto.mulhi", + "pto.mul_i32toi64", + "pto.absf", + "pto.sqrt", + "pto.exp", + "pto.log", + "pto.pow", + "pto.ceil", + "pto.floor", + "pto.rint", + "pto.round", + "pto.fmin", + "pto.fmax", + "pto.fma", + "pto.convert", + "pto.syncthreads", + "pto.threadfence", + "pto.threadfence_block", + "pto.keep", + "pto.resume", + ): + expect(op_name in simt_full_text, f"full SIMT surface should contain {op_name}") + + expect_raises( + TypeError, + lambda: simt_invalid_redux_signedness_launch.compile(TRACE_TOKEN=1).mlir_text(), + "requires signedness", + ) + expect_raises( + TypeError, + lambda: simt_invalid_convert_signedness_launch.compile(TRACE_TOKEN=1).mlir_text(), + "requires signedness", + ) + expect_raises( + TypeError, + lambda: simt_invalid_atomic_signedness_launch.compile(TRACE_TOKEN=1).mlir_text(), + "does not accept signedness", + ) + ast_subkernel_runtime_for_text = ast_subkernel_runtime_for_probe.compile().mlir_text() expect_parse_roundtrip_and_verify( ast_subkernel_runtime_for_text, From bf02b2e98c48dccf3d31b02f924bca76552789c8 Mon Sep 17 00:00:00 2001 From: jimmychou <47636600+jimmychou0@users.noreply.github.com> Date: Thu, 11 Jun 2026 11:06:03 +0800 Subject: [PATCH 3/8] fix(ptodsl): specialize simt helpers by signature --- ptodsl/ptodsl/_tracing/session.py | 93 ++++++++++++++++++++++++++++++- ptodsl/tests/test_jit_compile.py | 65 +++++++++++++++++++-- 2 files changed, 151 insertions(+), 7 deletions(-) diff --git a/ptodsl/ptodsl/_tracing/session.py b/ptodsl/ptodsl/_tracing/session.py index d00e2d907..33ca5a37f 100644 --- a/ptodsl/ptodsl/_tracing/session.py +++ b/ptodsl/ptodsl/_tracing/session.py @@ -35,6 +35,15 @@ class HelperFunctionSpec: attributes: tuple[tuple[str, object], ...] = () +@dataclass(frozen=True) +class SimtHelperSpecializationKey: + """Cache key for one specialized ``@pto.simt`` helper body.""" + + symbol_name: str + arg_types: tuple + static_kwargs: tuple[tuple[str, object], ...] + + @dataclass(frozen=True) class SubkernelTraceFrame: """Active inline-lowering frame for one PTODSL subkernel call.""" @@ -55,6 +64,8 @@ def __init__(self, module_spec, module, entry_function): self._function_stack = [entry_function] self._function_symbol_table = entry_function.operation.parent.regions[0].blocks[0] self._helpers: dict[str, object] = {} + self._simt_helper_specializations: dict[SimtHelperSpecializationKey, object] = {} + self._simt_helper_symbol_counters: dict[str, int] = {} self._subkernel_stack: list[SubkernelTraceFrame] = [] self._carry_loop_stack = [] @@ -173,7 +184,7 @@ def lower_simt_launch_subkernel(self, subkernel, *args, dims, **kwargs): dim_x, dim_y, dim_z = _coerce_simt_launch_dims(dims) Operation.create( "pto.simt_launch", - attributes={"callee": FlatSymbolRefAttr.get(subkernel.spec.symbol_name)}, + attributes={"callee": FlatSymbolRefAttr.get(_symbol_name(helper_fn))}, operands=[dim_x, dim_y, dim_z, *[unwrap_surface_value(arg) for arg in arg_templates]], ) @@ -185,12 +196,24 @@ def _get_or_create_simt_helper_function(self, subkernel, *args, **kwargs): arg_templates = tuple(args) arg_types = tuple(unwrap_surface_value(arg).type for arg in arg_templates) - helper_spec = HelperFunctionSpec( + static_kwargs = _simt_static_kwargs_signature(kwargs) + specialization_key = SimtHelperSpecializationKey( symbol_name=subkernel.spec.symbol_name, arg_types=arg_types, + static_kwargs=static_kwargs, + ) + helper_fn = self._simt_helper_specializations.get(specialization_key) + if helper_fn is not None: + return helper_fn, arg_templates + + helper_symbol = self._next_simt_helper_symbol(subkernel.spec.symbol_name) + helper_spec = HelperFunctionSpec( + symbol_name=helper_symbol, + arg_types=arg_types, attributes=(("pto.simt_entry", UnitAttr.get()),), ) helper_fn, created = self.get_or_create_helper_function(helper_spec) + self._simt_helper_specializations[specialization_key] = helper_fn if created: entry_block = helper_fn.add_entry_block() @@ -204,6 +227,15 @@ def _get_or_create_simt_helper_function(self, subkernel, *args, **kwargs): return helper_fn, arg_templates + def _next_simt_helper_symbol(self, base_symbol: str) -> str: + index = self._simt_helper_symbol_counters.get(base_symbol, 0) + while True: + symbol = f"{base_symbol}__simt_{index}" + index += 1 + if symbol not in self._helpers: + self._simt_helper_symbol_counters[base_symbol] = index + return symbol + def lookup_helper(self, symbol_name: str): """Return a previously declared helper function, or ``None``.""" return self._helpers.get(symbol_name) @@ -263,6 +295,63 @@ def _coerce_i32_dim(value, *, context: str): raise TypeError(f"{context} expects i32 launch dimension, got {raw_value.type}") +def _symbol_name(ir_fn) -> str: + try: + name_attr = ir_fn.attributes["sym_name"] + except KeyError as exc: + raise RuntimeError("PTODSL helper function is missing sym_name") + if name_attr is None: + raise RuntimeError("PTODSL helper function has empty sym_name") + return str(name_attr.value) + + +def _simt_static_kwargs_signature(kwargs): + return tuple( + (name, _simt_static_signature_atom(value)) + for name, value in sorted(kwargs.items()) + ) + + +def _simt_static_signature_atom(value): + raw_value = unwrap_surface_value(value) + if hasattr(raw_value, "type"): + raise TypeError( + "pto.simt_launch keyword arguments must be static hashable values; " + "pass runtime SSA arguments positionally" + ) + try: + hash(value) + except TypeError: + if isinstance(value, dict): + return ( + "dict", + tuple( + sorted( + tuple( + ( + _simt_static_signature_atom(key), + _simt_static_signature_atom(item), + ) + for key, item in value.items() + ), + key=repr, + ) + ), + ) + if isinstance(value, (list, tuple)): + return ( + type(value).__name__, + tuple(_simt_static_signature_atom(item) for item in value), + ) + if isinstance(value, set): + return ( + "set", + tuple(sorted((_simt_static_signature_atom(item) for item in value), key=repr)), + ) + return (type(value).__name__, repr(value)) + return value + + __all__ = [ "HelperFunctionSpec", "SubkernelTraceFrame", diff --git a/ptodsl/tests/test_jit_compile.py b/ptodsl/tests/test_jit_compile.py index 192c846e2..154ba04f8 100644 --- a/ptodsl/tests/test_jit_compile.py +++ b/ptodsl/tests/test_jit_compile.py @@ -487,6 +487,20 @@ def simt_memory_atomic_probe( pto.threadfence_block() +@pto.simt +def simt_specialized_ptr_probe(ptr): + value = scalar.load(ptr) + _ = value + + +@pto.simt +def simt_specialized_flag_probe(*, FLAG): + if FLAG: + pto.get_tid_x() + else: + pto.get_tid_y() + + @pto.simt def simt_keep_stage(): pto.keep(pto.get_tid_x(), slot=0) @@ -545,6 +559,23 @@ def simt_full_surface_probe( pto.simt_launch(simt_resume_stage, gm, dims=(32, 1, 1)) +@pto.jit(target="a5") +def simt_specialized_arg_type_probe( + gm_i32: pto.ptr(pto.i32, "gm"), + gm_f32: pto.ptr(pto.f32, "gm"), + *, + TRACE_TOKEN: pto.constexpr = 0, +): + pto.simt_launch(simt_specialized_ptr_probe, gm_i32, dims=(32, 1, 1)) + pto.simt_launch(simt_specialized_ptr_probe, gm_f32, dims=(32, 1, 1)) + + +@pto.jit(target="a5") +def simt_specialized_static_kwarg_probe(*, TRACE_TOKEN: pto.constexpr = 0): + pto.simt_launch(simt_specialized_flag_probe, dims=(32, 1, 1), FLAG=False) + pto.simt_launch(simt_specialized_flag_probe, dims=(32, 1, 1), FLAG=True) + + @pto.jit(target="a5") def simt_invalid_redux_signedness_launch(*, TRACE_TOKEN: pto.constexpr = 0): pto.simt_launch(simt_invalid_redux_signedness_probe, dims=(32, 1, 1)) @@ -2609,11 +2640,15 @@ def main() -> None: "each @pto.simt callsite should materialize a caller-side store_vfsimt_info", ) expect( - simt_text.count("call @simt_tid_probe()") == 2, + re.search(r"call @simt_tid_probe__simt_\d+\(\)", simt_text) is not None, "each @pto.simt callsite should lower to a func.call of the helper symbol", ) expect( - simt_text.count("func.func @simt_tid_probe() attributes {pto.simt_entry}") == 1, + len(re.findall(r"call @simt_tid_probe__simt_\d+\(\)", simt_text)) == 2, + "both @pto.simt callsites should call the same helper specialization", + ) + expect( + len(re.findall(r"func\.func @simt_tid_probe__simt_\d+\(\) attributes \{pto\.simt_entry\}", simt_text)) == 1, "@pto.simt helper should materialize exactly one reusable pto.simt_entry function", ) expect("pto.get_tid_x" in simt_text, "SIMT helper body should contain pto.get_tid_x") @@ -2623,11 +2658,11 @@ def main() -> None: simt_launch_text = simt_explicit_launch_probe.compile(TRACE_TOKEN=1).mlir_text() expect_parse_roundtrip_and_verify(simt_launch_text, "explicit simt launch specialization") expect( - "pto.simt_launch @simt_query_probe<<<" in simt_launch_text, + re.search(r"pto\.simt_launch @simt_query_probe__simt_\d+<<<", simt_launch_text) is not None, "pto.simt_launch(...) should emit VPTO simt_launch sugar", ) expect( - "func.func @simt_query_probe() attributes {pto.simt_entry}" in simt_launch_text, + re.search(r"func\.func @simt_query_probe__simt_\d+\(\) attributes \{pto\.simt_entry\}", simt_launch_text) is not None, "explicit pto.simt_launch should materialize a reusable pto.simt_entry helper", ) for op_name in ( @@ -2655,6 +2690,26 @@ def main() -> None: ): expect(op_name in simt_launch_text, f"SIMT query body should contain {op_name}") + simt_arg_type_text = simt_specialized_arg_type_probe.compile(TRACE_TOKEN=1).mlir_text() + expect_parse_roundtrip_and_verify(simt_arg_type_text, "simt arg-type specialization") + expect( + len(re.findall(r"func\.func @simt_specialized_ptr_probe__simt_\d+\(", simt_arg_type_text)) == 2, + "same @pto.simt body launched with different argument types should materialize two helpers", + ) + expect( + "!pto.ptr" in simt_arg_type_text and "!pto.ptr" in simt_arg_type_text, + "SIMT argument-type specializations should preserve distinct helper pointer types", + ) + + simt_static_kwarg_text = simt_specialized_static_kwarg_probe.compile(TRACE_TOKEN=1).mlir_text() + expect_parse_roundtrip_and_verify(simt_static_kwarg_text, "simt static kwarg specialization") + expect( + len(re.findall(r"func\.func @simt_specialized_flag_probe__simt_\d+\(", simt_static_kwarg_text)) == 2, + "same @pto.simt body launched with different static kwargs should materialize two helpers", + ) + expect("pto.get_tid_x" in simt_static_kwarg_text, "FLAG=True SIMT specialization should emit get_tid_x") + expect("pto.get_tid_y" in simt_static_kwarg_text, "FLAG=False SIMT specialization should emit get_tid_y") + simt_full_text = simt_full_surface_probe.compile(TRACE_TOKEN=1).mlir_text() expect_parse_roundtrip_and_verify(simt_full_text, "full simt surface specialization") for op_name in ( @@ -3228,7 +3283,7 @@ def main() -> None: simt_pointer_offset_text = simt_pointer_offset_probe.compile().mlir_text() expect_parse_roundtrip_and_verify(simt_pointer_offset_text, "simt pointer offset specialization") expect( - "call @simt_pointer_offset_helper" in simt_pointer_offset_text, + re.search(r"call @simt_pointer_offset_helper__simt_\d+", simt_pointer_offset_text) is not None, "@pto.simt pointer helper should lower to a helper func.call", ) expect( From b38f5fd773e2e4fc38511166c246cce5723b45e6 Mon Sep 17 00:00:00 2001 From: jimmychou <47636600+jimmychou0@users.noreply.github.com> Date: Thu, 11 Jun 2026 11:18:46 +0800 Subject: [PATCH 4/8] test(ptodsl): declare pipe peer fixture explicitly --- .../tests/support/docs_fragment_fixtures.py | 38 +++++++++++++++---- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/ptodsl/tests/support/docs_fragment_fixtures.py b/ptodsl/tests/support/docs_fragment_fixtures.py index e4a7edd07..92480a7b2 100644 --- a/ptodsl/tests/support/docs_fragment_fixtures.py +++ b/ptodsl/tests/support/docs_fragment_fixtures.py @@ -1978,29 +1978,51 @@ def pipe_communication_c2v_local_declaration_probe(): ), "pipe_communication.c2v_local_import": _fixture( f""" - @pto.simt - def vector_kernel(): - c2v_buf = pto.reserve_buffer("c2v_fifo", size=8192, location="vec") + from mlir.dialects import func + from mlir.ir import InsertionPoint + from ptodsl._tracing import current_session + + + def declare_vector_kernel_peer(): + session = current_session() + fn_ty = func.FunctionType.get([], []) + with InsertionPoint(session._function_symbol_table): + peer = func.FuncOp("vector_kernel", fn_ty) + entry = peer.add_entry_block() + with session.enter_function(peer), InsertionPoint(entry): + pto.reserve_buffer("c2v_fifo", size=8192, location="vec") + func.ReturnOp([]) @pto.jit(target="a5") def pipe_communication_c2v_local_import_probe(): - vector_kernel() + declare_vector_kernel_peer() {SNIPPET_PLACEHOLDER} """ ), "pipe_communication.c2v_local_producer": _fixture( f""" - @pto.simt - def vector_kernel(): - c2v_buf = pto.reserve_buffer("c2v_fifo", size=8192, location="vec") + from mlir.dialects import func + from mlir.ir import InsertionPoint + from ptodsl._tracing import current_session + + + def declare_vector_kernel_peer(): + session = current_session() + fn_ty = func.FunctionType.get([], []) + with InsertionPoint(session._function_symbol_table): + peer = func.FuncOp("vector_kernel", fn_ty) + entry = peer.add_entry_block() + with session.enter_function(peer), InsertionPoint(entry): + pto.reserve_buffer("c2v_fifo", size=8192, location="vec") + func.ReturnOp([]) @pto.jit(target="a5") def pipe_communication_c2v_local_producer_probe( src: pto.gm_ptr(pto.f32), ): - vector_kernel() + declare_vector_kernel_peer() c2v_buf = pto.import_reserved_buffer("c2v_fifo", peer_func="vector_kernel") c2v_peer = pto.pipe.c2v( slot_size=1024, From ec88c062625b320484296af29244ca5b9be1d745 Mon Sep 17 00:00:00 2001 From: jimmychou <47636600+jimmychou0@users.noreply.github.com> Date: Thu, 11 Jun 2026 11:35:24 +0800 Subject: [PATCH 5/8] docs(ptodsl): update simt helper design --- .../ptodsl-simt-micro-op-api-design.md | 130 ++++++++++++------ 1 file changed, 89 insertions(+), 41 deletions(-) diff --git a/docs/designs/ptodsl-simt-micro-op-api-design.md b/docs/designs/ptodsl-simt-micro-op-api-design.md index eb74d97ba..727394e85 100644 --- a/docs/designs/ptodsl-simt-micro-op-api-design.md +++ b/docs/designs/ptodsl-simt-micro-op-api-design.md @@ -124,27 +124,32 @@ Expose direct wrappers for: the existing `scalar.sqrt/exp/log` helpers, which currently emit generic `math.*` operations. -## 5. Batch 1 Detailed Design +## 5. Implemented Launch and Helper Design ### 5.1 Goals -Batch 1 should make SIMT launch dimensions and all nullary SIMT runtime queries -authorable from PTO-DSL. +The implemented SIMT launch layer makes launch dimensions, SIMT helper +materialization, and SIMT runtime queries authorable from PTO-DSL. Later +micro-op batches build on the same helper-lowering path. -The implementation should: +The implementation: - keep micro-op names aligned with VPTO op names; - preserve the low-level `store_vfsimt_info(dim_z, dim_y, dim_x)` order; - add an ergonomic launch wrapper that uses the launch-site `x, y, z` order; -- preserve current `@pto.simt` helper behavior for existing code; +- preserve direct `@pto.simt` calls with default launch dimensions; +- specialize reusable SIMT helper functions by argument types and static + keyword arguments; - avoid backend changes. ### 5.2 Non-goals -Batch 1 should not implement lane collectives, atomics, GM scalar cache policy, -scalar math, conversion, keep/resume, or runtime/ST coverage. +The launch/helper layer should not implement operation-specific SIMT semantics +itself. Lane collectives, atomics, GM scalar cache policy, scalar math, +conversion, keep/resume, and validation rules are exposed as direct VPTO +wrappers in Batches 2-4. -Batch 1 should not change the semantics of `scalar.load/store`. +The launch/helper layer should not change the semantics of `scalar.load/store`. ### 5.3 Operation Mapping @@ -231,10 +236,49 @@ pto.simt_launch @write_tid<<<%dim_x, %dim_y, %dim_z>>>(%dst) : (!pto.ptr) -> () ``` -Batch 1 emits VPTO `pto.simt_launch` directly. The existing backend +PTO-DSL emits VPTO `pto.simt_launch` directly. The existing backend `vpto-expand-wrapper-ops` pass expands it to `pto.store_vfsimt_info + func.call`. -### 5.5 `@pto.simt` Decorator Attributes +### 5.5 Helper Specialization and Symbol Naming + +Each `@pto.simt` body is lowered through a generated `func.func` marked with +`pto.simt_entry`. The generated helper symbol is an implementation detail, not +the public subkernel name. PTO-DSL currently uses symbols of the form: + +```text +__simt_ +``` + +The helper specialization key includes: + +- the authored subkernel symbol name; +- positional argument MLIR types; +- static keyword argument values. + +This prevents two invalid reuse cases: + +- the same SIMT body launched with different pointer or scalar argument types; +- the same SIMT body launched with different static keyword arguments that + change the traced body. + +`pto.simt_launch(...)` must reference the actual generated helper symbol, not +the authored subkernel symbol. Direct `@pto.simt` calls also reuse the same +specialized helper path, with default launch dimensions `(1, 1, 1)`. + +Keyword arguments passed to `pto.simt_launch` are treated as static values and +must be hashable or structurally representable for the specialization key. +Runtime SSA values must be passed positionally so they become helper function +arguments. This avoids capturing values from the enclosing entry function into +the generated SIMT helper body. + +Because generated SIMT helper symbols are internal specialization names, other +APIs that require stable `func.func` symbols must not reference authored +`@pto.simt` helper names. In particular, `pto.import_reserved_buffer(peer_func=...)` +must refer to a real peer `func.func` containing the matching +`pto.reserve_buffer`, not to an authored SIMT helper whose generated symbol may +be specialized. + +### 5.6 `@pto.simt` Decorator Attributes SIMT entry functions may carry optional VPTO attributes: @@ -270,10 +314,9 @@ Validation: `pto.simt_entry`. This extension is useful for launch-envelope documentation and resource -control, but it is not required to expose query ops. It can be implemented in -the same batch or as a small follow-up. +control, but it is not currently part of the implemented surface. -### 5.6 Query API Behavior +### 5.7 Query API Behavior All query APIs are nullary wrappers and return a wrapped MLIR SSA value. @@ -289,7 +332,7 @@ already knows which operations are legal in `pto.simt_entry` when applicable. Adding a frontend context check can be considered later if it improves error messages without hiding backend semantics. -### 5.7 Type Handling for Launch Dimensions +### 5.8 Type Handling for Launch Dimensions Launch dimensions are VPTO `i32` operands. PTO-DSL should accept: @@ -308,7 +351,7 @@ Proposed normalization rule: The implementation should not silently accept `i64` or arbitrary integers by truncation. -### 5.8 Interaction With Existing `@pto.simt` Calls +### 5.9 Interaction With Existing `@pto.simt` Calls Current code can call a SIMT subkernel directly: @@ -332,40 +375,39 @@ This method is not required for Batch 1. If added, it should call the same lowering path as `pto.simt_launch(...)` and should not create a second semantic route. -### 5.9 Implementation Sketch +### 5.10 Implementation Notes -Frontend files likely touched: +Frontend files touched by the implemented surface: - `ptodsl/ptodsl/_ops.py` - - add nullary query wrappers; - - add `_coerce_i32_dim(...)` helper if existing helpers are not sufficient; - - add `simt_launch(...)` wrapper or delegate to tracing runtime. + - SIMT query, launch, collective, memory, atomic, math, convert, sync, and + state wrappers; + - enum/cache/rounding/saturation normalization for SIMT attrs. - `ptodsl/ptodsl/pto.py` - - export new wrappers. -- `ptodsl/ptodsl/_subkernels.py` - - optionally extend `simt(..., max_threads=None, max_regs=None)`. + - exported SIMT wrappers. - `ptodsl/ptodsl/_tracing/session.py` - - add a reusable lowering method for explicit SIMT launches; - - optionally attach `pto.simt_max_threads` and `pto.simt_max_regs` attrs when - creating helper functions. + - reusable helper lowering for direct SIMT calls and explicit + `pto.simt_launch`; + - SIMT helper specialization by argument types and static kwargs; + - actual helper-symbol targeting for `pto.simt_launch`. - `ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md` - - document explicit `pto.simt_launch(...)` and optional decorator attrs. -- `ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md` or a new SIMT section - - document query ops if we want user-guide coverage in the same PR. + - documented the SIMT API surface. - `ptodsl/tests/support/docs_fragment_fixtures.py` - - update only if new docs snippets are executable docs-as-tests. + - declares stable peer functions for docs snippets that use + `pto.import_reserved_buffer(peer_func=...)`, instead of relying on + generated SIMT helper symbols. - `ptodsl/tests/test_jit_compile.py` - - add compile smoke tests for query wrappers and explicit launch dims. + - compile smoke tests for launch/query wrappers, full SIMT micro-op surface, + invalid frontend argument combinations, and SIMT helper specialization. -Backend files should not be touched for Batch 1 unless frontend-generated IR is -valid but rejected by existing VPTO code. +Backend files are not touched by this PTO-DSL frontend surface. -### 5.10 Test Plan +### 5.11 Test Plan Minimum Python/frontend tests: 1. Existing direct `@pto.simt` call still emits `pto.store_vfsimt_info` and a - single reusable `pto.simt_entry` function. + reusable `pto.simt_entry` helper specialization. 2. `pto.simt_launch(body, dst, dims=(32, 1, 1))` emits either: - `pto.simt_launch @body<<<...>>>`, or - an equivalent `pto.store_vfsimt_info` with dimensions reordered to @@ -375,6 +417,12 @@ Minimum Python/frontend tests: return `i32`. 5. Invalid launch dimensions raise Python errors before backend verification when the type is clearly unsupported. +6. The same `@pto.simt` body launched with different argument types produces + distinct helper functions. +7. The same `@pto.simt` body launched with different static keyword arguments + produces distinct helper functions and distinct traced bodies. +8. `pto.simt_launch` callee attributes reference the actual generated helper + symbols. Suggested lit/frontend assertions: @@ -393,12 +441,12 @@ Suggested lit/frontend assertions: Runtime/ST validation is not required for the first frontend API PR unless a later implementation changes runtime behavior. -### 5.11 Open Questions +### 5.12 Open Questions 1. Should `pto.simt_launch(...)` directly emit VPTO `SimtLaunchOp`, or should it lower immediately to `store_vfsimt_info + func.call` in PTO-DSL tracing? - Batch 1 uses direct `SimtLaunchOp` emission. This matches the ISA and keeps + PTO-DSL uses direct `SimtLaunchOp` emission. This matches the ISA and keeps the frontend surface one-to-one with VPTO. Expansion remains owned by the existing backend wrapper-expansion pass. @@ -406,20 +454,20 @@ later implementation changes runtime behavior. should they accept launch dims later through a method such as `body.launch(..., dims=(...))`? - Batch 1 preserves current direct-call behavior. A method can be added later + PTO-DSL preserves current direct-call behavior. A method can be added later as pure sugar over `pto.simt_launch(...)`. 3. Should PTO-DSL enforce "query ops only inside `pto.simt_entry`" at Python tracing time? - Batch 1 relies on backend verification. A frontend context check may improve + PTO-DSL relies on backend verification. A frontend context check may improve diagnostics later, but it should not invent semantics different from VPTO. 4. Should `@pto.simt(max_threads=..., max_regs=...)` be included in Batch 1? These attributes are part of the SIMT entry contract and are cheap to expose, - but they are not necessary for query wrappers. Batch 1 leaves them for a - follow-up. + but they are not necessary for the current SIMT micro-op API surface. They + remain a follow-up. ## 6. Backend Change Guardrail From 6a06a8b7d005b10f45a3e743bcc640764e5751d7 Mon Sep 17 00:00:00 2001 From: jimmychou <47636600+jimmychou0@users.noreply.github.com> Date: Thu, 11 Jun 2026 16:00:59 +0800 Subject: [PATCH 6/8] feat(ptodsl): document and expose simt entry resources --- .../ptodsl-simt-micro-op-api-design.md | 71 ++----- .../03-kernel-entry-and-subkernels.md | 175 +++++++++++++++++- ptodsl/ptodsl/_subkernels.py | 77 +++++++- ptodsl/ptodsl/_tracing/session.py | 20 +- ptodsl/tests/test_jit_compile.py | 39 ++++ 5 files changed, 315 insertions(+), 67 deletions(-) diff --git a/docs/designs/ptodsl-simt-micro-op-api-design.md b/docs/designs/ptodsl-simt-micro-op-api-design.md index 727394e85..466eedfba 100644 --- a/docs/designs/ptodsl-simt-micro-op-api-design.md +++ b/docs/designs/ptodsl-simt-micro-op-api-design.md @@ -10,12 +10,11 @@ The design is intentionally frontend-first: - expose Python PTO-DSL wrappers for existing VPTO SIMT operations; - keep wrapper names and parameters close to VPTO IR; - avoid backend changes unless the frontend generates valid IR that the - backend incorrectly rejects; -- document open questions before changing lowering, verifiers, or backend - passes. + backend incorrectly rejects. -The first implementation batch focuses on SIMT launch and query operations. -Later batches are listed for context so the API direction stays consistent. +The implementation was staged in batches so the API direction stays consistent +across launch helpers, query ops, lane collectives, scalar memory, atomics, +math, conversion, sync, and state preservation. ## 2. References @@ -35,6 +34,8 @@ Later batches are listed for context so the API direction stays consistent. Current PTO-DSL already has a narrow SIMT surface: - `@pto.simt` decorator and `with pto.simt():` inline scope. +- `@pto.simt(max_threads=..., max_regs=...)` optional entry resource + attributes. - `pto.store_vfsimt_info(dim_z, dim_y, dim_x)`. - `pto.get_tid_x()`, `pto.get_tid_y()`, `pto.get_tid_z()`. - `scalar.load(...)` and `scalar.store(...)` for plain scalar element access. @@ -140,6 +141,7 @@ The implementation: - preserve direct `@pto.simt` calls with default launch dimensions; - specialize reusable SIMT helper functions by argument types and static keyword arguments; +- expose optional SIMT entry resource attributes on generated helper functions; - avoid backend changes. ### 5.2 Non-goals @@ -285,7 +287,7 @@ SIMT entry functions may carry optional VPTO attributes: - `pto.simt_max_threads` - `pto.simt_max_regs` -Proposed PTO-DSL decorator extension: +PTO-DSL exposes them through `@pto.simt`: ```python @pto.simt(max_threads=256, max_regs=48) @@ -303,8 +305,9 @@ func.func @body(...) attributes { } ``` -Both decorator arguments should be optional. When omitted, PTO-DSL should emit -no explicit attributes and let backend defaults apply. +Lowering attaches these attributes to the generated specialized helper function, +not to the authored Python symbol. Omitting either argument emits no explicit +attribute and lets backend defaults apply. Validation: @@ -312,9 +315,12 @@ Validation: - values must be positive; - these attributes must only be attached to functions that are already marked `pto.simt_entry`. +- inline `with pto.simt():` scopes do not generate `pto.simt_entry` helper + functions, so they do not accept these attributes. -This extension is useful for launch-envelope documentation and resource -control, but it is not currently part of the implemented surface. +These attributes are part of the implemented SIMT entry surface. They only +describe the resource envelope; the actual workitem count still comes from +`pto.store_vfsimt_info` or `pto.simt_launch`. ### 5.7 Query API Behavior @@ -389,6 +395,7 @@ Frontend files touched by the implemented surface: - reusable helper lowering for direct SIMT calls and explicit `pto.simt_launch`; - SIMT helper specialization by argument types and static kwargs; + - SIMT entry resource attribute emission; - actual helper-symbol targeting for `pto.simt_launch`. - `ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md` - documented the SIMT API surface. @@ -423,6 +430,8 @@ Minimum Python/frontend tests: produces distinct helper functions and distinct traced bodies. 8. `pto.simt_launch` callee attributes reference the actual generated helper symbols. +9. `@pto.simt(max_threads=..., max_regs=...)` emits `pto.simt_max_threads` and + `pto.simt_max_regs` on the generated helper function. Suggested lit/frontend assertions: @@ -440,45 +449,3 @@ Suggested lit/frontend assertions: Runtime/ST validation is not required for the first frontend API PR unless a later implementation changes runtime behavior. - -### 5.12 Open Questions - -1. Should `pto.simt_launch(...)` directly emit VPTO `SimtLaunchOp`, or should - it lower immediately to `store_vfsimt_info + func.call` in PTO-DSL tracing? - - PTO-DSL uses direct `SimtLaunchOp` emission. This matches the ISA and keeps - the frontend surface one-to-one with VPTO. Expansion remains owned by the - existing backend wrapper-expansion pass. - -2. Should direct `@pto.simt` calls remain fixed at `(1, 1, 1)` forever, or - should they accept launch dims later through a method such as - `body.launch(..., dims=(...))`? - - PTO-DSL preserves current direct-call behavior. A method can be added later - as pure sugar over `pto.simt_launch(...)`. - -3. Should PTO-DSL enforce "query ops only inside `pto.simt_entry`" at Python - tracing time? - - PTO-DSL relies on backend verification. A frontend context check may improve - diagnostics later, but it should not invent semantics different from VPTO. - -4. Should `@pto.simt(max_threads=..., max_regs=...)` be included in Batch 1? - - These attributes are part of the SIMT entry contract and are cheap to expose, - but they are not necessary for the current SIMT micro-op API surface. They - remain a follow-up. - -## 6. Backend Change Guardrail - -Before changing `include/PTO/IR/*`, `lib/PTO/IR/*`, or -`lib/PTO/Transforms/*` for this work, answer: - -- Is PTO-DSL generating IR that matches `docs/isa/micro-isa/17-simt.md` and - `include/PTO/IR/VPTOOps.td`? -- Does the existing backend reject that valid IR? -- Did existing VPTO lit tests already cover the intended backend behavior? -- Can the issue be fixed by wrapper normalization, tracing, docs, or tests? -- If a backend change is still needed, can it be covered by a narrow lit test? - -The default answer for Batch 1 should be no backend changes. diff --git a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md index 42e3b26c4..8cdf013c1 100644 --- a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md +++ b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md @@ -539,7 +539,7 @@ instruction appears to operate on a single element (`lds`, `sts`, `a + b`), but the same instruction is issued across a large number of work-items simultaneously. -**Signature**: `@pto.simt(fn=None, *, name=None, target="a5")` +**Signature**: `@pto.simt(fn=None, *, name=None, target="a5", max_threads=None, max_regs=None)` ```python @@ -573,9 +573,40 @@ def blend_output_rows( scalar.store(o_next, o_next_tile[row, col]) ``` -SIMT kernels read and write individual scalar elements from tiles. The unit -executes the same scalar instruction across many work-items in parallel, making -it efficient for per-element operations. +SIMT kernels read and write individual scalar elements from tiles or typed +pointers. The unit executes the same scalar instruction across many work-items +in parallel, making it efficient for per-element operations. + +#### SIMT resource attributes + +Optional `max_threads` and `max_regs` arguments attach VPTO resource attributes +to the generated `pto.simt_entry` helper. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `max_threads` | positive Python `int` | backend default `1024` | Compile-time launch envelope for this SIMT helper | +| `max_regs` | positive Python `int` | backend default `32` | Scalar register budget per work-item | + +`max_threads` is not the launch size. The actual work-item count comes from +`pto.simt_launch(..., dims=(dim_x, dim_y, dim_z))`; `max_threads` should cover +the largest `dim_x * dim_y * dim_z` used for that helper. Both arguments must +be Python integers known at trace time, must be greater than zero, and must fit +in signless `i32`. `bool` values are rejected. These arguments are only valid +on decorated SIMT helper functions, not inline `with pto.simt():` scopes. + + +```python +@pto.simt(max_threads=256, max_regs=48) +def write_tid(dst: pto.ptr(pto.i32, "gm")): + tid = pto.get_tid_x() + idx = scalar.index_cast(tid) + pto.stg(tid, dst, idx) + + +@pto.jit(target="a5") +def kernel_entry_simt_resource_probe(dst: pto.ptr(pto.i32, "gm")): + pto.simt_launch(write_tid, dst, dims=(128, 1, 1)) +``` **Invocation modes**: can be called from `@pto.jit` in either mode, or used inline with `with pto.simt():` (Section 3.4). @@ -601,6 +632,16 @@ level `pto.store_vfsimt_info(dim_z, dim_y, dim_x)` wrapper is also available for direct VPTO authoring, but its operand order follows the backend launch descriptor order. + +```python +@pto.jit(target="a5") +def kernel_entry_simt_store_info_probe(): + dim_z = pto.const(1, dtype=pto.i32) + dim_y = pto.const(1, dtype=pto.i32) + dim_x = pto.const(32, dtype=pto.i32) + pto.store_vfsimt_info(dim_z, dim_y, dim_x) +``` + #### SIMT query ops SIMT query ops are nullary micro-op wrappers. They return PTO scalar values @@ -618,6 +659,39 @@ visible to the current SIMT work-item. | `pto.get_laneid()` | `i32` | Physical SIMT lane id | | `pto.get_lanemask_eq()` / `pto.get_lanemask_le()` / `pto.get_lanemask_lt()` / `pto.get_lanemask_ge()` / `pto.get_lanemask_gt()` | `i32` | Lane masks derived from the current lane id | + +```python +@pto.simt +def capture_query_state(dst: pto.ptr(pto.i32, "gm")): + tid_x = pto.get_tid_x() + pto.get_tid_y() + pto.get_tid_z() + pto.get_block_dim_x() + pto.get_block_dim_y() + pto.get_block_dim_z() + pto.get_grid_dim_x() + pto.get_grid_dim_y() + pto.get_grid_dim_z() + pto.get_block_idx_x() + pto.get_block_idx_y() + pto.get_block_idx_z() + pto.get_veccoreid() + pto.get_clock32() + pto.get_clock64() + lane = pto.get_laneid() + pto.get_lanemask_eq() + pto.get_lanemask_le() + pto.get_lanemask_lt() + pto.get_lanemask_ge() + pto.get_lanemask_gt() + pto.stg(tid_x, dst, scalar.index_cast(lane)) + + +@pto.jit(target="a5") +def kernel_entry_simt_query_probe(dst: pto.ptr(pto.i32, "gm")): + pto.simt_launch(capture_query_state, dst, dims=(32, 1, 1)) +``` + #### SIMT lane collective ops These wrappers map directly to VPTO SIMT lane collective micro-ops. @@ -643,6 +717,30 @@ pto.redux_min(value, *, signedness=None) require `signedness="signed"` or `signedness="unsigned"`; floating-point redux does not accept signedness. + +```python +@pto.simt +def reduce_lane_value(dst: pto.ptr(pto.i32, "gm")): + pred = pto.const(1, dtype=pto.i1) + lane = pto.get_laneid() + + pto.vote_all(pred) + pto.vote_any(pred) + pto.vote_uni(pred) + pto.vote_ballot(pred) + + value = pto.shuffle_bfly(lane, 1, width=32) + total = pto.redux_add(value, signedness="signed") + maximum = pto.redux_max(total, signedness="signed") + minimum = pto.redux_min(maximum, signedness="signed") + pto.stg(minimum, dst, scalar.index_cast(lane)) + + +@pto.jit(target="a5") +def kernel_entry_simt_collective_probe(dst: pto.ptr(pto.i32, "gm")): + pto.simt_launch(reduce_lane_value, dst, dims=(32, 1, 1)) +``` + #### SIMT scalar GM memory and atomic ops ```python @@ -664,12 +762,32 @@ pto.atomic_cas(ptr, compare, value, *, l2cache="nmfv", signedness=None) clauses. Plain scalar memory remains available through `scalar.load(...)` and `scalar.store(...)`. -`l1cache` accepts `"cache"` or `"uncache"`. Load `l2cache` accepts the VPTO -load L2 cache tokens; store and atomic `l2cache` accept the VPTO store/atomic -L2 cache tokens. Atomic pointers must point to GM or UB scalar storage accepted -by the VPTO verifier. Integer atomics may pass `signedness`; floating-point and +`l1cache` accepts `"cache"` or `"uncache"`. Load `l2cache` accepts `"nmfv"`, +`"nmlv"`, `"nmprs"`, `"nmpref"`, `"nakeep"`, `"naclean"`, `"nadrop"`, +`"idsfv"`, `"idslv"`, `"idsprs"`, `"idspref"`, `"exfv"`, `"exlv"`, `"exprs"`, +or `"expref"`. Store and atomic `l2cache` accepts `"nmfv"`, `"nmlv"`, +`"nmprs"`, `"nmred"`, `"naci"`, `"napw"`, `"napi"`, `"nared"`, `"wbhfv"`, +`"wbhlv"`, `"wbhprs"`, `"wbhred"`, `"wtsfv"`, `"wtslv"`, `"wtsprs"`, or +`"wtsred"`. Atomic pointers must point to GM or UB scalar storage accepted by +the VPTO verifier. Integer atomics may pass `signedness`; floating-point and packed atomics must omit it. + +```python +@pto.simt +def update_counter(counter: pto.ptr(pto.i32, "gm")): + tid = pto.get_tid_x() + idx = scalar.index_cast(tid) + value = pto.ldg(counter, idx, l1cache="cache", l2cache="nmfv") + old = pto.atomic_add(counter, value, l2cache="nmfv", signedness="signed") + pto.stg(old, counter, idx, l1cache="uncache", l2cache="wtsred") + + +@pto.jit(target="a5") +def kernel_entry_simt_memory_atomic_probe(counter: pto.ptr(pto.i32, "gm")): + pto.simt_launch(update_counter, counter, dims=(32, 1, 1)) +``` + #### SIMT scalar math, conversion, sync, and state ops ```python @@ -713,6 +831,47 @@ conversion is not supported by `pto.convert`. `pto.keep` and `pto.resume` use explicit non-negative Python integer slots. Keep/resume placement constraints are enforced by the VPTO verifier. + +```python +@pto.simt +def save_lane_state(): + pto.keep(pto.get_tid_x(), slot=0) + + +@pto.simt +def transform_lane_state(dst: pto.ptr(pto.f32, "gm")): + lane = pto.resume(pto.i32, slot=0) + permuted = pto.prmt(lane, lane, lane) + high = pto.mulhi(permuted, lane, signedness="unsigned") + product = pto.mul_i32toi64(lane, lane, signedness="unsigned") + _ = high + _ = product + + value = pto.convert( + lane, + pto.f32, + rounding="r", + saturation="nosat", + signedness="unsigned", + ) + root = pto.sqrt(pto.absf(value)) + powered = pto.pow(root, root) + rounded = pto.round(pto.rint(pto.floor(pto.ceil(powered)))) + bounded = pto.fmin(pto.fmax(value, root), rounded) + accum = pto.fma(bounded, pto.exp(value), pto.log(pto.fmax(value, root))) + + pto.syncthreads() + pto.threadfence() + pto.threadfence_block() + pto.stg(accum, dst, scalar.index_cast(lane)) + + +@pto.jit(target="a5") +def kernel_entry_simt_math_state_probe(dst: pto.ptr(pto.f32, "gm")): + pto.simt_launch(save_lane_state, dims=(32, 1, 1)) + pto.simt_launch(transform_lane_state, dst, dims=(32, 1, 1)) +``` + ## 3.4 Inline context manager syntax In addition to the decorator form, each sub-kernel unit provides a context diff --git a/ptodsl/ptodsl/_subkernels.py b/ptodsl/ptodsl/_subkernels.py index 58fd804dc..1b89547e9 100644 --- a/ptodsl/ptodsl/_subkernels.py +++ b/ptodsl/ptodsl/_subkernels.py @@ -40,6 +40,8 @@ class SubkernelSpec: role: KernelRole symbol_name: str target: str = "a5" + simt_max_threads: int | None = None + simt_max_regs: int | None = None class SubkernelTemplate: @@ -140,11 +142,15 @@ def __init__( name: str | None = None, target: str = "a5", ast_rewrite: bool = True, + simt_max_threads: int | None = None, + simt_max_regs: int | None = None, ): self._role = role self._name = name self._target = target self._ast_rewrite = ast_rewrite + self._simt_max_threads = simt_max_threads + self._simt_max_regs = simt_max_regs self._session_cm = None def __call__(self, fn): @@ -153,12 +159,18 @@ def __call__(self, fn): role=self._role, symbol_name=self._name or fn.__name__, target=self._target, + simt_max_threads=self._simt_max_threads, + simt_max_regs=self._simt_max_regs, ), fn, ast_rewrite=self._ast_rewrite, ) def __enter__(self): + if self._role == KernelRole.SIMT and ( + self._simt_max_threads is not None or self._simt_max_regs is not None + ): + raise TypeError("@pto.simt(max_threads=..., max_regs=...) is only supported as a function decorator") runtime = current_runtime() if runtime is None: raise RuntimeError( @@ -190,8 +202,17 @@ def _subkernel_decorator( name: str | None = None, target: str = "a5", ast_rewrite: bool = True, + simt_max_threads: int | None = None, + simt_max_regs: int | None = None, ): - return _SubkernelSurface(role, name=name, target=target, ast_rewrite=ast_rewrite) + return _SubkernelSurface( + role, + name=name, + target=target, + ast_rewrite=ast_rewrite, + simt_max_threads=simt_max_threads, + simt_max_regs=simt_max_regs, + ) def _decorate_subkernel( @@ -201,10 +222,26 @@ def _decorate_subkernel( name: str | None = None, target: str = "a5", ast_rewrite: bool = True, + simt_max_threads: int | None = None, + simt_max_regs: int | None = None, ): if fn is not None: - return _subkernel_decorator(role, name=name, target=target, ast_rewrite=ast_rewrite)(fn) - return _subkernel_decorator(role, name=name, target=target, ast_rewrite=ast_rewrite) + return _subkernel_decorator( + role, + name=name, + target=target, + ast_rewrite=ast_rewrite, + simt_max_threads=simt_max_threads, + simt_max_regs=simt_max_regs, + )(fn) + return _subkernel_decorator( + role, + name=name, + target=target, + ast_rewrite=ast_rewrite, + simt_max_threads=simt_max_threads, + simt_max_regs=simt_max_regs, + ) def cube(fn=None, *, name: str | None = None, target: str = "a5", ast_rewrite: bool = True): @@ -215,8 +252,38 @@ def simd(fn=None, *, name: str | None = None, target: str = "a5", ast_rewrite: b return _decorate_subkernel(KernelRole.SIMD, fn, name=name, target=target, ast_rewrite=ast_rewrite) -def simt(fn=None, *, name: str | None = None, target: str = "a5", ast_rewrite: bool = True): - return _decorate_subkernel(KernelRole.SIMT, fn, name=name, target=target, ast_rewrite=ast_rewrite) +def _validate_simt_resource_attr(name: str, value: int | None) -> int | None: + if value is None: + return None + if isinstance(value, bool) or not isinstance(value, int): + raise TypeError(f"@pto.simt(..., {name}=...) expects a positive Python int") + if value <= 0: + raise ValueError(f"@pto.simt(..., {name}=...) expects a positive Python int") + if value > 2**31 - 1: + raise ValueError(f"@pto.simt(..., {name}=...) must fit in signless i32") + return value + + +def simt( + fn=None, + *, + name: str | None = None, + target: str = "a5", + ast_rewrite: bool = True, + max_threads: int | None = None, + max_regs: int | None = None, +): + max_threads = _validate_simt_resource_attr("max_threads", max_threads) + max_regs = _validate_simt_resource_attr("max_regs", max_regs) + return _decorate_subkernel( + KernelRole.SIMT, + fn, + name=name, + target=target, + ast_rewrite=ast_rewrite, + simt_max_threads=max_threads, + simt_max_regs=max_regs, + ) __all__ = [ diff --git a/ptodsl/ptodsl/_tracing/session.py b/ptodsl/ptodsl/_tracing/session.py index 33ca5a37f..f8344af24 100644 --- a/ptodsl/ptodsl/_tracing/session.py +++ b/ptodsl/ptodsl/_tracing/session.py @@ -22,7 +22,7 @@ from mlir.dialects import arith, func from mlir.dialects import pto as _pto -from mlir.ir import FlatSymbolRefAttr, IndexType, InsertionPoint, IntegerType, Operation, UnitAttr +from mlir.ir import FlatSymbolRefAttr, IndexType, InsertionPoint, IntegerAttr, IntegerType, Operation, UnitAttr @dataclass(frozen=True) @@ -207,10 +207,26 @@ def _get_or_create_simt_helper_function(self, subkernel, *args, **kwargs): return helper_fn, arg_templates helper_symbol = self._next_simt_helper_symbol(subkernel.spec.symbol_name) + helper_attributes = [("pto.simt_entry", UnitAttr.get())] + i32_attr_type = IntegerType.get_signless(32) + if subkernel.spec.simt_max_threads is not None: + helper_attributes.append( + ( + "pto.simt_max_threads", + IntegerAttr.get(i32_attr_type, subkernel.spec.simt_max_threads), + ) + ) + if subkernel.spec.simt_max_regs is not None: + helper_attributes.append( + ( + "pto.simt_max_regs", + IntegerAttr.get(i32_attr_type, subkernel.spec.simt_max_regs), + ) + ) helper_spec = HelperFunctionSpec( symbol_name=helper_symbol, arg_types=arg_types, - attributes=(("pto.simt_entry", UnitAttr.get()),), + attributes=tuple(helper_attributes), ) helper_fn, created = self.get_or_create_helper_function(helper_spec) self._simt_helper_specializations[specialization_key] = helper_fn diff --git a/ptodsl/tests/test_jit_compile.py b/ptodsl/tests/test_jit_compile.py index 154ba04f8..ff53a35c7 100644 --- a/ptodsl/tests/test_jit_compile.py +++ b/ptodsl/tests/test_jit_compile.py @@ -425,6 +425,11 @@ def simt_query_probe(): pto.get_lanemask_gt() +@pto.simt(max_threads=256, max_regs=48) +def simt_resource_attr_probe(): + pto.get_tid_x() + + @pto.simt def simt_collective_math_probe(): lane = pto.get_laneid() @@ -547,6 +552,11 @@ def simt_explicit_launch_probe(*, TRACE_TOKEN: pto.constexpr = 0): pto.simt_launch(simt_query_probe, dims=(32, 2, 1)) +@pto.jit(target="a5") +def simt_resource_attr_launch_probe(*, TRACE_TOKEN: pto.constexpr = 0): + pto.simt_launch(simt_resource_attr_probe, dims=(128, 1, 1)) + + @pto.jit(target="a5") def simt_full_surface_probe( gm: pto.ptr(pto.i32, "gm"), @@ -2665,6 +2675,35 @@ def main() -> None: re.search(r"func\.func @simt_query_probe__simt_\d+\(\) attributes \{pto\.simt_entry\}", simt_launch_text) is not None, "explicit pto.simt_launch should materialize a reusable pto.simt_entry helper", ) + simt_resource_attr_text = simt_resource_attr_launch_probe.compile(TRACE_TOKEN=1).mlir_text() + expect_parse_roundtrip_and_verify(simt_resource_attr_text, "simt resource attr launch specialization") + expect( + re.search( + r"func\.func @simt_resource_attr_probe__simt_\d+\(\) attributes \{pto\.simt_entry, pto\.simt_max_regs = 48 : i32, pto\.simt_max_threads = 256 : i32\}", + simt_resource_attr_text, + ) is not None, + "@pto.simt(max_threads=..., max_regs=...) should attach resource attrs to the helper function", + ) + expect_raises( + ValueError, + lambda: pto.simt(max_threads=0)(lambda: None), + "max_threads", + ) + expect_raises( + TypeError, + lambda: pto.simt(max_regs=True)(lambda: None), + "max_regs", + ) + + def _enter_inline_simt_with_resource_attr(): + with pto.simt(max_threads=256): + pass + + expect_raises( + TypeError, + _enter_inline_simt_with_resource_attr, + "function decorator", + ) for op_name in ( "pto.get_tid_x", "pto.get_tid_y", From 6dfd047c366cddc4e2d459c004ba45e82bc8dfa5 Mon Sep 17 00:00:00 2001 From: jimmychou <47636600+jimmychou0@users.noreply.github.com> Date: Thu, 11 Jun 2026 17:11:15 +0800 Subject: [PATCH 7/8] fix(ptodsl): resolve simt reserved-buffer peer symbols --- .../ptodsl-simt-micro-op-api-design.md | 18 ++++---- .../03-kernel-entry-and-subkernels.md | 11 +++++ ptodsl/ptodsl/_ops.py | 14 ++++-- ptodsl/ptodsl/_tracing/session.py | 21 +++++++++ ptodsl/tests/test_jit_compile.py | 45 +++++++++++++++++++ 5 files changed, 98 insertions(+), 11 deletions(-) diff --git a/docs/designs/ptodsl-simt-micro-op-api-design.md b/docs/designs/ptodsl-simt-micro-op-api-design.md index 466eedfba..ba0ef216c 100644 --- a/docs/designs/ptodsl-simt-micro-op-api-design.md +++ b/docs/designs/ptodsl-simt-micro-op-api-design.md @@ -273,12 +273,14 @@ Runtime SSA values must be passed positionally so they become helper function arguments. This avoids capturing values from the enclosing entry function into the generated SIMT helper body. -Because generated SIMT helper symbols are internal specialization names, other -APIs that require stable `func.func` symbols must not reference authored -`@pto.simt` helper names. In particular, `pto.import_reserved_buffer(peer_func=...)` -must refer to a real peer `func.func` containing the matching -`pto.reserve_buffer`, not to an authored SIMT helper whose generated symbol may -be specialized. +Because generated SIMT helper symbols are internal specialization names, APIs +that require stable `func.func` symbols must resolve authored `@pto.simt` +helpers to a materialized helper symbol before emitting IR. In particular, +`pto.import_reserved_buffer(peer_func=simt_helper)` resolves to the generated +helper symbol when the helper has exactly one materialized specialization. If +the helper has not been called/launched yet, or if multiple specializations +exist, PTO-DSL raises a frontend error and requires the caller to pass an +explicit peer function symbol. ### 5.6 `@pto.simt` Decorator Attributes @@ -401,8 +403,8 @@ Frontend files touched by the implemented surface: - documented the SIMT API surface. - `ptodsl/tests/support/docs_fragment_fixtures.py` - declares stable peer functions for docs snippets that use - `pto.import_reserved_buffer(peer_func=...)`, instead of relying on - generated SIMT helper symbols. + `pto.import_reserved_buffer(peer_func=...)` where the snippet needs a + fixed peer symbol independent of SIMT helper specialization. - `ptodsl/tests/test_jit_compile.py` - compile smoke tests for launch/query wrappers, full SIMT micro-op surface, invalid frontend argument combinations, and SIMT helper specialization. diff --git a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md index 8cdf013c1..e8086d1c4 100644 --- a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md +++ b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md @@ -729,6 +729,9 @@ def reduce_lane_value(dst: pto.ptr(pto.i32, "gm")): pto.vote_uni(pred) pto.vote_ballot(pred) + pto.shuffle_idx(lane, lane, width=32) + pto.shuffle_up(lane, 1, width=32) + pto.shuffle_down(lane, 1, width=32) value = pto.shuffle_bfly(lane, 1, width=32) total = pto.redux_add(value, signedness="signed") maximum = pto.redux_max(total, signedness="signed") @@ -780,6 +783,14 @@ def update_counter(counter: pto.ptr(pto.i32, "gm")): idx = scalar.index_cast(tid) value = pto.ldg(counter, idx, l1cache="cache", l2cache="nmfv") old = pto.atomic_add(counter, value, l2cache="nmfv", signedness="signed") + pto.atomic_exch(counter, value, signedness="signed") + pto.atomic_sub(counter, value, signedness="signed") + pto.atomic_min(counter, value, signedness="signed") + pto.atomic_max(counter, value, signedness="signed") + pto.atomic_and(counter, value, signedness="unsigned") + pto.atomic_or(counter, value, signedness="unsigned") + pto.atomic_xor(counter, value, signedness="unsigned") + pto.atomic_cas(counter, old, value, signedness="signed") pto.stg(old, counter, idx, l1cache="uncache", l2cache="wtsred") diff --git a/ptodsl/ptodsl/_ops.py b/ptodsl/ptodsl/_ops.py index 90ae0e76d..b8bcf0466 100644 --- a/ptodsl/ptodsl/_ops.py +++ b/ptodsl/ptodsl/_ops.py @@ -5078,9 +5078,17 @@ def reserve_buffer(name, *, size, location, auto=True, base=None): def import_reserved_buffer(name, *, peer_func): """``pto.import_reserved_buffer(name, peer_func=...)``.""" if not isinstance(peer_func, str): - peer_func = getattr(getattr(peer_func, "spec", None), "symbol_name", None) \ - or getattr(peer_func, "__name__", None) \ - or str(peer_func) + spec = getattr(peer_func, "spec", None) + role = getattr(spec, "role", None) + role_value = getattr(role, "value", role) + if role_value == "simt": + from ._tracing.active import require_active_session + session = require_active_session("pto.import_reserved_buffer") + peer_func = session.resolve_simt_peer_symbol(peer_func) + else: + peer_func = getattr(spec, "symbol_name", None) \ + or getattr(peer_func, "__name__", None) \ + or str(peer_func) op = _pto.ImportReservedBufferOp(name, peer_func) return wrap_surface_value(op.result) diff --git a/ptodsl/ptodsl/_tracing/session.py b/ptodsl/ptodsl/_tracing/session.py index f8344af24..2eacb7714 100644 --- a/ptodsl/ptodsl/_tracing/session.py +++ b/ptodsl/ptodsl/_tracing/session.py @@ -252,6 +252,27 @@ def _next_simt_helper_symbol(self, base_symbol: str) -> str: self._simt_helper_symbol_counters[base_symbol] = index return symbol + def resolve_simt_peer_symbol(self, subkernel) -> str: + """Return the unique materialized helper symbol for a ``@pto.simt`` peer.""" + symbol_name = subkernel.spec.symbol_name + matches = [ + helper_fn + for key, helper_fn in self._simt_helper_specializations.items() + if key.symbol_name == symbol_name + ] + if not matches: + raise RuntimeError( + f"pto.import_reserved_buffer(..., peer_func={symbol_name}) cannot resolve " + "the @pto.simt helper symbol before the helper is called or launched" + ) + if len(matches) > 1: + raise RuntimeError( + f"pto.import_reserved_buffer(..., peer_func={symbol_name}) is ambiguous " + "because the @pto.simt helper has multiple specializations; pass the " + "materialized peer function symbol explicitly" + ) + return _symbol_name(matches[0]) + def lookup_helper(self, symbol_name: str): """Return a previously declared helper function, or ``None``.""" return self._helpers.get(symbol_name) diff --git a/ptodsl/tests/test_jit_compile.py b/ptodsl/tests/test_jit_compile.py index ff53a35c7..6c92ab052 100644 --- a/ptodsl/tests/test_jit_compile.py +++ b/ptodsl/tests/test_jit_compile.py @@ -1232,6 +1232,17 @@ def simt_pointer_offset_helper(meta_ptr: pto.ptr(pto.i32, pto.MemorySpace.UB)): scalar.store(9, meta_ptr + 1) +@pto.simt +def simt_reserved_buffer_peer(): + pto.reserve_buffer("simt_c2v_fifo", size=8192, location="vec") + + +@pto.simt +def simt_reserved_buffer_ambiguous_peer(ptr): + _ = ptr + pto.reserve_buffer("simt_c2v_fifo", size=8192, location="vec") + + @pto.jit(target="a5") def simt_pointer_offset_probe(): meta_tile = pto.alloc_tile(shape=[1, 8], dtype=pto.i32, valid_shape=[1, 2]) @@ -1242,6 +1253,23 @@ def simt_pointer_offset_probe(): _ = second +@pto.jit(target="a5") +def simt_reserved_buffer_peer_probe(): + simt_reserved_buffer_peer() + imported = pto.import_reserved_buffer("simt_c2v_fifo", peer_func=simt_reserved_buffer_peer) + _ = imported + + +@pto.jit(target="a5") +def simt_reserved_buffer_ambiguous_peer_probe( + gm_i32: pto.ptr(pto.i32, "gm"), + gm_f32: pto.ptr(pto.f32, "gm"), +): + simt_reserved_buffer_ambiguous_peer(gm_i32) + simt_reserved_buffer_ambiguous_peer(gm_f32) + pto.import_reserved_buffer("simt_c2v_fifo", peer_func=simt_reserved_buffer_ambiguous_peer) + + @pto.jit(target="a5") def scalar_store_element_coercion_probe(): meta_tile = pto.alloc_tile(shape=[1, 8], dtype=pto.i32, valid_shape=[1, 4]) @@ -3333,6 +3361,23 @@ def _enter_inline_simt_with_resource_attr(): re.search(r"pto\.load %\d+\[%c1(?:_\d+)?\]", simt_pointer_offset_text) is not None, "@pto.simt pointer helper probe should preserve ptr+offset load syntax on the caller side", ) + simt_reserved_buffer_peer_text = simt_reserved_buffer_peer_probe.compile().mlir_text() + expect_parse_roundtrip_and_verify( + simt_reserved_buffer_peer_text, + "simt reserved-buffer peer specialization", + ) + expect( + re.search( + r"pto\.import_reserved_buffer\{[^}]*peer_func = @simt_reserved_buffer_peer__simt_\d+", + simt_reserved_buffer_peer_text, + ) is not None, + "import_reserved_buffer(peer_func=@pto.simt helper) should reference the materialized helper symbol", + ) + expect_raises( + RuntimeError, + lambda: simt_reserved_buffer_ambiguous_peer_probe.compile().mlir_text(), + "multiple specializations", + ) scalar_store_coercion_text = scalar_store_element_coercion_probe.compile().mlir_text() expect_parse_roundtrip_and_verify(scalar_store_coercion_text, "scalar store coercion specialization") From 2aa540ec2b666d8b4e8a8a16b549589bd64a86c6 Mon Sep 17 00:00:00 2001 From: jimmychou <47636600+jimmychou0@users.noreply.github.com> Date: Fri, 12 Jun 2026 10:11:04 +0800 Subject: [PATCH 8/8] Add PTODSL SIMT micro-op surface --- ptodsl/docs/user_guide/01-introduction.md | 4 +- .../03-kernel-entry-and-subkernels.md | 288 +-------- ptodsl/docs/user_guide/13-simt-micro-ops.md | 585 ++++++++++++++++++ ptodsl/ptodsl/_ops.py | 21 +- ptodsl/ptodsl/_subkernels.py | 25 + ptodsl/ptodsl/pto.py | 6 +- ptodsl/tests/test_jit_compile.py | 95 ++- ptodsl/tests/test_ptoas_frontend_verify.py | 41 ++ test/dsl-st/simt_gm_memory_core.py | 74 +++ 9 files changed, 857 insertions(+), 282 deletions(-) create mode 100644 ptodsl/docs/user_guide/13-simt-micro-ops.md create mode 100644 test/dsl-st/simt_gm_memory_core.py diff --git a/ptodsl/docs/user_guide/01-introduction.md b/ptodsl/docs/user_guide/01-introduction.md index 7a6e6c3ee..8f5b07349 100644 --- a/ptodsl/docs/user_guide/01-introduction.md +++ b/ptodsl/docs/user_guide/01-introduction.md @@ -223,7 +223,7 @@ Chapter 11 walks through this example in full detail. | New to PTODSL | Chapter 2 (Quick Start), then Chapter 3 (Kernel Entries) | | Writing your first kernel | Chapter 2 → Chapter 4 (Type System) → Chapter 5 (Control Flow) | -| Looking up a specific operation | Chapters 6–10 (organized by topic) | +| Looking up a specific operation | Chapters 6–10 and Chapter 13 (organized by topic) | | Understanding the flash attention reference | Chapter 11 | **Chapter overview:** @@ -242,5 +242,5 @@ Chapter 11 walks through this example in full detail. | 10 | Synchronization: barriers, flags, memory fences | | 11 | Flash attention walkthrough | | 12 | Additional examples | -| 13 | Migration from the old `@pto.vkernel`/`@pto.ckernel` API | +| 13 | SIMT micro-ops | | 14 | Common errors and compatibility notes | diff --git a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md index e8086d1c4..723fdf9bd 100644 --- a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md +++ b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md @@ -582,17 +582,21 @@ in parallel, making it efficient for per-element operations. Optional `max_threads` and `max_regs` arguments attach VPTO resource attributes to the generated `pto.simt_entry` helper. +**Signature**: `@pto.simt(fn=None, *, name=None, target="a5", max_threads=None, max_regs=None)` + +**Parameters**: + | Parameter | Type | Default | Description | |-----------|------|---------|-------------| | `max_threads` | positive Python `int` | backend default `1024` | Compile-time launch envelope for this SIMT helper | | `max_regs` | positive Python `int` | backend default `32` | Scalar register budget per work-item | -`max_threads` is not the launch size. The actual work-item count comes from -`pto.simt_launch(..., dims=(dim_x, dim_y, dim_z))`; `max_threads` should cover -the largest `dim_x * dim_y * dim_z` used for that helper. Both arguments must -be Python integers known at trace time, must be greater than zero, and must fit -in signless `i32`. `bool` values are rejected. These arguments are only valid -on decorated SIMT helper functions, not inline `with pto.simt():` scopes. +`max_threads` is not the launch size. The actual work-item count comes from the +SIMT launch dimensions. Both arguments must be Python integers known at trace +time, greater than zero, and fit in signless `i32`. They are only valid on +decorated SIMT helper functions, not inline `with pto.simt():` scopes. + +**Example**: ```python @@ -605,7 +609,7 @@ def write_tid(dst: pto.ptr(pto.i32, "gm")): @pto.jit(target="a5") def kernel_entry_simt_resource_probe(dst: pto.ptr(pto.i32, "gm")): - pto.simt_launch(write_tid, dst, dims=(128, 1, 1)) + write_tid[128, 1, 1](dst) ``` **Invocation modes**: can be called from `@pto.jit` in either mode, or used @@ -614,274 +618,44 @@ inline with `with pto.simt():` (Section 3.4). #### Explicit SIMT launch dimensions Calling a decorated SIMT helper directly uses the default launch descriptor -emitted by the tracer. Use `pto.simt_launch` when the launch dimensions must be -authored explicitly. +emitted by the tracer. Use indexed launch syntax when the launch dimensions must +be authored explicitly. `pto.simt_launch(...)` is the equivalent functional +form. + +**Signatures**: ```python -pto.simt_launch(body, *args, dims=(dim_x, dim_y, dim_z)) +body[dim_x, dim_y, dim_z](*args, **static_kwargs) +pto.simt_launch(body, *args, dims=(dim_x, dim_y, dim_z), **static_kwargs) ``` +**Parameters**: + | Parameter | Type | Description | |-----------|------|-------------| | `body` | `@pto.simt` function | SIMT entry body to launch | -| `*args` | PTO values | Arguments passed to the SIMT body; types must match the body signature | -| `dims` | tuple of 3 `i32`-compatible values | Launch dimensions in `x, y, z` order | - -`pto.simt_launch` follows the source-level `x, y, z` launch order. The lower -level `pto.store_vfsimt_info(dim_z, dim_y, dim_x)` wrapper is also available -for direct VPTO authoring, but its operand order follows the backend launch -descriptor order. - - -```python -@pto.jit(target="a5") -def kernel_entry_simt_store_info_probe(): - dim_z = pto.const(1, dtype=pto.i32) - dim_y = pto.const(1, dtype=pto.i32) - dim_x = pto.const(32, dtype=pto.i32) - pto.store_vfsimt_info(dim_z, dim_y, dim_x) -``` - -#### SIMT query ops - -SIMT query ops are nullary micro-op wrappers. They return PTO scalar values -visible to the current SIMT work-item. - -| API | Return | Description | -|-----|--------|-------------| -| `pto.get_tid_x()` / `pto.get_tid_y()` / `pto.get_tid_z()` | `i32` | Current work-item coordinate | -| `pto.get_block_dim_x()` / `pto.get_block_dim_y()` / `pto.get_block_dim_z()` | `i32` | Block dimension in the selected axis | -| `pto.get_grid_dim_x()` / `pto.get_grid_dim_y()` / `pto.get_grid_dim_z()` | `i32` | Grid dimension in the selected axis | -| `pto.get_block_idx_x()` / `pto.get_block_idx_y()` / `pto.get_block_idx_z()` | `i32` | Block index in the selected axis | -| `pto.get_veccoreid()` | `i32` | Vector-core id visible to the work-item | -| `pto.get_clock32()` | `i32` | 32-bit clock sample | -| `pto.get_clock64()` | `i64` | 64-bit clock sample | -| `pto.get_laneid()` | `i32` | Physical SIMT lane id | -| `pto.get_lanemask_eq()` / `pto.get_lanemask_le()` / `pto.get_lanemask_lt()` / `pto.get_lanemask_ge()` / `pto.get_lanemask_gt()` | `i32` | Lane masks derived from the current lane id | +| `*args` | PTO values | Runtime arguments passed to the SIMT body | +| `dim_x`, `dim_y`, `dim_z` | `i32`-compatible values | Launch dimensions in source-level `x, y, z` order | +| `**static_kwargs` | hashable Python values | Trace-time specialization arguments for the SIMT body | - -```python -@pto.simt -def capture_query_state(dst: pto.ptr(pto.i32, "gm")): - tid_x = pto.get_tid_x() - pto.get_tid_y() - pto.get_tid_z() - pto.get_block_dim_x() - pto.get_block_dim_y() - pto.get_block_dim_z() - pto.get_grid_dim_x() - pto.get_grid_dim_y() - pto.get_grid_dim_z() - pto.get_block_idx_x() - pto.get_block_idx_y() - pto.get_block_idx_z() - pto.get_veccoreid() - pto.get_clock32() - pto.get_clock64() - lane = pto.get_laneid() - pto.get_lanemask_eq() - pto.get_lanemask_le() - pto.get_lanemask_lt() - pto.get_lanemask_ge() - pto.get_lanemask_gt() - pto.stg(tid_x, dst, scalar.index_cast(lane)) - - -@pto.jit(target="a5") -def kernel_entry_simt_query_probe(dst: pto.ptr(pto.i32, "gm")): - pto.simt_launch(capture_query_state, dst, dims=(32, 1, 1)) -``` - -#### SIMT lane collective ops +**Returns**: None. -These wrappers map directly to VPTO SIMT lane collective micro-ops. +**Example**: -```python -pto.vote_all(pred) -pto.vote_any(pred) -pto.vote_uni(pred) -pto.vote_ballot(pred) - -pto.shuffle_idx(value, index, *, width=32) -pto.shuffle_up(value, offset, *, width=32) -pto.shuffle_down(value, offset, *, width=32) -pto.shuffle_bfly(value, mask, *, width=32) - -pto.redux_add(value, *, signedness=None) -pto.redux_max(value, *, signedness=None) -pto.redux_min(value, *, signedness=None) -``` - -`pred` must be an `i1` predicate. Shuffle control operands are coerced to -`i32`; `width` must be `16` or `32`. Integer `redux_max` and `redux_min` -require `signedness="signed"` or `signedness="unsigned"`; floating-point redux -does not accept signedness. - - + ```python @pto.simt -def reduce_lane_value(dst: pto.ptr(pto.i32, "gm")): - pred = pto.const(1, dtype=pto.i1) - lane = pto.get_laneid() - - pto.vote_all(pred) - pto.vote_any(pred) - pto.vote_uni(pred) - pto.vote_ballot(pred) - - pto.shuffle_idx(lane, lane, width=32) - pto.shuffle_up(lane, 1, width=32) - pto.shuffle_down(lane, 1, width=32) - value = pto.shuffle_bfly(lane, 1, width=32) - total = pto.redux_add(value, signedness="signed") - maximum = pto.redux_max(total, signedness="signed") - minimum = pto.redux_min(maximum, signedness="signed") - pto.stg(minimum, dst, scalar.index_cast(lane)) - - -@pto.jit(target="a5") -def kernel_entry_simt_collective_probe(dst: pto.ptr(pto.i32, "gm")): - pto.simt_launch(reduce_lane_value, dst, dims=(32, 1, 1)) -``` - -#### SIMT scalar GM memory and atomic ops - -```python -pto.ldg(ptr, offset=0, *, l1cache="cache", l2cache="nmfv") -pto.stg(value, ptr, offset=0, *, l1cache="cache", l2cache="nmfv") - -pto.atomic_exch(ptr, value, *, l2cache="nmfv", signedness=None) -pto.atomic_add(ptr, value, *, l2cache="nmfv", signedness=None) -pto.atomic_sub(ptr, value, *, l2cache="nmfv", signedness=None) -pto.atomic_min(ptr, value, *, l2cache="nmfv", signedness=None) -pto.atomic_max(ptr, value, *, l2cache="nmfv", signedness=None) -pto.atomic_and(ptr, value, *, l2cache="nmfv", signedness=None) -pto.atomic_or(ptr, value, *, l2cache="nmfv", signedness=None) -pto.atomic_xor(ptr, value, *, l2cache="nmfv", signedness=None) -pto.atomic_cas(ptr, compare, value, *, l2cache="nmfv", signedness=None) -``` - -`pto.ldg` and `pto.stg` are GM scalar memory micro-ops with cache-control -clauses. Plain scalar memory remains available through `scalar.load(...)` and -`scalar.store(...)`. - -`l1cache` accepts `"cache"` or `"uncache"`. Load `l2cache` accepts `"nmfv"`, -`"nmlv"`, `"nmprs"`, `"nmpref"`, `"nakeep"`, `"naclean"`, `"nadrop"`, -`"idsfv"`, `"idslv"`, `"idsprs"`, `"idspref"`, `"exfv"`, `"exlv"`, `"exprs"`, -or `"expref"`. Store and atomic `l2cache` accepts `"nmfv"`, `"nmlv"`, -`"nmprs"`, `"nmred"`, `"naci"`, `"napw"`, `"napi"`, `"nared"`, `"wbhfv"`, -`"wbhlv"`, `"wbhprs"`, `"wbhred"`, `"wtsfv"`, `"wtslv"`, `"wtsprs"`, or -`"wtsred"`. Atomic pointers must point to GM or UB scalar storage accepted by -the VPTO verifier. Integer atomics may pass `signedness`; floating-point and -packed atomics must omit it. - - -```python -@pto.simt -def update_counter(counter: pto.ptr(pto.i32, "gm")): +def fill_tid(dst: pto.ptr(pto.i32, "gm")): tid = pto.get_tid_x() - idx = scalar.index_cast(tid) - value = pto.ldg(counter, idx, l1cache="cache", l2cache="nmfv") - old = pto.atomic_add(counter, value, l2cache="nmfv", signedness="signed") - pto.atomic_exch(counter, value, signedness="signed") - pto.atomic_sub(counter, value, signedness="signed") - pto.atomic_min(counter, value, signedness="signed") - pto.atomic_max(counter, value, signedness="signed") - pto.atomic_and(counter, value, signedness="unsigned") - pto.atomic_or(counter, value, signedness="unsigned") - pto.atomic_xor(counter, value, signedness="unsigned") - pto.atomic_cas(counter, old, value, signedness="signed") - pto.stg(old, counter, idx, l1cache="uncache", l2cache="wtsred") + pto.stg(tid, dst, scalar.index_cast(tid)) @pto.jit(target="a5") -def kernel_entry_simt_memory_atomic_probe(counter: pto.ptr(pto.i32, "gm")): - pto.simt_launch(update_counter, counter, dims=(32, 1, 1)) +def kernel_entry_simt_launch_probe(dst: pto.ptr(pto.i32, "gm")): + fill_tid[32, 1, 1](dst) ``` -#### SIMT scalar math, conversion, sync, and state ops - -```python -pto.prmt(lhs, rhs, selector) -pto.mulhi(lhs, rhs, *, signedness) -pto.mul_i32toi64(lhs, rhs, *, signedness) - -pto.absf(value) -pto.sqrt(value) -pto.exp(value) -pto.log(value) -pto.pow(lhs, rhs) -pto.ceil(value) -pto.floor(value) -pto.rint(value) -pto.round(value) -pto.fmin(lhs, rhs) -pto.fmax(lhs, rhs) -pto.fma(lhs, rhs, acc) - -pto.convert(src, dst_type, *, rounding, saturation, signedness=None) - -pto.syncthreads() -pto.threadfence() -pto.threadfence_block() -pto.keep(payload, *, slot) -pto.resume(result_type, *, slot) -``` - -`pto.sqrt`, `pto.exp`, `pto.log`, and related functions are VPTO SIMT -micro-ops. They are distinct from the generic `scalar.sqrt`, `scalar.exp`, and -`scalar.log` helpers in Chapter 6. - -`pto.convert` requires an explicit destination type plus VPTO conversion -controls. `rounding` accepts `"r"`, `"a"`, `"f"`, `"c"`, `"z"`, `"o"`, or -`"h"`. `saturation` accepts `"sat"`/`"nosat"` or `"on"`/`"off"`. -`signedness` is required when converting to or from integer types and omitted -for floating-to-floating or packed floating conversion. Integer-to-integer -conversion is not supported by `pto.convert`. - -`pto.keep` and `pto.resume` use explicit non-negative Python integer slots. -Keep/resume placement constraints are enforced by the VPTO verifier. - - -```python -@pto.simt -def save_lane_state(): - pto.keep(pto.get_tid_x(), slot=0) - - -@pto.simt -def transform_lane_state(dst: pto.ptr(pto.f32, "gm")): - lane = pto.resume(pto.i32, slot=0) - permuted = pto.prmt(lane, lane, lane) - high = pto.mulhi(permuted, lane, signedness="unsigned") - product = pto.mul_i32toi64(lane, lane, signedness="unsigned") - _ = high - _ = product - - value = pto.convert( - lane, - pto.f32, - rounding="r", - saturation="nosat", - signedness="unsigned", - ) - root = pto.sqrt(pto.absf(value)) - powered = pto.pow(root, root) - rounded = pto.round(pto.rint(pto.floor(pto.ceil(powered)))) - bounded = pto.fmin(pto.fmax(value, root), rounded) - accum = pto.fma(bounded, pto.exp(value), pto.log(pto.fmax(value, root))) - - pto.syncthreads() - pto.threadfence() - pto.threadfence_block() - pto.stg(accum, dst, scalar.index_cast(lane)) - - -@pto.jit(target="a5") -def kernel_entry_simt_math_state_probe(dst: pto.ptr(pto.f32, "gm")): - pto.simt_launch(save_lane_state, dims=(32, 1, 1)) - pto.simt_launch(transform_lane_state, dst, dims=(32, 1, 1)) -``` +Specific SIMT micro-op APIs are documented in Chapter 13. ## 3.4 Inline context manager syntax diff --git a/ptodsl/docs/user_guide/13-simt-micro-ops.md b/ptodsl/docs/user_guide/13-simt-micro-ops.md new file mode 100644 index 000000000..928cffeba --- /dev/null +++ b/ptodsl/docs/user_guide/13-simt-micro-ops.md @@ -0,0 +1,585 @@ +# 13. SIMT Micro-ops + +Chapter 3 introduces `@pto.simt` helpers and launch syntax. This chapter covers +the SIMT micro-op API surface used inside those helpers. These wrappers map to +VPTO SIMT operations and operate on PTO scalar values, typed pointers, and +scalar values loaded from tiles. + +## 13.1 Launch descriptor + +#### `pto.store_vfsimt_info(dim_z, dim_y, dim_x) -> None` + +**Description**: Emits the low-level VPTO launch descriptor operation. Most +code should use `body[dim_x, dim_y, dim_z](...)` or `pto.simt_launch(...)` +instead. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `dim_z` | `i32`-compatible value | Launch dimension in Z | +| `dim_y` | `i32`-compatible value | Launch dimension in Y | +| `dim_x` | `i32`-compatible value | Launch dimension in X | + +**Returns**: None. + +**Example**: + + +```python +@pto.jit(target="a5") +def simt_ops_store_info_probe(): + dim_z = pto.const(1, dtype=pto.i32) + dim_y = pto.const(1, dtype=pto.i32) + dim_x = pto.const(32, dtype=pto.i32) + pto.store_vfsimt_info(dim_z, dim_y, dim_x) +``` + +## 13.2 Query ops + +#### `pto.get_tid() -> tuple[pto.i32, pto.i32, pto.i32]` +#### `pto.get_tid_x() -> pto.i32` +#### `pto.get_tid_y() -> pto.i32` +#### `pto.get_tid_z() -> pto.i32` + +**Description**: Returns the current SIMT work-item coordinate. The grouped +form returns `(x, y, z)` and lowers through the three axis-specific micro-ops. + +**Parameters**: None. + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `x`, `y`, `z` | `pto.i32` | Work-item coordinates | + +--- + +#### `pto.get_block_dim() -> tuple[pto.i32, pto.i32, pto.i32]` +#### `pto.get_block_dim_x() -> pto.i32` +#### `pto.get_block_dim_y() -> pto.i32` +#### `pto.get_block_dim_z() -> pto.i32` + +**Description**: Returns SIMT block dimensions. The grouped form returns +`(x, y, z)`. + +**Parameters**: None. + +**Returns**: `pto.i32` for axis-specific forms, or a tuple of three `pto.i32` +values for `pto.get_block_dim()`. + +--- + +#### `pto.get_grid_dim() -> tuple[pto.i32, pto.i32, pto.i32]` +#### `pto.get_grid_dim_x() -> pto.i32` +#### `pto.get_grid_dim_y() -> pto.i32` +#### `pto.get_grid_dim_z() -> pto.i32` + +**Description**: Returns SIMT grid dimensions. The grouped form returns +`(x, y, z)`. + +**Parameters**: None. + +**Returns**: `pto.i32` for axis-specific forms, or a tuple of three `pto.i32` +values for `pto.get_grid_dim()`. + +--- + +#### `pto.get_block_idx_x() -> pto.i32` +#### `pto.get_block_idx_y() -> pto.i32` +#### `pto.get_block_idx_z() -> pto.i32` + +**Description**: Returns the current SIMT block index in the selected axis. + +**Parameters**: None. + +**Returns**: `pto.i32`. + +--- + +#### `pto.get_veccoreid() -> pto.i32` +#### `pto.get_clock32() -> pto.i32` +#### `pto.get_clock64() -> pto.i64` +#### `pto.get_laneid() -> pto.i32` +#### `pto.get_lanemask_eq() -> pto.i32` +#### `pto.get_lanemask_le() -> pto.i32` +#### `pto.get_lanemask_lt() -> pto.i32` +#### `pto.get_lanemask_ge() -> pto.i32` +#### `pto.get_lanemask_gt() -> pto.i32` + +**Description**: Returns SIMT execution state: vector-core id, clock samples, +lane id, or lane masks derived from the current lane id. + +**Parameters**: None. + +**Returns**: `pto.get_clock64()` returns `pto.i64`; the other query ops return +`pto.i32`. + +**Example**: + + +```python +@pto.simt +def capture_query_state(dst: pto.ptr(pto.i32, "gm")): + tid_x, tid_y, tid_z = pto.get_tid() + block_x, block_y, block_z = pto.get_block_dim() + grid_x, grid_y, grid_z = pto.get_grid_dim() + pto.get_block_idx_x() + pto.get_block_idx_y() + pto.get_block_idx_z() + pto.get_veccoreid() + pto.get_clock32() + pto.get_clock64() + lane = pto.get_laneid() + pto.get_lanemask_eq() + pto.get_lanemask_le() + pto.get_lanemask_lt() + pto.get_lanemask_ge() + pto.get_lanemask_gt() + value = ( + tid_x + tid_y + tid_z + + block_x + block_y + block_z + + grid_x + grid_y + grid_z + ) + pto.stg(value, dst, scalar.index_cast(lane)) + + +@pto.jit(target="a5") +def simt_ops_query_probe(dst: pto.ptr(pto.i32, "gm")): + capture_query_state[32, 1, 1](dst) +``` + +## 13.3 Lane collective ops + +#### `pto.vote_all(pred: pto.i1) -> pto.i1` +#### `pto.vote_any(pred: pto.i1) -> pto.i1` +#### `pto.vote_uni(pred: pto.i1) -> pto.i1` +#### `pto.vote_ballot(pred: pto.i1) -> pto.i32` + +**Description**: Performs a SIMT lane vote over an `i1` predicate. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `pred` | `pto.i1` | Per-lane predicate | + +**Returns**: `pto.i1` for `vote_all`, `vote_any`, and `vote_uni`; `pto.i32` +for `vote_ballot`. + +**Example**: + + +```python +@pto.simt +def vote_probe(dst: pto.ptr(pto.i32, "gm")): + lane = pto.get_laneid() + pred = lane < pto.const(16, dtype=pto.i32) + ballot = pto.vote_ballot(pred) + all_pred = pto.vote_all(pred) + any_pred = pto.vote_any(pred) + uni_pred = pto.vote_uni(pred) + value = ballot + all_pred + any_pred + uni_pred + pto.stg(value, dst, scalar.index_cast(lane)) + + +@pto.jit(target="a5") +def simt_ops_vote_probe(dst: pto.ptr(pto.i32, "gm")): + vote_probe[32, 1, 1](dst) +``` + +--- + +#### `pto.shuffle_idx(value: ScalarType, index: Index, *, width: int = 32) -> ScalarType` +#### `pto.shuffle_up(value: ScalarType, offset: Index, *, width: int = 32) -> ScalarType` +#### `pto.shuffle_down(value: ScalarType, offset: Index, *, width: int = 32) -> ScalarType` +#### `pto.shuffle_bfly(value: ScalarType, mask: Index, *, width: int = 32) -> ScalarType` + +**Description**: Reads a scalar payload from another lane. `shuffle_idx` uses an +absolute lane index, `shuffle_up` and `shuffle_down` use relative offsets, and +`shuffle_bfly` uses a butterfly mask. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `value` | PTO scalar | Payload to shuffle | +| `index` / `offset` / `mask` | `i32`-compatible value | Lane selector | +| `width` | Python `int` | Subgroup width, either `16` or `32` | + +**Returns**: PTO scalar with the same type as `value`. + +**Example**: + + +```python +@pto.simt +def shuffle_probe(dst: pto.ptr(pto.i32, "gm")): + lane = pto.get_laneid() + shuffled = pto.shuffle_idx(lane, lane, width=32) + shifted_up = pto.shuffle_up(lane, 1, width=32) + shifted_down = pto.shuffle_down(lane, 1, width=32) + butterfly = pto.shuffle_bfly(lane, 1, width=32) + value = shuffled + shifted_up + shifted_down + butterfly + pto.stg(value, dst, scalar.index_cast(lane)) + + +@pto.jit(target="a5") +def simt_ops_shuffle_probe(dst: pto.ptr(pto.i32, "gm")): + shuffle_probe[32, 1, 1](dst) +``` + +--- + +#### `pto.redux_add(value: ScalarType, *, signedness: str | None = None) -> ScalarType` +#### `pto.redux_max(value: ScalarType, *, signedness: str | None = None) -> ScalarType` +#### `pto.redux_min(value: ScalarType, *, signedness: str | None = None) -> ScalarType` + +**Description**: Reduces a scalar value across SIMT lanes. Integer +`redux_max` and `redux_min` require `signedness="signed"` or +`signedness="unsigned"`. Floating-point reductions do not accept `signedness`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `value` | PTO scalar | Payload to reduce | +| `signedness` | `"signed"`, `"unsigned"`, or `None` | Integer signedness control | + +**Returns**: PTO scalar with the same type as `value`. + +**Example**: + + +```python +@pto.simt +def reduce_lane_value(dst: pto.ptr(pto.i32, "gm")): + pred = pto.const(1, dtype=pto.i1) + lane = pto.get_laneid() + + pto.vote_all(pred) + pto.vote_any(pred) + pto.vote_uni(pred) + pto.vote_ballot(pred) + + pto.shuffle_idx(lane, lane, width=32) + pto.shuffle_up(lane, 1, width=32) + pto.shuffle_down(lane, 1, width=32) + value = pto.shuffle_bfly(lane, 1, width=32) + total = pto.redux_add(value, signedness="signed") + maximum = pto.redux_max(total, signedness="signed") + minimum = pto.redux_min(maximum, signedness="signed") + pto.stg(minimum, dst, scalar.index_cast(lane)) + + +@pto.jit(target="a5") +def simt_ops_collective_probe(dst: pto.ptr(pto.i32, "gm")): + reduce_lane_value[32, 1, 1](dst) +``` + +## 13.4 Scalar GM memory and atomic ops + +#### `pto.ldg(ptr: PtrType, offset: Index = 0, *, l1cache: str = "cache", l2cache: str = "nmfv") -> ScalarType` +#### `pto.stg(value: ScalarType, ptr: PtrType, offset: Index = 0, *, l1cache: str = "cache", l2cache: str = "nmfv") -> None` + +**Description**: Loads or stores one scalar value through a typed pointer with +cache controls. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `ptr` | `pto.ptr(dtype, "gm")` | GM pointer | +| `value` | PTO scalar | Store payload for `pto.stg` | +| `offset` | index-like value | Element offset | +| `l1cache` | `"cache"` or `"uncache"` | L1 cache policy | +| `l2cache` | cache token string | L2 cache policy accepted by VPTO | + +**Returns**: `pto.ldg` returns the pointer element type. `pto.stg` returns None. + +**Example**: + + +```python +@pto.simt +def ldg_stg_probe(src: pto.ptr(pto.i32, "gm"), dst: pto.ptr(pto.i32, "gm")): + lane = pto.get_tid_x() + idx = scalar.index_cast(lane) + value = pto.ldg(src, idx, l1cache="cache", l2cache="nmfv") + pto.stg(value, dst, idx, l1cache="uncache", l2cache="wtsred") + + +@pto.jit(target="a5") +def simt_ops_ldg_stg_probe( + src: pto.ptr(pto.i32, "gm"), + dst: pto.ptr(pto.i32, "gm"), +): + ldg_stg_probe[32, 1, 1](src, dst) +``` + +--- + +#### `pto.atomic_exch(ptr: PtrType, value: ScalarType, *, l2cache: str = "nmfv", signedness: str | None = None) -> ScalarType` +#### `pto.atomic_add(ptr: PtrType, value: ScalarType, *, l2cache: str = "nmfv", signedness: str | None = None) -> ScalarType` +#### `pto.atomic_sub(ptr: PtrType, value: ScalarType, *, l2cache: str = "nmfv", signedness: str | None = None) -> ScalarType` +#### `pto.atomic_min(ptr: PtrType, value: ScalarType, *, l2cache: str = "nmfv", signedness: str | None = None) -> ScalarType` +#### `pto.atomic_max(ptr: PtrType, value: ScalarType, *, l2cache: str = "nmfv", signedness: str | None = None) -> ScalarType` +#### `pto.atomic_and(ptr: PtrType, value: ScalarType, *, l2cache: str = "nmfv", signedness: str | None = None) -> ScalarType` +#### `pto.atomic_or(ptr: PtrType, value: ScalarType, *, l2cache: str = "nmfv", signedness: str | None = None) -> ScalarType` +#### `pto.atomic_xor(ptr: PtrType, value: ScalarType, *, l2cache: str = "nmfv", signedness: str | None = None) -> ScalarType` +#### `pto.atomic_cas(ptr: PtrType, compare: ScalarType, value: ScalarType, *, l2cache: str = "nmfv", signedness: str | None = None) -> ScalarType` + +**Description**: Performs a scalar atomic operation and returns the old value. +Integer atomics may pass `signedness`; floating-point and packed atomics must +omit it. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `ptr` | typed pointer | Atomic target | +| `value` | PTO scalar | Atomic payload | +| `compare` | PTO scalar | Compare value for `atomic_cas` | +| `l2cache` | cache token string | L2 cache policy accepted by VPTO | +| `signedness` | `"signed"`, `"unsigned"`, or `None` | Integer signedness control | + +**Returns**: Old value loaded from `ptr`. + +**Example**: + + +```python +@pto.simt +def update_counter(counter: pto.ptr(pto.i32, "gm")): + tid = pto.get_tid_x() + idx = scalar.index_cast(tid) + value = pto.ldg(counter, idx, l1cache="cache", l2cache="nmfv") + old = pto.atomic_add(counter, value, l2cache="nmfv", signedness="signed") + pto.atomic_exch(counter, value, signedness="signed") + pto.atomic_sub(counter, value, signedness="signed") + pto.atomic_min(counter, value, signedness="signed") + pto.atomic_max(counter, value, signedness="signed") + pto.atomic_and(counter, value, signedness="unsigned") + pto.atomic_or(counter, value, signedness="unsigned") + pto.atomic_xor(counter, value, signedness="unsigned") + pto.atomic_cas(counter, old, value, signedness="signed") + pto.stg(old, counter, idx, l1cache="uncache", l2cache="wtsred") + + +@pto.jit(target="a5") +def simt_ops_memory_atomic_probe(counter: pto.ptr(pto.i32, "gm")): + update_counter[32, 1, 1](counter) +``` + +## 13.5 Scalar math and conversion ops + +#### `pto.prmt(lhs: ScalarType, rhs: ScalarType, selector: Index) -> ScalarType` +#### `pto.mulhi(lhs: ScalarType, rhs: ScalarType, *, signedness: str) -> ScalarType` +#### `pto.mul_i32toi64(lhs: ScalarType, rhs: ScalarType, *, signedness: str) -> pto.i64` + +**Description**: Performs integer byte permutation or multiplication helper +operations. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `lhs`, `rhs` | integer PTO scalar | Source operands | +| `selector` | `i32`-compatible value | Byte selector for `prmt` | +| `signedness` | `"signed"` or `"unsigned"` | Integer signedness control | + +**Returns**: `pto.prmt` and `pto.mulhi` return the source integer type. +`pto.mul_i32toi64` returns `pto.i64`. + +**Example**: + + +```python +@pto.simt +def integer_math_probe(dst: pto.ptr(pto.i32, "gm")): + lane = pto.get_laneid() + permuted = pto.prmt(lane, lane, lane) + high = pto.mulhi(permuted, lane, signedness="unsigned") + wide = pto.mul_i32toi64(lane, lane, signedness="unsigned") + _ = wide + pto.stg(high, dst, scalar.index_cast(lane)) + + +@pto.jit(target="a5") +def simt_ops_integer_math_probe(dst: pto.ptr(pto.i32, "gm")): + integer_math_probe[32, 1, 1](dst) +``` + +--- + +#### `pto.absf(value: ScalarType) -> ScalarType` +#### `pto.sqrt(value: ScalarType) -> ScalarType` +#### `pto.exp(value: ScalarType) -> ScalarType` +#### `pto.log(value: ScalarType) -> ScalarType` +#### `pto.pow(lhs: ScalarType, rhs: ScalarType) -> ScalarType` +#### `pto.ceil(value: ScalarType) -> ScalarType` +#### `pto.floor(value: ScalarType) -> ScalarType` +#### `pto.rint(value: ScalarType) -> ScalarType` +#### `pto.round(value: ScalarType) -> ScalarType` +#### `pto.fmin(lhs: ScalarType, rhs: ScalarType) -> ScalarType` +#### `pto.fmax(lhs: ScalarType, rhs: ScalarType) -> ScalarType` +#### `pto.fma(lhs: ScalarType, rhs: ScalarType, acc: ScalarType) -> ScalarType` + +**Description**: Performs SIMT floating-point math. These functions are VPTO +SIMT micro-ops and are distinct from the generic scalar helpers in Chapter 6. + +**Parameters**: PTO floating-point scalar operands. + +**Returns**: PTO scalar with the same type as the input value. + +**Example**: + + +```python +@pto.simt +def float_math_probe(dst: pto.ptr(pto.f32, "gm")): + lane = pto.get_laneid() + value = pto.convert( + lane, + pto.f32, + rounding="r", + saturation="nosat", + signedness="unsigned", + ) + root = pto.sqrt(pto.absf(value)) + powered = pto.pow(root, root) + rounded = pto.round(pto.rint(pto.floor(pto.ceil(powered)))) + bounded = pto.fmin(pto.fmax(value, root), rounded) + accum = pto.fma(bounded, pto.exp(value), pto.log(pto.fmax(value, root))) + pto.stg(accum, dst, scalar.index_cast(lane)) + + +@pto.jit(target="a5") +def simt_ops_float_math_probe(dst: pto.ptr(pto.f32, "gm")): + float_math_probe[32, 1, 1](dst) +``` + +--- + +#### `pto.convert(src: ScalarType, dst_type: Type, *, rounding: str, saturation: str, signedness: str | None = None) -> ScalarType` + +**Description**: Converts a scalar or packed value to `dst_type` with explicit +VPTO conversion controls. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | PTO scalar | Source value | +| `dst_type` | PTO type | Destination type | +| `rounding` | `"r"`, `"a"`, `"f"`, `"c"`, `"z"`, `"o"`, or `"h"` | Rounding mode | +| `saturation` | `"sat"`, `"nosat"`, `"on"`, or `"off"` | Saturation mode | +| `signedness` | `"signed"`, `"unsigned"`, or `None` | Required when converting to/from integer types | + +**Returns**: Converted PTO scalar. Integer-to-integer conversion is not +supported by `pto.convert`. + +**Example**: + + +```python +@pto.simt +def transform_lane_value(dst: pto.ptr(pto.f32, "gm")): + lane = pto.get_laneid() + permuted = pto.prmt(lane, lane, lane) + high = pto.mulhi(permuted, lane, signedness="unsigned") + product = pto.mul_i32toi64(lane, lane, signedness="unsigned") + _ = high + _ = product + + value = pto.convert( + lane, + pto.f32, + rounding="r", + saturation="nosat", + signedness="unsigned", + ) + root = pto.sqrt(pto.absf(value)) + powered = pto.pow(root, root) + rounded = pto.round(pto.rint(pto.floor(pto.ceil(powered)))) + bounded = pto.fmin(pto.fmax(value, root), rounded) + accum = pto.fma(bounded, pto.exp(value), pto.log(pto.fmax(value, root))) + pto.stg(accum, dst, scalar.index_cast(lane)) + + +@pto.jit(target="a5") +def simt_ops_math_probe(dst: pto.ptr(pto.f32, "gm")): + transform_lane_value[32, 1, 1](dst) +``` + +## 13.6 Sync and state ops + +#### `pto.syncthreads() -> None` +#### `pto.threadfence() -> None` +#### `pto.threadfence_block() -> None` + +**Description**: Emits SIMT synchronization or memory fence operations. + +**Parameters**: None. + +**Returns**: None. + +**Example**: + + +```python +@pto.simt +def sync_probe(dst: pto.ptr(pto.i32, "gm")): + lane = pto.get_laneid() + pto.syncthreads() + pto.threadfence() + pto.threadfence_block() + pto.stg(lane, dst, scalar.index_cast(lane)) + + +@pto.jit(target="a5") +def simt_ops_sync_probe(dst: pto.ptr(pto.i32, "gm")): + sync_probe[32, 1, 1](dst) +``` + +--- + +#### `pto.keep(payload: ScalarType, *, slot: int) -> None` +#### `pto.resume(result_type: Type, *, slot: int) -> ScalarType` + +**Description**: Preserves and restores a SIMT scalar payload through an +explicit slot. Placement constraints are enforced by the VPTO verifier. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `payload` | PTO scalar | Value to preserve | +| `result_type` | PTO type | Type restored by `resume` | +| `slot` | non-negative Python `int` | State slot | + +**Returns**: `pto.keep` returns None. `pto.resume` returns the restored scalar. + +**Example**: + + +```python +@pto.simt +def save_lane_state(): + pto.keep(pto.get_tid_x(), slot=0) + + +@pto.simt +def use_lane_state(dst: pto.ptr(pto.i32, "gm")): + lane = pto.resume(pto.i32, slot=0) + pto.syncthreads() + pto.threadfence() + pto.threadfence_block() + pto.stg(lane, dst, scalar.index_cast(lane)) + + +@pto.jit(target="a5") +def simt_ops_sync_state_probe(dst: pto.ptr(pto.i32, "gm")): + save_lane_state[32, 1, 1]() + use_lane_state[32, 1, 1](dst) +``` diff --git a/ptodsl/ptodsl/_ops.py b/ptodsl/ptodsl/_ops.py index b8bcf0466..63a83777b 100644 --- a/ptodsl/ptodsl/_ops.py +++ b/ptodsl/ptodsl/_ops.py @@ -4433,6 +4433,11 @@ def get_tid_z(): return wrap_surface_value(_pto.GetTidZOp().result) +def get_tid(): + """``pto.get_tid`` → ``(x, y, z)`` SIMT lane coordinates.""" + return get_tid_x(), get_tid_y(), get_tid_z() + + def get_block_dim_x(): """``pto.get_block_dim_x`` → i32 SIMT block X dimension.""" return wrap_surface_value(_pto.GetBlockDimXOp().result) @@ -4448,6 +4453,11 @@ def get_block_dim_z(): return wrap_surface_value(_pto.GetBlockDimZOp().result) +def get_block_dim(): + """``pto.get_block_dim`` → ``(x, y, z)`` SIMT block dimensions.""" + return get_block_dim_x(), get_block_dim_y(), get_block_dim_z() + + def get_grid_dim_x(): """``pto.get_grid_dim_x`` → i32 SIMT grid X dimension.""" return wrap_surface_value(_pto.GetGridDimXOp().result) @@ -4463,6 +4473,11 @@ def get_grid_dim_z(): return wrap_surface_value(_pto.GetGridDimZOp().result) +def get_grid_dim(): + """``pto.get_grid_dim`` → ``(x, y, z)`` SIMT grid dimensions.""" + return get_grid_dim_x(), get_grid_dim_y(), get_grid_dim_z() + + def get_block_idx_x(): """``pto.get_block_idx_x`` → i32 SIMT block X index.""" return wrap_surface_value(_pto.GetBlockIdxXOp().result) @@ -5144,9 +5159,9 @@ def import_reserved_buffer(name, *, peer_func): "mad", "mad_acc", "mad_bias", "mad_mx", "mad_mx_acc", "mad_mx_bias", "get_block_idx", "get_block_num", "get_subblock_idx", "get_subblock_num", "store_vfsimt_info", "simt_launch", - "get_tid_x", "get_tid_y", "get_tid_z", - "get_block_dim_x", "get_block_dim_y", "get_block_dim_z", - "get_grid_dim_x", "get_grid_dim_y", "get_grid_dim_z", + "get_tid", "get_tid_x", "get_tid_y", "get_tid_z", + "get_block_dim", "get_block_dim_x", "get_block_dim_y", "get_block_dim_z", + "get_grid_dim", "get_grid_dim_x", "get_grid_dim_y", "get_grid_dim_z", "get_block_idx_x", "get_block_idx_y", "get_block_idx_z", "get_veccoreid", "get_clock32", "get_clock64", "get_laneid", "get_lanemask_eq", "get_lanemask_le", "get_lanemask_lt", diff --git a/ptodsl/ptodsl/_subkernels.py b/ptodsl/ptodsl/_subkernels.py index 1b89547e9..5e708f821 100644 --- a/ptodsl/ptodsl/_subkernels.py +++ b/ptodsl/ptodsl/_subkernels.py @@ -76,6 +76,18 @@ def __call__(self, *args, **kwargs): self._validate_invocation(*args, **kwargs) return runtime.dispatch_subkernel_call(self, *args, **kwargs) + def __getitem__(self, dims): + if self.spec.role != KernelRole.SIMT: + raise TypeError( + f"@pto.{self.spec.role.value} kernels do not support launch dimensions; " + "only @pto.simt helpers support helper[dim_x, dim_y, dim_z](...)" + ) + if not isinstance(dims, tuple): + dims = (dims,) + if len(dims) != 3: + raise TypeError("@pto.simt launch syntax expects helper[dim_x, dim_y, dim_z](...)") + return _SimtLaunchTemplate(self, dims) + def _validate_definition(self) -> None: for param in self.signature.parameters.values(): if isinstance(param.annotation, TensorSpec): @@ -99,6 +111,19 @@ def _validate_result(self, result) -> None: raise simd_value_escape_error(escaped_type) +class _SimtLaunchTemplate: + """Callable ``helper[x, y, z]`` launch descriptor for a decorated SIMT helper.""" + + def __init__(self, body: SubkernelTemplate, dims): + self._body = body + self._dims = dims + + def __call__(self, *args, **kwargs): + from ._ops import simt_launch + + return simt_launch(self._body, *args, dims=self._dims, **kwargs) + + def _find_transient_simd_escape(value): if value is None: return None diff --git a/ptodsl/ptodsl/pto.py b/ptodsl/ptodsl/pto.py index d6753e7f9..ef5549065 100644 --- a/ptodsl/ptodsl/pto.py +++ b/ptodsl/ptodsl/pto.py @@ -112,9 +112,9 @@ mad, mad_acc, mad_bias, mad_mx, mad_mx_acc, mad_mx_bias, get_block_idx, get_block_num, get_subblock_idx, get_subblock_num, store_vfsimt_info, simt_launch, - get_tid_x, get_tid_y, get_tid_z, - get_block_dim_x, get_block_dim_y, get_block_dim_z, - get_grid_dim_x, get_grid_dim_y, get_grid_dim_z, + get_tid, get_tid_x, get_tid_y, get_tid_z, + get_block_dim, get_block_dim_x, get_block_dim_y, get_block_dim_z, + get_grid_dim, get_grid_dim_x, get_grid_dim_y, get_grid_dim_z, get_block_idx_x, get_block_idx_y, get_block_idx_z, get_veccoreid, get_clock32, get_clock64, get_laneid, get_lanemask_eq, get_lanemask_le, get_lanemask_lt, diff --git a/ptodsl/tests/test_jit_compile.py b/ptodsl/tests/test_jit_compile.py index 6c92ab052..4338dbd91 100644 --- a/ptodsl/tests/test_jit_compile.py +++ b/ptodsl/tests/test_jit_compile.py @@ -402,15 +402,9 @@ def simt_tid_probe(): @pto.simt def simt_query_probe(): - pto.get_tid_x() - pto.get_tid_y() - pto.get_tid_z() - pto.get_block_dim_x() - pto.get_block_dim_y() - pto.get_block_dim_z() - pto.get_grid_dim_x() - pto.get_grid_dim_y() - pto.get_grid_dim_z() + pto.get_tid() + pto.get_block_dim() + pto.get_grid_dim() pto.get_block_idx_x() pto.get_block_idx_y() pto.get_block_idx_z() @@ -425,6 +419,22 @@ def simt_query_probe(): pto.get_lanemask_gt() +@pto.simt +def simt_grouped_query_probe(): + tid_x, tid_y, tid_z = pto.get_tid() + block_x, block_y, block_z = pto.get_block_dim() + grid_x, grid_y, grid_z = pto.get_grid_dim() + pto.keep(tid_x, slot=0) + pto.keep(tid_y, slot=1) + pto.keep(tid_z, slot=2) + pto.keep(block_x, slot=3) + pto.keep(block_y, slot=4) + pto.keep(block_z, slot=5) + pto.keep(grid_x, slot=6) + pto.keep(grid_y, slot=7) + pto.keep(grid_z, slot=8) + + @pto.simt(max_threads=256, max_regs=48) def simt_resource_attr_probe(): pto.get_tid_x() @@ -548,12 +558,22 @@ def simt_helper_lowering_probe(*, TRACE_TOKEN: pto.const_expr = 0): @pto.jit(target="a5") -def simt_explicit_launch_probe(*, TRACE_TOKEN: pto.constexpr = 0): +def simt_explicit_launch_probe(*, TRACE_TOKEN: pto.const_expr = 0): pto.simt_launch(simt_query_probe, dims=(32, 2, 1)) @pto.jit(target="a5") -def simt_resource_attr_launch_probe(*, TRACE_TOKEN: pto.constexpr = 0): +def simt_launch_index_sugar_probe(*, TRACE_TOKEN: pto.const_expr = 0): + simt_query_probe[32, 2, 1]() + + +@pto.jit(target="a5") +def simt_grouped_query_launch_probe(*, TRACE_TOKEN: pto.const_expr = 0): + simt_grouped_query_probe[32, 1, 1]() + + +@pto.jit(target="a5") +def simt_resource_attr_launch_probe(*, TRACE_TOKEN: pto.const_expr = 0): pto.simt_launch(simt_resource_attr_probe, dims=(128, 1, 1)) @@ -561,7 +581,7 @@ def simt_resource_attr_launch_probe(*, TRACE_TOKEN: pto.constexpr = 0): def simt_full_surface_probe( gm: pto.ptr(pto.i32, "gm"), *, - TRACE_TOKEN: pto.constexpr = 0, + TRACE_TOKEN: pto.const_expr = 0, ): pto.simt_launch(simt_collective_math_probe, dims=(32, 1, 1)) pto.simt_launch(simt_memory_atomic_probe, gm, dims=(32, 1, 1)) @@ -574,25 +594,25 @@ def simt_specialized_arg_type_probe( gm_i32: pto.ptr(pto.i32, "gm"), gm_f32: pto.ptr(pto.f32, "gm"), *, - TRACE_TOKEN: pto.constexpr = 0, + TRACE_TOKEN: pto.const_expr = 0, ): pto.simt_launch(simt_specialized_ptr_probe, gm_i32, dims=(32, 1, 1)) pto.simt_launch(simt_specialized_ptr_probe, gm_f32, dims=(32, 1, 1)) @pto.jit(target="a5") -def simt_specialized_static_kwarg_probe(*, TRACE_TOKEN: pto.constexpr = 0): +def simt_specialized_static_kwarg_probe(*, TRACE_TOKEN: pto.const_expr = 0): pto.simt_launch(simt_specialized_flag_probe, dims=(32, 1, 1), FLAG=False) pto.simt_launch(simt_specialized_flag_probe, dims=(32, 1, 1), FLAG=True) @pto.jit(target="a5") -def simt_invalid_redux_signedness_launch(*, TRACE_TOKEN: pto.constexpr = 0): +def simt_invalid_redux_signedness_launch(*, TRACE_TOKEN: pto.const_expr = 0): pto.simt_launch(simt_invalid_redux_signedness_probe, dims=(32, 1, 1)) @pto.jit(target="a5") -def simt_invalid_convert_signedness_launch(*, TRACE_TOKEN: pto.constexpr = 0): +def simt_invalid_convert_signedness_launch(*, TRACE_TOKEN: pto.const_expr = 0): pto.simt_launch(simt_invalid_convert_signedness_probe, dims=(32, 1, 1)) @@ -600,7 +620,7 @@ def simt_invalid_convert_signedness_launch(*, TRACE_TOKEN: pto.constexpr = 0): def simt_invalid_atomic_signedness_launch( gm: pto.ptr(pto.f32, "gm"), *, - TRACE_TOKEN: pto.constexpr = 0, + TRACE_TOKEN: pto.const_expr = 0, ): pto.simt_launch(simt_invalid_atomic_signedness_probe, gm, dims=(32, 1, 1)) @@ -2703,6 +2723,47 @@ def main() -> None: re.search(r"func\.func @simt_query_probe__simt_\d+\(\) attributes \{pto\.simt_entry\}", simt_launch_text) is not None, "explicit pto.simt_launch should materialize a reusable pto.simt_entry helper", ) + simt_launch_sugar_text = simt_launch_index_sugar_probe.compile(TRACE_TOKEN=1).mlir_text() + expect_parse_roundtrip_and_verify(simt_launch_sugar_text, "indexed simt launch specialization") + expect( + re.search(r"pto\.simt_launch @simt_query_probe__simt_\d+<<<", simt_launch_sugar_text) is not None, + "@pto.simt helper[x, y, z](...) should emit VPTO simt_launch sugar", + ) + simt_grouped_query_text = simt_grouped_query_launch_probe.compile(TRACE_TOKEN=1).mlir_text() + expect_parse_roundtrip_and_verify(simt_grouped_query_text, "grouped simt query specialization") + expect( + re.search(r"pto\.simt_launch @simt_grouped_query_probe__simt_\d+<<<", simt_grouped_query_text) is not None, + "grouped SIMT query probe should be launchable through helper[x, y, z](...)", + ) + for op_name in ( + "pto.get_tid_x", + "pto.get_tid_y", + "pto.get_tid_z", + "pto.get_block_dim_x", + "pto.get_block_dim_y", + "pto.get_block_dim_z", + "pto.get_grid_dim_x", + "pto.get_grid_dim_y", + "pto.get_grid_dim_z", + ): + expect( + simt_grouped_query_text.count(op_name) == 1, + f"grouped SIMT query helpers should lower exactly once to {op_name}", + ) + expect( + simt_grouped_query_text.count("pto.keep") == 9, + "grouped SIMT query helpers should return values that can be consumed by later micro-ops", + ) + expect_raises( + TypeError, + lambda: simt_query_probe[32, 1](), + "helper[dim_x, dim_y, dim_z]", + ) + expect_raises( + TypeError, + lambda: ast_subkernel_runtime_for_helper[32, 1, 1](pto.const(1, dtype=pto.i32)), + "only @pto.simt", + ) simt_resource_attr_text = simt_resource_attr_launch_probe.compile(TRACE_TOKEN=1).mlir_text() expect_parse_roundtrip_and_verify(simt_resource_attr_text, "simt resource attr launch specialization") expect( diff --git a/ptodsl/tests/test_ptoas_frontend_verify.py b/ptodsl/tests/test_ptoas_frontend_verify.py index 565109646..8333f40dc 100644 --- a/ptodsl/tests/test_ptoas_frontend_verify.py +++ b/ptodsl/tests/test_ptoas_frontend_verify.py @@ -19,6 +19,7 @@ sys.path.insert(0, str(REPO_ROOT / "ptodsl")) from ptodsl import pto +from ptodsl import scalar def expect(condition: bool, message: str) -> None: @@ -83,6 +84,22 @@ def host_vec_copy( pto.tile.store(o_tile, out) +@pto.simt +def simt_gm_memory_core_body(gm: pto.ptr(pto.i32, "gm")): + tx = pto.get_tid_x() + src_idx = scalar.index_cast(tx) + loaded = scalar.load(gm, src_idx) + with_bias = loaded + tx + 1000 + scalar.store(with_bias, gm, scalar.index_cast(tx + 32)) + scalar.store(tx, gm, scalar.index_cast(tx + 64)) + + +@pto.jit(target="a5", mode="explicit") +def simt_gm_memory_core_kernel(gm: pto.ptr(pto.i32, "gm")): + simt_gm_memory_core_body[32, 1, 1](gm) + pto.pipe_barrier(pto.Pipe.ALL) + + def main() -> None: ptoas_bin = resolve_ptoas_binary() @@ -100,6 +117,30 @@ def main() -> None: "pto.tload" in simple_frontend_text and "pto.tstore" in simple_frontend_text, "host_vec_copy frontend verification output should keep the tile IO contract visible", ) + + simt_gm_memory_text = simt_gm_memory_core_kernel.compile().mlir_text() + simt_frontend_text = run_ptoas_frontend_verify( + ptoas_bin, + simt_gm_memory_text, + "simt_gm_memory_core PTODSL artifact", + ) + expect( + "func.func @simt_gm_memory_core_kernel" in simt_frontend_text, + "simt_gm_memory_core frontend output should preserve the kernel symbol", + ) + expect( + "pto.simt_launch @simt_gm_memory_core_body__simt_" in simt_frontend_text, + "simt_gm_memory_core frontend output should preserve the SIMT launch", + ) + expect( + "pto.get_tid_x" in simt_frontend_text, + "simt_gm_memory_core frontend output should preserve SIMT thread queries", + ) + expect( + "pto.load" in simt_frontend_text and simt_frontend_text.count("pto.store") >= 2, + "simt_gm_memory_core frontend output should preserve GM load/store operations", + ) + print("ptodsl_ptoas_frontend_verify: PASS") diff --git a/test/dsl-st/simt_gm_memory_core.py b/test/dsl-st/simt_gm_memory_core.py new file mode 100644 index 000000000..c19b8b2b7 --- /dev/null +++ b/test/dsl-st/simt_gm_memory_core.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from pathlib import Path +import sys + +import numpy as np + +if __package__ in {None, ""}: + sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from common import auto_main, golden_output_case +from ptodsl import pto +from ptodsl import scalar + + +LANES = 32 + + +@pto.simt +def simt_gm_memory_core_body( + inp: pto.ptr(pto.i32, "gm"), + out: pto.ptr(pto.i32, "gm"), +): + tid = pto.get_tid_x() + idx = scalar.index_cast(tid) + loaded = scalar.load(inp, idx) + scalar.store(loaded + tid + 1000, out, idx) + scalar.store(tid, out, scalar.index_cast(tid + LANES)) + + +@pto.jit( + name="simt_gm_memory_core_kernel", + kernel_kind="vector", + target="a5", + mode="explicit", + insert_sync=False, +) +def simt_gm_memory_core_kernel( + inp_ptr: pto.ptr(pto.i32, "gm"), + out_ptr: pto.ptr(pto.i32, "gm"), +): + simt_gm_memory_core_body[LANES, 1, 1](inp_ptr, out_ptr) + pto.pipe_barrier(pto.Pipe.ALL) + + +def make_inputs(): + return [(np.arange(LANES, dtype=np.int32) * 3) - 17] + + +def make_expected(inp): + tid = np.arange(LANES, dtype=np.int32) + return np.concatenate([inp + tid + 1000, tid]) + + +CASES = [ + golden_output_case( + "simt_gm_memory_core", + simt_gm_memory_core_kernel, + inputs=make_inputs, + expected=make_expected, + rtol=0.0, + atol=0.0, + ), +] + + +auto_main(globals())