Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/aot/flash_attention/caller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ extern "C" void call_kernel(
(void)fftsLen;

call_both<<<blockDim, nullptr, stream>>>(
(__gm__ int64_t *)fftsAddr,
(__gm__ uint64_t *)fftsAddr,
(__gm__ float *)gmSlotBuffer,
(__gm__ half *)q,
(__gm__ half *)k,
Expand Down
2 changes: 1 addition & 1 deletion examples/aot/tpushpop/mix-kernel_mlir/caller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ extern "C" void call_kernel(
(void)fftsLen;

call_both<<<blockDim, nullptr, stream>>>(
(__gm__ int64_t *)fftsAddr,
(__gm__ uint64_t *)fftsAddr,
(__gm__ float *)gmSlotBuffer,
(__gm__ float *)x,
(__gm__ float *)y);
Expand Down
67 changes: 67 additions & 0 deletions examples/jit/scan/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Single core prefix sum (scan)

An implementation of prefix sum (scan) algorithm, based on https://arxiv.org/abs/2505.15112v1. Only the single core algorithm is implemented (ScanU from the paper).

## Algorithm

The ScanU algorithm computes the prefix sum (cumulative sum) of a 1-D input vector by decomposing it into tiles and leveraging the Cube unit's matrix multiply for parallelism within each tile:

1. **Reshape** the flat input vector into tiles of shape `TILE_SIZE × TILE_SIZE`.
2. **Precompute** an upper-triangular matrix `U` of 1s (shape `TILE_SIZE × TILE_SIZE`).
3. **For each tile** `X_i`:
- **Cube** computes `C_i = X_i @ U`. Each row of `C_i` contains partial prefix sums within that row of the tile.
- **Vector** adds a running scalar sum to every element of each row (the cross-tile carry), then extracts the last element of the row as the new running sum.
4. The concatenation of all processed tiles is the full prefix sum.

The Cube and Vector units run concurrently but must synchronize: the Cube cannot overwrite the next tile's result before the Vector has finished processing the current one.

## Implementations

There are two implementations that differ only in how Cube ↔ Vector synchronization is achieved. The algorithm logic, tile types, memory layouts, and test harness are identical.

### `run_scan_single_core.py` — TSync (sync_set / sync_wait)

Uses a single function with `cube_section` / `vector_section` blocks. Synchronization is performed with the low-level `sync_set` / `sync_wait` primitives operating on `PIPE_FIX` and `PIPE_MTE3`:

- Cube signals Vector via `pto.sync_set(pto.PIPE_FIX, 0)` after storing the matmul result to GM.
- Vector waits via `pto.sync_wait(pto.PIPE_FIX, 0)`, processes the tile, then signals back via `pto.sync_set(pto.PIPE_MTE3, 1)`.
- Cube waits for the Vector's acknowledgement via `pto.sync_wait(pto.PIPE_MTE3, 1)` before advancing.

```bash
python ./run_scan_single_core.py
```

### `run_scan_single_core_tpushpop.py` — TPush / TPop

Uses the structured multi-function module pattern with separate `cube_kernel` and `vector_kernel` functions. Synchronization uses two unidirectional TPush/TPop pipes:

| Pipe | Direction | `dir_mask` | Purpose |
|------|-----------|------------|---------|
| C2V | Cube → Vector | 1 | Sends the matmul ACC tile directly to Vector VEC memory |
| V2C | Vector → Cube | 2 | Sends a dummy signal tile back for back-pressure |

Both pipes use GM-staged L2G2L transport with `slot_num=8`.

The Cube pushes the ACC tile via `tpush_to_aiv`, then blocks on `tpop_from_aiv` (waiting for the Vector's V2C signal). The Vector pops with `tpop_from_aic`, stores the tile to GM, processes rows with the running sum, then pushes a signal via `tpush_to_aic` and frees the C2V slot with `tfree_from_aic`.

```bash
python ./run_scan_single_core_tpushpop.py
```

## Differences between TSync and TPush/TPop

| Aspect | TSync | TPush/TPop |
|--------|-------|------------|
| Code structure | Single function with `cube_section` / `vector_section` | Separate `@pto.func(kernel="cube")` and `@pto.func(kernel="vector")` functions |
| Sync mechanism | `sync_set` / `sync_wait` on hardware pipes | `tpush` / `tpop` / `tfree` on logical pipe handles |
| Data transfer | Cube stores to GM, Vector loads from GM | Cube pushes ACC tile through pipe, Vector pops into VEC memory, then stores to GM |
| Back-pressure | Explicit `sync_wait` on `PIPE_MTE3` | V2C pipe with dummy signal tile |
| GM slot buffer | Not needed | Required (FIFO staging area for both pipes) |
| Address management | N/A | `reserve_buffer` / `import_reserved_buffer` for cross-kernel FIFO address sharing |
| Insert sync | `enable_insert_sync=False` (manual sync) | `--enable-insert-sync` (PTOAS auto-inserts intra-pipe sync) |

## Implementation notes

- The running sum calculation is stored in a tile (`sumTile1x8`) to avoid issues with the compiler removing it as unused code.
- Due to the inability to synchronize on the scalar pipe with `record_wait_pair`, a `barrier(PIPE_ALL)` is used as a workaround before extracting the last row element.
- **`record_wait_pair` and `--enable-insert-sync` are mutually exclusive.** The TSync version uses manual sync (`enable_insert_sync=False` + explicit `record_wait_pair` calls). The TPush/TPop version uses auto sync (`--enable-insert-sync`; PTOAS InsertSync pass inserts all needed `set_flag`/`wait_flag` pairs). Mixing both causes event ID collisions that lead to data races and non-deterministic results.
306 changes: 306 additions & 0 deletions examples/jit/scan/run_scan_single_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
import torch
import torch_npu
from ptodsl import jit, pto, tile
from ptodsl import scalar as s
from ptodsl.npu_info import get_test_device

TILE_SIZE = 64

const = s.const


def meta_data():
dtype = pto.float32
ptr_type = pto.PtrType(dtype)
ffts_type = pto.ffts_type
len_type = pto.int32

tensor_type = pto.TensorType(rank=2, dtype=dtype)

subtensor_type_u = pto.SubTensorType(shape=[TILE_SIZE, TILE_SIZE], dtype=dtype)
subtensor_type_a = pto.SubTensorType(shape=[TILE_SIZE, TILE_SIZE], dtype=dtype)
subtensor_type_c = pto.SubTensorType(shape=[TILE_SIZE, TILE_SIZE], dtype=dtype)
subtensor_type_row = pto.SubTensorType(shape=[1, TILE_SIZE], dtype=dtype)

tile_cfg_mat = pto.TileBufConfig(blayout="ColMajor", slayout="RowMajor")
tile_cfg_left = pto.TileBufConfig(blayout="RowMajor", slayout="RowMajor")
tile_cfg_right = pto.TileBufConfig(blayout="RowMajor", slayout="ColMajor")
tile_cfg_acc = pto.TileBufConfig(blayout="ColMajor", slayout="RowMajor")

tile_type_a_l1 = pto.TileBufType(
shape=[TILE_SIZE, TILE_SIZE],
dtype=dtype,
memory_space="MAT",
config=tile_cfg_mat,
)
tile_type_u_l1 = pto.TileBufType(
shape=[TILE_SIZE, TILE_SIZE],
dtype=dtype,
memory_space="MAT",
config=tile_cfg_mat,
)

tile_type_a = pto.TileBufType(
shape=[TILE_SIZE, TILE_SIZE],
dtype=dtype,
memory_space="LEFT",
config=tile_cfg_left,
)
tile_type_u = pto.TileBufType(
shape=[TILE_SIZE, TILE_SIZE],
dtype=dtype,
memory_space="RIGHT",
config=tile_cfg_right,
)
tile_type_c = pto.TileBufType(
shape=[TILE_SIZE, TILE_SIZE],
dtype=dtype,
memory_space="ACC",
config=tile_cfg_acc,
)

tile_type_row = pto.TileBufType(
shape=[1, TILE_SIZE],
valid_shape=[1, TILE_SIZE],
dtype=dtype,
memory_space="VEC",
config=pto.TileBufConfig(),
)
tile_type_1x8 = pto.TileBufType(
shape=[1, 8],
valid_shape=[1, 8],
dtype=dtype,
memory_space="VEC",
config=pto.TileBufConfig(),
)

return {
"ptr_type": ptr_type,
"ffts_type": ffts_type,
"len_type": len_type,
"tensor_type": tensor_type,
"subtensor_type_u": subtensor_type_u,
"subtensor_type_a": subtensor_type_a,
"subtensor_type_c": subtensor_type_c,
"subtensor_type_row": subtensor_type_row,
"tile_type_a_l1": tile_type_a_l1,
"tile_type_u_l1": tile_type_u_l1,
"tile_type_a": tile_type_a,
"tile_type_u": tile_type_u,
"tile_type_c": tile_type_c,
"tile_type_row": tile_type_row,
"tile_type_1x8": tile_type_1x8,
}


@jit(meta_data=meta_data, block_dim=1, enable_insert_sync=False, init_ffts="ffts_addr")
def run_scan_kernel(
x_ptr: "ptr_type",
y_ptr: "ptr_type",
u_ptr: "ptr_type",
total_len_i32: "len_type",
) -> None:
c0 = const(0)
c1 = const(1)
cLAST_ROW_ELEM = const(TILE_SIZE - 1)
cTILE_SIZE = const(TILE_SIZE)
cN_TILE_ELEM = const(TILE_SIZE * TILE_SIZE)
c0f = const(0.0, s.float32)
c0i64 = const(0, s.int64)

total_len = s.index_cast(total_len_i32)
num_tiles = total_len // cN_TILE_ELEM

with pto.cube_section():
tvX_cube = pto.as_tensor(
tensor_type,
ptr=x_ptr,
shape=[cTILE_SIZE, cTILE_SIZE],
strides=[cTILE_SIZE, cTILE_SIZE],
layout="ND",
)
tvU = pto.as_tensor(
tensor_type,
ptr=u_ptr,
shape=[cTILE_SIZE, cTILE_SIZE],
strides=[cTILE_SIZE, cTILE_SIZE],
layout="ND",
)
tvOut_cube = pto.as_tensor(
tensor_type,
ptr=y_ptr,
shape=[cTILE_SIZE, cTILE_SIZE],
strides=[cTILE_SIZE, cTILE_SIZE],
layout="ND",
)

uTileL1 = pto.alloc_tile(tile_type_u_l1)
uTile = pto.alloc_tile(tile_type_u)
aTileL1 = pto.alloc_tile(tile_type_a_l1)
aTile = pto.alloc_tile(tile_type_a)
cTile = pto.alloc_tile(tile_type_c)

svU = pto.slice_view(
subtensor_type_u,
source=tvU,
offsets=[c0, c0],
sizes=[cTILE_SIZE, cTILE_SIZE],
)
pto.load(svU, uTileL1)
pto.record_wait_pair("LOAD", "MOV_M2L", 0)

tile.mov(uTileL1, uTile)
pto.record_wait_pair("MOV_M2L", "MATMUL", 0)

for tile_idx in pto.range(c0, num_tiles, c1):
offset = tile_idx * cTILE_SIZE
svX = pto.slice_view(
subtensor_type_a,
source=tvX_cube,
offsets=[offset, c0],
sizes=[cTILE_SIZE, cTILE_SIZE],
)
svOut = pto.slice_view(
subtensor_type_c,
source=tvOut_cube,
offsets=[offset, c0],
sizes=[cTILE_SIZE, cTILE_SIZE],
)

pto.load(svX, aTileL1)
pto.record_wait_pair("LOAD", "MOV_M2L", 1)

tile.mov(aTileL1, aTile)
pto.record_wait_pair("MOV_M2L", "MATMUL", 1)

tile.matmul(aTile, uTile, cTile)
pto.record_wait_pair("MATMUL", "STORE_ACC", 1)

pto.store(cTile, svOut)
pto.record_wait_pair("STORE_ACC", "LOAD", 2)

pto.sync_set(pto.PIPE_FIX, 0)
pto.sync_wait(pto.PIPE_MTE3, 1)

with pto.vector_section():
tvOut_vec = pto.as_tensor(
tensor_type,
ptr=y_ptr,
shape=[total_len // cTILE_SIZE, cTILE_SIZE],
strides=[cTILE_SIZE, c1],
)
Comment on lines +179 to +191
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the with pto.vector_section() syntax actually works for mix kernels as long as using manual sync + explicit sync_set/sync_wait... The syntax is quite different from the entry + subfunction + push/pop + auto-sync as in #98 . So we have two styles to write mix kernels now. They can co-exist for now, but will need to unify the style in the future.


rowTile = pto.alloc_tile(tile_type_row)
sumTile1x8 = pto.alloc_tile(tile_type_1x8)
tile.setval(sumTile1x8, c0, c0f)

for tile_idx in pto.range(c0, num_tiles, c1):
pto.sync_wait(pto.PIPE_FIX, 0)

vid = pto.get_subblock_idx()

with pto.if_context(vid == c0i64):
tile_offset = tile_idx * cTILE_SIZE
for r in pto.range(c0, cTILE_SIZE, c1):
offset = tile_offset + r
svRow = pto.slice_view(
subtensor_type_row,
source=tvOut_vec,
offsets=[offset, c0],
sizes=[c1, cTILE_SIZE],
)

pto.load(svRow, rowTile)

pto.record_wait_pair("LOAD", "VEC", 2)

# Extract the stateful running_sum from our memory buffer
running_sum = tile.getval(sumTile1x8, c0, dtype=s.float32)

tile.adds(rowTile, running_sum, rowTile)
# Ideally we would synchronize PIPE_S and PIPE_V here, but that is not currently possible
# with pto.record_wait_pair, instead we use a barrier
# pto.record_wait_pair("PIPE_V", "PIPE_S", 2)
pto.barrier(pto.PIPE_ALL)
running_sum_next = tile.getval(
rowTile, cLAST_ROW_ELEM, dtype=s.float32
)
# Persist the new running sum back to the memory buffer to loop-carry
tile.setval(sumTile1x8, c0, running_sum_next)
pto.record_wait_pair("VEC", "STORE_VEC", 2)

pto.store(rowTile, svRow)

pto.record_wait_pair("STORE_VEC", "LOAD", 3)

pto.record_wait_pair("LOAD", "VEC", 3)

pto.sync_set(pto.PIPE_MTE3, 1)


def test_scan(n_tiles=64):
device = get_test_device()
torch.npu.set_device(device)

total_len = TILE_SIZE * TILE_SIZE * n_tiles
torch.manual_seed(0)
dtype = torch.float32

# Prepare Inputs
x = torch.rand(size=(total_len,), device=device, dtype=dtype).contiguous()
y = torch.zeros_like(x)

# Generate upper triangular matrix of 1s (s x s)
u = torch.triu(
torch.ones((TILE_SIZE, TILE_SIZE), device=device, dtype=dtype)
).contiguous()

# Expected PyTorch computation
expected_scan = torch.cumsum(x.cpu(), dim=0)

# NPU JIT Kernel execution
repeat_runs = 20
print(
f"Running scan for {total_len} elements ({n_tiles} {TILE_SIZE}x{TILE_SIZE} tiles)"
)
actual_scan = []
for _ in range(repeat_runs):
y.zero_()
run_scan_kernel(x, y, u, total_len)
actual_scan.append(y.cpu().clone())

torch.npu.synchronize()

# Check for consistency across runs and correctness against the expected count
repeat_results = []
for i, scan in enumerate(actual_scan):
are_close = torch.allclose(scan, expected_scan, rtol=1e-3, atol=1e-3)
if not are_close:
unequal_count = torch.sum(scan != expected_scan)
else:
unequal_count = 0
repeat_results.append([are_close, unequal_count])

has_mismatch = any(not eq for eq, _ in repeat_results)
if has_mismatch:
print("Expected:\n", expected_scan[-10:])
for i, result in enumerate(repeat_results):
eq, count = result
if not eq:
print(
f"Inconsistent results run {i}, different elements: {count}/{total_len}. Sample:"
)
print(actual_scan[i][-10:])
raise AssertionError(
f"Scan mismatch for tile_size={tile_size}, total_len={total_len} ({n_tiles} tiles)"
)

print("All results matched. Scan test passed successfully.\n")


if __name__ == "__main__":
test_scan(1)
test_scan(16)
test_scan(64)
test_scan(100)
test_scan(1000)
Loading
Loading