feat: Add PTO-DSL GEMM performance kernel with Python build and validation#116
feat: Add PTO-DSL GEMM performance kernel with Python build and validation#116Crystal-wzy wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a PTO-DSL implementation of a GEMM performance kernel for Ascend A2/A3 processors, along with a C++ wrapper, build scripts, and a Python-based validation and benchmarking suite. The kernel features a four-stage pipeline with double buffering and multi-core partitioning. Feedback identifies a critical issue where the Python script fails to output the actual IR, as well as opportunities to improve build flexibility by making the NPU architecture configurable and ensuring broader DSL compatibility by avoiding list arguments in synchronization primitives.
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| print(build()) |
There was a problem hiding this comment.
The build() function returns the GemmPerformance kernel function object decorated with @to_ir_module. To generate and print the actual PTO IR (which is what compile.sh expects to capture), you must invoke the returned function. Printing the function object itself will only output its string representation (e.g., <function GemmPerformance at ...>), resulting in an invalid .pto file and causing the build to fail.
| print(build()) | |
| print(build()()) |
| PTO_LIB_PATH=${PTO_LIB_PATH:-$(cd "${SCRIPT_DIR}/../../.." && pwd)} | ||
|
|
||
| bisheng -fPIC -shared -xcce -O2 -std=c++17 \ | ||
| --npu-arch=dav-2201 \ |
There was a problem hiding this comment.
The NPU architecture is hardcoded to dav-2201 (which corresponds to the A2 platform). Since this kernel is intended to support both A2 and A3 (as stated in the README), hardcoding the architecture might prevent the compiler from applying A3-specific optimizations or could lead to compatibility issues on A3 hardware. Consider making the architecture configurable via an environment variable.
| --npu-arch=dav-2201 \ | |
| --npu-arch=${NPU_ARCH:-dav-2201} \ |
| pto.record_event("MOV_M2L", "LOAD", event_id=[0, 1]) | ||
| pto.record_event("MATMUL", "MOV_M2L", event_id=[0, 1]) |
There was a problem hiding this comment.
Passing a list to event_id in pto.record_event may not be supported by all versions of the PTO-DSL or might lead to unexpected IR generation if the backend expects a single integer constant. It is safer and more explicit to record each event individually to ensure both synchronization slots are correctly initialized.
pto.record_event("MOV_M2L", "LOAD", event_id=0)
pto.record_event("MOV_M2L", "LOAD", event_id=1)
pto.record_event("MATMUL", "MOV_M2L", event_id=0)
pto.record_event("MATMUL", "MOV_M2L", event_id=1)| pto.wait_event("MOV_M2L", "LOAD", event_id=[0, 1]) | ||
| pto.wait_event("MATMUL", "MOV_M2L", event_id=[0, 1]) |
There was a problem hiding this comment.
Similar to the initialization at the start of the kernel, pto.wait_event should ideally be called for each event ID individually if the DSL does not explicitly support list inputs for synchronization primitives.
| pto.wait_event("MOV_M2L", "LOAD", event_id=[0, 1]) | |
| pto.wait_event("MATMUL", "MOV_M2L", event_id=[0, 1]) | |
| pto.wait_event("MOV_M2L", "LOAD", event_id=0) | |
| pto.wait_event("MOV_M2L", "LOAD", event_id=1) | |
| pto.wait_event("MATMUL", "MOV_M2L", event_id=0) | |
| pto.wait_event("MATMUL", "MOV_M2L", event_id=1) |
|
Triage review (2026-05-08): this PR is in good merge shape. GitHub reports it as clean against Two pre-merge checks I recommend: confirm the build artifacts listed in the README ( |
…ation (hw-native-sys#116) ## Summary - Add `kernels/python/gemm_performance/` with a PTO-DSL GEMM kernel (`gemm_performance.py`) targeting A2/A3, featuring 2D multi-core tiling, L1 double buffering, L0 ping-pong buffers, and a four-stage LOAD→EXTRACT→MATMUL→STORE pipeline - Add `run_gemm.py` runner with multiple shape presets, auto-build (ptoas + bisheng), Torch-NPU correctness validation, and benchmark with optional `torch.matmul` baseline comparison - Add `compile.sh` for standalone `.pto` → `.cpp` → `.so` build flow - Add bilingual README (English and Chinese) documenting operator description, tiling parameters, prerequisites, build/run instructions, and available presets - Register `python/gemm_performance/` in the top-level `kernels/` README ## Testing - [x] Correctness test passes on A2/A3 hardware (`python3 run_gemm.py`) - [x] Benchmark mode reports expected TFLOPS (`python3 run_gemm.py --benchmark`) - [x] Pre-commit hooks pass (ruff, pyright, markdownlint)
5a61598 to
7705674
Compare
…ation ## Summary - Add `kernels/python/gemm_performance/` with a PTO-DSL GEMM kernel (`gemm_performance.py`) targeting A2/A3, featuring 2D multi-core tiling, L1 double buffering, L0 ping-pong buffers, and a four-stage LOAD→EXTRACT→MATMUL→STORE pipeline - Add `run_gemm.py` runner with multiple shape presets, auto-build (ptoas + bisheng), Torch-NPU correctness validation, and benchmark with optional `torch.matmul` baseline comparison - Add `compile.sh` for standalone `.pto` → `.cpp` → `.so` build flow - Add bilingual README (English and Chinese) documenting operator description, tiling parameters, prerequisites, build/run instructions, and available presets - Register `python/gemm_performance/` in the top-level `kernels/` README ## Testing - [x] Correctness test passes on A2/A3 hardware (`python3 run_gemm.py`) - [x] Benchmark mode reports expected TFLOPS (`python3 run_gemm.py --benchmark`) - [x] Pre-commit hooks pass (ruff, pyright, markdownlint)
Summary
kernels/python/gemm_performance/with a PTO-DSL GEMM kernel(
gemm_performance.py) targeting A2/A3, featuring 2D multi-core tiling,L1 double buffering, L0 ping-pong buffers, and a four-stage
LOAD→EXTRACT→MATMUL→STORE pipeline
run_gemm.pyrunner with multiple shape presets, auto-build(ptoas + bisheng), Torch-NPU correctness validation, and benchmark
with optional
torch.matmulbaseline comparisoncompile.shfor standalone.pto→.cpp→.sobuild flowdescription, tiling parameters, prerequisites, build/run instructions,
and available presets
python/gemm_performance/in the top-levelkernels/READMETesting
python3 run_gemm.py)python3 run_gemm.py --benchmark)