diff --git a/examples/aot/flash_attention/caller.cpp b/examples/aot/flash_attention/caller.cpp index 87546451..abdbfb9f 100644 --- a/examples/aot/flash_attention/caller.cpp +++ b/examples/aot/flash_attention/caller.cpp @@ -23,7 +23,7 @@ extern "C" void call_kernel( (void)fftsLen; call_both<<>>( - (__gm__ int64_t *)fftsAddr, + (__gm__ uint64_t *)fftsAddr, (__gm__ float *)gmSlotBuffer, (__gm__ half *)q, (__gm__ half *)k, diff --git a/examples/aot/tpushpop/mix-kernel_mlir/caller.cpp b/examples/aot/tpushpop/mix-kernel_mlir/caller.cpp index b8a9e8b2..02fa4b50 100644 --- a/examples/aot/tpushpop/mix-kernel_mlir/caller.cpp +++ b/examples/aot/tpushpop/mix-kernel_mlir/caller.cpp @@ -21,7 +21,7 @@ extern "C" void call_kernel( (void)fftsLen; call_both<<>>( - (__gm__ int64_t *)fftsAddr, + (__gm__ uint64_t *)fftsAddr, (__gm__ float *)gmSlotBuffer, (__gm__ float *)x, (__gm__ float *)y); diff --git a/examples/jit/scan/README.md b/examples/jit/scan/README.md new file mode 100644 index 00000000..2ef01a19 --- /dev/null +++ b/examples/jit/scan/README.md @@ -0,0 +1,67 @@ +# Single core prefix sum (scan) + +An implementation of prefix sum (scan) algorithm, based on https://arxiv.org/abs/2505.15112v1. Only the single core algorithm is implemented (ScanU from the paper). + +## Algorithm + +The ScanU algorithm computes the prefix sum (cumulative sum) of a 1-D input vector by decomposing it into tiles and leveraging the Cube unit's matrix multiply for parallelism within each tile: + +1. **Reshape** the flat input vector into tiles of shape `TILE_SIZE × TILE_SIZE`. +2. **Precompute** an upper-triangular matrix `U` of 1s (shape `TILE_SIZE × TILE_SIZE`). +3. **For each tile** `X_i`: + - **Cube** computes `C_i = X_i @ U`. Each row of `C_i` contains partial prefix sums within that row of the tile. + - **Vector** adds a running scalar sum to every element of each row (the cross-tile carry), then extracts the last element of the row as the new running sum. +4. The concatenation of all processed tiles is the full prefix sum. + +The Cube and Vector units run concurrently but must synchronize: the Cube cannot overwrite the next tile's result before the Vector has finished processing the current one. + +## Implementations + +There are two implementations that differ only in how Cube ↔ Vector synchronization is achieved. The algorithm logic, tile types, memory layouts, and test harness are identical. + +### `run_scan_single_core.py` — TSync (sync_set / sync_wait) + +Uses a single function with `cube_section` / `vector_section` blocks. Synchronization is performed with the low-level `sync_set` / `sync_wait` primitives operating on `PIPE_FIX` and `PIPE_MTE3`: + +- Cube signals Vector via `pto.sync_set(pto.PIPE_FIX, 0)` after storing the matmul result to GM. +- Vector waits via `pto.sync_wait(pto.PIPE_FIX, 0)`, processes the tile, then signals back via `pto.sync_set(pto.PIPE_MTE3, 1)`. +- Cube waits for the Vector's acknowledgement via `pto.sync_wait(pto.PIPE_MTE3, 1)` before advancing. + +```bash +python ./run_scan_single_core.py +``` + +### `run_scan_single_core_tpushpop.py` — TPush / TPop + +Uses the structured multi-function module pattern with separate `cube_kernel` and `vector_kernel` functions. Synchronization uses two unidirectional TPush/TPop pipes: + +| Pipe | Direction | `dir_mask` | Purpose | +|------|-----------|------------|---------| +| C2V | Cube → Vector | 1 | Sends the matmul ACC tile directly to Vector VEC memory | +| V2C | Vector → Cube | 2 | Sends a dummy signal tile back for back-pressure | + +Both pipes use GM-staged L2G2L transport with `slot_num=8`. + +The Cube pushes the ACC tile via `tpush_to_aiv`, then blocks on `tpop_from_aiv` (waiting for the Vector's V2C signal). The Vector pops with `tpop_from_aic`, stores the tile to GM, processes rows with the running sum, then pushes a signal via `tpush_to_aic` and frees the C2V slot with `tfree_from_aic`. + +```bash +python ./run_scan_single_core_tpushpop.py +``` + +## Differences between TSync and TPush/TPop + +| Aspect | TSync | TPush/TPop | +|--------|-------|------------| +| Code structure | Single function with `cube_section` / `vector_section` | Separate `@pto.func(kernel="cube")` and `@pto.func(kernel="vector")` functions | +| Sync mechanism | `sync_set` / `sync_wait` on hardware pipes | `tpush` / `tpop` / `tfree` on logical pipe handles | +| Data transfer | Cube stores to GM, Vector loads from GM | Cube pushes ACC tile through pipe, Vector pops into VEC memory, then stores to GM | +| Back-pressure | Explicit `sync_wait` on `PIPE_MTE3` | V2C pipe with dummy signal tile | +| GM slot buffer | Not needed | Required (FIFO staging area for both pipes) | +| Address management | N/A | `reserve_buffer` / `import_reserved_buffer` for cross-kernel FIFO address sharing | +| Insert sync | `enable_insert_sync=False` (manual sync) | `--enable-insert-sync` (PTOAS auto-inserts intra-pipe sync) | + +## Implementation notes + +- The running sum calculation is stored in a tile (`sumTile1x8`) to avoid issues with the compiler removing it as unused code. +- Due to the inability to synchronize on the scalar pipe with `record_wait_pair`, a `barrier(PIPE_ALL)` is used as a workaround before extracting the last row element. +- **`record_wait_pair` and `--enable-insert-sync` are mutually exclusive.** The TSync version uses manual sync (`enable_insert_sync=False` + explicit `record_wait_pair` calls). The TPush/TPop version uses auto sync (`--enable-insert-sync`; PTOAS InsertSync pass inserts all needed `set_flag`/`wait_flag` pairs). Mixing both causes event ID collisions that lead to data races and non-deterministic results. diff --git a/examples/jit/scan/run_scan_single_core.py b/examples/jit/scan/run_scan_single_core.py new file mode 100644 index 00000000..8ecdcde6 --- /dev/null +++ b/examples/jit/scan/run_scan_single_core.py @@ -0,0 +1,306 @@ +import torch +import torch_npu +from ptodsl import jit, pto, tile +from ptodsl import scalar as s +from ptodsl.npu_info import get_test_device + +TILE_SIZE = 64 + +const = s.const + + +def meta_data(): + dtype = pto.float32 + ptr_type = pto.PtrType(dtype) + ffts_type = pto.ffts_type + len_type = pto.int32 + + tensor_type = pto.TensorType(rank=2, dtype=dtype) + + subtensor_type_u = pto.SubTensorType(shape=[TILE_SIZE, TILE_SIZE], dtype=dtype) + subtensor_type_a = pto.SubTensorType(shape=[TILE_SIZE, TILE_SIZE], dtype=dtype) + subtensor_type_c = pto.SubTensorType(shape=[TILE_SIZE, TILE_SIZE], dtype=dtype) + subtensor_type_row = pto.SubTensorType(shape=[1, TILE_SIZE], dtype=dtype) + + tile_cfg_mat = pto.TileBufConfig(blayout="ColMajor", slayout="RowMajor") + tile_cfg_left = pto.TileBufConfig(blayout="RowMajor", slayout="RowMajor") + tile_cfg_right = pto.TileBufConfig(blayout="RowMajor", slayout="ColMajor") + tile_cfg_acc = pto.TileBufConfig(blayout="ColMajor", slayout="RowMajor") + + tile_type_a_l1 = pto.TileBufType( + shape=[TILE_SIZE, TILE_SIZE], + dtype=dtype, + memory_space="MAT", + config=tile_cfg_mat, + ) + tile_type_u_l1 = pto.TileBufType( + shape=[TILE_SIZE, TILE_SIZE], + dtype=dtype, + memory_space="MAT", + config=tile_cfg_mat, + ) + + tile_type_a = pto.TileBufType( + shape=[TILE_SIZE, TILE_SIZE], + dtype=dtype, + memory_space="LEFT", + config=tile_cfg_left, + ) + tile_type_u = pto.TileBufType( + shape=[TILE_SIZE, TILE_SIZE], + dtype=dtype, + memory_space="RIGHT", + config=tile_cfg_right, + ) + tile_type_c = pto.TileBufType( + shape=[TILE_SIZE, TILE_SIZE], + dtype=dtype, + memory_space="ACC", + config=tile_cfg_acc, + ) + + tile_type_row = pto.TileBufType( + shape=[1, TILE_SIZE], + valid_shape=[1, TILE_SIZE], + dtype=dtype, + memory_space="VEC", + config=pto.TileBufConfig(), + ) + tile_type_1x8 = pto.TileBufType( + shape=[1, 8], + valid_shape=[1, 8], + dtype=dtype, + memory_space="VEC", + config=pto.TileBufConfig(), + ) + + return { + "ptr_type": ptr_type, + "ffts_type": ffts_type, + "len_type": len_type, + "tensor_type": tensor_type, + "subtensor_type_u": subtensor_type_u, + "subtensor_type_a": subtensor_type_a, + "subtensor_type_c": subtensor_type_c, + "subtensor_type_row": subtensor_type_row, + "tile_type_a_l1": tile_type_a_l1, + "tile_type_u_l1": tile_type_u_l1, + "tile_type_a": tile_type_a, + "tile_type_u": tile_type_u, + "tile_type_c": tile_type_c, + "tile_type_row": tile_type_row, + "tile_type_1x8": tile_type_1x8, + } + + +@jit(meta_data=meta_data, block_dim=1, enable_insert_sync=False, init_ffts="ffts_addr") +def run_scan_kernel( + x_ptr: "ptr_type", + y_ptr: "ptr_type", + u_ptr: "ptr_type", + total_len_i32: "len_type", +) -> None: + c0 = const(0) + c1 = const(1) + cLAST_ROW_ELEM = const(TILE_SIZE - 1) + cTILE_SIZE = const(TILE_SIZE) + cN_TILE_ELEM = const(TILE_SIZE * TILE_SIZE) + c0f = const(0.0, s.float32) + c0i64 = const(0, s.int64) + + total_len = s.index_cast(total_len_i32) + num_tiles = total_len // cN_TILE_ELEM + + with pto.cube_section(): + tvX_cube = pto.as_tensor( + tensor_type, + ptr=x_ptr, + shape=[cTILE_SIZE, cTILE_SIZE], + strides=[cTILE_SIZE, cTILE_SIZE], + layout="ND", + ) + tvU = pto.as_tensor( + tensor_type, + ptr=u_ptr, + shape=[cTILE_SIZE, cTILE_SIZE], + strides=[cTILE_SIZE, cTILE_SIZE], + layout="ND", + ) + tvOut_cube = pto.as_tensor( + tensor_type, + ptr=y_ptr, + shape=[cTILE_SIZE, cTILE_SIZE], + strides=[cTILE_SIZE, cTILE_SIZE], + layout="ND", + ) + + uTileL1 = pto.alloc_tile(tile_type_u_l1) + uTile = pto.alloc_tile(tile_type_u) + aTileL1 = pto.alloc_tile(tile_type_a_l1) + aTile = pto.alloc_tile(tile_type_a) + cTile = pto.alloc_tile(tile_type_c) + + svU = pto.slice_view( + subtensor_type_u, + source=tvU, + offsets=[c0, c0], + sizes=[cTILE_SIZE, cTILE_SIZE], + ) + pto.load(svU, uTileL1) + pto.record_wait_pair("LOAD", "MOV_M2L", 0) + + tile.mov(uTileL1, uTile) + pto.record_wait_pair("MOV_M2L", "MATMUL", 0) + + for tile_idx in pto.range(c0, num_tiles, c1): + offset = tile_idx * cTILE_SIZE + svX = pto.slice_view( + subtensor_type_a, + source=tvX_cube, + offsets=[offset, c0], + sizes=[cTILE_SIZE, cTILE_SIZE], + ) + svOut = pto.slice_view( + subtensor_type_c, + source=tvOut_cube, + offsets=[offset, c0], + sizes=[cTILE_SIZE, cTILE_SIZE], + ) + + pto.load(svX, aTileL1) + pto.record_wait_pair("LOAD", "MOV_M2L", 1) + + tile.mov(aTileL1, aTile) + pto.record_wait_pair("MOV_M2L", "MATMUL", 1) + + tile.matmul(aTile, uTile, cTile) + pto.record_wait_pair("MATMUL", "STORE_ACC", 1) + + pto.store(cTile, svOut) + pto.record_wait_pair("STORE_ACC", "LOAD", 2) + + pto.sync_set(pto.PIPE_FIX, 0) + pto.sync_wait(pto.PIPE_MTE3, 1) + + with pto.vector_section(): + tvOut_vec = pto.as_tensor( + tensor_type, + ptr=y_ptr, + shape=[total_len // cTILE_SIZE, cTILE_SIZE], + strides=[cTILE_SIZE, c1], + ) + + rowTile = pto.alloc_tile(tile_type_row) + sumTile1x8 = pto.alloc_tile(tile_type_1x8) + tile.setval(sumTile1x8, c0, c0f) + + for tile_idx in pto.range(c0, num_tiles, c1): + pto.sync_wait(pto.PIPE_FIX, 0) + + vid = pto.get_subblock_idx() + + with pto.if_context(vid == c0i64): + tile_offset = tile_idx * cTILE_SIZE + for r in pto.range(c0, cTILE_SIZE, c1): + offset = tile_offset + r + svRow = pto.slice_view( + subtensor_type_row, + source=tvOut_vec, + offsets=[offset, c0], + sizes=[c1, cTILE_SIZE], + ) + + pto.load(svRow, rowTile) + + pto.record_wait_pair("LOAD", "VEC", 2) + + # Extract the stateful running_sum from our memory buffer + running_sum = tile.getval(sumTile1x8, c0, dtype=s.float32) + + tile.adds(rowTile, running_sum, rowTile) + # Ideally we would synchronize PIPE_S and PIPE_V here, but that is not currently possible + # with pto.record_wait_pair, instead we use a barrier + # pto.record_wait_pair("PIPE_V", "PIPE_S", 2) + pto.barrier(pto.PIPE_ALL) + running_sum_next = tile.getval( + rowTile, cLAST_ROW_ELEM, dtype=s.float32 + ) + # Persist the new running sum back to the memory buffer to loop-carry + tile.setval(sumTile1x8, c0, running_sum_next) + pto.record_wait_pair("VEC", "STORE_VEC", 2) + + pto.store(rowTile, svRow) + + pto.record_wait_pair("STORE_VEC", "LOAD", 3) + + pto.record_wait_pair("LOAD", "VEC", 3) + + pto.sync_set(pto.PIPE_MTE3, 1) + + +def test_scan(n_tiles=64): + device = get_test_device() + torch.npu.set_device(device) + + total_len = TILE_SIZE * TILE_SIZE * n_tiles + torch.manual_seed(0) + dtype = torch.float32 + + # Prepare Inputs + x = torch.rand(size=(total_len,), device=device, dtype=dtype).contiguous() + y = torch.zeros_like(x) + + # Generate upper triangular matrix of 1s (s x s) + u = torch.triu( + torch.ones((TILE_SIZE, TILE_SIZE), device=device, dtype=dtype) + ).contiguous() + + # Expected PyTorch computation + expected_scan = torch.cumsum(x.cpu(), dim=0) + + # NPU JIT Kernel execution + repeat_runs = 20 + print( + f"Running scan for {total_len} elements ({n_tiles} {TILE_SIZE}x{TILE_SIZE} tiles)" + ) + actual_scan = [] + for _ in range(repeat_runs): + y.zero_() + run_scan_kernel(x, y, u, total_len) + actual_scan.append(y.cpu().clone()) + + torch.npu.synchronize() + + # Check for consistency across runs and correctness against the expected count + repeat_results = [] + for i, scan in enumerate(actual_scan): + are_close = torch.allclose(scan, expected_scan, rtol=1e-3, atol=1e-3) + if not are_close: + unequal_count = torch.sum(scan != expected_scan) + else: + unequal_count = 0 + repeat_results.append([are_close, unequal_count]) + + has_mismatch = any(not eq for eq, _ in repeat_results) + if has_mismatch: + print("Expected:\n", expected_scan[-10:]) + for i, result in enumerate(repeat_results): + eq, count = result + if not eq: + print( + f"Inconsistent results run {i}, different elements: {count}/{total_len}. Sample:" + ) + print(actual_scan[i][-10:]) + raise AssertionError( + f"Scan mismatch for tile_size={tile_size}, total_len={total_len} ({n_tiles} tiles)" + ) + + print("All results matched. Scan test passed successfully.\n") + + +if __name__ == "__main__": + test_scan(1) + test_scan(16) + test_scan(64) + test_scan(100) + test_scan(1000) diff --git a/examples/jit/scan/run_scan_single_core_tpushpop.py b/examples/jit/scan/run_scan_single_core_tpushpop.py new file mode 100644 index 00000000..d2053d1e --- /dev/null +++ b/examples/jit/scan/run_scan_single_core_tpushpop.py @@ -0,0 +1,471 @@ +"""Prefix-sum (scan) on a single AI Core — TPush/TPop synchronization. + +This is a port of ``run_scan_single_core.py`` that replaces the legacy +``sync_set`` / ``sync_wait`` (TSync-based) Cube ↔ Vector handshake with the +structured TPush / TPop pipe primitives. + +Algorithm overview (unchanged from the TSync version): + 1. Cube loads U (upper-triangular matrix of 1s) once. + 2. For each input tile X_i: + a. Cube computes C_i = X_i @ U (matmul → partial prefix sums). + b. Cube pushes the ACC tile to Vector via the C2V pipe. + c. Vector pops the tile, adds the running sum to each row, + stores the result back to GM. + d. Vector pushes a (dummy) signal tile back to Cube via V2C pipe + so Cube knows the tile is done and can advance. + +Synchronization: + - C2V pipe (dir_mask=1): Cube pushes ACC tiles to Vector. + - V2C pipe (dir_mask=2): Vector pushes a signal tile back to Cube. + Both pipes use GM-staged L2G2L transport (A2/A3 path). + +Usage: + python run_scan_single_core_tpushpop.py +""" + +import torch +import torch_npu +from ptodsl import jit, pto, tile +from ptodsl import scalar as s +from ptodsl.npu_info import get_test_device + +TILE_SIZE = 64 + +const = s.const + +# --- Pipe parameters --- +# C2V: Cube pushes TILE_SIZE x TILE_SIZE f32 ACC tiles to Vector. +# slot_size = TILE_SIZE * TILE_SIZE * sizeof(f32) +SLOT_SIZE_C2V = TILE_SIZE * TILE_SIZE * 4 +# V2C: Vector pushes a tiny 1x8 f32 VEC signal tile back to Cube. +# We use a minimal 1x8 tile (32 bytes, but slot_size must cover the +# fractal footprint; for a 1×8 f32 VEC tile the minimum is 1*8*4=32 bytes, +# but pto-isa rounds up internally). We keep this at 32 bytes. +SLOT_SIZE_V2C = 1 * 8 * 4 + +# dir_mask=1 or 2 → slot_num=8 (§4.4 of the design doc). +SLOT_NUM = 8 + +# GM buffer sizes (in f32 elements): slot_size * slot_num / sizeof(f32) +GM_C2V_ELEMS = (SLOT_SIZE_C2V * SLOT_NUM) // 4 +GM_V2C_ELEMS = (SLOT_SIZE_V2C * SLOT_NUM) // 4 +GM_TOTAL_ELEMS = GM_C2V_ELEMS + GM_V2C_ELEMS + +# reserve_buffer sizes (in bytes) +C2V_FIFO_BYTES = SLOT_SIZE_C2V * SLOT_NUM # consumer-side VEC buffer +V2C_FIFO_BYTES = SLOT_SIZE_V2C * SLOT_NUM # consumer-side MAT buffer + +# Frontend pipe IDs (arbitrary, must be unique per-function). +ID_C2V = 0 +ID_V2C = 1 + +SPLIT_NONE = 0 + + +def meta_data(): + dtype = pto.float32 + ptr_type = pto.PtrType(dtype) + ffts_type = pto.ffts_type + len_type = pto.int32 + i32 = pto.int32 + + tensor_type = pto.TensorType(rank=2, dtype=dtype) + + subtensor_type_u = pto.SubTensorType(shape=[TILE_SIZE, TILE_SIZE], dtype=dtype) + subtensor_type_a = pto.SubTensorType(shape=[TILE_SIZE, TILE_SIZE], dtype=dtype) + subtensor_type_c = pto.SubTensorType(shape=[TILE_SIZE, TILE_SIZE], dtype=dtype) + subtensor_type_row = pto.SubTensorType(shape=[1, TILE_SIZE], dtype=dtype) + + tile_cfg_mat = pto.TileBufConfig(blayout="ColMajor", slayout="RowMajor") + tile_cfg_left = pto.TileBufConfig(blayout="RowMajor", slayout="RowMajor") + tile_cfg_right = pto.TileBufConfig(blayout="RowMajor", slayout="ColMajor") + tile_cfg_acc = pto.TileBufConfig(blayout="ColMajor", slayout="RowMajor") + + tile_type_a_l1 = pto.TileBufType( + shape=[TILE_SIZE, TILE_SIZE], + dtype=dtype, + memory_space="MAT", + config=tile_cfg_mat, + ) + tile_type_u_l1 = pto.TileBufType( + shape=[TILE_SIZE, TILE_SIZE], + dtype=dtype, + memory_space="MAT", + config=tile_cfg_mat, + ) + + tile_type_a = pto.TileBufType( + shape=[TILE_SIZE, TILE_SIZE], + dtype=dtype, + memory_space="LEFT", + config=tile_cfg_left, + ) + tile_type_u = pto.TileBufType( + shape=[TILE_SIZE, TILE_SIZE], + dtype=dtype, + memory_space="RIGHT", + config=tile_cfg_right, + ) + tile_type_c = pto.TileBufType( + shape=[TILE_SIZE, TILE_SIZE], + dtype=dtype, + memory_space="ACC", + config=tile_cfg_acc, + ) + + # Vector-side tile type for receiving the C2V accumulator pop. + # The ACC tile from cube is popped as a VEC-space tile. + tile_type_c_vec = pto.TileBufType( + shape=[TILE_SIZE, TILE_SIZE], + dtype=dtype, + memory_space="VEC", + config=pto.TileBufConfig(), + ) + + tile_type_row = pto.TileBufType( + shape=[1, TILE_SIZE], + valid_shape=[1, TILE_SIZE], + dtype=dtype, + memory_space="VEC", + config=pto.TileBufConfig(), + ) + tile_type_1x8 = pto.TileBufType( + shape=[1, 8], + valid_shape=[1, 8], + dtype=dtype, + memory_space="VEC", + config=pto.TileBufConfig(), + ) + + # Signal tile: Cube side receives into MAT space for V2C pop. + tile_type_signal_mat = pto.TileBufType( + shape=[1, 8], + valid_shape=[1, 8], + dtype=dtype, + memory_space="MAT", + config=pto.TileBufConfig( + blayout="RowMajor", + slayout="NoneBox", + s_fractal_size=32, + ), + ) + + return locals() + + +@jit( + meta_data=meta_data, + block_dim=1, + module=True, + enable_insert_sync=True, + init_ffts="ffts_addr", +) +def scan_module(): + # --------------------------------------------------------------- + # Cube kernel + # --------------------------------------------------------------- + @pto.func(kernel="cube") + def cube_kernel( + gm_slot_buffer: "ptr_type", + x_ptr: "ptr_type", + u_ptr: "ptr_type", + total_len_i32: "len_type", + ) -> None: + c0 = const(0) + c1 = const(1) + cTILE_SIZE = const(TILE_SIZE) + cN_TILE_ELEM = const(TILE_SIZE * TILE_SIZE) + c0_i32 = const(0, type=i32) + + total_len = s.index_cast(total_len_i32) + num_tiles = total_len // cN_TILE_ELEM + + # --- GM slot buffer partitioning (per-block) --- + # Single-core, so block_idx is always 0, but keep the pattern + # for correctness in a multi-block future. + gm_c2v = gm_slot_buffer + gm_v2c = pto.add_ptr(gm_slot_buffer, const(GM_C2V_ELEMS)) + + # --- Initialize C2V pipe (Cube is producer, dir_mask=1) --- + c2v_import = pto.import_reserved_buffer( + name="c2v_fifo", peer_func="@vector_kernel" + ) + pto.aic_initialize_pipe( + id=ID_C2V, + dir_mask=1, + slot_size=SLOT_SIZE_C2V, + gm_slot_buffer=gm_c2v, + c2v_consumer_buf=c2v_import, + v2c_consumer_buf=c0_i32, + ) + + # --- Initialize V2C pipe (Cube is consumer, dir_mask=2) --- + v2c_local = pto.reserve_buffer( + name="v2c_fifo", size=V2C_FIFO_BYTES, location="MAT" + ) + pto.aic_initialize_pipe( + id=ID_V2C, + dir_mask=2, + slot_size=SLOT_SIZE_V2C, + gm_slot_buffer=gm_v2c, + c2v_consumer_buf=c0_i32, + v2c_consumer_buf=v2c_local, + ) + + # --- Tile allocations --- + uTileL1 = pto.alloc_tile(tile_type_u_l1) + uTile = pto.alloc_tile(tile_type_u) + aTileL1 = pto.alloc_tile(tile_type_a_l1) + aTile = pto.alloc_tile(tile_type_a) + cTile = pto.alloc_tile(tile_type_c) + + # --- Load U matrix once --- + tvU = pto.as_tensor( + tensor_type, + ptr=u_ptr, + shape=[cTILE_SIZE, cTILE_SIZE], + strides=[cTILE_SIZE, cTILE_SIZE], + layout="ND", + ) + svU = pto.slice_view( + subtensor_type_u, + source=tvU, + offsets=[c0, c0], + sizes=[cTILE_SIZE, cTILE_SIZE], + ) + pto.load(svU, uTileL1) + tile.mov(uTileL1, uTile) + + # --- GM tensor views for input --- + tvX = pto.as_tensor( + tensor_type, + ptr=x_ptr, + shape=[cTILE_SIZE, cTILE_SIZE], + strides=[cTILE_SIZE, cTILE_SIZE], + layout="ND", + ) + + for tile_idx in pto.range(c0, num_tiles, c1): + offset = tile_idx * cTILE_SIZE + + svX = pto.slice_view( + subtensor_type_a, + source=tvX, + offsets=[offset, c0], + sizes=[cTILE_SIZE, cTILE_SIZE], + ) + + pto.load(svX, aTileL1) + + tile.mov(aTileL1, aTile) + + tile.matmul(aTile, uTile, cTile) + + pto.tpush_to_aiv(cTile, SPLIT_NONE, id=ID_C2V) + + signal = pto.tpop_from_aiv(tile_type_signal_mat, SPLIT_NONE, id=ID_V2C) + pto.tfree_from_aiv(SPLIT_NONE, id=ID_V2C) + + # --------------------------------------------------------------- + # Vector kernel + # --------------------------------------------------------------- + @pto.func(kernel="vector") + def vector_kernel( + gm_slot_buffer: "ptr_type", + y_ptr: "ptr_type", + total_len_i32: "len_type", + ) -> None: + c0 = const(0) + c1 = const(1) + cLAST_ROW_ELEM = const(TILE_SIZE - 1) + cTILE_SIZE = const(TILE_SIZE) + cN_TILE_ELEM = const(TILE_SIZE * TILE_SIZE) + c0f = const(0.0, s.float32) + c0i64 = const(0, s.int64) + c0_i32 = const(0, type=i32) + + total_len = s.index_cast(total_len_i32) + num_tiles = total_len // cN_TILE_ELEM + + # --- GM slot buffer partitioning (must match cube) --- + gm_c2v = gm_slot_buffer + gm_v2c = pto.add_ptr(gm_slot_buffer, const(GM_C2V_ELEMS)) + + # --- Initialize C2V pipe (Vector is consumer, dir_mask=1) --- + c2v_local = pto.reserve_buffer( + name="c2v_fifo", size=C2V_FIFO_BYTES, location="VEC" + ) + pto.aiv_initialize_pipe( + id=ID_C2V, + dir_mask=1, + slot_size=SLOT_SIZE_C2V, + gm_slot_buffer=gm_c2v, + c2v_consumer_buf=c2v_local, + v2c_consumer_buf=c0_i32, + ) + + # --- Initialize V2C pipe (Vector is producer, dir_mask=2) --- + v2c_import = pto.import_reserved_buffer( + name="v2c_fifo", peer_func="@cube_kernel" + ) + pto.aiv_initialize_pipe( + id=ID_V2C, + dir_mask=2, + slot_size=SLOT_SIZE_V2C, + gm_slot_buffer=gm_v2c, + c2v_consumer_buf=c0_i32, + v2c_consumer_buf=v2c_import, + ) + + tvOut_vec = pto.as_tensor( + tensor_type, + ptr=y_ptr, + shape=[total_len // cTILE_SIZE, cTILE_SIZE], + strides=[cTILE_SIZE, c1], + ) + + rowTile = pto.alloc_tile(tile_type_row) + sumTile1x8 = pto.alloc_tile(tile_type_1x8) + signalTile = pto.alloc_tile(tile_type_1x8) + tile.setval(sumTile1x8, c0, c0f) + + for tile_idx in pto.range(c0, num_tiles, c1): + # Pop the ACC tile from Cube via C2V pipe. + # The tile arrives in VEC space as a full TILE_SIZE x TILE_SIZE tile. + cTile_vec = pto.tpop_from_aic(tile_type_c_vec, SPLIT_NONE, id=ID_C2V) + + vid = pto.get_subblock_idx() + + with pto.if_context(vid == c0i64): + tile_offset = tile_idx * cTILE_SIZE + + # Step 1: Store the full popped tile to GM output. + # This writes the within-tile prefix sums (matmul result). + svOut_full = pto.slice_view( + subtensor_type_c, + source=tvOut_vec, + offsets=[tile_offset, c0], + sizes=[cTILE_SIZE, cTILE_SIZE], + ) + pto.store(cTile_vec, svOut_full) + + # Step 2: Re-read rows from GM and add the running sum. + for r in pto.range(c0, cTILE_SIZE, c1): + offset = tile_offset + r + svRow = pto.slice_view( + subtensor_type_row, + source=tvOut_vec, + offsets=[offset, c0], + sizes=[c1, cTILE_SIZE], + ) + + pto.load(svRow, rowTile) + + # Extract the running sum from the persistent buffer + running_sum = tile.getval(sumTile1x8, c0, dtype=s.float32) + + tile.adds(rowTile, running_sum, rowTile) + # Sync scalar pipe and vector pipe before extracting value + pto.barrier(pto.PIPE_ALL) + running_sum_next = tile.getval( + rowTile, cLAST_ROW_ELEM, dtype=s.float32 + ) + # Persist new running sum + tile.setval(sumTile1x8, c0, running_sum_next) + + pto.store(rowTile, svRow) + + # Free the C2V slot after processing + pto.tfree_from_aic(SPLIT_NONE, id=ID_C2V) + + # Signal Cube that this tile is done (V2C push) + pto.tpush_to_aic(signalTile, SPLIT_NONE, id=ID_V2C) + + # --------------------------------------------------------------- + # Entry point (dispatches to both kernels) + # --------------------------------------------------------------- + @pto.func + def run_scan_tpushpop( + ffts_addr: "ffts_type", + gm_slot_buffer: "ptr_type", + x_ptr: "ptr_type", + y_ptr: "ptr_type", + u_ptr: "ptr_type", + total_len_i32: "len_type", + ) -> None: + pto.set_ffts(ffts_addr) + pto.call(cube_kernel, gm_slot_buffer, x_ptr, u_ptr, total_len_i32) + pto.call(vector_kernel, gm_slot_buffer, y_ptr, total_len_i32) + + +def test_scan(n_tiles=64): + device = get_test_device() + torch.npu.set_device(device) + + total_len = TILE_SIZE * TILE_SIZE * n_tiles + torch.manual_seed(0) + dtype = torch.float32 + + # Prepare Inputs + x = torch.rand(size=(total_len,), device=device, dtype=dtype).contiguous() + y = torch.zeros_like(x) + + # Generate upper triangular matrix of 1s (s x s) + u = torch.triu( + torch.ones((TILE_SIZE, TILE_SIZE), device=device, dtype=dtype) + ).contiguous() + + # GM slot buffer for TPush/TPop FIFO staging + gm_slot_buffer = torch.zeros((GM_TOTAL_ELEMS,), dtype=torch.float32, device=device) + + # Expected PyTorch computation + expected_scan = torch.cumsum(x.cpu(), dim=0) + + # NPU kernel execution — scan_module is a JitWrapper that lazily + # compiles on first call; ffts_addr is injected automatically. + repeat_runs = 20 + print( + f"Running TPush/TPop scan for {total_len} elements " + f"({n_tiles} {TILE_SIZE}x{TILE_SIZE} tiles)" + ) + actual_scan = [] + for _ in range(repeat_runs): + y.zero_() + scan_module(gm_slot_buffer, x, y, u, total_len) + actual_scan.append(y.cpu().clone()) + + torch.npu.synchronize() + + # Check for consistency across runs and correctness against the expected count + repeat_results = [] + for i, scan in enumerate(actual_scan): + are_close = torch.allclose(scan, expected_scan, rtol=1e-3, atol=1e-3) + if not are_close: + unequal_count = torch.sum(scan != expected_scan) + else: + unequal_count = 0 + repeat_results.append([are_close, unequal_count]) + + has_mismatch = any(not eq for eq, _ in repeat_results) + if has_mismatch: + print("Expected:\n", expected_scan[-10:]) + for i, result in enumerate(repeat_results): + eq, count = result + if not eq: + print( + f"Inconsistent results run {i}, different elements: {count}/{total_len}. Sample:" + ) + print(actual_scan[i][-10:]) + raise AssertionError( + f"Scan mismatch for tile_size={TILE_SIZE}, total_len={total_len} ({n_tiles} tiles)" + ) + + print("All results matched. TPush/TPop scan test passed successfully.\n") + + +if __name__ == "__main__": + test_scan(1) + test_scan(16) + test_scan(64) + test_scan(100) + test_scan(1000) diff --git a/ptodsl/api/pto.py b/ptodsl/api/pto.py index 0c69ea9f..f2d80e67 100644 --- a/ptodsl/api/pto.py +++ b/ptodsl/api/pto.py @@ -9,6 +9,8 @@ as_tensor, call, set_ffts, + sync_set, + sync_wait, cube_section, declare_global, declare_tile, diff --git a/ptodsl/api/pto_general.py b/ptodsl/api/pto_general.py index 360216c3..b6bfe3cf 100644 --- a/ptodsl/api/pto_general.py +++ b/ptodsl/api/pto_general.py @@ -62,6 +62,16 @@ def set_ffts(ffts): return _pto.SetFFTsOp(_unwrap(ffts)) +@with_loc +def sync_set(pipe, event_id, ffts_mode=2): + return _pto.sync_set(pipe, event_id, ffts_mode) + + +@with_loc +def sync_wait(pipe, event_id): + return _pto.sync_wait(pipe, event_id) + + @with_loc def add_ptr(ptr, offset): """Return ptr advanced by offset elements, preserving the !pto.ptr type. @@ -367,6 +377,8 @@ def print(format, scalar): "get_block_num", "call", "set_ffts", + "sync_set", + "sync_wait", "add_ptr", "as_tensor", "slice_view", diff --git a/ptodsl/api/synchronization.py b/ptodsl/api/synchronization.py index 1a0801ee..380848ff 100644 --- a/ptodsl/api/synchronization.py +++ b/ptodsl/api/synchronization.py @@ -6,12 +6,16 @@ def _resolve_sync_op(sync_op): if isinstance(sync_op, str): normalized = sync_op.strip().upper() - if not normalized.startswith("T"): - normalized = f"T{normalized}" try: + if normalized.startswith("PIPE_"): + return _pto.PipeAttr.get(getattr(_pto.PIPE, normalized)) + elif not normalized.startswith("T"): + normalized = f"T{normalized}" return getattr(_pto, normalized) except AttributeError as exc: - raise ValueError(f"Unsupported sync op type '{sync_op}'.") from exc + raise ValueError( + f"Unsupported sync op type '{sync_op}', attrs {dir(_pto.PIPE)}." + ) from exc return sync_op diff --git a/ptodsl/api/tile.py b/ptodsl/api/tile.py index b75a988d..a107e95f 100644 --- a/ptodsl/api/tile.py +++ b/ptodsl/api/tile.py @@ -279,6 +279,19 @@ def quant(src, fp, dst, quant_type, *, offset=None): _pto.TQuantOp(src=src, fp=fp, dst=dst, quant_type=qtype_attr, offset=offset) +def getval(src, offset, dtype=None): + """Reads a single element from a tile at a linear offset.""" + if dtype is None: + # TODO extract dtype from the src tile + raise ValueError("getval requires an explicit dtype argument.") + return _pto.tgetval(dtype, src, _unwrap(offset)) + + +def setval(dst, offset, val): + """Writes a scalar value into a tile at a linear offset.""" + _pto.tsetval(dst, _unwrap(offset), _unwrap(val)) + + def print(source): _pto.tprint(source) @@ -335,5 +348,7 @@ def print(source): "adds", "cvt", "quant", + "getval", + "setval", "subview", ] diff --git a/ptodsl/api/type_def.py b/ptodsl/api/type_def.py index 575b4a24..a09c8f76 100644 --- a/ptodsl/api/type_def.py +++ b/ptodsl/api/type_def.py @@ -20,7 +20,11 @@ def __getattr__(name): }: return getattr(scalar, name) if name == "ffts_type": - return MemRefType.get([256], IntegerType.get_signless(64)) + return MemRefType.get([256], IntegerType.get_unsigned(64)) + + if name.startswith("PIPE_"): + return _pto.PipeAttr.get(getattr(_pto.PIPE, name)) + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/ptodsl/compiler/ir.py b/ptodsl/compiler/ir.py index 42bdcf8f..bab1b9c1 100644 --- a/ptodsl/compiler/ir.py +++ b/ptodsl/compiler/ir.py @@ -64,20 +64,28 @@ def _has_func_return(block): return last_name == "func.return" +def _get_globals(fn): + while hasattr(fn, "__wrapped__"): + fn = fn.__wrapped__ + return fn.__globals__ + + def _inject_globals(fn, values): + globs = _get_globals(fn) old = {} for name, value in values.items(): - old[name] = fn.__globals__.get(name, None) - fn.__globals__[name] = value + old[name] = globs.get(name, None) + globs[name] = value return old def _restore_globals(fn, old, names): + globs = _get_globals(fn) for name in names: - if old[name] is None and name in fn.__globals__: - del fn.__globals__[name] + if old[name] is None and name in globs: + del globs[name] else: - fn.__globals__[name] = old[name] + globs[name] = old[name] def _define(module, ctx, meta_map, fn, *, name=None, entry=False, kernel=None): @@ -111,6 +119,13 @@ def _define(module, ctx, meta_map, fn, *, name=None, entry=False, kernel=None): if not ret_types and not _has_func_return(block): func.ReturnOp([]) + # When building a multi-function module, record the entry function's + # metadata so that JitWrapper can discover the signature for caller.cpp. + if entry and _CURRENT is not None: + _CURRENT["entry_name"] = fn_name + _CURRENT["entry_sig"] = sig + _CURRENT["entry_arg_types"] = arg_types + return FuncRef(fn_name) @@ -138,9 +153,21 @@ def decorator(fn): return decorator +# Stores entry function metadata from the last to_ir_module(module=True) call. +# Read by JitWrapper._build immediately after calling to_ir_module (synchronous). +_LAST_ENTRY_META = None + + +def get_last_entry_meta(): + """Return the entry function metadata from the last module build, or None.""" + return _LAST_ENTRY_META + + def to_ir_module(*, meta_data, module=False): def decorator(fn): - global _CURRENT + global _CURRENT, _LAST_ENTRY_META + _LAST_ENTRY_META = None + with Context() as ctx, get_user_code_loc(): _pto.register_dialect(ctx, load=True) meta_map = _resolve_meta(meta_data) @@ -156,6 +183,12 @@ def decorator(fn): _CURRENT = {"ctx": ctx, "module": ir_module, "meta_map": meta_map} try: fn() + # Capture entry metadata before _CURRENT is restored. + _LAST_ENTRY_META = { + "entry_name": _CURRENT.get("entry_name"), + "entry_sig": _CURRENT.get("entry_sig"), + "entry_arg_types": _CURRENT.get("entry_arg_types"), + } finally: _CURRENT = prev _restore_globals(fn, old, meta_map.keys()) @@ -168,4 +201,4 @@ def decorator(fn): return decorator -__all__ = ["FuncRef", "ir_func", "to_ir_module"] +__all__ = ["FuncRef", "get_last_entry_meta", "ir_func", "to_ir_module"] diff --git a/ptodsl/compiler/jit.py b/ptodsl/compiler/jit.py index 0ddb6cd1..626a824c 100644 --- a/ptodsl/compiler/jit.py +++ b/ptodsl/compiler/jit.py @@ -3,7 +3,7 @@ import os import pathlib import subprocess -from functools import update_wrapper +from functools import update_wrapper, wraps from mlir.dialects import pto as _pto from mlir.ir import Context, Location @@ -28,27 +28,31 @@ def _ptr_elem_cpp_type(type_obj): return "__fp16" if "bf16" in type_repr: return "__bf16" + if "ui8" in type_repr or "u8" in type_repr: + return "uint8_t" + if "ui16" in type_repr or "u16" in type_repr: + return "uint16_t" + if "ui32" in type_repr or "u32" in type_repr: + return "uint32_t" + if "ui64" in type_repr or "u64" in type_repr: + return "uint64_t" if "i8" in type_repr: return "int8_t" - if "u8" in type_repr: - return "uint8_t" if "i16" in type_repr: return "int16_t" - if "u16" in type_repr: - return "uint16_t" if "i32" in type_repr: return "int32_t" - if "u32" in type_repr: - return "uint32_t" if "i64" in type_repr: return "int64_t" - if "u64" in type_repr: - return "uint64_t" return "float" def _scalar_cpp_type(type_obj): type_repr = _type_repr(type_obj) + if "ui32" in type_repr or "u32" in type_repr: + return "uint32_t" + if "ui64" in type_repr or "u64" in type_repr: + return "uint64_t" if "i32" in type_repr: return "int32_t" if "i64" in type_repr or "index" in type_repr: @@ -62,12 +66,16 @@ def _scalar_cpp_type(type_obj): def _scalar_ctype(type_obj): type_repr = _type_repr(type_obj) + if "ui64" in type_repr or "u64" in type_repr: + return ctypes.c_uint64 if "i64" in type_repr or "index" in type_repr: return ctypes.c_int64 if "f32" in type_repr: return ctypes.c_float if "f16" in type_repr: return ctypes.c_uint16 + if "ui32" in type_repr or "u32" in type_repr: + return ctypes.c_uint32 return ctypes.c_int32 @@ -90,11 +98,14 @@ def __init__( output_dir=None, block_dim=None, enable_insert_sync=True, + init_ffts=None, npu_arch="dav-2201", + module=False, ): self._fn = fn + self._orig_sig = inspect.signature(fn) + self._sig = self._orig_sig self._meta_data = meta_data - self._sig = inspect.signature(fn) self._arg_types = None self._output_dir = ( pathlib.Path(output_dir) @@ -103,12 +114,41 @@ def __init__( ) self._block_dim = block_dim if block_dim is not None else get_num_cube_cores() self._enable_insert_sync = enable_insert_sync + self._init_ffts = init_ffts self._npu_arch = npu_arch + self._module = module self._compiled = False self._lib = None self._lib_path = self._output_dir / "kernel.so" + self._entry_name = None update_wrapper(self, fn) + # When module=True, the user explicitly declares the ffts parameter + # in the entry function and calls pto.set_ffts() themselves. + # init_ffts is only used as a name hint for caller.cpp generation. + if self._init_ffts is not None and not self._module: + original_fn = self._fn + + @wraps(original_fn) + def wrapper(*args, **kwargs): + # Automatically emit the MLIR operation before tracing the rest of the kernel + from ..api import pto as pto_api + + pto_api.set_ffts(args[-1]) + return original_fn(*args[:-1], **kwargs) + + new_params = list(self._sig.parameters.values()) + new_params.append( + inspect.Parameter( + self._init_ffts, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation="ffts_type", + ) + ) + self._sig = self._sig.replace(parameters=new_params) + wrapper.__signature__ = self._sig + self._fn = wrapper + def _artifact_paths(self): pto_path = self._output_dir / "kernel.pto" cpp_path = self._output_dir / "kernel.cpp" @@ -120,22 +160,42 @@ def _generate_caller_cpp(self, kernel_cpp_name): cpp_args = [] launch_args = [] for param, arg_type in zip(params, self._arg_types): - if _is_ptr_type(arg_type): - cpp_args.append(f"uint8_t *{param.name}") - launch_args.append(f"({_ptr_elem_cpp_type(arg_type)} *){param.name}") + if param.name == self._init_ffts: + launch_args.append(f"reinterpret_cast(fftsAddr)") else: - cpp_t = _scalar_cpp_type(arg_type) - cpp_args.append(f"{cpp_t} {param.name}") - launch_args.append(param.name) + if _is_ptr_type(arg_type): + cpp_args.append(f"uint8_t *{param.name}") + launch_args.append( + f"({_ptr_elem_cpp_type(arg_type)} *){param.name}" + ) + else: + cpp_t = _scalar_cpp_type(arg_type) + cpp_args.append(f"{cpp_t} {param.name}") + launch_args.append(param.name) wrapper_sig = ", ".join(["uint32_t blockDim", "void *stream"] + cpp_args) kernel_call = ", ".join(launch_args) + + ffts_init_code = "" + if self._init_ffts is not None: + ffts_init_code = ( + " void *fftsAddr = nullptr;\n" + " uint32_t fftsLen = 0;\n" + " (void)rtGetC2cCtrlAddr(reinterpret_cast(&fftsAddr), &fftsLen);\n" + ) + + # In module mode, the kernel launch name is the entry function name + # discovered from the IR module, not the builder function name. + launch_name = self._entry_name or self.__name__ + return ( f'#include "{kernel_cpp_name}"\n' - f"#include \n\n" + f"#include \n" + f'#include "runtime/rt.h"\n\n' f'extern "C" void call_kernel({wrapper_sig})\n' "{\n" - f" {self._fn.__name__}<<>>({kernel_call});\n" + f"{ffts_init_code}" + f" {launch_name}<<>>({kernel_call});\n" "}\n" ) @@ -146,9 +206,14 @@ def _compile_shared_library(self, caller_cpp_path, lib_path): raise RuntimeError( "PTO_LIB_PATH is required to compile generated caller.cpp." ) + ascend_home = os.environ.get("ASCEND_TOOLKIT_HOME") cmd = [ "bisheng", f"-I{pto_isa}/include", + f"-I{ascend_home}/include", + f"-I{ascend_home}/pkg_inc", + f"-I{ascend_home}/pkg_inc/runtime", + f"-I{ascend_home}/pkg_inc/profiling", "-fPIC", "-shared", "-D_FORTIFY_SOURCE=2", @@ -177,7 +242,17 @@ def _compile_shared_library(self, caller_cpp_path, lib_path): "-o", str(lib_path), ] - subprocess.run(cmd, check=True, cwd=str(self._output_dir)) + try: + subprocess.run(cmd, check=True, cwd=str(self._output_dir)) + except Exception as e: + output = ( + e.stdout.decode("utf-8", errors="replace") + if hasattr(e, "stdout") and e.stdout + else "" + ) + raise RuntimeError( + f"Compile failed with exit code {e.returncode}:\n{output}" + ) from e def _resolve_runtime_arg_types(self): from .ir import _resolve_arg_types, _resolve_meta @@ -190,9 +265,27 @@ def _resolve_runtime_arg_types(self): def _build(self): self._output_dir.mkdir(parents=True, exist_ok=True) pto_path, cpp_path, caller_path, lib_path = self._artifact_paths() - self._arg_types = self._resolve_runtime_arg_types() - ir_module = to_ir_module(meta_data=self._meta_data)(self._fn) + if self._module: + # Multi-function module mode: build the module and extract + # the entry function signature from the module-level metadata. + from .ir import get_last_entry_meta + + ir_module = to_ir_module(meta_data=self._meta_data, module=True)(self._fn) + entry_meta = get_last_entry_meta() + if entry_meta is None or entry_meta.get("entry_name") is None: + raise RuntimeError( + "module=True requires at least one `@pto.func` (without " + "kernel=) as the entry point." + ) + self._entry_name = entry_meta["entry_name"] + self._sig = entry_meta["entry_sig"] + self._arg_types = entry_meta["entry_arg_types"] + else: + # Single-function mode (original path). + self._arg_types = self._resolve_runtime_arg_types() + ir_module = to_ir_module(meta_data=self._meta_data)(self._fn) + pto_path.write_text(f"{ir_module}\n", encoding="utf-8") ptoas_cmd = ["ptoas"] @@ -207,10 +300,16 @@ def _build(self): self._compile_shared_library(caller_path, lib_path) self._lib = ctypes.CDLL(str(lib_path)) - self._lib.call_kernel.argtypes = [ctypes.c_uint32, ctypes.c_void_p] + [ - ctypes.c_void_p if _is_ptr_type(arg_type) else _scalar_ctype(arg_type) - for arg_type in self._arg_types - ] + argtypes = [ctypes.c_uint32, ctypes.c_void_p] + for param, arg_type in zip(self._sig.parameters.values(), self._arg_types): + if self._init_ffts is not None and param.name == self._init_ffts: + continue + if _is_ptr_type(arg_type): + argtypes.append(ctypes.c_void_p) + else: + argtypes.append(_scalar_ctype(arg_type)) + + self._lib.call_kernel.argtypes = argtypes self._compiled = True def _convert_ptr(self, value): @@ -224,23 +323,32 @@ def _convert_ptr(self, value): def _prepare_call_args(self, args): params = list(self._sig.parameters.values()) - if len(args) > len(params): + orig_params = [ + p for p in params if self._init_ffts is None or p.name != self._init_ffts + ] + orig_arg_types = [ + t + for p, t in zip(params, self._arg_types) + if self._init_ffts is None or p.name != self._init_ffts + ] + + if len(args) > len(orig_params): raise TypeError( - f"Expected at most {len(params)} arguments, got {len(args)}." + f"Expected at most {len(orig_params)} arguments, got {len(args)}." ) filled_args = list(args) - for idx in range(len(args), len(params)): - param = params[idx] + for idx in range(len(filled_args), len(orig_params)): + param = orig_params[idx] if param.default is not inspect._empty: filled_args.append(param.default) continue - arg_type = self._arg_types[idx] + arg_type = orig_arg_types[idx] if _is_ptr_type(arg_type): raise TypeError(f"Missing required pointer argument '{param.name}'.") converted = [] - for value, arg_type in zip(filled_args, self._arg_types): + for value, arg_type in zip(filled_args, orig_arg_types): if _is_ptr_type(arg_type): converted.append(self._convert_ptr(value)) else: @@ -286,7 +394,9 @@ def jit( output_dir=None, block_dim=1, enable_insert_sync=True, + init_ffts=None, npu_arch="dav-2201", + module=False, ): def decorator(fn): return JitWrapper( @@ -295,7 +405,9 @@ def decorator(fn): output_dir=output_dir, block_dim=block_dim, enable_insert_sync=enable_insert_sync, + init_ffts=init_ffts, npu_arch=npu_arch, + module=module, ) return decorator diff --git a/tests/frontend/tooling.py b/tests/frontend/tooling.py index c228cb88..dad057de 100644 --- a/tests/frontend/tooling.py +++ b/tests/frontend/tooling.py @@ -48,9 +48,14 @@ def run_bisheng(caller_cpp, output_so, *, npu_arch="dav-2201", cwd=None): Raises :class:`subprocess.CalledProcessError` on failure. """ pto_isa = os.environ.get("PTO_LIB_PATH", "/sources/pto-isa") + ascend_home = os.environ.get("ASCEND_TOOLKIT_HOME") cmd = [ "bisheng", f"-I{pto_isa}/include", + f"-I{ascend_home}/include", + f"-I{ascend_home}/pkg_inc", + f"-I{ascend_home}/pkg_inc/runtime", + f"-I{ascend_home}/pkg_inc/profiling", "-fPIC", "-shared", "-D_FORTIFY_SOURCE=2",