refactor(examples): Migrate kernels and simple models to @pl.jit#1323
refactor(examples): Migrate kernels and simple models to @pl.jit#1323lyfne123 merged 2 commits intohw-native-sys:mainfrom
Conversation
…m to @pl.jit
Replaces @pl.program + harness-based test pattern with @pl.jit decorators
across the examples/ tree and the example-tests under tests/st/examples/.
Each migrated example exits 0 from `python examples/<file>.py` and exercises
the kernel via torch.allclose against a torch reference (where one exists),
matching the runnable-example pattern in examples/models/qwen3_jit/.
Migrated examples (15):
- examples/hello_world.py
- examples/kernels/01-09 (elementwise, fused_ops, matmul, concat, activation,
softmax, normalization, assemble, dyn_valid_shape)
- examples/models/01_ffn.py, 02_vector_dag.py, 03_flash_attention.py
- examples/utils/cross_function_calls.py, error_handling.py
Migrated tests (17):
- tests/st/examples/{00_hello_world,01_beginner/basic,02_intermediate}/test_*.py
- tests/st/runtime/test_{elementwise,compiled_program,concat,matmul,assemble,
device_tensor,dag}.py
- tests/st/codegen/test_{dyn_valid_shape_loop,dynamic_valid_shape_if_else,
add_mul_orch_codegen}.py
Out of scope (deferred to follow-up issues, see KNOWN_ISSUES.md):
- examples/models/04-09 (paged_attention family + llama_mini): blocked on
JIT specializer not tracking pl.slice() results / @pl.jit.incore returns
- examples/utils/parse_from_text.py: orthogonal to JIT (IR text parsing demo)
- tests/st/examples/03_llm_models/test_llama_7b_mini_1h.py: depends on
08_llama_mini.py
- tests/st/runtime/test_dyn_orch_shape.py and the 5 paged-attention codegen
tests: still depend on un-migrated examples
The 03_flash_attention.py __main__ smoke is print-only because the original
@pl.function body has an IfStmt yield/return_vars structural mismatch that
the @pl.program path masked (the original example only ever called print()).
Tracked in KNOWN_ISSUES.md.
Refs: hw-native-sys#1320
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughMigrates example programs to module-level ChangesExamples & Kernel API Refactoring
Test Infrastructure Migration
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. |
There was a problem hiding this comment.
Code Review
This pull request migrates the PyPTO examples and system tests to the new @pl.jit API, replacing the legacy @pl.program structure with a more flexible JIT-based execution model. This involves updating numerous kernels and orchestration logic to use the pl.incore() context and inline helpers. Feedback on the flash_attention implementation highlights critical semantic errors where orchestration-level operations like pl.create_tensor are incorrectly used inside device-side compute blocks, and global tensors are passed directly to matrix multiplication without being loaded into tiles.
There was a problem hiding this comment.
Pull request overview
This PR modernizes the examples/ and their companion system tests by migrating from legacy @pl.program + harness-based execution to the @pl.jit entrypoint style, so examples/tests can directly specialize/compile/execute kernels against torch references (or run compile_for_test smoke where execution isn’t meaningful yet).
Changes:
- Migrated example kernels/models to
@pl.jit/@pl.jit.incore/@pl.jit.inlinewith runnable__main__blocks (torch reference checks orcompile_for_testsmoke). - Rewrote example and runtime tests to call JIT functions directly with the
test_configfixture andtorch.allclosecomparisons. - Updated codegen smoke tests to validate JIT compilation for dynamic valid-shape patterns via
compile_for_test.
Reviewed changes
Copilot reviewed 32 out of 32 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| examples/hello_world.py | Replaces @pl.program with a @pl.jit tile_add and a runnable __main__ torch reference check. |
| examples/kernels/01_elementwise.py | Splits elementwise into multiple @pl.jit kernels (64/128) and adds runnable __main__ validation. |
| examples/kernels/02_fused_ops.py | Migrates fused-op examples to @pl.jit + @pl.jit.incore helpers and runnable checks. |
| examples/kernels/03_matmul.py | Migrates matmul + matmul_acc examples to @pl.jit and adds runnable validation. |
| examples/kernels/04_concat.py | Migrates concat example to @pl.jit and adds runnable validation. |
| examples/kernels/05_activation.py | Migrates activation kernels to @pl.jit and adds runnable validation. |
| examples/kernels/06_softmax.py | Migrates softmax kernel to @pl.jit and adds runnable validation. |
| examples/kernels/07_normalization.py | Migrates RMSNorm/LayerNorm to @pl.jit and adds runnable validation. |
| examples/kernels/08_assemble.py | Migrates assemble patterns to @pl.jit and uses compile_for_test smoke in __main__. |
| examples/kernels/09_dyn_valid_shape.py | Refactors dynamic valid-shape example to caller-provided vlen with @pl.jit and compile_for_test smoke. |
| examples/models/01_ffn.py | Migrates FFN examples to @pl.jit orchestration with @pl.jit.incore kernels and runnable torch validation. |
| examples/models/02_vector_dag.py | Migrates vector DAG example to @pl.jit orchestration + @pl.jit.incore kernels; keeps golden() reference. |
| examples/models/03_flash_attention.py | Wraps flash attention in @pl.jit but keeps __main__ print-only due to known IR verification issue. |
| examples/utils/cross_function_calls.py | Updates cross-function composition example to @pl.jit + @pl.jit.inline deps and compile_for_test output. |
| examples/utils/error_handling.py | Replaces prior error-renderer example with a JIT “invalid kernel rejected” demonstration. |
| tests/st/examples/00_hello_world/test_hello_world.py | Rewrites hello-world test to call @pl.jit kernel with torch reference compare. |
| tests/st/examples/01_beginner/basic/test_basic_ops.py | Rewrites fused-ops tests to call JIT functions directly with torch reference checks. |
| tests/st/examples/02_intermediate/test_activation.py | Rewrites activation tests to call JIT kernels directly with torch reference checks. |
| tests/st/examples/02_intermediate/test_ffn_activations.py | Rewrites FFN activation tests to call JIT entries directly with torch reference checks. |
| tests/st/examples/02_intermediate/test_layer_norm.py | Rewrites LayerNorm test to call layer_norm JIT kernel and compare to torch reference. |
| tests/st/examples/02_intermediate/test_rms_norm.py | Rewrites RMSNorm test to call rms_norm JIT kernel and compare to torch reference. |
| tests/st/examples/02_intermediate/test_softmax.py | Rewrites softmax test to call tile_softmax JIT kernel and compare to torch reference. |
| tests/st/runtime/test_assemble.py | Rewrites assemble runtime tests to call JIT kernels directly (with existing skips/markers preserved). |
| tests/st/runtime/test_compiled_program.py | Refocuses compiled-program tests around @pl.jit cache population and direct CompiledProgram invocation. |
| tests/st/runtime/test_concat.py | Rewrites concat runtime test to call JIT kernel directly (test remains skipped). |
| tests/st/runtime/test_dag.py | Rewrites DAG runtime test to call vector_dag JIT entry and validate via shared golden(). |
| tests/st/runtime/test_device_tensor.py | Refactors DeviceTensor test to specialize via JIT then reuse cached CompiledProgram for host+device mixing. |
| tests/st/runtime/test_elementwise.py | Rewrites elementwise runtime tests to call JIT kernels directly with torch reference checks. |
| tests/st/runtime/test_matmul.py | Switches matmul-acc test to JIT (matmul_acc_64) while leaving the rest of the file harness-based. |
| tests/st/codegen/test_add_mul_orch_codegen.py | Updates orchestration codegen test to use compile_for_test on example_orch. |
| tests/st/codegen/test_dyn_valid_shape_loop.py | Updates dynamic valid-shape loop smoke test to JIT compile_for_test with caller-provided vlen. |
| tests/st/codegen/test_dynamic_valid_shape_if_else.py | Updates dynamic valid-shape if/else smoke test to JIT compile_for_test with caller-provided vlen. |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (5)
examples/models/03_flash_attention.py (1)
115-117: ⚡ Quick winPrefix unused loop-carried variables with underscore.
The final unpacking extracts
mi_final,li_final, andoi_finalfrom the loop, but onlyattn_finalis used in the return statement. Following Python convention, prefix the unused variables with_to signal they are intentionally discarded.♻️ Proposed fix
- for i, (mi_update, li_update, attn_update, oi_update) in pl.range( + for _i, (mi_update, li_update, attn_update, oi_update) in pl.range( 16, init_values=( mi_update_initial, li_update_initial, attn_initial, oi_update_initial, ), ):And:
- mi_final, li_final, attn_final, oi_final = pl.yield_( + _mi_final, _li_final, attn_final, _oi_final = pl.yield_( miUpdate_126, liUpdate_127, attn_128, oiUpdate_129 )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/models/03_flash_attention.py` around lines 115 - 117, The final unpacking from pl.yield_ currently assigns mi_final, li_final, attn_final, oi_final but only attn_final is used; update the unpacking to prefix unused loop-carried variables (mi_final, li_final, oi_final) with an underscore (e.g., _mi_final, _li_final, _oi_final) so they are clearly discarded while keeping attn_final unchanged in the call to pl.yield_.tests/st/codegen/test_dyn_valid_shape_loop.py (1)
34-37: 💤 Low valueUnused constant
N_ROWfor documentation only.The constant
N_ROW = Q_TILEis defined but never used in the code. While the comment explains it's a "documentation marker" from the old multi-block tests, consider removing it to reduce clutter, or add a# noqa: F841comment if you want to keep it for historical context.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/st/codegen/test_dyn_valid_shape_loop.py` around lines 34 - 37, Remove the dead documentation-only constant N_ROW (assigned from Q_TILE) or explicitly suppress the unused-variable lint by adding a # noqa: F841 comment; locate the definition of N_ROW in the test (the line "N_ROW = Q_TILE") and either delete that line or append "# noqa: F841" to preserve the historical marker while silencing the unused-variable warning.examples/kernels/05_activation.py (1)
97-107: ⚡ Quick winSeed set once affects only the first test.
torch.manual_seed(0)is set at line 97, but each subsequent test (GELU, SwiGLU, GeGLU) creates new random tensors without re-seeding. This means:
- The SiLU test uses seeded
torch.randn- GELU test uses non-deterministic randn (different on each run)
- SwiGLU and GeGLU also use non-deterministic randn
For fully reproducible example runs, consider calling
torch.manual_seed(0)before each test section, or usetorch.full/ deterministic values for all tests.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/kernels/05_activation.py` around lines 97 - 107, The random seed is only set once before the SiLU test so subsequent tests (GELU, SwiGLU, GeGLU) use non-deterministic inputs; ensure reproducibility by re-seeding torch.manual_seed(0) immediately before each test block that calls torch.randn (or alternatively replace random inputs with deterministic tensors via torch.full) — update the test sections that call silu(...), gelu(...), swiglu(...) and geglu(...) to either call torch.manual_seed(0) right before creating their input tensors or to use fixed values for x and any other random inputs.tests/st/runtime/test_device_tensor.py (1)
47-55: 💤 Low valueDirect
_cacheaccess suggests internal API usage.The helper function accesses
tile_add_128._cachedirectly (lines 49, 54), which is typically considered an internal implementation detail (indicated by the leading underscore). While this is acceptable in test code to verify caching behavior, consider:
- Documenting that this is intentional test-level introspection
- Adding a comment explaining why direct cache access is necessary here (to get the
CompiledProgramfor DeviceTensor testing)The assertion at line 54 is good but could be more defensive if the cache implementation changes.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/st/runtime/test_device_tensor.py` around lines 47 - 55, The helper _specialize_and_get_compiled directly inspects tile_add_128._cache (an internal API); make this intentional by adding an inline comment above _specialize_and_get_compiled stating that test-level introspection of tile_add_128._cache is deliberate to retrieve the cached CompiledProgram for DeviceTensor tests, and harden the assertion by checking that the cache has at least one entry (e.g., assert len(tile_add_128._cache) >= 1 with a clear failure message) before returning next(iter(tile_add_128._cache.values())); keep references to tile_add_128, _cache and CompiledProgram so future readers know why internals are accessed.tests/st/codegen/test_add_mul_orch_codegen.py (1)
38-51: 💤 Low valueConsider tightening the function-count assertion.
The docstring (L38-40) promises "3 InCore + 1 Orchestration" (4 functions), but the body only asserts
> 0. Asserting>= 4(or== 4) would actually exercise the documented contract and catch silent regressions in pass output.♻️ Proposed tightening
assert program is not None, "compile_for_test returned None" - assert len(program.functions) > 0, "compile_for_test produced no functions" + # 3 InCore kernels (add, add_scalar, mul) + 1 Orchestration entry. + assert len(program.functions) >= 4, ( + f"expected at least 4 functions (3 InCore + 1 Orchestration), got {len(program.functions)}" + )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/st/codegen/test_add_mul_orch_codegen.py` around lines 38 - 51, The test currently asserts that program.functions has any functions but the docstring expects "3 InCore + 1 Orchestration" (4 total); update the assertion after calling example_orch.compile_for_test(a, b, output) to validate the documented contract by asserting program.functions length is at least 4 (or == 4 if you want exactness) so the test will catch regressions — locate the assertion that references program.functions and replace the loose check (len(program.functions) > 0) with a stricter assertion (e.g., len(program.functions) >= 4 or len(program.functions) == 4).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/utils/error_handling.py`:
- Around line 36-40: The current try block around test_ssa_violation(x, result,
config=RunConfig()) prints an "ERROR" message but still exits with code 0 if no
PartialCodegenError is raised; change the branch where the kernel unexpectedly
succeeds to terminate with a non-zero exit (e.g., call sys.exit(1) or raise
SystemExit(1)) so CI detects the failure, and ensure sys is imported at top of
examples/utils/error_handling.py; keep the except clause catching
PartialCodegenError as-is.
---
Nitpick comments:
In `@examples/kernels/05_activation.py`:
- Around line 97-107: The random seed is only set once before the SiLU test so
subsequent tests (GELU, SwiGLU, GeGLU) use non-deterministic inputs; ensure
reproducibility by re-seeding torch.manual_seed(0) immediately before each test
block that calls torch.randn (or alternatively replace random inputs with
deterministic tensors via torch.full) — update the test sections that call
silu(...), gelu(...), swiglu(...) and geglu(...) to either call
torch.manual_seed(0) right before creating their input tensors or to use fixed
values for x and any other random inputs.
In `@examples/models/03_flash_attention.py`:
- Around line 115-117: The final unpacking from pl.yield_ currently assigns
mi_final, li_final, attn_final, oi_final but only attn_final is used; update the
unpacking to prefix unused loop-carried variables (mi_final, li_final, oi_final)
with an underscore (e.g., _mi_final, _li_final, _oi_final) so they are clearly
discarded while keeping attn_final unchanged in the call to pl.yield_.
In `@tests/st/codegen/test_add_mul_orch_codegen.py`:
- Around line 38-51: The test currently asserts that program.functions has any
functions but the docstring expects "3 InCore + 1 Orchestration" (4 total);
update the assertion after calling example_orch.compile_for_test(a, b, output)
to validate the documented contract by asserting program.functions length is at
least 4 (or == 4 if you want exactness) so the test will catch regressions —
locate the assertion that references program.functions and replace the loose
check (len(program.functions) > 0) with a stricter assertion (e.g.,
len(program.functions) >= 4 or len(program.functions) == 4).
In `@tests/st/codegen/test_dyn_valid_shape_loop.py`:
- Around line 34-37: Remove the dead documentation-only constant N_ROW (assigned
from Q_TILE) or explicitly suppress the unused-variable lint by adding a # noqa:
F841 comment; locate the definition of N_ROW in the test (the line "N_ROW =
Q_TILE") and either delete that line or append "# noqa: F841" to preserve the
historical marker while silencing the unused-variable warning.
In `@tests/st/runtime/test_device_tensor.py`:
- Around line 47-55: The helper _specialize_and_get_compiled directly inspects
tile_add_128._cache (an internal API); make this intentional by adding an inline
comment above _specialize_and_get_compiled stating that test-level introspection
of tile_add_128._cache is deliberate to retrieve the cached CompiledProgram for
DeviceTensor tests, and harden the assertion by checking that the cache has at
least one entry (e.g., assert len(tile_add_128._cache) >= 1 with a clear failure
message) before returning next(iter(tile_add_128._cache.values())); keep
references to tile_add_128, _cache and CompiledProgram so future readers know
why internals are accessed.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: df192d80-7513-4343-a242-70fc4f5a8eee
📒 Files selected for processing (32)
examples/hello_world.pyexamples/kernels/01_elementwise.pyexamples/kernels/02_fused_ops.pyexamples/kernels/03_matmul.pyexamples/kernels/04_concat.pyexamples/kernels/05_activation.pyexamples/kernels/06_softmax.pyexamples/kernels/07_normalization.pyexamples/kernels/08_assemble.pyexamples/kernels/09_dyn_valid_shape.pyexamples/models/01_ffn.pyexamples/models/02_vector_dag.pyexamples/models/03_flash_attention.pyexamples/utils/cross_function_calls.pyexamples/utils/error_handling.pytests/st/codegen/test_add_mul_orch_codegen.pytests/st/codegen/test_dyn_valid_shape_loop.pytests/st/codegen/test_dynamic_valid_shape_if_else.pytests/st/examples/00_hello_world/test_hello_world.pytests/st/examples/01_beginner/basic/test_basic_ops.pytests/st/examples/02_intermediate/test_activation.pytests/st/examples/02_intermediate/test_ffn_activations.pytests/st/examples/02_intermediate/test_layer_norm.pytests/st/examples/02_intermediate/test_rms_norm.pytests/st/examples/02_intermediate/test_softmax.pytests/st/runtime/test_assemble.pytests/st/runtime/test_compiled_program.pytests/st/runtime/test_concat.pytests/st/runtime/test_dag.pytests/st/runtime/test_device_tensor.pytests/st/runtime/test_elementwise.pytests/st/runtime/test_matmul.py
- test_add_mul_orch_codegen: assert exactly 1 Orchestration + 3 AIV functions in the post-pass IR (was: only checked >0 functions) -- addresses copilot-pull-request-reviewer feedback on weak assertion - examples/utils/error_handling: sys.exit(1) when the expected PartialCodegenError is not raised, so unexpected success doesn't silently exit 0 in CI -- addresses coderabbitai feedback
Summary
Migrates 15 example files + 17 test files from the legacy
@pl.program+harness.core.harnesstest pattern to the modern@pl.jitdecorator. Each migrated example exits 0 frompython examples/<file>.pyand exercises the kernel viatorch.allcloseagainst a torch reference, matching the runnable-example pattern inexamples/models/qwen3_jit/.Migrated test files drop
harness.core.harnessimports and use thetest_configfixture + direct@pl.jitcalls (mirroringtests/st/runtime/test_jit.py).In scope
Examples (15):
examples/hello_world.pyexamples/kernels/01-09(elementwise, fused_ops, matmul, concat, activation, softmax, normalization, assemble, dyn_valid_shape)examples/models/01_ffn.py,02_vector_dag.py,03_flash_attention.pyexamples/utils/cross_function_calls.py,error_handling.pyTests (17):
tests/st/examples/{00_hello_world,01_beginner/basic,02_intermediate}/test_*.py(6 files)tests/st/runtime/test_{elementwise,compiled_program,concat,matmul,assemble,device_tensor,dag}.py(7 files)tests/st/codegen/test_{dyn_valid_shape_loop,dynamic_valid_shape_if_else,add_mul_orch_codegen}.py(3 files)Deferred to follow-ups
examples/models/04-09(paged_attention family + llama_mini): blocked on the JIT specializer not trackingpl.slice()results or@pl.jit.incorereturn-value tensor metadata. ~30-line specializer extension needed.examples/utils/parse_from_text.py: orthogonal to JIT (IR text parsing demo).tests/st/examples/03_llm_models/test_llama_7b_mini_1h.py,tests/st/runtime/test_dyn_orch_shape.py, and the 5 paged-attention codegen tests: still depend on un-migrated examples.Notes
examples/models/03_flash_attention.py__main__is print-only because the original@pl.functionbody has an IfStmt yield/return_vars structural mismatch that the@pl.programpath masked (the original example only ever calledprint()). Tracked as a known issue.examples/kernels/09_dyn_valid_shape.pycollapses the originalif/else-based factory into a caller-side scalar selection: the JIT specializer's alpha-renamer breaksif/elserebinding of the same scalar variable (also tracked as a known issue).examples/kernels/08_assemble.pyand09_dyn_valid_shape.pyusecompile_for_testsmoke (no torch reference) —tile.assembleandvalid_shapes/fillpadsemantics are best validated on device via the corresponding tests.Net diff: 32 files changed, +1645 / -2927 (mostly removed
@pl.program/orchestrator boilerplate).Test plan
compile_for_testsucceeds for every migrated@pl.jitfunction (verified locally on simulator IR pipeline).pytest --collect-onlysucceeds).harness.coreimports remain in any migrated file.tests/st/examples/+tests/st/runtime/+tests/st/codegen/tests end-to-end.Refs: #1320