Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/designs/ptodsl-ast-preprocess-control-flow.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Compile-time control flow stays explicit:

```python
@pto.jit(target="a5")
def kernel(*, BLOCK: pto.constexpr = 128):
def kernel(*, BLOCK: pto.const_expr = 128):
if pto.const_expr(BLOCK == 128):
for stage in pto.static_range(4):
...
Expand Down
4 changes: 2 additions & 2 deletions ptodsl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def Softmax(
rows: pto.i32,
cols: pto.i32,
*,
BLOCK: pto.constexpr = 128,
BLOCK: pto.const_expr = 128,
):
x_view = pto.make_tensor_view(X_ptr, shape=[rows, cols], strides=[cols, 1])
o_view = pto.make_tensor_view(O_ptr, shape=[rows, cols], strides=[cols, 1])
Expand All @@ -288,7 +288,7 @@ PTODSL v1 keeps the public `@pto.jit` entry ABI intentionally narrow:
- positional runtime scalars use PTO scalar annotations such as `pto.i32`,
`pto.f32`, and `pto.i1`, while launch-time values remain ordinary Python
scalars
- keyword-only parameters annotated with `pto.constexpr` are compile-time
- keyword-only parameters annotated with `pto.const_expr` are compile-time
specialization knobs

The host wrapper is responsible for extracting or deriving whatever runtime
Expand Down
8 changes: 4 additions & 4 deletions ptodsl/docs/user_guide/01-introduction.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ Decorating a function with `@pto.jit` marks it as a launchable PTO kernel. This
- **Caching**: compiled kernels are cached by specialization key (function identity + entry annotation signature + constexpr parameter values), so repeated calls with the same configuration skip recompilation.
- **Launch binding**: the compiled kernel can be invoked with a grid and stream — `compiled[grid, stream](args...)` — which launches the executable on the NPU with the given SPMD grid.

The public `@pto.jit` entry contract is pointer-first. Device buffers are explicit GM pointers (`pto.ptr(..., "gm")`), launch-varying shape/stride metadata travels as runtime scalars, and the kernel body materializes `TensorView` descriptors with `make_tensor_view(ptr, shape=..., strides=...)`. Compile-time constants remain keyword-only `pto.constexpr` parameters:
The public `@pto.jit` entry contract is pointer-first. Device buffers are explicit GM pointers (`pto.ptr(..., "gm")`), launch-varying shape/stride metadata travels as runtime scalars, and the kernel body materializes `TensorView` descriptors with `make_tensor_view(ptr, shape=..., strides=...)`. Compile-time constants remain keyword-only `pto.const_expr` parameters:

<!-- ptodsl-doc-test: {"mode":"compile","symbol":"flash_attention_kernel","compile":{"BLOCK_Q":128,"BLOCK_KV":128,"CAUSAL":false}} -->
```python
Expand All @@ -120,9 +120,9 @@ def flash_attention_kernel(
heads: pto.i32,
dim: pto.i32,
*,
BLOCK_Q: pto.constexpr = 128,
BLOCK_KV: pto.constexpr = 128,
CAUSAL: pto.constexpr = False,
BLOCK_Q: pto.const_expr = 128,
BLOCK_KV: pto.const_expr = 128,
CAUSAL: pto.const_expr = False,
):
q_view = pto.make_tensor_view(
Q_ptr,
Expand Down
10 changes: 5 additions & 5 deletions ptodsl/docs/user_guide/02-quick-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def tile_copy(
rows: pto.i32,
cols: pto.i32,
*,
BLOCK: pto.constexpr = 128,
BLOCK: pto.const_expr = 128,
):
"""Copy one 2D tensor tile from A to O."""

Expand All @@ -43,10 +43,10 @@ Let us step through each piece.

```python
@pto.jit(target="a5")
def tile_copy(A, O, *, BLOCK: pto.constexpr = 128):
def tile_copy(A, O, *, BLOCK: pto.const_expr = 128):
```

`@pto.jit` marks this function as a launchable PTO kernel. The positional parameters `A_ptr` and `O_ptr` are explicit GM pointers, while `rows` and `cols` are runtime scalar metadata passed at launch time. The keyword-only argument `BLOCK` is a compile-time constant declared with `pto.constexpr`; the compiler specializes the kernel for each tile width.
`@pto.jit` marks this function as a launchable PTO kernel. The positional parameters `A_ptr` and `O_ptr` are explicit GM pointers, while `rows` and `cols` are runtime scalar metadata passed at launch time. The keyword-only argument `BLOCK` is a compile-time constant declared with `pto.const_expr`; the compiler specializes the kernel for each tile width.

### Describing GM tensors

Expand Down Expand Up @@ -114,7 +114,7 @@ def blocked_copy(
rows: pto.i32,
cols: pto.i32,
*,
BLOCK: pto.constexpr = 128,
BLOCK: pto.const_expr = 128,
):
a_view = pto.make_tensor_view(A_ptr, shape=[rows, cols], strides=[cols, 1])
o_view = pto.make_tensor_view(O_ptr, shape=[rows, cols], strides=[cols, 1])
Expand Down Expand Up @@ -213,7 +213,7 @@ def vec_add_micro(
O_ptr: pto.ptr(pto.f32, "gm"),
N: pto.i32,
*,
BLOCK: pto.constexpr = 128,
BLOCK: pto.const_expr = 128,
):
a_view = pto.make_tensor_view(A_ptr, shape=[N], strides=[1])
b_view = pto.make_tensor_view(B_ptr, shape=[N], strides=[1])
Expand Down
26 changes: 13 additions & 13 deletions ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def kernel_name(
rows: pto.i32, # runtime metadata (positional)
cols: pto.i32, # runtime metadata (positional)
*,
CONST_A: pto.constexpr = 128, # compile-time constant (keyword-only)
CONST_B: pto.constexpr = 64, # compile-time constant (keyword-only)
CONST_A: pto.const_expr = 128, # compile-time constant (keyword-only)
CONST_B: pto.const_expr = 64, # compile-time constant (keyword-only)
):
x_view = pto.make_tensor_view(x_ptr, shape=[rows, cols], strides=[cols, 1])
y_view = pto.make_tensor_view(y_ptr, shape=[rows, cols], strides=[cols, 1])
Expand All @@ -54,7 +54,7 @@ position in the signature, and way to supply the value:
|---|---|---|---|
| **Device buffer** | positional (before `*`) | `pto.ptr(dtype, "gm")` | launch time |
| **Runtime scalar** | positional (before `*`) | `pto.i32`, `pto.f32`, `pto.i1`, etc. | launch time |
| **Compile-time constant** | keyword-only (after `*`) | `pto.constexpr = <default>` | compile time |
| **Compile-time constant** | keyword-only (after `*`) | `pto.const_expr = <default>` | compile time |

#### 1. Device-buffer parameters

Expand Down Expand Up @@ -95,15 +95,15 @@ def my_kernel(

#### 3. Compile-time constants

Declare after `*` with `pto.constexpr` and a default value.
Declare after `*` with `pto.const_expr` and a default value.
Pass the value to `.compile(...)` — **not** at launch time:

```python
@pto.jit(target="a5")
def my_kernel(
X_ptr: pto.ptr(pto.f32, "gm"),
*,
BLOCK: pto.constexpr = 128,
BLOCK: pto.const_expr = 128,
):
# BLOCK is a Python value at trace time — use it for tile shapes,
# unrolled loops, or dtype arguments:
Expand All @@ -128,7 +128,7 @@ def scaled_bias_add(
alpha: pto.f32, # runtime scalar
bias: pto.f32, # runtime scalar
*,
BLOCK: pto.constexpr = 128, # compile-time constant
BLOCK: pto.const_expr = 128, # compile-time constant
):
x_view = pto.make_tensor_view(X_ptr, shape=[rows, cols], strides=[cols, 1])
o_view = pto.make_tensor_view(O_ptr, shape=[rows, cols], strides=[cols, 1])
Expand Down Expand Up @@ -225,7 +225,7 @@ def my_kernel(
rows: pto.i32,
cols: pto.i32,
*,
BLOCK: pto.constexpr = 128,
BLOCK: pto.const_expr = 128,
):
a_view = pto.make_tensor_view(A_ptr, shape=[rows, cols], strides=[cols, 1])
b_view = pto.make_tensor_view(B_ptr, shape=[rows, cols], strides=[cols, 1])
Expand Down Expand Up @@ -283,7 +283,7 @@ def my_kernel(
rows: pto.i32,
cols: pto.i32,
*,
BLOCK: pto.constexpr = 128,
BLOCK: pto.const_expr = 128,
):
a_view = pto.make_tensor_view(A_ptr, shape=[rows, cols], strides=[cols, 1])
b_view = pto.make_tensor_view(B_ptr, shape=[rows, cols], strides=[cols, 1])
Expand Down Expand Up @@ -655,9 +655,9 @@ pointers:
| `@pto.simd` → caller | Only via `vsts`/`psts` to UB tiles; `vreg` cannot escape |
| Cube-local → UB | Only via `mte_l0c_ub`; LEFT/RIGHT/ACC/BIAS are private |

## 3.6 `pto.constexpr`
## 3.6 `pto.const_expr`

`pto.constexpr` marks a `@pto.jit` keyword-only parameter as a compile-time
`pto.const_expr` marks a `@pto.jit` keyword-only parameter as a compile-time
constant. The compiler specializes the kernel for each combination of constexpr
values, and the compiled artifact is cached by specialization key together with
the kernel's entry annotation contract.
Expand All @@ -668,8 +668,8 @@ the kernel's entry annotation contract.
def kernel(
A_ptr: pto.ptr(pto.f32, "gm"),
*,
BLOCK: pto.constexpr = 128,
DTYPE: pto.constexpr = pto.f32,
BLOCK: pto.const_expr = 128,
DTYPE: pto.const_expr = pto.f32,
):
# ... use BLOCK / DTYPE in tile shapes, loop bounds, or dtype-specialized paths ...
return
Expand All @@ -682,7 +682,7 @@ def kernel(
- Cannot change between launches of the same compiled instance — compile a new
variant for a different value.

`pto.constexpr` parameters can be used anywhere in the kernel body where a
`pto.const_expr` parameters can be used anywhere in the kernel body where a
Python value is expected: tile shapes, loop bounds that are known at compile
time, dtype arguments, etc. They are evaluated at trace time, so `for i in
range(BLOCK)` would unroll `BLOCK` times.
Expand Down
2 changes: 1 addition & 1 deletion ptodsl/docs/user_guide/04-type-system-and-buffer.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def kernel(
rows: pto.i32,
cols: pto.i32,
*,
BLOCK: pto.constexpr = 128,
BLOCK: pto.const_expr = 128,
):
tv = pto.make_tensor_view(A_ptr, shape=[rows, cols], strides=[cols, 1])
return
Expand Down
20 changes: 10 additions & 10 deletions ptodsl/docs/user_guide/05-control-flow.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ sequence with the body repeated:

```python
@pto.jit(target="a5")
def unrolled_kernel(A, O, *, N: pto.constexpr):
def unrolled_kernel(A, O, *, N: pto.const_expr):
a_view = pto.make_tensor_view(A, shape=[N], strides=A.strides)
o_view = pto.make_tensor_view(O, shape=[N], strides=O.strides)

Expand Down Expand Up @@ -85,7 +85,7 @@ When a loop needs to propagate state from one iteration to the next, use the `.c
<!-- ptodsl-doc-test: {"mode":"compile","symbol":"carry_loop_probe","compile":{"BLOCK":128}} -->
```python
@pto.jit(target="a5")
def carry_loop_probe(*, BLOCK: pto.constexpr = 128):
def carry_loop_probe(*, BLOCK: pto.const_expr = 128):
m_prev = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32)
l_prev = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32)
o_prev = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32)
Expand Down Expand Up @@ -212,18 +212,18 @@ conditional closes, `br.val` is the SSA-merged result seen by downstream code.
This surface avoids explicit result-type declarations and explicit
`pto.yield_(...)` in user code while still keeping the merge contract explicit.

## 5.4 `pto.constexpr` and tracing
## 5.4 `pto.const_expr` and tracing

`pto.constexpr` parameters (Section 3.8) are compile-time constants. They are fixed at `.compile()` time and cannot change between launches of the same compiled kernel. Because their values are known during tracing, they interact naturally with Python control flow:
`pto.const_expr` parameters (Section 3.8) are compile-time constants. They are fixed at `.compile()` time and cannot change between launches of the same compiled kernel. Because their values are known during tracing, they interact naturally with Python control flow:

```python
@pto.jit(target="a5")
def kernel(
A,
*,
BLOCK: pto.constexpr = 128,
NUM_BLOCKS: pto.constexpr = 8,
UNROLL: pto.constexpr = False,
BLOCK: pto.const_expr = 128,
NUM_BLOCKS: pto.const_expr = 8,
UNROLL: pto.const_expr = False,
):
N = A.shape[0]
num_blocks = (N + BLOCK - 1) // BLOCK
Expand Down Expand Up @@ -370,7 +370,7 @@ Use `pto.const_expr(...)` for trace-time branches:
<!-- ptodsl-doc-test: {"mode":"compile","symbol":"ast_rewrite_static_branch_kernel","compile":{"ENABLE":true}} -->
```python
@pto.jit(target="a5")
def ast_rewrite_static_branch_kernel(*, ENABLE: pto.constexpr = True):
def ast_rewrite_static_branch_kernel(*, ENABLE: pto.const_expr = True):
if pto.const_expr(ENABLE):
pto.pipe_barrier(pto.Pipe.ALL)
```
Expand Down Expand Up @@ -460,7 +460,7 @@ It is not the recommended user-facing mode for new examples or kernels.

```python
@pto.jit(target="a5", ast_rewrite=False)
def debug_kernel(*, BLOCK: pto.constexpr = 4):
def debug_kernel(*, BLOCK: pto.const_expr = 4):
for _ in range(BLOCK):
pto.pipe_barrier(pto.Pipe.ALL)

Expand All @@ -480,7 +480,7 @@ The same switch is available through `frontend_options`:
"ast_rewrite": False,
},
)
def debug_options_disable_rewrite_kernel(*, BLOCK: pto.constexpr = 4):
def debug_options_disable_rewrite_kernel(*, BLOCK: pto.const_expr = 4):
for _ in range(BLOCK):
pto.pipe_barrier(pto.Pipe.ALL)
```
Expand Down
4 changes: 2 additions & 2 deletions ptodsl/docs/user_guide/07-data-movement-ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -1342,7 +1342,7 @@ def cube_producer(
gm_slot_buffer: pto.gm_ptr(pto.f32),
src: pto.gm_ptr(pto.f32),
*,
BLOCK: pto.constexpr = 128,
BLOCK: pto.const_expr = 128,
):
gm_view = pto.make_tensor_view(gm_slot_buffer, shape=[16, 16], strides=[16, 1])
c2v = pto.pipe.c2v(
Expand Down Expand Up @@ -1371,7 +1371,7 @@ def vector_consumer(
gm_slot_buffer: pto.gm_ptr(pto.f32),
dst: pto.gm_ptr(pto.f32),
*,
BLOCK: pto.constexpr = 128,
BLOCK: pto.const_expr = 128,
):
gm_view = pto.make_tensor_view(gm_slot_buffer, shape=[16, 16], strides=[16, 1])
c2v = pto.pipe.c2v(
Expand Down
8 changes: 4 additions & 4 deletions ptodsl/docs/user_guide/11-flash-attention-walkthrough.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ def flash_attention_kernel(
heads: pto.i32,
dim: pto.i32,
*,
BLOCK_Q: pto.constexpr = 128,
BLOCK_KV: pto.constexpr = 128,
CAUSAL: pto.constexpr = False,
NUM_STAGES: pto.constexpr = 2,
BLOCK_Q: pto.const_expr = 128,
BLOCK_KV: pto.const_expr = 128,
CAUSAL: pto.const_expr = False,
NUM_STAGES: pto.const_expr = 2,
):
# Walkthrough body omitted in this signature overview.
return
Expand Down
14 changes: 7 additions & 7 deletions ptodsl/docs/user_guide/12-additional-examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def mat_add(
M: pto.i32,
N_: pto.i32,
*,
BLOCK_M: pto.constexpr = 64,
BLOCK_N: pto.constexpr = 128,
BLOCK_M: pto.const_expr = 64,
BLOCK_N: pto.const_expr = 128,
):
a_view = pto.make_tensor_view(A_ptr, shape=[batch, M, N_], strides=[M * N_, N_, 1])
b_view = pto.make_tensor_view(B_ptr, shape=[batch, M, N_], strides=[M * N_, N_, 1])
Expand Down Expand Up @@ -114,7 +114,7 @@ def vec_add_with_tail(
O_ptr: pto.ptr(pto.f32, "gm"),
N: pto.i32,
*,
BLOCK: pto.constexpr = 128,
BLOCK: pto.const_expr = 128,
):
a_view = pto.make_tensor_view(A_ptr, shape=[N], strides=[1])
b_view = pto.make_tensor_view(B_ptr, shape=[N], strides=[1])
Expand Down Expand Up @@ -192,9 +192,9 @@ def gemm(
K_: pto.i32,
N_: pto.i32,
*,
BLOCK_M: pto.constexpr = 64,
BLOCK_K: pto.constexpr = 64,
BLOCK_N: pto.constexpr = 64,
BLOCK_M: pto.const_expr = 64,
BLOCK_K: pto.const_expr = 64,
BLOCK_N: pto.const_expr = 64,
):
a_view = pto.make_tensor_view(A_ptr, shape=[M, K_], strides=[K_, 1])
b_view = pto.make_tensor_view(B_ptr, shape=[K_, N_], strides=[N_, 1])
Expand Down Expand Up @@ -296,7 +296,7 @@ def online_layernorm(
O_ptr: pto.ptr(pto.f32, "gm"),
N: pto.i32,
*,
BLOCK: pto.constexpr = 128,
BLOCK: pto.const_expr = 128,
):
x_view = pto.make_tensor_view(X_ptr, shape=[N], strides=[1])
o_view = pto.make_tensor_view(O_ptr, shape=[N], strides=[1])
Expand Down
10 changes: 5 additions & 5 deletions ptodsl/examples/flash_attention_sketch.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,11 @@ def flash_attention_kernel(
heads: pto.i32,
dim: pto.i32,
*,
BLOCK_Q: pto.constexpr = 128,
BLOCK_KV: pto.constexpr = 128,
HEAD_DIM: pto.constexpr = 128,
CAUSAL: pto.constexpr = False,
NUM_STAGES: pto.constexpr = 2,
BLOCK_Q: pto.const_expr = 128,
BLOCK_KV: pto.const_expr = 128,
HEAD_DIM: pto.const_expr = 128,
CAUSAL: pto.const_expr = False,
NUM_STAGES: pto.const_expr = 2,
):
"""
Launchable device entry.
Expand Down
7 changes: 1 addition & 6 deletions ptodsl/ptodsl/_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from ._bootstrap import make_context # noqa: F401
from ._runtime_index_ops import coerce_runtime_index
from ._surface_types import const_expr

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import of const_expr is unused in this file. Since const_expr is now defined in _surface_types.py and imported directly by pto.py, it is no longer needed here and can be safely removed.

from ._tracing.active import current_session
from ._surface_values import unwrap_surface_value, wrap_like_surface_value, wrap_surface_value

Expand Down Expand Up @@ -57,12 +58,6 @@ def static_range(*args):
"""Return ``range(*args)`` for trace-time unrolling under AST rewrite."""
return range(*args)


def const_expr(value):
"""Return Python truthiness for trace-time branches under AST rewrite."""
return bool(value)


# ── for_ ──────────────────────────────────────────────────────────────────────

class LoopHandle:
Expand Down
Loading
Loading