-
Notifications
You must be signed in to change notification settings - Fork 13
Port single-core scan from pto-kernels #113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
c83f830
Port single-core scan from pto-kernels
vloncar e6893f0
Follow-up the compile flag changes to comile tests
vloncar 116ddd3
remove printops
vloncar 60257ee
Merge remote-tracking branch 'upstream/main' into scan
vloncar 021d5f2
Merge branch 'main' into scan
vloncar 061c53f
Add TPush/TPop example of scan algorithm
vloncar 8b08ac1
Fix compilation of FA example
vloncar File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| # Single core prefix sum (scan) | ||
|
|
||
| An implementation of prefix sum (scan) algorithm, based on https://arxiv.org/abs/2505.15112v1. Only the single core algorithm is implemented (ScanU from the paper). | ||
|
|
||
| ## Algorithm | ||
|
|
||
| The ScanU algorithm computes the prefix sum (cumulative sum) of a 1-D input vector by decomposing it into tiles and leveraging the Cube unit's matrix multiply for parallelism within each tile: | ||
|
|
||
| 1. **Reshape** the flat input vector into tiles of shape `TILE_SIZE × TILE_SIZE`. | ||
| 2. **Precompute** an upper-triangular matrix `U` of 1s (shape `TILE_SIZE × TILE_SIZE`). | ||
| 3. **For each tile** `X_i`: | ||
| - **Cube** computes `C_i = X_i @ U`. Each row of `C_i` contains partial prefix sums within that row of the tile. | ||
| - **Vector** adds a running scalar sum to every element of each row (the cross-tile carry), then extracts the last element of the row as the new running sum. | ||
| 4. The concatenation of all processed tiles is the full prefix sum. | ||
|
|
||
| The Cube and Vector units run concurrently but must synchronize: the Cube cannot overwrite the next tile's result before the Vector has finished processing the current one. | ||
|
|
||
| ## Implementations | ||
|
|
||
| There are two implementations that differ only in how Cube ↔ Vector synchronization is achieved. The algorithm logic, tile types, memory layouts, and test harness are identical. | ||
|
|
||
| ### `run_scan_single_core.py` — TSync (sync_set / sync_wait) | ||
|
|
||
| Uses a single function with `cube_section` / `vector_section` blocks. Synchronization is performed with the low-level `sync_set` / `sync_wait` primitives operating on `PIPE_FIX` and `PIPE_MTE3`: | ||
|
|
||
| - Cube signals Vector via `pto.sync_set(pto.PIPE_FIX, 0)` after storing the matmul result to GM. | ||
| - Vector waits via `pto.sync_wait(pto.PIPE_FIX, 0)`, processes the tile, then signals back via `pto.sync_set(pto.PIPE_MTE3, 1)`. | ||
| - Cube waits for the Vector's acknowledgement via `pto.sync_wait(pto.PIPE_MTE3, 1)` before advancing. | ||
|
|
||
| ```bash | ||
| python ./run_scan_single_core.py | ||
| ``` | ||
|
|
||
| ### `run_scan_single_core_tpushpop.py` — TPush / TPop | ||
|
|
||
| Uses the structured multi-function module pattern with separate `cube_kernel` and `vector_kernel` functions. Synchronization uses two unidirectional TPush/TPop pipes: | ||
|
|
||
| | Pipe | Direction | `dir_mask` | Purpose | | ||
| |------|-----------|------------|---------| | ||
| | C2V | Cube → Vector | 1 | Sends the matmul ACC tile directly to Vector VEC memory | | ||
| | V2C | Vector → Cube | 2 | Sends a dummy signal tile back for back-pressure | | ||
|
|
||
| Both pipes use GM-staged L2G2L transport with `slot_num=8`. | ||
|
|
||
| The Cube pushes the ACC tile via `tpush_to_aiv`, then blocks on `tpop_from_aiv` (waiting for the Vector's V2C signal). The Vector pops with `tpop_from_aic`, stores the tile to GM, processes rows with the running sum, then pushes a signal via `tpush_to_aic` and frees the C2V slot with `tfree_from_aic`. | ||
|
|
||
| ```bash | ||
| python ./run_scan_single_core_tpushpop.py | ||
| ``` | ||
|
|
||
| ## Differences between TSync and TPush/TPop | ||
|
|
||
| | Aspect | TSync | TPush/TPop | | ||
| |--------|-------|------------| | ||
| | Code structure | Single function with `cube_section` / `vector_section` | Separate `@pto.func(kernel="cube")` and `@pto.func(kernel="vector")` functions | | ||
| | Sync mechanism | `sync_set` / `sync_wait` on hardware pipes | `tpush` / `tpop` / `tfree` on logical pipe handles | | ||
| | Data transfer | Cube stores to GM, Vector loads from GM | Cube pushes ACC tile through pipe, Vector pops into VEC memory, then stores to GM | | ||
| | Back-pressure | Explicit `sync_wait` on `PIPE_MTE3` | V2C pipe with dummy signal tile | | ||
| | GM slot buffer | Not needed | Required (FIFO staging area for both pipes) | | ||
| | Address management | N/A | `reserve_buffer` / `import_reserved_buffer` for cross-kernel FIFO address sharing | | ||
| | Insert sync | `enable_insert_sync=False` (manual sync) | `--enable-insert-sync` (PTOAS auto-inserts intra-pipe sync) | | ||
|
|
||
| ## Implementation notes | ||
|
|
||
| - The running sum calculation is stored in a tile (`sumTile1x8`) to avoid issues with the compiler removing it as unused code. | ||
| - Due to the inability to synchronize on the scalar pipe with `record_wait_pair`, a `barrier(PIPE_ALL)` is used as a workaround before extracting the last row element. | ||
| - **`record_wait_pair` and `--enable-insert-sync` are mutually exclusive.** The TSync version uses manual sync (`enable_insert_sync=False` + explicit `record_wait_pair` calls). The TPush/TPop version uses auto sync (`--enable-insert-sync`; PTOAS InsertSync pass inserts all needed `set_flag`/`wait_flag` pairs). Mixing both causes event ID collisions that lead to data races and non-deterministic results. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,306 @@ | ||
| import torch | ||
| import torch_npu | ||
| from ptodsl import jit, pto, tile | ||
| from ptodsl import scalar as s | ||
| from ptodsl.npu_info import get_test_device | ||
|
|
||
| TILE_SIZE = 64 | ||
|
|
||
| const = s.const | ||
|
|
||
|
|
||
| def meta_data(): | ||
| dtype = pto.float32 | ||
| ptr_type = pto.PtrType(dtype) | ||
| ffts_type = pto.ffts_type | ||
| len_type = pto.int32 | ||
|
|
||
| tensor_type = pto.TensorType(rank=2, dtype=dtype) | ||
|
|
||
| subtensor_type_u = pto.SubTensorType(shape=[TILE_SIZE, TILE_SIZE], dtype=dtype) | ||
| subtensor_type_a = pto.SubTensorType(shape=[TILE_SIZE, TILE_SIZE], dtype=dtype) | ||
| subtensor_type_c = pto.SubTensorType(shape=[TILE_SIZE, TILE_SIZE], dtype=dtype) | ||
| subtensor_type_row = pto.SubTensorType(shape=[1, TILE_SIZE], dtype=dtype) | ||
|
|
||
| tile_cfg_mat = pto.TileBufConfig(blayout="ColMajor", slayout="RowMajor") | ||
| tile_cfg_left = pto.TileBufConfig(blayout="RowMajor", slayout="RowMajor") | ||
| tile_cfg_right = pto.TileBufConfig(blayout="RowMajor", slayout="ColMajor") | ||
| tile_cfg_acc = pto.TileBufConfig(blayout="ColMajor", slayout="RowMajor") | ||
|
|
||
| tile_type_a_l1 = pto.TileBufType( | ||
| shape=[TILE_SIZE, TILE_SIZE], | ||
| dtype=dtype, | ||
| memory_space="MAT", | ||
| config=tile_cfg_mat, | ||
| ) | ||
| tile_type_u_l1 = pto.TileBufType( | ||
| shape=[TILE_SIZE, TILE_SIZE], | ||
| dtype=dtype, | ||
| memory_space="MAT", | ||
| config=tile_cfg_mat, | ||
| ) | ||
|
|
||
| tile_type_a = pto.TileBufType( | ||
| shape=[TILE_SIZE, TILE_SIZE], | ||
| dtype=dtype, | ||
| memory_space="LEFT", | ||
| config=tile_cfg_left, | ||
| ) | ||
| tile_type_u = pto.TileBufType( | ||
| shape=[TILE_SIZE, TILE_SIZE], | ||
| dtype=dtype, | ||
| memory_space="RIGHT", | ||
| config=tile_cfg_right, | ||
| ) | ||
| tile_type_c = pto.TileBufType( | ||
| shape=[TILE_SIZE, TILE_SIZE], | ||
| dtype=dtype, | ||
| memory_space="ACC", | ||
| config=tile_cfg_acc, | ||
| ) | ||
|
|
||
| tile_type_row = pto.TileBufType( | ||
| shape=[1, TILE_SIZE], | ||
| valid_shape=[1, TILE_SIZE], | ||
| dtype=dtype, | ||
| memory_space="VEC", | ||
| config=pto.TileBufConfig(), | ||
| ) | ||
| tile_type_1x8 = pto.TileBufType( | ||
| shape=[1, 8], | ||
| valid_shape=[1, 8], | ||
| dtype=dtype, | ||
| memory_space="VEC", | ||
| config=pto.TileBufConfig(), | ||
| ) | ||
|
|
||
| return { | ||
| "ptr_type": ptr_type, | ||
| "ffts_type": ffts_type, | ||
| "len_type": len_type, | ||
| "tensor_type": tensor_type, | ||
| "subtensor_type_u": subtensor_type_u, | ||
| "subtensor_type_a": subtensor_type_a, | ||
| "subtensor_type_c": subtensor_type_c, | ||
| "subtensor_type_row": subtensor_type_row, | ||
| "tile_type_a_l1": tile_type_a_l1, | ||
| "tile_type_u_l1": tile_type_u_l1, | ||
| "tile_type_a": tile_type_a, | ||
| "tile_type_u": tile_type_u, | ||
| "tile_type_c": tile_type_c, | ||
| "tile_type_row": tile_type_row, | ||
| "tile_type_1x8": tile_type_1x8, | ||
| } | ||
|
|
||
|
|
||
| @jit(meta_data=meta_data, block_dim=1, enable_insert_sync=False, init_ffts="ffts_addr") | ||
| def run_scan_kernel( | ||
| x_ptr: "ptr_type", | ||
| y_ptr: "ptr_type", | ||
| u_ptr: "ptr_type", | ||
| total_len_i32: "len_type", | ||
| ) -> None: | ||
| c0 = const(0) | ||
| c1 = const(1) | ||
| cLAST_ROW_ELEM = const(TILE_SIZE - 1) | ||
| cTILE_SIZE = const(TILE_SIZE) | ||
| cN_TILE_ELEM = const(TILE_SIZE * TILE_SIZE) | ||
| c0f = const(0.0, s.float32) | ||
| c0i64 = const(0, s.int64) | ||
|
|
||
| total_len = s.index_cast(total_len_i32) | ||
| num_tiles = total_len // cN_TILE_ELEM | ||
|
|
||
| with pto.cube_section(): | ||
| tvX_cube = pto.as_tensor( | ||
| tensor_type, | ||
| ptr=x_ptr, | ||
| shape=[cTILE_SIZE, cTILE_SIZE], | ||
| strides=[cTILE_SIZE, cTILE_SIZE], | ||
| layout="ND", | ||
| ) | ||
| tvU = pto.as_tensor( | ||
| tensor_type, | ||
| ptr=u_ptr, | ||
| shape=[cTILE_SIZE, cTILE_SIZE], | ||
| strides=[cTILE_SIZE, cTILE_SIZE], | ||
| layout="ND", | ||
| ) | ||
| tvOut_cube = pto.as_tensor( | ||
| tensor_type, | ||
| ptr=y_ptr, | ||
| shape=[cTILE_SIZE, cTILE_SIZE], | ||
| strides=[cTILE_SIZE, cTILE_SIZE], | ||
| layout="ND", | ||
| ) | ||
|
|
||
| uTileL1 = pto.alloc_tile(tile_type_u_l1) | ||
| uTile = pto.alloc_tile(tile_type_u) | ||
| aTileL1 = pto.alloc_tile(tile_type_a_l1) | ||
| aTile = pto.alloc_tile(tile_type_a) | ||
| cTile = pto.alloc_tile(tile_type_c) | ||
|
|
||
| svU = pto.slice_view( | ||
| subtensor_type_u, | ||
| source=tvU, | ||
| offsets=[c0, c0], | ||
| sizes=[cTILE_SIZE, cTILE_SIZE], | ||
| ) | ||
| pto.load(svU, uTileL1) | ||
| pto.record_wait_pair("LOAD", "MOV_M2L", 0) | ||
|
|
||
| tile.mov(uTileL1, uTile) | ||
| pto.record_wait_pair("MOV_M2L", "MATMUL", 0) | ||
|
|
||
| for tile_idx in pto.range(c0, num_tiles, c1): | ||
| offset = tile_idx * cTILE_SIZE | ||
| svX = pto.slice_view( | ||
| subtensor_type_a, | ||
| source=tvX_cube, | ||
| offsets=[offset, c0], | ||
| sizes=[cTILE_SIZE, cTILE_SIZE], | ||
| ) | ||
| svOut = pto.slice_view( | ||
| subtensor_type_c, | ||
| source=tvOut_cube, | ||
| offsets=[offset, c0], | ||
| sizes=[cTILE_SIZE, cTILE_SIZE], | ||
| ) | ||
|
|
||
| pto.load(svX, aTileL1) | ||
| pto.record_wait_pair("LOAD", "MOV_M2L", 1) | ||
|
|
||
| tile.mov(aTileL1, aTile) | ||
| pto.record_wait_pair("MOV_M2L", "MATMUL", 1) | ||
|
|
||
| tile.matmul(aTile, uTile, cTile) | ||
| pto.record_wait_pair("MATMUL", "STORE_ACC", 1) | ||
|
|
||
| pto.store(cTile, svOut) | ||
| pto.record_wait_pair("STORE_ACC", "LOAD", 2) | ||
|
|
||
| pto.sync_set(pto.PIPE_FIX, 0) | ||
| pto.sync_wait(pto.PIPE_MTE3, 1) | ||
|
|
||
| with pto.vector_section(): | ||
| tvOut_vec = pto.as_tensor( | ||
| tensor_type, | ||
| ptr=y_ptr, | ||
| shape=[total_len // cTILE_SIZE, cTILE_SIZE], | ||
| strides=[cTILE_SIZE, c1], | ||
| ) | ||
|
|
||
| rowTile = pto.alloc_tile(tile_type_row) | ||
| sumTile1x8 = pto.alloc_tile(tile_type_1x8) | ||
| tile.setval(sumTile1x8, c0, c0f) | ||
|
|
||
| for tile_idx in pto.range(c0, num_tiles, c1): | ||
| pto.sync_wait(pto.PIPE_FIX, 0) | ||
|
|
||
| vid = pto.get_subblock_idx() | ||
|
|
||
| with pto.if_context(vid == c0i64): | ||
| tile_offset = tile_idx * cTILE_SIZE | ||
| for r in pto.range(c0, cTILE_SIZE, c1): | ||
| offset = tile_offset + r | ||
| svRow = pto.slice_view( | ||
| subtensor_type_row, | ||
| source=tvOut_vec, | ||
| offsets=[offset, c0], | ||
| sizes=[c1, cTILE_SIZE], | ||
| ) | ||
|
|
||
| pto.load(svRow, rowTile) | ||
|
|
||
| pto.record_wait_pair("LOAD", "VEC", 2) | ||
|
|
||
| # Extract the stateful running_sum from our memory buffer | ||
| running_sum = tile.getval(sumTile1x8, c0, dtype=s.float32) | ||
|
|
||
| tile.adds(rowTile, running_sum, rowTile) | ||
| # Ideally we would synchronize PIPE_S and PIPE_V here, but that is not currently possible | ||
| # with pto.record_wait_pair, instead we use a barrier | ||
| # pto.record_wait_pair("PIPE_V", "PIPE_S", 2) | ||
| pto.barrier(pto.PIPE_ALL) | ||
| running_sum_next = tile.getval( | ||
| rowTile, cLAST_ROW_ELEM, dtype=s.float32 | ||
| ) | ||
| # Persist the new running sum back to the memory buffer to loop-carry | ||
| tile.setval(sumTile1x8, c0, running_sum_next) | ||
| pto.record_wait_pair("VEC", "STORE_VEC", 2) | ||
|
|
||
| pto.store(rowTile, svRow) | ||
|
|
||
| pto.record_wait_pair("STORE_VEC", "LOAD", 3) | ||
|
|
||
| pto.record_wait_pair("LOAD", "VEC", 3) | ||
|
|
||
| pto.sync_set(pto.PIPE_MTE3, 1) | ||
|
|
||
|
|
||
| def test_scan(n_tiles=64): | ||
| device = get_test_device() | ||
| torch.npu.set_device(device) | ||
|
|
||
| total_len = TILE_SIZE * TILE_SIZE * n_tiles | ||
| torch.manual_seed(0) | ||
| dtype = torch.float32 | ||
|
|
||
| # Prepare Inputs | ||
| x = torch.rand(size=(total_len,), device=device, dtype=dtype).contiguous() | ||
| y = torch.zeros_like(x) | ||
|
|
||
| # Generate upper triangular matrix of 1s (s x s) | ||
| u = torch.triu( | ||
| torch.ones((TILE_SIZE, TILE_SIZE), device=device, dtype=dtype) | ||
| ).contiguous() | ||
|
|
||
| # Expected PyTorch computation | ||
| expected_scan = torch.cumsum(x.cpu(), dim=0) | ||
|
|
||
| # NPU JIT Kernel execution | ||
| repeat_runs = 20 | ||
| print( | ||
| f"Running scan for {total_len} elements ({n_tiles} {TILE_SIZE}x{TILE_SIZE} tiles)" | ||
| ) | ||
| actual_scan = [] | ||
| for _ in range(repeat_runs): | ||
| y.zero_() | ||
| run_scan_kernel(x, y, u, total_len) | ||
| actual_scan.append(y.cpu().clone()) | ||
|
|
||
| torch.npu.synchronize() | ||
|
|
||
| # Check for consistency across runs and correctness against the expected count | ||
| repeat_results = [] | ||
| for i, scan in enumerate(actual_scan): | ||
| are_close = torch.allclose(scan, expected_scan, rtol=1e-3, atol=1e-3) | ||
| if not are_close: | ||
| unequal_count = torch.sum(scan != expected_scan) | ||
| else: | ||
| unequal_count = 0 | ||
| repeat_results.append([are_close, unequal_count]) | ||
|
|
||
| has_mismatch = any(not eq for eq, _ in repeat_results) | ||
| if has_mismatch: | ||
| print("Expected:\n", expected_scan[-10:]) | ||
| for i, result in enumerate(repeat_results): | ||
| eq, count = result | ||
| if not eq: | ||
| print( | ||
| f"Inconsistent results run {i}, different elements: {count}/{total_len}. Sample:" | ||
| ) | ||
| print(actual_scan[i][-10:]) | ||
| raise AssertionError( | ||
| f"Scan mismatch for tile_size={tile_size}, total_len={total_len} ({n_tiles} tiles)" | ||
| ) | ||
|
|
||
| print("All results matched. Scan test passed successfully.\n") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test_scan(1) | ||
| test_scan(16) | ||
| test_scan(64) | ||
| test_scan(100) | ||
| test_scan(1000) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the
with pto.vector_section()syntax actually works for mix kernels as long as using manual sync + explicit sync_set/sync_wait... The syntax is quite different from the entry + subfunction + push/pop + auto-sync as in #98 . So we have two styles to write mix kernels now. They can co-exist for now, but will need to unify the style in the future.