From c52114f6807711d5472878455aa15956a53d322a Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Thu, 11 Jun 2026 12:28:38 +0000 Subject: [PATCH] ptodsl: remove pto.constexpr in favor of pto.const_expr --- .../ptodsl-ast-preprocess-control-flow.md | 2 +- ptodsl/README.md | 4 +- ptodsl/docs/user_guide/01-introduction.md | 8 +- ptodsl/docs/user_guide/02-quick-start.md | 10 +- .../03-kernel-entry-and-subkernels.md | 26 +-- .../user_guide/04-type-system-and-buffer.md | 2 +- ptodsl/docs/user_guide/05-control-flow.md | 20 +- .../docs/user_guide/07-data-movement-ops.md | 4 +- .../11-flash-attention-walkthrough.md | 8 +- .../docs/user_guide/12-additional-examples.md | 14 +- ptodsl/examples/flash_attention_sketch.py | 10 +- ptodsl/ptodsl/_control_flow.py | 7 +- ptodsl/ptodsl/_diagnostics.py | 15 +- ptodsl/ptodsl/_kernel_signature.py | 4 +- ptodsl/ptodsl/_surface_types.py | 13 +- ptodsl/ptodsl/pto.py | 6 +- .../tests/support/docs_fragment_fixtures.py | 196 +++++++++--------- ptodsl/tests/test_ast_rewrite_example_ir.py | 10 +- ptodsl/tests/test_jit_compile.py | 51 +++-- ptodsl/tests/test_jit_diagnostics.py | 10 +- ptodsl/tests/test_ptoas_frontend_verify.py | 2 +- ptodsl/tests/test_subkernel_diagnostics.py | 6 +- 22 files changed, 218 insertions(+), 210 deletions(-) diff --git a/docs/designs/ptodsl-ast-preprocess-control-flow.md b/docs/designs/ptodsl-ast-preprocess-control-flow.md index 8c2b94c03..c67a1fd4d 100644 --- a/docs/designs/ptodsl-ast-preprocess-control-flow.md +++ b/docs/designs/ptodsl-ast-preprocess-control-flow.md @@ -50,7 +50,7 @@ Compile-time control flow stays explicit: ```python @pto.jit(target="a5") -def kernel(*, BLOCK: pto.constexpr = 128): +def kernel(*, BLOCK: pto.const_expr = 128): if pto.const_expr(BLOCK == 128): for stage in pto.static_range(4): ... diff --git a/ptodsl/README.md b/ptodsl/README.md index b2f9ec16b..c27b010ed 100644 --- a/ptodsl/README.md +++ b/ptodsl/README.md @@ -262,7 +262,7 @@ def Softmax( rows: pto.i32, cols: pto.i32, *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): x_view = pto.make_tensor_view(X_ptr, shape=[rows, cols], strides=[cols, 1]) o_view = pto.make_tensor_view(O_ptr, shape=[rows, cols], strides=[cols, 1]) @@ -288,7 +288,7 @@ PTODSL v1 keeps the public `@pto.jit` entry ABI intentionally narrow: - positional runtime scalars use PTO scalar annotations such as `pto.i32`, `pto.f32`, and `pto.i1`, while launch-time values remain ordinary Python scalars -- keyword-only parameters annotated with `pto.constexpr` are compile-time +- keyword-only parameters annotated with `pto.const_expr` are compile-time specialization knobs The host wrapper is responsible for extracting or deriving whatever runtime diff --git a/ptodsl/docs/user_guide/01-introduction.md b/ptodsl/docs/user_guide/01-introduction.md index 773cb544a..7a6e6c3ee 100644 --- a/ptodsl/docs/user_guide/01-introduction.md +++ b/ptodsl/docs/user_guide/01-introduction.md @@ -101,7 +101,7 @@ Decorating a function with `@pto.jit` marks it as a launchable PTO kernel. This - **Caching**: compiled kernels are cached by specialization key (function identity + entry annotation signature + constexpr parameter values), so repeated calls with the same configuration skip recompilation. - **Launch binding**: the compiled kernel can be invoked with a grid and stream — `compiled[grid, stream](args...)` — which launches the executable on the NPU with the given SPMD grid. -The public `@pto.jit` entry contract is pointer-first. Device buffers are explicit GM pointers (`pto.ptr(..., "gm")`), launch-varying shape/stride metadata travels as runtime scalars, and the kernel body materializes `TensorView` descriptors with `make_tensor_view(ptr, shape=..., strides=...)`. Compile-time constants remain keyword-only `pto.constexpr` parameters: +The public `@pto.jit` entry contract is pointer-first. Device buffers are explicit GM pointers (`pto.ptr(..., "gm")`), launch-varying shape/stride metadata travels as runtime scalars, and the kernel body materializes `TensorView` descriptors with `make_tensor_view(ptr, shape=..., strides=...)`. Compile-time constants remain keyword-only `pto.const_expr` parameters: ```python @@ -120,9 +120,9 @@ def flash_attention_kernel( heads: pto.i32, dim: pto.i32, *, - BLOCK_Q: pto.constexpr = 128, - BLOCK_KV: pto.constexpr = 128, - CAUSAL: pto.constexpr = False, + BLOCK_Q: pto.const_expr = 128, + BLOCK_KV: pto.const_expr = 128, + CAUSAL: pto.const_expr = False, ): q_view = pto.make_tensor_view( Q_ptr, diff --git a/ptodsl/docs/user_guide/02-quick-start.md b/ptodsl/docs/user_guide/02-quick-start.md index 4ffce0ea2..67afbd6bc 100644 --- a/ptodsl/docs/user_guide/02-quick-start.md +++ b/ptodsl/docs/user_guide/02-quick-start.md @@ -16,7 +16,7 @@ def tile_copy( rows: pto.i32, cols: pto.i32, *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): """Copy one 2D tensor tile from A to O.""" @@ -43,10 +43,10 @@ Let us step through each piece. ```python @pto.jit(target="a5") -def tile_copy(A, O, *, BLOCK: pto.constexpr = 128): +def tile_copy(A, O, *, BLOCK: pto.const_expr = 128): ``` -`@pto.jit` marks this function as a launchable PTO kernel. The positional parameters `A_ptr` and `O_ptr` are explicit GM pointers, while `rows` and `cols` are runtime scalar metadata passed at launch time. The keyword-only argument `BLOCK` is a compile-time constant declared with `pto.constexpr`; the compiler specializes the kernel for each tile width. +`@pto.jit` marks this function as a launchable PTO kernel. The positional parameters `A_ptr` and `O_ptr` are explicit GM pointers, while `rows` and `cols` are runtime scalar metadata passed at launch time. The keyword-only argument `BLOCK` is a compile-time constant declared with `pto.const_expr`; the compiler specializes the kernel for each tile width. ### Describing GM tensors @@ -114,7 +114,7 @@ def blocked_copy( rows: pto.i32, cols: pto.i32, *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): a_view = pto.make_tensor_view(A_ptr, shape=[rows, cols], strides=[cols, 1]) o_view = pto.make_tensor_view(O_ptr, shape=[rows, cols], strides=[cols, 1]) @@ -213,7 +213,7 @@ def vec_add_micro( O_ptr: pto.ptr(pto.f32, "gm"), N: pto.i32, *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): a_view = pto.make_tensor_view(A_ptr, shape=[N], strides=[1]) b_view = pto.make_tensor_view(B_ptr, shape=[N], strides=[1]) diff --git a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md index 3745ee94c..67b3d51c8 100644 --- a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md +++ b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md @@ -36,8 +36,8 @@ def kernel_name( rows: pto.i32, # runtime metadata (positional) cols: pto.i32, # runtime metadata (positional) *, - CONST_A: pto.constexpr = 128, # compile-time constant (keyword-only) - CONST_B: pto.constexpr = 64, # compile-time constant (keyword-only) + CONST_A: pto.const_expr = 128, # compile-time constant (keyword-only) + CONST_B: pto.const_expr = 64, # compile-time constant (keyword-only) ): x_view = pto.make_tensor_view(x_ptr, shape=[rows, cols], strides=[cols, 1]) y_view = pto.make_tensor_view(y_ptr, shape=[rows, cols], strides=[cols, 1]) @@ -54,7 +54,7 @@ position in the signature, and way to supply the value: |---|---|---|---| | **Device buffer** | positional (before `*`) | `pto.ptr(dtype, "gm")` | launch time | | **Runtime scalar** | positional (before `*`) | `pto.i32`, `pto.f32`, `pto.i1`, etc. | launch time | -| **Compile-time constant** | keyword-only (after `*`) | `pto.constexpr = ` | compile time | +| **Compile-time constant** | keyword-only (after `*`) | `pto.const_expr = ` | compile time | #### 1. Device-buffer parameters @@ -95,7 +95,7 @@ def my_kernel( #### 3. Compile-time constants -Declare after `*` with `pto.constexpr` and a default value. +Declare after `*` with `pto.const_expr` and a default value. Pass the value to `.compile(...)` — **not** at launch time: ```python @@ -103,7 +103,7 @@ Pass the value to `.compile(...)` — **not** at launch time: def my_kernel( X_ptr: pto.ptr(pto.f32, "gm"), *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): # BLOCK is a Python value at trace time — use it for tile shapes, # unrolled loops, or dtype arguments: @@ -128,7 +128,7 @@ def scaled_bias_add( alpha: pto.f32, # runtime scalar bias: pto.f32, # runtime scalar *, - BLOCK: pto.constexpr = 128, # compile-time constant + BLOCK: pto.const_expr = 128, # compile-time constant ): x_view = pto.make_tensor_view(X_ptr, shape=[rows, cols], strides=[cols, 1]) o_view = pto.make_tensor_view(O_ptr, shape=[rows, cols], strides=[cols, 1]) @@ -225,7 +225,7 @@ def my_kernel( rows: pto.i32, cols: pto.i32, *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): a_view = pto.make_tensor_view(A_ptr, shape=[rows, cols], strides=[cols, 1]) b_view = pto.make_tensor_view(B_ptr, shape=[rows, cols], strides=[cols, 1]) @@ -283,7 +283,7 @@ def my_kernel( rows: pto.i32, cols: pto.i32, *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): a_view = pto.make_tensor_view(A_ptr, shape=[rows, cols], strides=[cols, 1]) b_view = pto.make_tensor_view(B_ptr, shape=[rows, cols], strides=[cols, 1]) @@ -655,9 +655,9 @@ pointers: | `@pto.simd` → caller | Only via `vsts`/`psts` to UB tiles; `vreg` cannot escape | | Cube-local → UB | Only via `mte_l0c_ub`; LEFT/RIGHT/ACC/BIAS are private | -## 3.6 `pto.constexpr` +## 3.6 `pto.const_expr` -`pto.constexpr` marks a `@pto.jit` keyword-only parameter as a compile-time +`pto.const_expr` marks a `@pto.jit` keyword-only parameter as a compile-time constant. The compiler specializes the kernel for each combination of constexpr values, and the compiled artifact is cached by specialization key together with the kernel's entry annotation contract. @@ -668,8 +668,8 @@ the kernel's entry annotation contract. def kernel( A_ptr: pto.ptr(pto.f32, "gm"), *, - BLOCK: pto.constexpr = 128, - DTYPE: pto.constexpr = pto.f32, + BLOCK: pto.const_expr = 128, + DTYPE: pto.const_expr = pto.f32, ): # ... use BLOCK / DTYPE in tile shapes, loop bounds, or dtype-specialized paths ... return @@ -682,7 +682,7 @@ def kernel( - Cannot change between launches of the same compiled instance — compile a new variant for a different value. -`pto.constexpr` parameters can be used anywhere in the kernel body where a +`pto.const_expr` parameters can be used anywhere in the kernel body where a Python value is expected: tile shapes, loop bounds that are known at compile time, dtype arguments, etc. They are evaluated at trace time, so `for i in range(BLOCK)` would unroll `BLOCK` times. diff --git a/ptodsl/docs/user_guide/04-type-system-and-buffer.md b/ptodsl/docs/user_guide/04-type-system-and-buffer.md index dc05e6ef9..7d8cecf51 100644 --- a/ptodsl/docs/user_guide/04-type-system-and-buffer.md +++ b/ptodsl/docs/user_guide/04-type-system-and-buffer.md @@ -183,7 +183,7 @@ def kernel( rows: pto.i32, cols: pto.i32, *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): tv = pto.make_tensor_view(A_ptr, shape=[rows, cols], strides=[cols, 1]) return diff --git a/ptodsl/docs/user_guide/05-control-flow.md b/ptodsl/docs/user_guide/05-control-flow.md index bacef136c..bb88d87e8 100644 --- a/ptodsl/docs/user_guide/05-control-flow.md +++ b/ptodsl/docs/user_guide/05-control-flow.md @@ -18,7 +18,7 @@ sequence with the body repeated: ```python @pto.jit(target="a5") -def unrolled_kernel(A, O, *, N: pto.constexpr): +def unrolled_kernel(A, O, *, N: pto.const_expr): a_view = pto.make_tensor_view(A, shape=[N], strides=A.strides) o_view = pto.make_tensor_view(O, shape=[N], strides=O.strides) @@ -85,7 +85,7 @@ When a loop needs to propagate state from one iteration to the next, use the `.c ```python @pto.jit(target="a5") -def carry_loop_probe(*, BLOCK: pto.constexpr = 128): +def carry_loop_probe(*, BLOCK: pto.const_expr = 128): m_prev = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) l_prev = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) o_prev = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) @@ -212,18 +212,18 @@ conditional closes, `br.val` is the SSA-merged result seen by downstream code. This surface avoids explicit result-type declarations and explicit `pto.yield_(...)` in user code while still keeping the merge contract explicit. -## 5.4 `pto.constexpr` and tracing +## 5.4 `pto.const_expr` and tracing -`pto.constexpr` parameters (Section 3.8) are compile-time constants. They are fixed at `.compile()` time and cannot change between launches of the same compiled kernel. Because their values are known during tracing, they interact naturally with Python control flow: +`pto.const_expr` parameters (Section 3.8) are compile-time constants. They are fixed at `.compile()` time and cannot change between launches of the same compiled kernel. Because their values are known during tracing, they interact naturally with Python control flow: ```python @pto.jit(target="a5") def kernel( A, *, - BLOCK: pto.constexpr = 128, - NUM_BLOCKS: pto.constexpr = 8, - UNROLL: pto.constexpr = False, + BLOCK: pto.const_expr = 128, + NUM_BLOCKS: pto.const_expr = 8, + UNROLL: pto.const_expr = False, ): N = A.shape[0] num_blocks = (N + BLOCK - 1) // BLOCK @@ -370,7 +370,7 @@ Use `pto.const_expr(...)` for trace-time branches: ```python @pto.jit(target="a5") -def ast_rewrite_static_branch_kernel(*, ENABLE: pto.constexpr = True): +def ast_rewrite_static_branch_kernel(*, ENABLE: pto.const_expr = True): if pto.const_expr(ENABLE): pto.pipe_barrier(pto.Pipe.ALL) ``` @@ -460,7 +460,7 @@ It is not the recommended user-facing mode for new examples or kernels. ```python @pto.jit(target="a5", ast_rewrite=False) -def debug_kernel(*, BLOCK: pto.constexpr = 4): +def debug_kernel(*, BLOCK: pto.const_expr = 4): for _ in range(BLOCK): pto.pipe_barrier(pto.Pipe.ALL) @@ -480,7 +480,7 @@ The same switch is available through `frontend_options`: "ast_rewrite": False, }, ) -def debug_options_disable_rewrite_kernel(*, BLOCK: pto.constexpr = 4): +def debug_options_disable_rewrite_kernel(*, BLOCK: pto.const_expr = 4): for _ in range(BLOCK): pto.pipe_barrier(pto.Pipe.ALL) ``` diff --git a/ptodsl/docs/user_guide/07-data-movement-ops.md b/ptodsl/docs/user_guide/07-data-movement-ops.md index eeb0dae59..7bcdcae97 100644 --- a/ptodsl/docs/user_guide/07-data-movement-ops.md +++ b/ptodsl/docs/user_guide/07-data-movement-ops.md @@ -1342,7 +1342,7 @@ def cube_producer( gm_slot_buffer: pto.gm_ptr(pto.f32), src: pto.gm_ptr(pto.f32), *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): gm_view = pto.make_tensor_view(gm_slot_buffer, shape=[16, 16], strides=[16, 1]) c2v = pto.pipe.c2v( @@ -1371,7 +1371,7 @@ def vector_consumer( gm_slot_buffer: pto.gm_ptr(pto.f32), dst: pto.gm_ptr(pto.f32), *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): gm_view = pto.make_tensor_view(gm_slot_buffer, shape=[16, 16], strides=[16, 1]) c2v = pto.pipe.c2v( diff --git a/ptodsl/docs/user_guide/11-flash-attention-walkthrough.md b/ptodsl/docs/user_guide/11-flash-attention-walkthrough.md index ab5d5bd41..2aaec14f0 100644 --- a/ptodsl/docs/user_guide/11-flash-attention-walkthrough.md +++ b/ptodsl/docs/user_guide/11-flash-attention-walkthrough.md @@ -75,10 +75,10 @@ def flash_attention_kernel( heads: pto.i32, dim: pto.i32, *, - BLOCK_Q: pto.constexpr = 128, - BLOCK_KV: pto.constexpr = 128, - CAUSAL: pto.constexpr = False, - NUM_STAGES: pto.constexpr = 2, + BLOCK_Q: pto.const_expr = 128, + BLOCK_KV: pto.const_expr = 128, + CAUSAL: pto.const_expr = False, + NUM_STAGES: pto.const_expr = 2, ): # Walkthrough body omitted in this signature overview. return diff --git a/ptodsl/docs/user_guide/12-additional-examples.md b/ptodsl/docs/user_guide/12-additional-examples.md index 51be697ce..217019b34 100644 --- a/ptodsl/docs/user_guide/12-additional-examples.md +++ b/ptodsl/docs/user_guide/12-additional-examples.md @@ -16,8 +16,8 @@ def mat_add( M: pto.i32, N_: pto.i32, *, - BLOCK_M: pto.constexpr = 64, - BLOCK_N: pto.constexpr = 128, + BLOCK_M: pto.const_expr = 64, + BLOCK_N: pto.const_expr = 128, ): a_view = pto.make_tensor_view(A_ptr, shape=[batch, M, N_], strides=[M * N_, N_, 1]) b_view = pto.make_tensor_view(B_ptr, shape=[batch, M, N_], strides=[M * N_, N_, 1]) @@ -114,7 +114,7 @@ def vec_add_with_tail( O_ptr: pto.ptr(pto.f32, "gm"), N: pto.i32, *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): a_view = pto.make_tensor_view(A_ptr, shape=[N], strides=[1]) b_view = pto.make_tensor_view(B_ptr, shape=[N], strides=[1]) @@ -192,9 +192,9 @@ def gemm( K_: pto.i32, N_: pto.i32, *, - BLOCK_M: pto.constexpr = 64, - BLOCK_K: pto.constexpr = 64, - BLOCK_N: pto.constexpr = 64, + BLOCK_M: pto.const_expr = 64, + BLOCK_K: pto.const_expr = 64, + BLOCK_N: pto.const_expr = 64, ): a_view = pto.make_tensor_view(A_ptr, shape=[M, K_], strides=[K_, 1]) b_view = pto.make_tensor_view(B_ptr, shape=[K_, N_], strides=[N_, 1]) @@ -296,7 +296,7 @@ def online_layernorm( O_ptr: pto.ptr(pto.f32, "gm"), N: pto.i32, *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): x_view = pto.make_tensor_view(X_ptr, shape=[N], strides=[1]) o_view = pto.make_tensor_view(O_ptr, shape=[N], strides=[1]) diff --git a/ptodsl/examples/flash_attention_sketch.py b/ptodsl/examples/flash_attention_sketch.py index 18d7a809c..04d037a97 100644 --- a/ptodsl/examples/flash_attention_sketch.py +++ b/ptodsl/examples/flash_attention_sketch.py @@ -147,11 +147,11 @@ def flash_attention_kernel( heads: pto.i32, dim: pto.i32, *, - BLOCK_Q: pto.constexpr = 128, - BLOCK_KV: pto.constexpr = 128, - HEAD_DIM: pto.constexpr = 128, - CAUSAL: pto.constexpr = False, - NUM_STAGES: pto.constexpr = 2, + BLOCK_Q: pto.const_expr = 128, + BLOCK_KV: pto.const_expr = 128, + HEAD_DIM: pto.const_expr = 128, + CAUSAL: pto.const_expr = False, + NUM_STAGES: pto.const_expr = 2, ): """ Launchable device entry. diff --git a/ptodsl/ptodsl/_control_flow.py b/ptodsl/ptodsl/_control_flow.py index 41d149d01..8ff6acef1 100644 --- a/ptodsl/ptodsl/_control_flow.py +++ b/ptodsl/ptodsl/_control_flow.py @@ -23,6 +23,7 @@ from ._bootstrap import make_context # noqa: F401 from ._runtime_index_ops import coerce_runtime_index +from ._surface_types import const_expr from ._tracing.active import current_session from ._surface_values import unwrap_surface_value, wrap_like_surface_value, wrap_surface_value @@ -57,12 +58,6 @@ def static_range(*args): """Return ``range(*args)`` for trace-time unrolling under AST rewrite.""" return range(*args) - -def const_expr(value): - """Return Python truthiness for trace-time branches under AST rewrite.""" - return bool(value) - - # ── for_ ────────────────────────────────────────────────────────────────────── class LoopHandle: diff --git a/ptodsl/ptodsl/_diagnostics.py b/ptodsl/ptodsl/_diagnostics.py index 219cd4e88..98217621d 100644 --- a/ptodsl/ptodsl/_diagnostics.py +++ b/ptodsl/ptodsl/_diagnostics.py @@ -34,7 +34,7 @@ def native_python_control_flow_error(usage: str) -> PTODSLTracingMisuseError: f"native Python {usage} cannot consume a PTODSL runtime value during tracing. " "This value is a device-side SSA/runtime-metadata value, not a Python bool/int. " "Use pto.if_(...) or pto.for_(...) for device-side control flow, or keep the " - "bound/condition in pto.constexpr." + "bound/condition in pto.const_expr." ) @@ -52,7 +52,7 @@ def jit_missing_annotation_error(name: str) -> TypeError: f"@pto.jit positional parameter '{name}' does not declare an entry ABI annotation. " 'Use an explicit GM pointer such as pto.ptr(pto.f32, "gm") for device buffers, ' "a PTO scalar type such as pto.i32/pto.f32/pto.i1 for runtime scalars, " - "or move compile-time values to keyword-only pto.constexpr parameters." + "or move compile-time values to keyword-only pto.const_expr parameters." ) @@ -62,7 +62,7 @@ def jit_illegal_formal_annotation_error(name: str, annotation: object) -> TypeEr f"@pto.jit positional parameter '{name}' uses unsupported entry annotation {annotation!r}. " 'The public @pto.jit entry ABI accepts explicit GM pointers such as pto.ptr(pto.f32, "gm"), ' "PTO scalar annotations such as pto.i32/pto.f32/pto.i1 for runtime scalars, " - "and keyword-only pto.constexpr compile-time parameters. " + "and keyword-only pto.const_expr compile-time parameters. " "Legacy host tensor annotations such as pto.tensor_spec(...), and low-level PTODSL " "types such as Tile, PartitionTensorView, VReg, or non-entry pointer forms do not " "belong at the host/kernel entry." @@ -91,16 +91,16 @@ def jit_non_gm_ptr_entry_error(name: str, annotation: object) -> TypeError: def jit_keyword_only_non_constexpr_error(name: str, annotation: object) -> TypeError: - """Return one diagnostic for keyword-only params that are not ``pto.constexpr``.""" + """Return one diagnostic for keyword-only params that are not ``pto.const_expr``.""" return TypeError( f"@pto.jit keyword-only parameter '{name}' uses unsupported compile-time annotation {annotation!r}. " - "Compile-time @pto.jit parameters must remain keyword-only pto.constexpr values in this change; " + "Compile-time @pto.jit parameters must remain keyword-only pto.const_expr values in this change; " "move runtime data to positional pointer/scalar parameters instead." ) def jit_constexpr_missing_default_error(name: str) -> TypeError: - """Return one diagnostic for ``pto.constexpr`` params missing a default value.""" + """Return one diagnostic for ``pto.const_expr`` params missing a default value.""" return TypeError( f"@pto.jit constexpr parameter '{name}' must declare a default value until explicit " "compile-time specialization is implemented. Keep this parameter keyword-only and " @@ -239,6 +239,9 @@ def unsupported_public_surface_error(name: str) -> AttributeError: "vsts_1pt": ( 'Use pto.vsts(vec, ptr, offset, mask, dist="1PT_B32") instead of the removed pto.vsts_1pt(...) helper.' ), + "constexpr": ( + "Use pto.const_expr for compile-time @pto.jit parameters and trace-time control-flow guards." + ), } suffix = hints.get(name, "Use the documented PTODSL public surface instead.") return AttributeError( diff --git a/ptodsl/ptodsl/_kernel_signature.py b/ptodsl/ptodsl/_kernel_signature.py index e53ed4053..c57c3cf94 100644 --- a/ptodsl/ptodsl/_kernel_signature.py +++ b/ptodsl/ptodsl/_kernel_signature.py @@ -22,7 +22,7 @@ ) from ._host_tensors import TensorSpec from ._surface_values import wrap_surface_value -from ._surface_types import constexpr as _constexpr_marker +from ._surface_types import const_expr as _const_expr_marker from ._types import _DType, _MaskDescriptor, _PtrDescriptor, _VRegDescriptor, _resolve @@ -206,7 +206,7 @@ def parse_jit_kernel_signature(py_fn) -> KernelSignature: continue if param.kind is inspect.Parameter.KEYWORD_ONLY: - if param.annotation is not _constexpr_marker: + if param.annotation is not _const_expr_marker: raise jit_keyword_only_non_constexpr_error(param.name, param.annotation) if param.default is inspect.Parameter.empty: raise jit_constexpr_missing_default_error(param.name) diff --git a/ptodsl/ptodsl/_surface_types.py b/ptodsl/ptodsl/_surface_types.py index 0a3d574a1..dba34d9fb 100644 --- a/ptodsl/ptodsl/_surface_types.py +++ b/ptodsl/ptodsl/_surface_types.py @@ -15,14 +15,17 @@ from mlir.dialects import pto as _pto -class _ConstexprMarker: - """Marker annotation for PTODSL compile-time specialization parameters.""" +class _ConstExprHelper: + """Callable marker for PTODSL compile-time specialization parameters and branches.""" def __repr__(self): - return "pto.constexpr" + return "pto.const_expr" + def __call__(self, value): + return bool(value) -constexpr = _ConstexprMarker() + +const_expr = _ConstExprHelper() class MemorySpace: @@ -284,7 +287,7 @@ class Tile: __all__ = [ - "constexpr", + "const_expr", "TensorSpec", "MemorySpace", "BarrierType", diff --git a/ptodsl/ptodsl/pto.py b/ptodsl/ptodsl/pto.py index c02d7bb4f..66e693a90 100644 --- a/ptodsl/ptodsl/pto.py +++ b/ptodsl/ptodsl/pto.py @@ -34,7 +34,7 @@ _resolve, ) from ._surface_types import ( # noqa: F401 - constexpr, + const_expr, tensor_spec, TensorSpec, BarrierType, @@ -122,7 +122,7 @@ # ── Control flow ────────────────────────────────────────────────────────────── from ._control_flow import ( # noqa: F401 for_, if_, yield_, - const_expr, static_range, + static_range, LoopHandle, BranchHandle, ) @@ -149,6 +149,6 @@ def gm_ptr(elem): def __getattr__(name): - if name in {"ukernel", "tile_buf_type", "vecscope", "as_ptr", "vbrc_load", "vsts_1pt"}: + if name in {"ukernel", "tile_buf_type", "vecscope", "as_ptr", "vbrc_load", "vsts_1pt", "constexpr"}: raise unsupported_public_surface_error(name) raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/ptodsl/tests/support/docs_fragment_fixtures.py b/ptodsl/tests/support/docs_fragment_fixtures.py index a2087adfe..e4a7edd07 100644 --- a/ptodsl/tests/support/docs_fragment_fixtures.py +++ b/ptodsl/tests/support/docs_fragment_fixtures.py @@ -66,7 +66,7 @@ def type_system_scalar_expr_probe(): @pto.jit(target="a5") def type_system_low_precision_types_probe( *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): {SNIPPET_PLACEHOLDER} """ @@ -77,7 +77,7 @@ def type_system_low_precision_types_probe( def type_system_tensor_view_probe( A: pto.tensor_spec(rank=2, dtype=pto.f32), *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): rows = A.shape[0] cols = A.shape[1] @@ -93,7 +93,7 @@ def type_system_partition_view_probe( rows: pto.i32, cols: pto.i32, *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): dim = cols row_offset = 0 @@ -106,10 +106,10 @@ def type_system_partition_view_probe( @pto.jit(target="a5") def type_system_tile_alloc_probe( *, - BLOCK: pto.constexpr = 128, - Br: pto.constexpr = 16, - Bc: pto.constexpr = 16, - dim: pto.constexpr = 16, + BLOCK: pto.const_expr = 128, + Br: pto.const_expr = 16, + Bc: pto.const_expr = 16, + dim: pto.const_expr = 16, ): {SNIPPET_PLACEHOLDER} """ @@ -119,9 +119,9 @@ def type_system_tile_alloc_probe( @pto.jit(target="a5") def type_system_tile_methods_probe( *, - Br: pto.constexpr = 16, - Bc: pto.constexpr = 16, - dim: pto.constexpr = 16, + Br: pto.const_expr = 16, + Bc: pto.const_expr = 16, + dim: pto.const_expr = 16, ): m_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, blayout="ColMajor") l_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, blayout="ColMajor") @@ -137,8 +137,8 @@ def type_system_tile_methods_probe( @pto.jit(target="a5") def type_system_tile_reshape_probe( *, - BR: pto.constexpr = 8, - BC: pto.constexpr = 64, + BR: pto.const_expr = 8, + BC: pto.const_expr = 64, ): {SNIPPET_PLACEHOLDER} """ @@ -148,7 +148,7 @@ def type_system_tile_reshape_probe( @pto.jit(target="a5") def type_system_vreg_bitcast_probe( *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): tile = pto.alloc_tile(shape=[2, BLOCK], dtype=pto.f32) row = 0 @@ -161,7 +161,7 @@ def type_system_vreg_bitcast_probe( @pto.jit(target="a5") def type_system_vreg_bitcast_ptr_probe( *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): tile = pto.alloc_tile(shape=[2, BLOCK], dtype=pto.f32) ptr = tile.as_ptr() @@ -200,7 +200,7 @@ def quick_start_make_tensor_view_probe( @pto.jit(target="a5") def quick_start_alloc_tile_probe( *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): {SNIPPET_PLACEHOLDER} """ @@ -226,7 +226,7 @@ def quick_start_tile_io_probe( rows: pto.i32, cols: pto.i32, *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): a_view = pto.make_tensor_view(A_ptr, shape=[rows, cols], strides=[cols, 1]) o_view = pto.make_tensor_view(O_ptr, shape=[rows, cols], strides=[cols, 1]) @@ -266,9 +266,9 @@ def flash_attention_kernel( heads: pto.i32, dim: pto.i32, *, - BLOCK_Q: pto.constexpr = 128, - BLOCK_KV: pto.constexpr = 128, - CAUSAL: pto.constexpr = False, + BLOCK_Q: pto.const_expr = 128, + BLOCK_KV: pto.const_expr = 128, + CAUSAL: pto.const_expr = False, ): pto.get_block_idx() @@ -309,7 +309,7 @@ def blocked_copy( rows: pto.i32, cols: pto.i32, *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): a_view = pto.make_tensor_view(A_ptr, shape=[rows, cols], strides=[cols, 1]) o_view = pto.make_tensor_view(O_ptr, shape=[rows, cols], strides=[cols, 1]) @@ -344,8 +344,8 @@ def kernel_name( rows: pto.i32, cols: pto.i32, *, - CONST_A: pto.constexpr = 128, - CONST_B: pto.constexpr = 64, + CONST_A: pto.const_expr = 128, + CONST_B: pto.const_expr = 64, ): pto.get_block_idx() @@ -378,8 +378,8 @@ def mat_add( rows: pto.i32, cols: pto.i32, *, - BLOCK_M: pto.constexpr = 64, - BLOCK_N: pto.constexpr = 128, + BLOCK_M: pto.const_expr = 64, + BLOCK_N: pto.const_expr = 128, ): pto.get_block_idx() @@ -413,9 +413,9 @@ def gemm( reduce_dim: pto.i32, cols: pto.i32, *, - BLOCK_M: pto.constexpr = 64, - BLOCK_K: pto.constexpr = 64, - BLOCK_N: pto.constexpr = 64, + BLOCK_M: pto.const_expr = 64, + BLOCK_K: pto.const_expr = 64, + BLOCK_N: pto.const_expr = 64, ): pto.get_block_idx() @@ -444,7 +444,7 @@ def control_flow_basic_for_probe( rows: pto.i32, cols: pto.i32, *, - BLOCK: pto.constexpr = 8, + BLOCK: pto.const_expr = 8, ): start = pto.const(0, dtype=pto.i32) stop = pto.const(BLOCK, dtype=pto.i32) @@ -464,7 +464,7 @@ def control_flow_compare_loops_probe( rows: pto.i32, cols: pto.i32, *, - BLOCK: pto.constexpr = 8, + BLOCK: pto.const_expr = 8, ): num_blocks = rows a_view = pto.make_tensor_view(A_ptr, shape=[rows, cols], strides=[cols, 1]) @@ -480,7 +480,7 @@ def control_flow_nested_loops_probe( rows: pto.i32, cols: pto.i32, *, - BLOCK: pto.constexpr = 8, + BLOCK: pto.const_expr = 8, ): tile = pto.alloc_tile(shape=[2, BLOCK], dtype=pto.f32, valid_shape=[rows, cols]) {SNIPPET_PLACEHOLDER} @@ -491,8 +491,8 @@ def control_flow_nested_loops_probe( @pto.jit(target="a5") def control_flow_carry_pingpong_probe( *, - Br: pto.constexpr = 16, - num_blocks: pto.constexpr = 4, + Br: pto.const_expr = 16, + num_blocks: pto.const_expr = 4, ): {SNIPPET_PLACEHOLDER} """ @@ -513,7 +513,7 @@ def scalar_ops_tile_access_probe(): "tail.chunked_inner_loop": _fixture( f""" @pto.jit(target="a5") - def tail_chunked_inner_loop_probe(*, BLOCK: pto.constexpr = 128): + def tail_chunked_inner_loop_probe(*, BLOCK: pto.const_expr = 128): cols = pto.const(BLOCK, dtype=pto.i32) tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) out_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) @@ -524,7 +524,7 @@ def tail_chunked_inner_loop_probe(*, BLOCK: pto.constexpr = 128): "tail.vector_pattern": _fixture( f""" @pto.jit(target="a5") - def tail_vector_pattern_probe(*, BLOCK: pto.constexpr = 128): + def tail_vector_pattern_probe(*, BLOCK: pto.const_expr = 128): rows = pto.const(1, dtype=pto.i32) cols = pto.const(BLOCK, dtype=pto.i32) tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) @@ -538,7 +538,7 @@ def tail_vector_pattern_probe(*, BLOCK: pto.constexpr = 128): @pto.jit(target="a5") - def tail_simd_helper_probe(*, BLOCK: pto.constexpr = 128): + def tail_simd_helper_probe(*, BLOCK: pto.const_expr = 128): a_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) b_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) o_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) @@ -570,7 +570,7 @@ def kernel_entry_explicit_signature_probe( rows: pto.i32, cols: pto.i32, *, - BLOCK: pto.constexpr = 16, + BLOCK: pto.const_expr = 16, ): view = pto.make_tensor_view(A_ptr, shape=[rows, cols], strides=[cols, 1]) part = pto.partition_view(view, offsets=[0, 0], sizes=[1, BLOCK]) @@ -600,8 +600,8 @@ def kernel_entry_explicit_body_probe( V_ptr: pto.ptr(pto.f16, "gm"), O_ptr: pto.ptr(pto.f32, "gm"), *, - ROWS: pto.constexpr = 8, - COLS: pto.constexpr = 16, + ROWS: pto.const_expr = 8, + COLS: pto.const_expr = 16, ): k_view = pto.make_tensor_view(K_ptr, shape=[ROWS, COLS], strides=[COLS, 1]) v_view = pto.make_tensor_view(V_ptr, shape=[ROWS, COLS], strides=[COLS, 1]) @@ -624,7 +624,7 @@ def kernel_entry_inline_explicit_scope_probe( A: pto.tensor_spec(rank=2, dtype=pto.f32), O: pto.tensor_spec(rank=2, dtype=pto.f32), *, - BLOCK: pto.constexpr = 16, + BLOCK: pto.const_expr = 16, ): a_view = pto.make_tensor_view(A, shape=A.shape, strides=A.strides) o_view = pto.make_tensor_view(O, shape=O.shape, strides=O.strides) @@ -647,9 +647,9 @@ def kernel_entry_inline_explicit_scope_probe( @pto.jit(target="a5", mode="explicit") def kernel_entry_cube_signature_probe( *, - BLOCK_M: pto.constexpr = 16, - BLOCK_K: pto.constexpr = 16, - BLOCK_N: pto.constexpr = 16, + BLOCK_M: pto.const_expr = 16, + BLOCK_K: pto.const_expr = 16, + BLOCK_N: pto.const_expr = 16, ): input_tile = pto.alloc_tile(shape=[BLOCK_M, BLOCK_K], dtype=pto.f16, valid_shape=[BLOCK_M, BLOCK_K]) output_tile = pto.alloc_tile(shape=[BLOCK_M, BLOCK_N], dtype=pto.f32, valid_shape=[BLOCK_M, BLOCK_N]) @@ -665,7 +665,7 @@ def kernel_entry_cube_signature_probe( @pto.jit(target="a5") - def kernel_entry_simd_signature_probe(*, BLOCK: pto.constexpr = 128): + def kernel_entry_simd_signature_probe(*, BLOCK: pto.const_expr = 128): input_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) output_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) my_simd_kernel(input_tile, output_tile, pto.const(1, dtype=pto.i32), pto.const(BLOCK, dtype=pto.i32)) @@ -677,7 +677,7 @@ def kernel_entry_simd_signature_probe(*, BLOCK: pto.constexpr = 128): @pto.jit(target="a5") - def kernel_entry_simd_body_probe(*, BLOCK: pto.constexpr = 128): + def kernel_entry_simd_body_probe(*, BLOCK: pto.const_expr = 128): a_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) b_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) o_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) @@ -696,7 +696,7 @@ def kernel_entry_simd_body_probe(*, BLOCK: pto.constexpr = 128): @pto.jit(target="a5") - def kernel_entry_simt_signature_probe(*, BLOCK: pto.constexpr = 8): + def kernel_entry_simt_signature_probe(*, BLOCK: pto.const_expr = 8): tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32, valid_shape=[1, BLOCK]) my_simt_kernel(tile, tile.as_ptr(), pto.const(0, dtype=pto.i32)) """ @@ -715,7 +715,7 @@ def kernel_entry_inline_simd_scope( @pto.jit(target="a5") - def kernel_entry_inline_simd_scope_probe(*, BLOCK: pto.constexpr = 128): + def kernel_entry_inline_simd_scope_probe(*, BLOCK: pto.const_expr = 128): a_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) b_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) o_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) @@ -739,7 +739,7 @@ def kernel_entry_inline_simt_scope( @pto.jit(target="a5") - def kernel_entry_inline_simt_scope_probe(*, BLOCK: pto.constexpr = 8): + def kernel_entry_inline_simt_scope_probe(*, BLOCK: pto.const_expr = 8): one = 1 o_prev_tile = pto.alloc_tile(shape=[8, BLOCK], dtype=pto.f32, valid_shape=[2, BLOCK]) pv_tile = pto.alloc_tile(shape=[8, BLOCK], dtype=pto.f32, valid_shape=[2, BLOCK]) @@ -754,9 +754,9 @@ def kernel_entry_inline_simt_scope_probe(*, BLOCK: pto.constexpr = 8): @pto.jit(target="a5", mode="explicit") def kernel_entry_inline_cube_scope_probe( *, - BLOCK_M: pto.constexpr = 16, - BLOCK_K: pto.constexpr = 16, - BLOCK_N: pto.constexpr = 16, + BLOCK_M: pto.const_expr = 16, + BLOCK_K: pto.const_expr = 16, + BLOCK_N: pto.const_expr = 16, ): q_tile = pto.alloc_tile(shape=[BLOCK_M, BLOCK_K], dtype=pto.f16, memory_space=pto.MemorySpace.MAT, valid_shape=[BLOCK_M, BLOCK_K]) k_tile = pto.alloc_tile(shape=[BLOCK_K, BLOCK_N], dtype=pto.f16, memory_space=pto.MemorySpace.MAT, valid_shape=[BLOCK_K, BLOCK_N]) @@ -789,7 +789,7 @@ def scalar_ops_helper_queries_probe(): "scalar_ops.chunk_loop": _fixture( f""" @pto.jit(target="a5") - def scalar_ops_chunk_loop_probe(*, BLOCK: pto.constexpr = 128): + def scalar_ops_chunk_loop_probe(*, BLOCK: pto.const_expr = 128): cols = pto.const(BLOCK, dtype=pto.i32) {SNIPPET_PLACEHOLDER} """ @@ -800,7 +800,7 @@ def scalar_ops_chunk_loop_probe(*, BLOCK: pto.constexpr = 128): @pto.jit(target="a5") - def scalar_ops_simt_scale_probe(*, BLOCK: pto.constexpr = 8): + def scalar_ops_simt_scale_probe(*, BLOCK: pto.const_expr = 8): src_tile = pto.alloc_tile(shape=[8, BLOCK], dtype=pto.f32, valid_shape=[2, BLOCK]) dst_tile = pto.alloc_tile(shape=[8, BLOCK], dtype=pto.f32, valid_shape=[2, BLOCK]) elementwise_scale( @@ -818,7 +818,7 @@ def scalar_ops_simt_scale_probe(*, BLOCK: pto.constexpr = 8): @pto.jit(target="a5") - def scalar_ops_simt_row_coeffs_probe(*, BLOCK: pto.constexpr = 8): + def scalar_ops_simt_row_coeffs_probe(*, BLOCK: pto.const_expr = 8): one = 1 o_prev_tile = pto.alloc_tile(shape=[8, BLOCK], dtype=pto.f32, valid_shape=[2, BLOCK]) pv_tile = pto.alloc_tile(shape=[8, BLOCK], dtype=pto.f32, valid_shape=[2, BLOCK]) @@ -865,7 +865,7 @@ def scalar_ops_pointer_sources_probe( rows: pto.i32, cols: pto.i32, *, - BLOCK: pto.constexpr = 8, + BLOCK: pto.const_expr = 8, ): a_view = pto.make_tensor_view(A_ptr, shape=[rows, cols], strides=[cols, 1]) partition = pto.partition_view(a_view, offsets=[0, 0], sizes=[rows, cols]) @@ -891,7 +891,7 @@ def data_movement_tload_probe( rows: pto.i32, cols: pto.i32, *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): offset = 0 a_view = pto.make_tensor_view(A_ptr, shape=[rows, cols], strides=[cols, 1]) @@ -919,8 +919,8 @@ def data_movement_explicit_dma_probe( V_ptr: pto.ptr(pto.f16, "gm"), O_ptr: pto.ptr(pto.f32, "gm"), *, - ROWS: pto.constexpr = 8, - COLS: pto.constexpr = 16, + ROWS: pto.const_expr = 8, + COLS: pto.const_expr = 16, ): k_view = pto.make_tensor_view(K_ptr, shape=[ROWS, COLS], strides=[COLS, 1]) v_view = pto.make_tensor_view(V_ptr, shape=[ROWS, COLS], strides=[COLS, 1]) @@ -955,8 +955,8 @@ def sync_ops_flag_pattern_explicit_probe( V_ptr: pto.ptr(pto.f16, "gm"), O_ptr: pto.ptr(pto.f32, "gm"), *, - ROWS: pto.constexpr = 8, - COLS: pto.constexpr = 16, + ROWS: pto.const_expr = 8, + COLS: pto.const_expr = 16, ): k_view = pto.make_tensor_view(K_ptr, shape=[ROWS, COLS], strides=[COLS, 1]) v_view = pto.make_tensor_view(V_ptr, shape=[ROWS, COLS], strides=[COLS, 1]) @@ -1013,8 +1013,8 @@ def sync_ops_phase_barrier_explicit_probe( K_ptr: pto.ptr(pto.f16, "gm"), V_ptr: pto.ptr(pto.f16, "gm"), *, - ROWS: pto.constexpr = 8, - COLS: pto.constexpr = 16, + ROWS: pto.const_expr = 8, + COLS: pto.const_expr = 16, ): k_view = pto.make_tensor_view(K_ptr, shape=[ROWS, COLS], strides=[COLS, 1]) v_view = pto.make_tensor_view(V_ptr, shape=[ROWS, COLS], strides=[COLS, 1]) @@ -1065,7 +1065,7 @@ def data_movement_grouped_dma_ptrs_probe(): @pto.jit(target="a5") def data_movement_tile_slice_2d_probe( *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): tile = pto.alloc_tile(shape=[2, BLOCK], dtype=pto.f32) col = 0 @@ -1078,7 +1078,7 @@ def data_movement_tile_slice_2d_probe( @pto.jit(target="a5") def data_movement_tile_slice_1d_probe( *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) start = pto.const(0, dtype=pto.i32) @@ -1102,9 +1102,9 @@ def qk_matmul( @pto.jit(target="a5", mode="explicit") def data_movement_cube_helper_probe( *, - BLOCK_M: pto.constexpr = 16, - BLOCK_K: pto.constexpr = 16, - BLOCK_N: pto.constexpr = 16, + BLOCK_M: pto.const_expr = 16, + BLOCK_K: pto.const_expr = 16, + BLOCK_N: pto.const_expr = 16, ): q_tile = pto.alloc_tile(shape=[BLOCK_M, BLOCK_K], dtype=pto.f16, memory_space=pto.MemorySpace.MAT, valid_shape=[BLOCK_M, BLOCK_K]) k_tile = pto.alloc_tile(shape=[BLOCK_K, BLOCK_N], dtype=pto.f16, memory_space=pto.MemorySpace.MAT, valid_shape=[BLOCK_K, BLOCK_N]) @@ -1134,7 +1134,7 @@ def compute_ops_vector_helper(inp_tile: pto.Tile, out_tile: pto.Tile, row: pto.i @pto.jit(target="a5") - def compute_ops_vector_probe(*, BLOCK: pto.constexpr = 128): + def compute_ops_vector_probe(*, BLOCK: pto.const_expr = 128): inp_tile = pto.alloc_tile(shape=[2, BLOCK], dtype=pto.f32) out_tile = pto.alloc_tile(shape=[2, BLOCK], dtype=pto.f32) for row in range(0, 1, 1): @@ -1146,11 +1146,11 @@ def compute_ops_vector_probe(*, BLOCK: pto.constexpr = 128): @pto.jit(target="a5") def compute_ops_tile_window_matmul_probe( *, - BLOCK_M: pto.constexpr = 16, - BLOCK_K: pto.constexpr = 16, - BLOCK_N: pto.constexpr = 16, - CARRIER_M: pto.constexpr = 64, - CARRIER_N: pto.constexpr = 64, + BLOCK_M: pto.const_expr = 16, + BLOCK_K: pto.const_expr = 16, + BLOCK_N: pto.const_expr = 16, + CARRIER_M: pto.const_expr = 64, + CARRIER_N: pto.const_expr = 64, ): src_mat = pto.alloc_tile(shape=[CARRIER_M, CARRIER_N], dtype=pto.f32, memory_space=pto.MemorySpace.MAT) dst_mat = pto.alloc_tile( @@ -1250,10 +1250,10 @@ def flash_attention_l1_tensor_views_probe( heads: pto.i32, dim: pto.i32, *, - BLOCK_Q: pto.constexpr = 128, - BLOCK_KV: pto.constexpr = 128, - CAUSAL: pto.constexpr = False, - NUM_STAGES: pto.constexpr = 2, + BLOCK_Q: pto.const_expr = 128, + BLOCK_KV: pto.const_expr = 128, + CAUSAL: pto.const_expr = False, + NUM_STAGES: pto.const_expr = 2, ): Br = BLOCK_Q Bc = BLOCK_KV @@ -1278,10 +1278,10 @@ def flash_attention_l1_partitions_probe( heads: pto.i32, dim: pto.i32, *, - BLOCK_Q: pto.constexpr = 128, - BLOCK_KV: pto.constexpr = 128, - CAUSAL: pto.constexpr = False, - NUM_STAGES: pto.constexpr = 2, + BLOCK_Q: pto.const_expr = 128, + BLOCK_KV: pto.const_expr = 128, + CAUSAL: pto.const_expr = False, + NUM_STAGES: pto.const_expr = 2, ): q_view = pto.make_tensor_view(Q_ptr, shape=[batch, seq_q, heads, dim], strides=[seq_q * heads * dim, heads * dim, dim, 1]) k_view = pto.make_tensor_view(K_ptr, shape=[batch, seq_k, heads, dim], strides=[seq_k * heads * dim, heads * dim, dim, 1]) @@ -1307,9 +1307,9 @@ def flash_attention_l1_tiles_probe( heads: pto.i32, dim: pto.i32, *, - BLOCK_Q: pto.constexpr = 128, - BLOCK_KV: pto.constexpr = 128, - HEAD_DIM: pto.constexpr = 128, + BLOCK_Q: pto.const_expr = 128, + BLOCK_KV: pto.const_expr = 128, + HEAD_DIM: pto.const_expr = 128, ): Br = BLOCK_Q Bc = BLOCK_KV @@ -1370,11 +1370,11 @@ def flash_attention_l1_loop_body_probe( heads: pto.i32, dim: pto.i32, *, - BLOCK_Q: pto.constexpr = 128, - BLOCK_KV: pto.constexpr = 128, - HEAD_DIM: pto.constexpr = 128, - CAUSAL: pto.constexpr = False, - NUM_STAGES: pto.constexpr = 2, + BLOCK_Q: pto.const_expr = 128, + BLOCK_KV: pto.const_expr = 128, + HEAD_DIM: pto.const_expr = 128, + CAUSAL: pto.const_expr = False, + NUM_STAGES: pto.const_expr = 2, ): q_view = pto.make_tensor_view(Q_ptr, shape=[batch, seq_q, heads, dim], strides=[seq_q * heads * dim, heads * dim, dim, 1]) k_view = pto.make_tensor_view(K_ptr, shape=[batch, seq_k, heads, dim], strides=[seq_k * heads * dim, heads * dim, dim, 1]) @@ -1551,8 +1551,8 @@ def flash_attention_explicit_phase_probe( V_ptr: pto.ptr(pto.f32, "gm"), seq_k: pto.i32, *, - BLOCK_Q: pto.constexpr = 16, - BLOCK_KV: pto.constexpr = 16, + BLOCK_Q: pto.const_expr = 16, + BLOCK_KV: pto.const_expr = 16, ): Br = BLOCK_Q Bc = BLOCK_KV @@ -1602,7 +1602,7 @@ def flash_attention_explicit_phase_probe( @pto.jit(target="a5", mode="explicit") - def flash_attention_qk_cube_helper_probe(*, BLOCK_Q: pto.constexpr = 16, BLOCK_KV: pto.constexpr = 16): + def flash_attention_qk_cube_helper_probe(*, BLOCK_Q: pto.const_expr = 16, BLOCK_KV: pto.const_expr = 16): Br = BLOCK_Q Bc = BLOCK_KV D = 16 @@ -1621,7 +1621,7 @@ def flash_attention_qk_cube_helper_probe(*, BLOCK_Q: pto.constexpr = 16, BLOCK_K @pto.jit(target="a5", mode="explicit") - def flash_attention_pv_cube_helper_probe(*, BLOCK_Q: pto.constexpr = 16, BLOCK_KV: pto.constexpr = 16): + def flash_attention_pv_cube_helper_probe(*, BLOCK_Q: pto.const_expr = 16, BLOCK_KV: pto.const_expr = 16): Br = BLOCK_Q Bc = BLOCK_KV D = 16 @@ -1645,7 +1645,7 @@ def flash_attention_inline_simt_scope( @pto.jit(target="a5") - def flash_attention_inline_simt_scope_probe(*, BLOCK_Q: pto.constexpr = 16, BLOCK_KV: pto.constexpr = 16): + def flash_attention_inline_simt_scope_probe(*, BLOCK_Q: pto.const_expr = 16, BLOCK_KV: pto.const_expr = 16): Br = BLOCK_Q Bc = BLOCK_KV D = 16 @@ -1675,7 +1675,7 @@ def flash_attention_online_softmax_loop_helper( @pto.jit(target="a5") - def flash_attention_online_softmax_loop_probe(*, BLOCK: pto.constexpr = 16): + def flash_attention_online_softmax_loop_probe(*, BLOCK: pto.const_expr = 16): one = 1 s_tile = pto.alloc_tile(shape=[8, BLOCK], dtype=pto.f32, valid_shape=[2, BLOCK]) p_tile = pto.alloc_tile(shape=[8, BLOCK], dtype=pto.f32, valid_shape=[2, BLOCK]) @@ -1719,7 +1719,7 @@ def flash_attention_online_softmax_compute_helper( @pto.jit(target="a5") - def flash_attention_online_softmax_compute_probe(*, BLOCK: pto.constexpr = 16): + def flash_attention_online_softmax_compute_probe(*, BLOCK: pto.const_expr = 16): one = 1 s_tile = pto.alloc_tile(shape=[8, BLOCK], dtype=pto.f32, valid_shape=[2, BLOCK]) p_tile = pto.alloc_tile(shape=[8, BLOCK], dtype=pto.f32, valid_shape=[2, BLOCK]) @@ -1765,7 +1765,7 @@ def flash_attention_online_softmax_store_helper( @pto.jit(target="a5") - def flash_attention_online_softmax_store_probe(*, BLOCK: pto.constexpr = 16): + def flash_attention_online_softmax_store_probe(*, BLOCK: pto.const_expr = 16): one = 1 s_tile = pto.alloc_tile(shape=[8, BLOCK], dtype=pto.f32, valid_shape=[2, BLOCK]) p_tile = pto.alloc_tile(shape=[8, BLOCK], dtype=pto.f32, valid_shape=[2, BLOCK]) @@ -1804,7 +1804,7 @@ def flash_attention_simt_materialize_probe(): @pto.jit(target="a5") - def flash_attention_simt_blend_probe(*, BLOCK: pto.constexpr = 8): + def flash_attention_simt_blend_probe(*, BLOCK: pto.const_expr = 8): one = 1 o_prev_tile = pto.alloc_tile(shape=[8, BLOCK], dtype=pto.f32, valid_shape=[2, BLOCK]) pv_tile = pto.alloc_tile(shape=[8, BLOCK], dtype=pto.f32, valid_shape=[2, BLOCK]) @@ -1838,7 +1838,7 @@ def gemm_tile( @pto.jit(target="a5", mode="explicit") - def gemm_tile_probe(*, BLOCK_M: pto.constexpr = 64, BLOCK_K: pto.constexpr = 64, BLOCK_N: pto.constexpr = 64): + def gemm_tile_probe(*, BLOCK_M: pto.const_expr = 64, BLOCK_K: pto.const_expr = 64, BLOCK_N: pto.const_expr = 64): a_mat = pto.alloc_tile(shape=[BLOCK_M, BLOCK_K], dtype=pto.f32, memory_space=pto.MemorySpace.MAT, valid_shape=[BLOCK_M, BLOCK_K]) b_mat = pto.alloc_tile(shape=[BLOCK_K, BLOCK_N], dtype=pto.f32, memory_space=pto.MemorySpace.MAT, valid_shape=[BLOCK_K, BLOCK_N]) o_tile = pto.alloc_tile(shape=[BLOCK_M, BLOCK_N], dtype=pto.f32, valid_shape=[BLOCK_M, BLOCK_N]) diff --git a/ptodsl/tests/test_ast_rewrite_example_ir.py b/ptodsl/tests/test_ast_rewrite_example_ir.py index 14c90a5f1..c15ec2e3f 100644 --- a/ptodsl/tests/test_ast_rewrite_example_ir.py +++ b/ptodsl/tests/test_ast_rewrite_example_ir.py @@ -332,11 +332,11 @@ def kernel( heads: pto.i32, dim: pto.i32, *, - BLOCK_Q: pto.constexpr = 128, - BLOCK_KV: pto.constexpr = 128, - HEAD_DIM: pto.constexpr = 128, - CAUSAL: pto.constexpr = False, - NUM_STAGES: pto.constexpr = 2, + BLOCK_Q: pto.const_expr = 128, + BLOCK_KV: pto.const_expr = 128, + HEAD_DIM: pto.const_expr = 128, + CAUSAL: pto.const_expr = False, + NUM_STAGES: pto.const_expr = 2, ): q_strides = [seq_q * heads * dim, heads * dim, dim, 1] kv_strides = [seq_k * heads * dim, heads * dim, dim, 1] diff --git a/ptodsl/tests/test_jit_compile.py b/ptodsl/tests/test_jit_compile.py index 401c42268..4ba5a7ff3 100644 --- a/ptodsl/tests/test_jit_compile.py +++ b/ptodsl/tests/test_jit_compile.py @@ -102,7 +102,7 @@ def host_vec_copy( rows: pto.i32, cols: pto.i32, *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): a_view = pto.make_tensor_view(A_ptr, shape=[rows, cols], strides=[cols, 1]) o_view = pto.make_tensor_view(O_ptr, shape=[rows, cols], strides=[cols, 1]) @@ -121,7 +121,7 @@ def host_vec_copy_explicit( rows: pto.i32, cols: pto.i32, *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): a_view = pto.make_tensor_view(A_ptr, shape=[rows, cols], strides=[cols, 1]) o_view = pto.make_tensor_view(O_ptr, shape=[rows, cols], strides=[cols, 1]) @@ -140,7 +140,7 @@ def pointer_runtime_shape_specialization_probe( cols: pto.i32, row_stride: pto.i32, *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): x_view = pto.make_tensor_view(x_ptr, shape=[rows, cols], strides=[row_stride, 1]) x_part = pto.partition_view(x_view, offsets=[0, 0], sizes=[rows, cols]) @@ -155,7 +155,7 @@ def tile_transfer_surface_probe( rows: pto.i32, cols: pto.i32, *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): a_view = pto.make_tensor_view(A_ptr, shape=[rows, cols], strides=[cols, 1]) o_view = pto.make_tensor_view(O_ptr, shape=[rows, cols], strides=[cols, 1]) @@ -208,7 +208,7 @@ def runtime_metadata_kernel( row_stride: pto.i32, col_stride: pto.i32, *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): a_view = pto.make_tensor_view(A_ptr, shape=[rows, cols], strides=[row_stride, col_stride]) o_view = pto.make_tensor_view(O_ptr, shape=[rows, cols], strides=[row_stride, col_stride]) @@ -368,14 +368,14 @@ def top_level_simd_probe(): @pto.jit(target="a5") -def shared_subkernel_lowering_probe(*, TRACE_TOKEN: pto.constexpr = 0): +def shared_subkernel_lowering_probe(*, TRACE_TOKEN: pto.const_expr = 0): top_level_cube_probe() top_level_simd_probe() nested_simd_probe() @pto.jit(target="a5", mode="explicit") -def inline_subkernel_scope_probe(*, TRACE_TOKEN: pto.constexpr = 0): +def inline_subkernel_scope_probe(*, TRACE_TOKEN: pto.const_expr = 0): session = current_session() meta_tile = pto.alloc_tile(shape=[1, 8], dtype=pto.i32, valid_shape=[1, 1]) @@ -408,7 +408,7 @@ def ast_subkernel_runtime_for_helper(rows: pto.i32): @pto.jit(target="a5") -def simt_helper_lowering_probe(*, TRACE_TOKEN: pto.constexpr = 0): +def simt_helper_lowering_probe(*, TRACE_TOKEN: pto.const_expr = 0): simt_tid_probe() simt_tid_probe() @@ -419,7 +419,7 @@ def ast_subkernel_runtime_for_probe(rows: pto.i32): @pto.jit(target="a5") -def carry_loop_lowering_probe(*, BLOCK: pto.constexpr = 128): +def carry_loop_lowering_probe(*, BLOCK: pto.const_expr = 128): m_prev = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) l_prev = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) o_prev = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) @@ -736,7 +736,7 @@ def set_limit(value: int): def make_ast_signature_closure_default_kernel(limit: int): @pto.jit(target="a5") - def ast_signature_closure_default_kernel(*, BLOCK: pto.constexpr = limit): + def ast_signature_closure_default_kernel(*, BLOCK: pto.const_expr = limit): for _ in pto.static_range(BLOCK): pto.pipe_barrier(pto.Pipe.ALL) @@ -757,7 +757,7 @@ def helper(): limit = 4 @pto.jit(target="a5") - def ast_rebound_subkernel_probe(*, TRACE_TOKEN: pto.constexpr = 0): + def ast_rebound_subkernel_probe(*, TRACE_TOKEN: pto.const_expr = 0): helper() return ast_rebound_subkernel_probe @@ -797,7 +797,7 @@ def sourceless_subkernel_helper(): helper = namespace["sourceless_subkernel_helper"] @pto.jit(target="a5") - def sourceless_subkernel_entry_probe(*, TRACE_TOKEN: pto.constexpr = 0): + def sourceless_subkernel_entry_probe(*, TRACE_TOKEN: pto.const_expr = 0): helper() return sourceless_subkernel_entry_probe @@ -807,7 +807,7 @@ def sourceless_subkernel_entry_probe(*, TRACE_TOKEN: pto.constexpr = 0): @pto.jit(target="a5") -def ast_static_control_flow_probe(*, ENABLE: pto.constexpr = True): +def ast_static_control_flow_probe(*, ENABLE: pto.const_expr = True): if pto.const_expr(ENABLE): for _ in pto.static_range(2): pto.pipe_barrier(pto.Pipe.ALL) @@ -816,8 +816,8 @@ def ast_static_control_flow_probe(*, ENABLE: pto.constexpr = True): @pto.jit(target="a5") def ast_python_bool_guard_probe( *, - BLOCK: pto.constexpr = 128, - ENABLE: pto.constexpr = True, + BLOCK: pto.const_expr = 128, + ENABLE: pto.const_expr = True, ): if BLOCK == 128: pto.pipe_barrier(pto.Pipe.ALL) @@ -852,7 +852,7 @@ def runtime_scalar_operator_probe( cols: pto.i32, row_stride: pto.i32, *, - BLOCK: pto.constexpr = 8, + BLOCK: pto.const_expr = 8, ): block_idx = pto.get_block_idx() o_view = pto.make_tensor_view(O_ptr, shape=[rows, cols], strides=[row_stride, 1]) @@ -926,7 +926,7 @@ def tile_slice_vector_probe(inp_tile: pto.Tile, out_tile: pto.Tile, row: pto.ind @pto.jit(target="a5") -def tile_slice_surface_probe(*, BLOCK: pto.constexpr = 128): +def tile_slice_surface_probe(*, BLOCK: pto.const_expr = 128): inp_tile = pto.alloc_tile(shape=[2, BLOCK], dtype=pto.f32) out_tile = pto.alloc_tile(shape=[2, BLOCK], dtype=pto.f32) for row in range(0, 1, 1): @@ -934,7 +934,7 @@ def tile_slice_surface_probe(*, BLOCK: pto.constexpr = 128): @pto.jit(target="a5") -def tile_slice_1d_surface_probe(*, BLOCK: pto.constexpr = 128): +def tile_slice_1d_surface_probe(*, BLOCK: pto.const_expr = 128): inp_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) out_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32) start = pto.const(0, dtype=pto.i32) @@ -950,7 +950,7 @@ def tile_valid_shape_update_probe( rows: pto.i32, cols: pto.i32, *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): tile = pto.alloc_tile( shape=[1, BLOCK], @@ -964,7 +964,7 @@ def tile_valid_shape_update_probe( def tile_valid_shape_update_1d_probe( length: pto.i32, *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): tile = pto.alloc_tile( shape=[BLOCK], @@ -995,7 +995,7 @@ def carry_static_pyint_init_probe(): @pto.jit(target="a5") -def integer_loop_bound_probe(*, BLOCK: pto.constexpr = 8): +def integer_loop_bound_probe(*, BLOCK: pto.const_expr = 8): row_start = pto.const(0, dtype=pto.i32) row_stop = pto.const(BLOCK, dtype=pto.i32) valid_dim = pto.const(BLOCK // 2, dtype=pto.i32) @@ -1386,7 +1386,7 @@ def low_precision_storage_probe(): @pto.jit(target="a5") -def pointer_vlds_inference_probe(*, BLOCK: pto.constexpr = 128): +def pointer_vlds_inference_probe(*, BLOCK: pto.const_expr = 128): tile = pto.alloc_tile(shape=[2, BLOCK], dtype=pto.f32) vec = pto.vlds(tile.as_ptr(), pto.const(0)) vec_brc = pto.vlds(tile.as_ptr(), pto.const(0), dist="BRC_B32") @@ -1905,6 +1905,7 @@ def main() -> None: expect(not hasattr(pto, "as_ptr"), "pto.as_ptr should not remain on the public pto namespace") expect(not hasattr(pto, "vbrc_load"), "pto.vbrc_load should not remain on the public pto namespace") expect(not hasattr(pto, "vsts_1pt"), "pto.vsts_1pt should not remain on the public pto namespace") + expect(not hasattr(pto, "constexpr"), "pto.constexpr should not remain on the public pto namespace") expect(not hasattr(scalar, "sts"), "scalar.sts should not remain in the public scalar namespace") expect(not hasattr(scalar, "cmpi"), "scalar.cmpi should not remain in the public scalar namespace") expect(not hasattr(scalar, "cmpi_sgt"), "scalar.cmpi_sgt should not remain in the public scalar namespace") @@ -1933,6 +1934,12 @@ def main() -> None: "pto.vsts_1pt is not a supported PTODSL public interface" in str(removed_vsts_1pt), "removed pto.vsts_1pt should diagnose the public vsts(dist=...) replacement", ) + removed_constexpr = expect_raises(AttributeError, lambda: getattr(pto, "constexpr")) + expect( + "pto.constexpr is not a supported PTODSL public interface" in str(removed_constexpr) + and "Use pto.const_expr" in str(removed_constexpr), + "removed pto.constexpr should diagnose pto.const_expr as the replacement", + ) for name in ("max", "min", "exp", "log", "sqrt", "abs"): expect(hasattr(scalar, name), f"scalar.{name} should be exported from the public scalar namespace") diff --git a/ptodsl/tests/test_jit_diagnostics.py b/ptodsl/tests/test_jit_diagnostics.py index 56bbc6102..84c32b56f 100644 --- a/ptodsl/tests/test_jit_diagnostics.py +++ b/ptodsl/tests/test_jit_diagnostics.py @@ -66,7 +66,7 @@ def float_bitwise_probe(): @pto.jit(target="a5") -def carry_update_mismatch_probe(*, BLOCK: pto.constexpr = 8): +def carry_update_mismatch_probe(*, BLOCK: pto.const_expr = 8): acc = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) loop = pto.for_(0, 1, step=1).carry(acc=acc) with loop: @@ -74,7 +74,7 @@ def carry_update_mismatch_probe(*, BLOCK: pto.constexpr = 8): @pto.jit(target="a5") -def carry_final_mismatch_probe(*, BLOCK: pto.constexpr = 8): +def carry_final_mismatch_probe(*, BLOCK: pto.const_expr = 8): acc = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) loop = pto.for_(0, 1, step=1).carry(acc=acc) with loop: @@ -106,7 +106,7 @@ def data_ptr(self): def define_missing_constexpr_default_probe(): @pto.jit(target="a5") - def bad_probe(*, BLOCK: pto.constexpr): + def bad_probe(*, BLOCK: pto.const_expr): pto.pipe_barrier(pto.Pipe.ALL) return bad_probe @@ -359,7 +359,7 @@ def main() -> None: TypeError, "native Python if/while condition", "pto.if_(...)", - "pto.constexpr", + "pto.const_expr", ) expect_raises( native_python_range_runtime_metadata_probe.compile, @@ -420,7 +420,7 @@ def main() -> None: define_illegal_keyword_only_probe, TypeError, "@pto.jit keyword-only parameter 'BLOCK' uses unsupported compile-time annotation", - "pto.constexpr", + "pto.const_expr", "move runtime data to positional pointer/scalar parameters", ) expect_raises( diff --git a/ptodsl/tests/test_ptoas_frontend_verify.py b/ptodsl/tests/test_ptoas_frontend_verify.py index a6cfdb2b0..565109646 100644 --- a/ptodsl/tests/test_ptoas_frontend_verify.py +++ b/ptodsl/tests/test_ptoas_frontend_verify.py @@ -71,7 +71,7 @@ def host_vec_copy( rows: pto.i32, cols: pto.i32, *, - BLOCK: pto.constexpr = 128, + BLOCK: pto.const_expr = 128, ): a_view = pto.make_tensor_view(A_ptr, shape=[rows, cols], strides=[cols, 1]) o_view = pto.make_tensor_view(O_ptr, shape=[rows, cols], strides=[cols, 1]) diff --git a/ptodsl/tests/test_subkernel_diagnostics.py b/ptodsl/tests/test_subkernel_diagnostics.py index 8bd6149d0..de8375bc7 100644 --- a/ptodsl/tests/test_subkernel_diagnostics.py +++ b/ptodsl/tests/test_subkernel_diagnostics.py @@ -76,7 +76,7 @@ def illegal_simt_placement_probe(): @pto.jit(target="a5") -def nested_simt_from_simd_entry(*, TRACE_TOKEN: pto.constexpr = 0): +def nested_simt_from_simd_entry(*, TRACE_TOKEN: pto.const_expr = 0): illegal_simt_placement_probe() @@ -87,7 +87,7 @@ def illegal_inline_simt_placement_probe(): @pto.jit(target="a5") -def nested_inline_simt_from_simd_entry(*, TRACE_TOKEN: pto.constexpr = 0): +def nested_inline_simt_from_simd_entry(*, TRACE_TOKEN: pto.const_expr = 0): illegal_inline_simt_placement_probe() @@ -97,7 +97,7 @@ def simd_value_escape_probe(): @pto.jit(target="a5") -def simd_value_escape_entry(*, TRACE_TOKEN: pto.constexpr = 0): +def simd_value_escape_entry(*, TRACE_TOKEN: pto.const_expr = 0): simd_value_escape_probe()