diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cf8f26cb..068abedb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -50,7 +50,7 @@ jobs: RELEASE_VER: 0.24 RELEASE_TAG: v0.24 CLI_DIR: /installers/ptoas-cli - PTOISA_COMMIT: 2ee948ef636863ed149f176d5327d9db5f349bb6 + PTOISA_COMMIT: a8c3fbf42a2f4a0f609f64e138dda62deefddb8e steps: - name: Install system packages diff --git a/.gitignore b/.gitignore index 09d8265b..1ccc35d0 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,5 @@ __pycache__ extra-info *.ptodsl_jit + +msprof_res/ diff --git a/docker/Dockerfile b/docker/Dockerfile index 27a98f39..17a70d96 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -16,8 +16,8 @@ RUN pip install --no-cache-dir \ ipython jupyterlab matplotlib pandas # certain operations need latest isa header, not CANN 8.5.0 default -# header on 2026/04/01 -ARG PTOISA_COMMIT=2ee948ef636863ed149f176d5327d9db5f349bb6 +# header on 2026/04/08 +ARG PTOISA_COMMIT=a8c3fbf42a2f4a0f609f64e138dda62deefddb8e WORKDIR /sources RUN git clone https://gitcode.com/cann/pto-isa.git \ && cd pto-isa && git checkout $PTOISA_COMMIT diff --git a/examples/aot/tpushpop/.gitignore b/examples/aot/tpushpop/.gitignore new file mode 100644 index 00000000..b0d498be --- /dev/null +++ b/examples/aot/tpushpop/.gitignore @@ -0,0 +1 @@ +build_artifacts/ diff --git a/examples/aot/tpushpop/mix-kernel_mlir/README.md b/examples/aot/tpushpop/mix-kernel_mlir/README.md new file mode 100644 index 00000000..886a8c2c --- /dev/null +++ b/examples/aot/tpushpop/mix-kernel_mlir/README.md @@ -0,0 +1,80 @@ +# TPush / TPop mixed-kernel examples + +Small examples of tile FIFO communication between Cube (`AIC`) and Vector (`AIV`). + +```bash +python run.py c2v +python run.py v2c +python run.py bidi +``` + +`python run.py` defaults to `c2v`. + +Files: + +- `kernels/` has the Python builders. +- `build_artifacts/` gets generated MLIR, generated C++, and the `.so`. +- `gm_slot_buffer` is the GM backing store for the pipe. +- `caller.cpp` sets the FFTS base before launching the generated kernel. + +Core idea: + +- `aic_initialize_pipe` / `aiv_initialize_pipe` lower to matching `TPipe<...>` objects. +- `gm_slot_buffer` is the shared GM slot memory used by that `TPipe`. +- `tpush_to_aiv` / `tpush_to_aic` lower to `TPUSH(pipe, tile)`. +- `tpop_from_aic` / `tpop_from_aiv` lower to `TPOP(pipe, tile)`. +- `tfree_from_aic` / `tfree_from_aiv` lower to `TFREE(pipe)` and release the consumed slot. + +## C2V + +Cube sends. Vector receives. + +This example computes `X @ X` on Cube, sends the accumulator tile to Vector, then Vector stores it to GM. + +```text +Cube: load X -> matmul -> tpush_to_aiv +Vector: tpop_from_aic -> store Y -> tfree_from_aic +``` + +Pipe wiring: + +- Vector owns the consumer buffer: `reserve_buffer("c2v_fifo", location="VEC")` +- Cube imports it: `import_reserved_buffer("c2v_fifo", peer_func="@vector_kernel")` +- Both sides initialize with `dir_mask = 1` + +## V2C + +Vector sends. Cube receives. + +This example loads `X` on Vector, sends that tile to Cube, then Cube stores it to GM. + +```text +Vector: load X -> tpush_to_aic +Cube: tpop_from_aiv -> store Y -> tfree_from_aiv +``` + +Pipe wiring: + +- Cube owns the consumer buffer: `reserve_buffer("v2c_fifo", location="MAT")` +- Vector imports it: `import_reserved_buffer("v2c_fifo", peer_func="@cube_kernel")` +- Both sides initialize with `dir_mask = 2` + +## BIDI + +Both directions are enabled. + +This example sends `X @ X` from Cube to Vector. Vector doubles it and sends it back. Cube receives the returned tile and stores it to GM. + +```text +Cube: matmul -> tpush_to_aiv +Vector: tpop_from_aic -> add -> tpush_to_aic -> tfree_from_aic +Cube: tpop_from_aiv -> store Y -> tfree_from_aiv +``` + +Pipe wiring: + +- Vector reserves `c2v_fifo`; Cube imports it +- Cube reserves `v2c_fifo`; Vector imports it +- Both sides initialize with `dir_mask = 3` + +For `dir_mask = 3`, allocate FIFO backing for both directions. `run.py` uses `8 KiB`. diff --git a/examples/aot/tpushpop/mix-kernel_mlir/c2v.mlir b/examples/aot/tpushpop/mix-kernel_mlir/c2v.mlir new file mode 100644 index 00000000..d5eab7fb --- /dev/null +++ b/examples/aot/tpushpop/mix-kernel_mlir/c2v.mlir @@ -0,0 +1,89 @@ +// Bidirectional pipe example. +// +// This reduced version only uses the C2V pipe: +// - `c2v_fifo`: cube/kernel `@cube_kernel` pushes to vector/kernel `@vector_kernel` +// +// `gm_slot_buffer` is the GM-backed slot storage for these pipes. The reserve/import +// ops connect each side of the same named FIFO, and `aic/aiv_initialize_pipe` +// binds those FIFO endpoints to the shared GM slot buffer plus each side's local +// consumer buffer. +// +// End-to-end data flow: +// - Cube loads one input matrix `X` from GM. +// - Cube computes `Y = X @ X`. +// - Cube sends that accumulator tile to vector over `c2v_fifo`. +// - Vector pops the tile and stores it to GM as output matrix `Y`. +// +// What is transferred: +// - Cube -> Vector: one full `16 x 16` `f32` accumulator tile `Y = X @ X` +// sent with `pto.tpush_to_aiv` using `split = 0` (no split). Vector receives +// that same logical `16 x 16` tile with `pto.tpop_from_aic` in a vector tile +// type/layout, then stores it to the GM output buffer. +// +// Shape summary: +// - All transferred tiles are `rows=16, cols=16, dtype=f32` +// - Cube-produced C2V tile: `loc=acc`, `blayout=col_major`, `slayout=row_major` +// - Vector-consumed tile after C2V pop: `loc=vec`, `blayout=row_major`, `slayout=none_box` +module { + + func.func @call_both(%gm_slot_buffer: !pto.ptr, %gm_x: !pto.ptr, %gm_y: !pto.ptr) attributes {pto.entry} { + func.call @cube_kernel(%gm_slot_buffer, %gm_x) : (!pto.ptr, !pto.ptr) -> () + func.call @vector_kernel(%gm_slot_buffer, %gm_y) : (!pto.ptr, !pto.ptr) -> () + return + } + + func.func @cube_kernel(%gm_slot_buffer: !pto.ptr, %gm_x: !pto.ptr) attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c2v_import = pto.import_reserved_buffer { + name = "c2v_fifo", + peer_func = @vector_kernel + } -> i32 + %c0_i32 = arith.constant 0 : i32 + pto.aic_initialize_pipe {dir_mask = 1, slot_size = 1024} + (gm_slot_buffer = %gm_slot_buffer : !pto.ptr, + c2v_consumer_buf = %c2v_import : i32, + v2c_consumer_buf = %c0_i32 : i32) + + %x_mat_tile = pto.alloc_tile : !pto.tile_buf + %x_left_tile = pto.alloc_tile : !pto.tile_buf + %x_right_tile = pto.alloc_tile : !pto.tile_buf + %acc_tile = pto.alloc_tile : !pto.tile_buf + %gm_x_view = pto.make_tensor_view %gm_x, shape = [%c16, %c16], strides = [%c16, %c1] : !pto.tensor_view + %gm_x_tile_view = pto.partition_view %gm_x_view, offsets = [%c0, %c0], sizes = [%c16, %c16] : !pto.tensor_view -> !pto.partition_tensor_view<16x16xf32> + pto.tload ins(%gm_x_tile_view : !pto.partition_tensor_view<16x16xf32>) outs(%x_mat_tile : !pto.tile_buf) + pto.tmov ins(%x_mat_tile : !pto.tile_buf) outs(%x_left_tile : !pto.tile_buf) + pto.tmov ins(%x_mat_tile : !pto.tile_buf) outs(%x_right_tile : !pto.tile_buf) + pto.tmatmul ins(%x_left_tile, %x_right_tile : !pto.tile_buf, !pto.tile_buf) outs(%acc_tile : !pto.tile_buf) + pto.tpush_to_aiv(%acc_tile : !pto.tile_buf) {split = 0} + return + } + + func.func @vector_kernel(%gm_slot_buffer: !pto.ptr, %gm_y: !pto.ptr) + attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c2v_local = pto.reserve_buffer { + name = "c2v_fifo", + size = 4096, + location = #pto.address_space, + auto = true + } -> i32 + %c0_i32 = arith.constant 0 : i32 + pto.aiv_initialize_pipe {dir_mask = 1, slot_size = 1024} + (gm_slot_buffer = %gm_slot_buffer : !pto.ptr, + c2v_consumer_buf = %c2v_local : i32, + v2c_consumer_buf = %c0_i32 : i32) + + %gm_y_view = pto.make_tensor_view %gm_y, shape = [%c16, %c16], strides = [%c16, %c1] : !pto.tensor_view + %gm_y_tile_view = pto.partition_view %gm_y_view, offsets = [%c0, %c0], sizes = [%c16, %c16] : !pto.tensor_view -> !pto.partition_tensor_view<16x16xf32> + %recv_tile = pto.tpop_from_aic {split = 0} + -> !pto.tile_buf + pto.tstore ins(%recv_tile : !pto.tile_buf) outs(%gm_y_tile_view : !pto.partition_tensor_view<16x16xf32>) + pto.tfree_from_aic {split = 0} + return + } + +} diff --git a/examples/aot/tpushpop/mix-kernel_mlir/caller.cpp b/examples/aot/tpushpop/mix-kernel_mlir/caller.cpp new file mode 100644 index 00000000..b8a9e8b2 --- /dev/null +++ b/examples/aot/tpushpop/mix-kernel_mlir/caller.cpp @@ -0,0 +1,28 @@ +#ifndef KERNEL_CPP +#error "KERNEL_CPP must be defined at compile time." +#endif + +#include + +extern "C" int rtGetC2cCtrlAddr(uint64_t *ctrlAddr, uint32_t *ctrlLen); + +#include KERNEL_CPP + +extern "C" void call_kernel( + uint32_t blockDim, + void *stream, + uint8_t *gmSlotBuffer, + uint8_t *x, + uint8_t *y) +{ + void *fftsAddr = nullptr; + uint32_t fftsLen = 0; + (void)rtGetC2cCtrlAddr(reinterpret_cast(&fftsAddr), &fftsLen); + (void)fftsLen; + + call_both<<>>( + (__gm__ int64_t *)fftsAddr, + (__gm__ float *)gmSlotBuffer, + (__gm__ float *)x, + (__gm__ float *)y); +} diff --git a/examples/aot/tpushpop/mix-kernel_mlir/compile.sh b/examples/aot/tpushpop/mix-kernel_mlir/compile.sh new file mode 100644 index 00000000..6b7df346 --- /dev/null +++ b/examples/aot/tpushpop/mix-kernel_mlir/compile.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ARTIFACT_DIR="${SCRIPT_DIR}/build_artifacts" +MODE="${TPUSHPOP_MODE:-c2v}" +BUILDER_PATH="${SCRIPT_DIR}/kernels/${MODE}_builder.py" +MLIR_GEN_PATH="${ARTIFACT_DIR}/${MODE}_gen.mlir" +GENERATED_CPP="${ARTIFACT_DIR}/${MODE}.cpp" +LIB_PATH="${ARTIFACT_DIR}/tpushpop_mlir_lib.so" + +case "${MODE}" in + c2v|c2v_add|v2c|bidi) ;; + *) + echo "Unknown TPUSHPOP_MODE: ${MODE}" >&2 + exit 2 + ;; +esac + +mkdir -p "${ARTIFACT_DIR}" +rm -f "${GENERATED_CPP}" "${LIB_PATH}" + +python "${BUILDER_PATH}" > "${MLIR_GEN_PATH}" +ptoas --pto-arch=a3 --enable-insert-sync "${MLIR_GEN_PATH}" > "${GENERATED_CPP}" +# add extern "C" to function so kernel name is not mangled +perl -0pi -e 's/\b__global__ AICORE void call_both\(/extern "C" __global__ AICORE void call_both(/' "${GENERATED_CPP}" + +bisheng \ + -I/sources/pto-isa/include/ \ + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 -g \ + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ + -xcce -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -std=gnu++17 \ + -DKERNEL_CPP="\"${GENERATED_CPP}\"" \ + "${SCRIPT_DIR}/caller.cpp" \ + -o "${LIB_PATH}" + +echo "Generated ${GENERATED_CPP}." +echo "Built ${LIB_PATH}." diff --git a/examples/aot/tpushpop/mix-kernel_mlir/kernels/bidi_builder.py b/examples/aot/tpushpop/mix-kernel_mlir/kernels/bidi_builder.py new file mode 100644 index 00000000..8ea9ad03 --- /dev/null +++ b/examples/aot/tpushpop/mix-kernel_mlir/kernels/bidi_builder.py @@ -0,0 +1,130 @@ +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + + +def meta_data(): + ffts_ty = pto.ffts_type + dtype = pto.float32 + ptr_ty = pto.PtrType(dtype) + i32 = pto.int32 + tensor_ty = pto.TensorType(rank=2, dtype=dtype) + tile_view_ty = pto.SubTensorType(shape=[16, 16], dtype=dtype) + x_mat_ty = pto.TileBufType(shape=[16, 16], dtype=dtype, memory_space="MAT") + x_left_ty = pto.TileBufType( + shape=[16, 16], + dtype=dtype, + memory_space="LEFT", + config=pto.TileBufConfig(blayout="ColMajor", slayout="RowMajor"), + ) + x_right_ty = pto.TileBufType(shape=[16, 16], dtype=dtype, memory_space="RIGHT") + acc_ty = pto.TileBufType(shape=[16, 16], dtype=dtype, memory_space="ACC") + vec_ty = pto.TileBufType(shape=[16, 16], dtype=dtype, memory_space="VEC") + # Direct GM writeback from cube needs a row-major NoneBox tile. + cube_recv_ty = pto.TileBufType( + shape=[16, 16], + dtype=dtype, + memory_space="MAT", + config=pto.TileBufConfig( + blayout="RowMajor", + slayout="NoneBox", + s_fractal_size=512, + ), + ) + return locals() + + +@to_ir_module(meta_data=meta_data, module=True) +def module(): + @pto.func(kernel="cube") + def cube_kernel(gm_slot_buffer: "ptr_ty", gm_x: "ptr_ty", gm_y: "ptr_ty") -> None: + c0 = const(0) + c1 = const(1) + c16 = const(16) + c2v_import = pto.import_reserved_buffer( + name="c2v_fifo", + peer_func="@vector_kernel", + ) + v2c_local = pto.reserve_buffer(name="v2c_fifo", size=4096, location="MAT") + + # One DIR_BOTH pipe handles both legs of the round trip. + pto.aic_initialize_pipe( + dir_mask=3, + slot_size=1024, + gm_slot_buffer=gm_slot_buffer, + c2v_consumer_buf=c2v_import, + v2c_consumer_buf=v2c_local, + ) + + x_mat_tile = pto.alloc_tile(x_mat_ty) + x_left_tile = pto.alloc_tile(x_left_ty) + x_right_tile = pto.alloc_tile(x_right_ty) + acc_tile = pto.alloc_tile(acc_ty) + + gm_x_tile_view = pto.slice_view( + tile_view_ty, + source=pto.as_tensor( + tensor_ty, + ptr=gm_x, + shape=[c16, c16], + strides=[c16, c1], + ), + offsets=[c0, c0], + sizes=[c16, c16], + ) + gm_y_tile_view = pto.slice_view( + tile_view_ty, + source=pto.as_tensor( + tensor_ty, + ptr=gm_y, + shape=[c16, c16], + strides=[c16, c1], + ), + offsets=[c0, c0], + sizes=[c16, c16], + ) + + pto.load(gm_x_tile_view, x_mat_tile) + tile.mov(x_mat_tile, x_left_tile) + tile.mov(x_mat_tile, x_right_tile) + tile.matmul(x_left_tile, x_right_tile, acc_tile) + pto.tpush_to_aiv(acc_tile, 0) + returned_tile = pto.tpop_from_aiv(cube_recv_ty, 0) + pto.store(returned_tile, gm_y_tile_view) + pto.tfree_from_aiv(0) + + @pto.func(kernel="vector") + def vector_kernel(gm_slot_buffer: "ptr_ty") -> None: + c2v_local = pto.reserve_buffer(name="c2v_fifo", size=4096, location="VEC") + v2c_import = pto.import_reserved_buffer( + name="v2c_fifo", + peer_func="@cube_kernel", + ) + + # Vector pops cube's tile, doubles it, then pushes the result back. + pto.aiv_initialize_pipe( + dir_mask=3, + slot_size=1024, + gm_slot_buffer=gm_slot_buffer, + c2v_consumer_buf=c2v_local, + v2c_consumer_buf=v2c_import, + ) + + doubled_tile = pto.alloc_tile(vec_ty) + recv_tile = pto.tpop_from_aic(vec_ty, 0) + tile.add(recv_tile, recv_tile, doubled_tile) + pto.tpush_to_aic(doubled_tile, 0) + pto.tfree_from_aic(0) + + @pto.func(entry=True) + def call_both( + ffts_addr: "ffts_ty", gm_slot_buffer: "ptr_ty", gm_x: "ptr_ty", gm_y: "ptr_ty" + ) -> None: + pto.set_ffts(ffts_addr) + pto.call(cube_kernel, gm_slot_buffer, gm_x, gm_y) + pto.call(vector_kernel, gm_slot_buffer) + + +if __name__ == "__main__": + print(module) diff --git a/examples/aot/tpushpop/mix-kernel_mlir/kernels/c2v_add_builder.py b/examples/aot/tpushpop/mix-kernel_mlir/kernels/c2v_add_builder.py new file mode 100644 index 00000000..d0aef4ae --- /dev/null +++ b/examples/aot/tpushpop/mix-kernel_mlir/kernels/c2v_add_builder.py @@ -0,0 +1,117 @@ +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + + +def meta_data(): + ffts_ty = pto.ffts_type + dtype = pto.float32 + ptr_ty = pto.PtrType(dtype) + i32 = pto.int32 + tensor_ty = pto.TensorType(rank=2, dtype=dtype) + tile_view_ty = pto.SubTensorType(shape=[16, 16], dtype=dtype) + x_mat_ty = pto.TileBufType(shape=[16, 16], dtype=dtype, memory_space="MAT") + x_left_ty = pto.TileBufType( + shape=[16, 16], + dtype=dtype, + memory_space="LEFT", + config=pto.TileBufConfig(blayout="ColMajor", slayout="RowMajor"), + ) + x_right_ty = pto.TileBufType(shape=[16, 16], dtype=dtype, memory_space="RIGHT") + acc_ty = pto.TileBufType(shape=[16, 16], dtype=dtype, memory_space="ACC") + vec_ty = pto.TileBufType(shape=[16, 16], dtype=dtype, memory_space="VEC") + return locals() + + +@to_ir_module(meta_data=meta_data, module=True) +def module(): + @pto.func(kernel="cube") + def cube_kernel(gm_slot_buffer: "ptr_ty", gm_x: "ptr_ty") -> None: + c0 = const(0) + c1 = const(1) + c16 = const(16) + c0_i32 = const(0, type=i32) + c2v_import = pto.import_reserved_buffer( + name="c2v_fifo", + peer_func="@vector_kernel", + ) + + pto.aic_initialize_pipe( + dir_mask=1, + slot_size=1024, + gm_slot_buffer=gm_slot_buffer, + c2v_consumer_buf=c2v_import, + v2c_consumer_buf=c0_i32, + ) + + x_mat_tile = pto.alloc_tile(x_mat_ty) + x_left_tile = pto.alloc_tile(x_left_ty) + x_right_tile = pto.alloc_tile(x_right_ty) + acc_tile = pto.alloc_tile(acc_ty) + + gm_x_tile_view = pto.slice_view( + tile_view_ty, + source=pto.as_tensor( + tensor_ty, + ptr=gm_x, + shape=[c16, c16], + strides=[c16, c1], + ), + offsets=[c0, c0], + sizes=[c16, c16], + ) + + pto.load(gm_x_tile_view, x_mat_tile) + tile.mov(x_mat_tile, x_left_tile) + tile.mov(x_mat_tile, x_right_tile) + tile.matmul(x_left_tile, x_right_tile, acc_tile) + # Debug step: only send cube's result to vector. + pto.tpush_to_aiv(acc_tile, 0) + + @pto.func(kernel="vector") + def vector_kernel(gm_slot_buffer: "ptr_ty", gm_y: "ptr_ty") -> None: + c0 = const(0) + c1 = const(1) + c16 = const(16) + c0_i32 = const(0, type=i32) + c2v_local = pto.reserve_buffer(name="c2v_fifo", size=4096, location="VEC") + + pto.aiv_initialize_pipe( + dir_mask=1, + slot_size=1024, + gm_slot_buffer=gm_slot_buffer, + c2v_consumer_buf=c2v_local, + v2c_consumer_buf=c0_i32, + ) + + gm_y_tile_view = pto.slice_view( + tile_view_ty, + source=pto.as_tensor( + tensor_ty, + ptr=gm_y, + shape=[c16, c16], + strides=[c16, c1], + ), + offsets=[c0, c0], + sizes=[c16, c16], + ) + + doubled_tile = pto.alloc_tile(vec_ty) + recv_tile = pto.tpop_from_aic(vec_ty, 0) + # First isolate the vector-side path: pop, double, store from vector. + tile.add(recv_tile, recv_tile, doubled_tile) + pto.store(doubled_tile, gm_y_tile_view) + pto.tfree_from_aic(0) + + @pto.func(entry=True) + def call_both( + ffts_addr: "ffts_ty", gm_slot_buffer: "ptr_ty", gm_x: "ptr_ty", gm_y: "ptr_ty" + ) -> None: + pto.set_ffts(ffts_addr) + pto.call(cube_kernel, gm_slot_buffer, gm_x) + pto.call(vector_kernel, gm_slot_buffer, gm_y) + + +if __name__ == "__main__": + print(module) diff --git a/examples/aot/tpushpop/mix-kernel_mlir/kernels/c2v_builder.py b/examples/aot/tpushpop/mix-kernel_mlir/kernels/c2v_builder.py new file mode 100644 index 00000000..51312f36 --- /dev/null +++ b/examples/aot/tpushpop/mix-kernel_mlir/kernels/c2v_builder.py @@ -0,0 +1,112 @@ +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + + +def meta_data(): + ffts_ty = pto.ffts_type + dtype = pto.float32 + ptr_ty = pto.PtrType(dtype) + i32 = pto.int32 + tensor_ty = pto.TensorType(rank=2, dtype=dtype) + tile_view_ty = pto.SubTensorType(shape=[16, 16], dtype=dtype) + x_mat_ty = pto.TileBufType(shape=[16, 16], dtype=dtype, memory_space="MAT") + x_left_ty = pto.TileBufType( + shape=[16, 16], + dtype=dtype, + memory_space="LEFT", + config=pto.TileBufConfig(blayout="ColMajor", slayout="RowMajor"), + ) + x_right_ty = pto.TileBufType(shape=[16, 16], dtype=dtype, memory_space="RIGHT") + acc_ty = pto.TileBufType(shape=[16, 16], dtype=dtype, memory_space="ACC") + recv_ty = pto.TileBufType(shape=[16, 16], dtype=dtype, memory_space="VEC") + return locals() + + +@to_ir_module(meta_data=meta_data, module=True) +def module(): + @pto.func(kernel="cube") + def cube_kernel(gm_slot_buffer: "ptr_ty", gm_x: "ptr_ty") -> None: + c0 = const(0) + c1 = const(1) + c16 = const(16) + c0_i32 = const(0, type=i32) + c2v_import = pto.import_reserved_buffer( + name="c2v_fifo", + peer_func="@vector_kernel", + ) + + pto.aic_initialize_pipe( + dir_mask=1, + slot_size=1024, + gm_slot_buffer=gm_slot_buffer, + c2v_consumer_buf=c2v_import, + v2c_consumer_buf=c0_i32, + ) + + x_mat_tile = pto.alloc_tile(x_mat_ty) + x_left_tile = pto.alloc_tile(x_left_ty) + x_right_tile = pto.alloc_tile(x_right_ty) + acc_tile = pto.alloc_tile(acc_ty) + + gm_x_tile_view = pto.slice_view( + tile_view_ty, + source=pto.as_tensor( + tensor_ty, + ptr=gm_x, + shape=[c16, c16], + strides=[c16, c1], + ), + offsets=[c0, c0], + sizes=[c16, c16], + ) + + pto.load(gm_x_tile_view, x_mat_tile) + tile.mov(x_mat_tile, x_left_tile) + tile.mov(x_mat_tile, x_right_tile) + tile.matmul(x_left_tile, x_right_tile, acc_tile) + pto.tpush_to_aiv(acc_tile, 0) + + @pto.func(kernel="vector") + def vector_kernel(gm_slot_buffer: "ptr_ty", gm_y: "ptr_ty") -> None: + c0 = const(0) + c1 = const(1) + c16 = const(16) + c0_i32 = const(0, type=i32) + c2v_local = pto.reserve_buffer(name="c2v_fifo", size=4096, location="VEC") + + pto.aiv_initialize_pipe( + dir_mask=1, + slot_size=1024, + gm_slot_buffer=gm_slot_buffer, + c2v_consumer_buf=c2v_local, + v2c_consumer_buf=c0_i32, + ) + + gm_y_tile_view = pto.slice_view( + tile_view_ty, + source=pto.as_tensor( + tensor_ty, + ptr=gm_y, + shape=[c16, c16], + strides=[c16, c1], + ), + offsets=[c0, c0], + sizes=[c16, c16], + ) + + pto.store(pto.tpop_from_aic(recv_ty, 0), gm_y_tile_view) + pto.tfree_from_aic(0) + + @pto.func(entry=True) + def call_both( + ffts_addr: "ffts_ty", gm_slot_buffer: "ptr_ty", gm_x: "ptr_ty", gm_y: "ptr_ty" + ) -> None: + pto.set_ffts(ffts_addr) + pto.call(cube_kernel, gm_slot_buffer, gm_x) + pto.call(vector_kernel, gm_slot_buffer, gm_y) + + +if __name__ == "__main__": + print(module) diff --git a/examples/aot/tpushpop/mix-kernel_mlir/kernels/v2c_builder.py b/examples/aot/tpushpop/mix-kernel_mlir/kernels/v2c_builder.py new file mode 100644 index 00000000..96ba943e --- /dev/null +++ b/examples/aot/tpushpop/mix-kernel_mlir/kernels/v2c_builder.py @@ -0,0 +1,106 @@ +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + + +def meta_data(): + ffts_ty = pto.ffts_type + dtype = pto.float32 + ptr_ty = pto.PtrType(dtype) + i32 = pto.int32 + tensor_ty = pto.TensorType(rank=2, dtype=dtype) + tile_view_ty = pto.SubTensorType(shape=[16, 16], dtype=dtype) + vec_ty = pto.TileBufType(shape=[16, 16], dtype=dtype, memory_space="VEC") + recv_ty = pto.TileBufType( + shape=[16, 16], + dtype=dtype, + memory_space="MAT", + config=pto.TileBufConfig( + blayout="RowMajor", + slayout="NoneBox", + s_fractal_size=512, + ), + ) + return locals() + + +@to_ir_module(meta_data=meta_data, module=True) +def module(): + @pto.func(kernel="cube") + def cube_kernel(gm_slot_buffer: "ptr_ty", gm_y: "ptr_ty") -> None: + c0 = const(0) + c1 = const(1) + c16 = const(16) + c0_i32 = const(0, type=i32) + v2c_local = pto.reserve_buffer(name="v2c_fifo", size=4096, location="MAT") + + pto.aic_initialize_pipe( + dir_mask=2, + slot_size=1024, + gm_slot_buffer=gm_slot_buffer, + c2v_consumer_buf=c0_i32, + v2c_consumer_buf=v2c_local, + ) + + gm_y_tile_view = pto.slice_view( + tile_view_ty, + source=pto.as_tensor( + tensor_ty, + ptr=gm_y, + shape=[c16, c16], + strides=[c16, c1], + ), + offsets=[c0, c0], + sizes=[c16, c16], + ) + + pto.store(pto.tpop_from_aiv(recv_ty, 0), gm_y_tile_view) + pto.tfree_from_aiv(0) + + @pto.func(kernel="vector") + def vector_kernel(gm_slot_buffer: "ptr_ty", gm_x: "ptr_ty") -> None: + c0 = const(0) + c1 = const(1) + c16 = const(16) + c0_i32 = const(0, type=i32) + v2c_import = pto.import_reserved_buffer( + name="v2c_fifo", + peer_func="@cube_kernel", + ) + + pto.aiv_initialize_pipe( + dir_mask=2, + slot_size=1024, + gm_slot_buffer=gm_slot_buffer, + c2v_consumer_buf=c0_i32, + v2c_consumer_buf=v2c_import, + ) + + gm_x_tile_view = pto.slice_view( + tile_view_ty, + source=pto.as_tensor( + tensor_ty, + ptr=gm_x, + shape=[c16, c16], + strides=[c16, c1], + ), + offsets=[c0, c0], + sizes=[c16, c16], + ) + + send_tile = pto.alloc_tile(vec_ty) + pto.load(gm_x_tile_view, send_tile) + pto.tpush_to_aic(send_tile, 0) + + @pto.func(entry=True) + def call_both( + ffts_addr: "ffts_ty", gm_slot_buffer: "ptr_ty", gm_x: "ptr_ty", gm_y: "ptr_ty" + ) -> None: + pto.set_ffts(ffts_addr) + pto.call(cube_kernel, gm_slot_buffer, gm_y) + pto.call(vector_kernel, gm_slot_buffer, gm_x) + + +if __name__ == "__main__": + print(module) diff --git a/examples/aot/tpushpop/mix-kernel_mlir/run.py b/examples/aot/tpushpop/mix-kernel_mlir/run.py new file mode 100644 index 00000000..f749e378 --- /dev/null +++ b/examples/aot/tpushpop/mix-kernel_mlir/run.py @@ -0,0 +1,130 @@ +import argparse +import ctypes +import os +import subprocess + +import torch +import torch_npu # noqa: F401 + +from ptodsl.test_util import get_test_device + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +DEFAULT_LIB_PATH = os.path.join(THIS_DIR, "build_artifacts", "tpushpop_mlir_lib.so") +DEFAULT_COMPILE_SCRIPT = os.path.join(THIS_DIR, "compile.sh") +DEFAULT_FIFO_BYTES = 4 * 1024 +DEFAULT_FIFO_BYTES_BOTH = 8 * 1024 +M = 16 +N = 16 +ATOL = 1e-4 +RTOL = 1e-4 +MODES = ("c2v", "c2v_add", "v2c", "bidi") + + +def torch_to_ctypes(tensor: torch.Tensor) -> ctypes.c_void_p: + return ctypes.c_void_p(tensor.data_ptr()) + + +def compile_example(compile_script: str, mode: str) -> None: + env = dict(os.environ, TPUSHPOP_MODE=mode) + subprocess.run( + ["bash", compile_script], + check=True, + cwd=THIS_DIR, + env=env, + ) + + +def load_lib(lib_path: str) -> ctypes.CDLL: + lib = ctypes.CDLL(lib_path) + lib.call_kernel.argtypes = [ + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ] + lib.call_kernel.restype = None + return lib + + +def make_gm_slot_buffer(*, fifo_bytes: int, device: str) -> torch.Tensor: + fifo_elems = max(1, (fifo_bytes + 3) // 4) + return torch.zeros((fifo_elems,), dtype=torch.float32, device=device) + + +def make_io_tensors(*, device: str) -> tuple[torch.Tensor, torch.Tensor]: + x = torch.rand((M, N), dtype=torch.float32, device=device) - 0.5 + y = torch.zeros((M, N), dtype=torch.float32, device=device) + return x, y + + +def fifo_bytes_for_mode(mode: str) -> int: + return DEFAULT_FIFO_BYTES_BOTH if mode in ("v2c", "bidi") else DEFAULT_FIFO_BYTES + + +def run_kernel( + lib: ctypes.CDLL, *, gm_slot_buffer: torch.Tensor, x: torch.Tensor, y: torch.Tensor +) -> None: + stream_ptr = torch.npu.current_stream()._as_parameter_ + lib.call_kernel( + 1, + stream_ptr, + torch_to_ctypes(gm_slot_buffer), + torch_to_ctypes(x), + torch_to_ctypes(y), + ) + torch.npu.synchronize() + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("mode", nargs="?", choices=MODES, default="c2v") + return parser.parse_args() + + +def reference(mode: str, x: torch.Tensor) -> torch.Tensor: + y = x.cpu() @ x.cpu() + if mode == "c2v": + return y + if mode == "v2c": + return x.cpu() + return 2 * y + + +def main() -> None: + args = parse_args() + compile_example(DEFAULT_COMPILE_SCRIPT, args.mode) + + device = get_test_device() + torch.npu.set_device(device) + + lib = load_lib(DEFAULT_LIB_PATH) + gm_slot_buffer = make_gm_slot_buffer( + fifo_bytes=fifo_bytes_for_mode(args.mode), + device=device, + ) + torch.set_printoptions(precision=1, threshold=2000, linewidth=250, sci_mode=False) + x, y = make_io_tensors(device=device) + + print(y) + run_kernel(lib, gm_slot_buffer=gm_slot_buffer, x=x, y=y) + print(y) + + y_ref = reference(args.mode, x) + y_cpu = y.cpu() + + print(y_ref - y_cpu) + max_abs = float(torch.max(torch.abs(y_cpu - y_ref)).item()) + ok = bool(torch.allclose(y_cpu, y_ref, atol=ATOL, rtol=RTOL)) + + print(f"shape=({M}, {N}) max_abs={max_abs:.6f}") + if not ok: + raise SystemExit( + f"Validation failed with atol={ATOL} rtol={RTOL}. max_abs={max_abs:.6f}" + ) + + print(f"Validation passed for mode={args.mode} using {DEFAULT_LIB_PATH}.") + + +if __name__ == "__main__": + main() diff --git a/ptodsl/api/pto.py b/ptodsl/api/pto.py index f2e2d0ac..caf5cda8 100644 --- a/ptodsl/api/pto.py +++ b/ptodsl/api/pto.py @@ -1,16 +1,29 @@ +from ..compiler.ir import ir_func as func from .control_flow import cond, range, if_context from .scalar import Value, wrap_value from .pto_general import ( alloc_tile, + aic_initialize_pipe, + aiv_initialize_pipe, as_tensor, + call, + set_ffts, cube_section, get_block_idx, get_block_num, get_subblock_idx, get_subblock_num, + import_reserved_buffer, load, + reserve_buffer, slice_view, store, + tfree_from_aic, + tfree_from_aiv, + tpop_from_aic, + tpop_from_aiv, + tpush_to_aic, + tpush_to_aiv, vector_section, print, ) @@ -33,15 +46,19 @@ "float32", "int16", "int32", + "ffts_type", "PtrType", "TensorType", "SubTensorType", "TileBufConfig", "TileBufType", + "func", "get_block_idx", "get_subblock_idx", "get_subblock_num", "get_block_num", + "call", + "set_ffts", "as_tensor", "slice_view", "vector_section", @@ -49,9 +66,19 @@ "range", "if_context", "cond", + "reserve_buffer", + "import_reserved_buffer", + "aic_initialize_pipe", + "aiv_initialize_pipe", "alloc_tile", "load", "store", + "tpush_to_aiv", + "tpush_to_aic", + "tpop_from_aic", + "tpop_from_aiv", + "tfree_from_aic", + "tfree_from_aiv", "print", "record_event", "wait_event", diff --git a/ptodsl/api/pto_general.py b/ptodsl/api/pto_general.py index c8f649ea..9187d25d 100644 --- a/ptodsl/api/pto_general.py +++ b/ptodsl/api/pto_general.py @@ -1,7 +1,7 @@ from contextlib import contextmanager from mlir.dialects import pto as _pto -from mlir.ir import InsertionPoint +from mlir.ir import FlatSymbolRefAttr, InsertionPoint, Operation from .scalar import Value, _unwrap @@ -30,6 +30,32 @@ def _resolve_layout_attr(layout): return layout +def _resolve_address_space_attr(location): + if isinstance(location, str): + return _pto.AddressSpaceAttr.get(getattr(_pto.AddressSpace, location.upper())) + return location + + +def _resolve_peer_func_attr(peer_func): + if hasattr(peer_func, "sym_name"): + peer_func = peer_func.sym_name + if isinstance(peer_func, str): + return FlatSymbolRefAttr.get(peer_func.removeprefix("@")) + return peer_func + + +def call(callee, *args): + return Operation.create( + "func.call", + operands=[_unwrap(arg) for arg in args], + attributes={"callee": _resolve_peer_func_attr(callee)}, + ) + + +def set_ffts(ffts): + return _pto.SetFFTsOp(_unwrap(ffts)) + + def as_tensor(tensor_type, *, ptr, shape, strides, layout=None): shape_vals = [_unwrap(v) for v in shape] stride_vals = [_unwrap(v) for v in strides] @@ -77,6 +103,111 @@ def alloc_tile(tile_type, *, addr=None, valid_row=None, valid_col=None): return _pto.AllocTileOp(tile_type, **kwargs).result +# %c2v_local = pto.reserve_buffer { +# name = "c2v_fifo", +# size = 4096, +# location = #pto.address_space, +# auto = true +# } -> i32 +def reserve_buffer(*, name, size, location, auto_alloc=True, base=None): + """ + - At most one `pto.reserve_buffer` is expected in one function + - `location` must be a supported local address space + - Op-level verification requires: + - `auto = false` must provide `base` + - `auto = true` must not provide `base` + """ + # All params are compile time attributes + # wrap reserve_buffer(name, size, location, auto_alloc, *, base=None, loc=None, ip=None) -> mlir._mlir_libs._mlir.ir.Value + + return _pto.ReserveBufferOp( + name, size, _resolve_address_space_attr(location), auto_alloc, base=base + ).result + + +# %c2v_import = pto.import_reserved_buffer { +# name = "c2v_fifo", +# peer_func = @vector_kernel +# } -> i32 +def import_reserved_buffer(*, name, peer_func): + # wrap import_reserved_buffer(name, peer_func, *, loc=None, ip=None) -> mlir._mlir_libs._mlir.ir.Value + return _pto.ImportReservedBufferOp(name, _resolve_peer_func_attr(peer_func)).result + + +def aic_initialize_pipe( + *, + dir_mask, + slot_size, + gm_slot_buffer=None, # only needed on a2/a3? + c2v_consumer_buf, + v2c_consumer_buf, +): + # wrap aic_initialize_pipe(dir_mask, slot_size, c2v_consumer_buf, v2c_consumer_buf, *, gm_slot_buffer=None, loc=None, ip=None) -> mlir._mlir_libs._mlir.ir.Operation + return _pto.AicInitializePipeOp( + dir_mask, + slot_size, + c2v_consumer_buf=_unwrap(c2v_consumer_buf), + v2c_consumer_buf=_unwrap(v2c_consumer_buf), + gm_slot_buffer=_unwrap(gm_slot_buffer), + ) + + +# pto.aiv_initialize_pipe {dir_mask = 1, slot_size = 1024} ( +# gm_slot_buffer = %gm_slot_buffer : !pto.ptr, +# c2v_consumer_buf = %c2v_local : i32, +# v2c_consumer_buf = %c0_i32 : i32 +# ) +def aiv_initialize_pipe( + *, + dir_mask, + slot_size, + gm_slot_buffer=None, # only needed on a2/a3 + c2v_consumer_buf, + v2c_consumer_buf, +): + # wrap aiv_initialize_pipe(dir_mask, slot_size, c2v_consumer_buf, v2c_consumer_buf, *, gm_slot_buffer=None, loc=None, ip=None) -> mlir._mlir_libs._mlir.ir.Operation + return _pto.AivInitializePipeOp( + dir_mask, + slot_size, + c2v_consumer_buf=_unwrap(c2v_consumer_buf), + v2c_consumer_buf=_unwrap(v2c_consumer_buf), + gm_slot_buffer=_unwrap(gm_slot_buffer), + ) + + +# pto.tpush_to_aiv(%acc_tile : !pto.tile_buf) {split = 0} +def tpush_to_aiv(tile, split): + # wrap tpush_to_aiv(tile, split, *, loc=None, ip=None) -> mlir._mlir_libs._mlir.ir.Operation + return _pto.TPushToAivOp(_unwrap(tile), split) + + +def tpush_to_aic(tile, split): + # wrap: tpush_to_aic(tile, split, *, loc=None, ip=None) -> mlir._mlir_libs._mlir.ir.Operation + return _pto.TPushToAicOp(_unwrap(tile), split) + + +# %recv_tile = pto.tpop_from_aic {split = 0} -> !pto.tile_buf +def tpop_from_aic(tile_type, split): + # wrap tpop_from_aic(tile, split, *, loc=None, ip=None) -> mlir._mlir_libs._mlir.ir.Value + return _pto.TPopFromAicOp(tile_type, split).result + + +def tpop_from_aiv(tile_type, split): + # wraps tpop_from_aiv(tile, split, *, loc=None, ip=None) -> mlir._mlir_libs._mlir.ir.Value + return _pto.TPopFromAivOp(tile_type, split).result + + +# pto.tfree_from_aic {split = 0} +def tfree_from_aic(split): + # wrap tfree_from_aic(split, *, loc=None, ip=None) -> mlir._mlir_libs._mlir.ir.Operation + return _pto.TFreeFromAicOp(split) + + +def tfree_from_aiv(split): + # wrap tfree_from_aiv(split, *, loc=None, ip=None) -> mlir._mlir_libs._mlir.ir.Operation + return _pto.TFreeFromAivOp(split) + + def load(source, dest): _pto.TLoadOp(None, source, dest) @@ -106,12 +237,24 @@ def print(format, scalar): "get_subblock_idx", "get_subblock_num", "get_block_num", + "call", + "set_ffts", "as_tensor", "slice_view", "vector_section", "cube_section", "alloc_tile", + "reserve_buffer", + "import_reserved_buffer", + "aic_initialize_pipe", + "aiv_initialize_pipe", "load", "store", + "tpush_to_aiv", + "tpush_to_aic", + "tpop_from_aic", + "tpop_from_aiv", + "tfree_from_aic", + "tfree_from_aiv", "print", ] diff --git a/ptodsl/api/scalar.py b/ptodsl/api/scalar.py index 9561eb85..857f8780 100644 --- a/ptodsl/api/scalar.py +++ b/ptodsl/api/scalar.py @@ -104,8 +104,10 @@ def __getattr__(name): raise AttributeError(f"module '{__name__}' has no attribute '{name}'") -def const(value): - return Value(arith.ConstantOp(IndexType.get(), value).result) +def const(value, type=None): + if type is None: + type = IndexType.get() + return Value(arith.ConstantOp(type, value).result) def index_cast(value, index_type=IndexType): diff --git a/ptodsl/api/type_def.py b/ptodsl/api/type_def.py index 414478c2..4f84bc4a 100644 --- a/ptodsl/api/type_def.py +++ b/ptodsl/api/type_def.py @@ -1,4 +1,5 @@ from mlir.dialects import pto as _pto +from mlir.ir import IntegerType, MemRefType from . import scalar @@ -18,6 +19,8 @@ def __getattr__(name): "int64", }: return getattr(scalar, name) + if name == "ffts_type": + return MemRefType.get([256], IntegerType.get_signless(64)) raise AttributeError(f"module '{__name__}' has no attribute '{name}'") @@ -118,5 +121,6 @@ def TileBufType(*, shape, dtype, memory_space, valid_shape=None, config=None): "float32", "int16", "int32", + "ffts_type", "uint32", ] diff --git a/ptodsl/compiler/ir.py b/ptodsl/compiler/ir.py index b32730ef..c2cef082 100644 --- a/ptodsl/compiler/ir.py +++ b/ptodsl/compiler/ir.py @@ -1,11 +1,21 @@ import inspect from mlir.dialects import func, pto as _pto -from mlir.ir import Context, InsertionPoint, Location, Module +from mlir.ir import Attribute, Context, InsertionPoint, Location, Module, UnitAttr from ..api.scalar import wrap_value +# For the inner decorators to be clean for the user visible API `pto.func(kernel='cube')` +# with no reference to module, we need this: +_CURRENT = None + + +class FuncRef: + def __init__(self, sym_name): + self.sym_name = sym_name + + def _resolve_meta(meta_fn): values = meta_fn() if not isinstance(values, dict): @@ -41,10 +51,7 @@ def _resolve_ret_types(signature, meta_map): if isinstance(ret_annot, (list, tuple)): out = [] for elem in ret_annot: - if isinstance(elem, str): - out.append(meta_map[elem]) - else: - out.append(elem) + out.append(meta_map[elem] if isinstance(elem, str) else elem) return out return [ret_annot] @@ -64,46 +71,94 @@ def _inject_globals(fn, values): return old -def _restore_globals(fn, old, injected_names): - for name in injected_names: +def _restore_globals(fn, old, names): + for name in names: if old[name] is None and name in fn.__globals__: del fn.__globals__[name] else: fn.__globals__[name] = old[name] -def to_ir_module(*, meta_data): +def _define(module, ctx, meta_map, fn, *, name=None, entry=False, kernel=None): + sig = inspect.signature(fn) + arg_types = _resolve_arg_types(sig, meta_map) + ret_types = _resolve_ret_types(sig, meta_map) + fn_name = name or fn.__name__ + fn_ty = func.FunctionType.get(arg_types, ret_types) + + with InsertionPoint(module.body): + ir_func = func.FuncOp(fn_name, fn_ty) + + if entry: + ir_func.operation.attributes["pto.entry"] = UnitAttr.get(ctx) + if kernel is not None: + ir_func.operation.attributes["pto.kernel_kind"] = Attribute.parse( + f"#pto.kernel_kind<{kernel}>" + ) + + block = ir_func.add_entry_block() + with InsertionPoint(block): + wrapped_args = [wrap_value(arg) for arg in block.arguments] + old = _inject_globals(fn, meta_map) + try: + fn(*wrapped_args) + finally: + _restore_globals(fn, old, meta_map.keys()) + + if not ret_types and not _has_func_return(block): + func.ReturnOp([]) + + return FuncRef(fn_name) + + +def ir_func(*, name=None, entry=False, kernel=None): + def decorator(fn): + if _CURRENT is None: + raise RuntimeError( + "`pto.func(...)` can only be used inside `@to_ir_module(..., module=True)`." + ) + return _define( + _CURRENT["module"], + _CURRENT["ctx"], + _CURRENT["meta_map"], + fn, + name=name, + entry=entry, + kernel=kernel, + ) + + return decorator + + +def to_ir_module(*, meta_data, module=False): def decorator(fn): - sig = inspect.signature(fn) + global _CURRENT with Context() as ctx, Location.unknown(): _pto.register_dialect(ctx, load=True) meta_map = _resolve_meta(meta_data) - arg_types = _resolve_arg_types(sig, meta_map) - ret_types = _resolve_ret_types(sig, meta_map) - module = Module.create() - fn_ty = func.FunctionType.get(arg_types, ret_types) - - with InsertionPoint(module.body): - ir_func = func.FuncOp(fn.__name__, fn_ty) - entry = ir_func.add_entry_block() - - with InsertionPoint(entry): - wrapped_args = [wrap_value(arg) for arg in entry.arguments] - injected = set(meta_map.keys()) - old_globals = _inject_globals(fn, meta_map) + ir_module = Module.create() + + if module: + if inspect.signature(fn).parameters: + raise ValueError( + "`module=True` expects a zero-argument builder function." + ) + old = _inject_globals(fn, meta_map) + prev = _CURRENT + _CURRENT = {"ctx": ctx, "module": ir_module, "meta_map": meta_map} try: - fn(*wrapped_args) + fn() finally: - _restore_globals(fn, old_globals, injected) - - if not ret_types and not _has_func_return(entry): - func.ReturnOp([]) + _CURRENT = prev + _restore_globals(fn, old, meta_map.keys()) + else: + _define(ir_module, ctx, meta_map, fn) - module.operation.verify() - return module + ir_module.operation.verify() + return ir_module return decorator -__all__ = ["to_ir_module"] +__all__ = ["FuncRef", "ir_func", "to_ir_module"] diff --git a/tests/frontend/test_multifunc_ir.py b/tests/frontend/test_multifunc_ir.py new file mode 100644 index 00000000..0b4343e4 --- /dev/null +++ b/tests/frontend/test_multifunc_ir.py @@ -0,0 +1,113 @@ +import subprocess + +from mlir.dialects import func, pto as _pto +from mlir.ir import ( + Attribute, + Context, + FlatSymbolRefAttr, + InsertionPoint, + Location, + Module, + Operation, + UnitAttr, +) + +from ptodsl import pto, to_ir_module + + +def meta_data(): + dtype = pto.float32 + ptr_ty = pto.PtrType(dtype) + return {"ptr_ty": ptr_ty} + + +@to_ir_module(meta_data=meta_data) +def single_kernel(arg0: "ptr_ty") -> None: + pass + + +@to_ir_module(meta_data=meta_data, module=True) +def multi_kernel_module(): + @pto.func(kernel="vector") + def worker(arg0: "ptr_ty") -> None: + pass + + @pto.func(entry=True) + def entry(arg0: "ptr_ty") -> None: + pto.call(worker, arg0) + + +def build_single_verbose(): + with Context() as ctx, Location.unknown(): + _pto.register_dialect(ctx, load=True) + module = Module.create() + ptr_ty = _pto.PtrType.get(pto.float32) + fn_ty = func.FunctionType.get([ptr_ty], []) + + with InsertionPoint(module.body): + fn = func.FuncOp("single_kernel", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + func.ReturnOp([]) + + module.operation.verify() + return module + + +def build_multi_verbose(): + with Context() as ctx, Location.unknown(): + _pto.register_dialect(ctx, load=True) + module = Module.create() + ptr_ty = _pto.PtrType.get(pto.float32) + fn_ty = func.FunctionType.get([ptr_ty], []) + + with InsertionPoint(module.body): + worker = func.FuncOp("worker", fn_ty) + entry = func.FuncOp("entry", fn_ty) + + worker.operation.attributes["pto.kernel_kind"] = Attribute.parse( + "#pto.kernel_kind" + ) + entry.operation.attributes["pto.entry"] = UnitAttr.get(ctx) + + with InsertionPoint(worker.add_entry_block()): + func.ReturnOp([]) + + entry_block = entry.add_entry_block() + with InsertionPoint(entry_block): + arg0 = entry_block.arguments[0] + Operation.create( + "func.call", + operands=[arg0], + attributes={"callee": FlatSymbolRefAttr.get("worker")}, + ) + func.ReturnOp([]) + + module.operation.verify() + return module + + +def test_old_single_function_builder_matches_raw_mlir(): + assert str(single_kernel) == str(build_single_verbose()) + + +def test_new_multi_function_builder_matches_raw_mlir(): + assert str(multi_kernel_module) == str(build_multi_verbose()) + + +def test_multi_function_module_compiles_with_ptoas(tmp_path): + pto_path = tmp_path / "multi_kernel_module.pto" + cpp_path = tmp_path / "multi_kernel_module.cpp" + pto_path.write_text(str(multi_kernel_module), encoding="utf-8") + + subprocess.run( + [ + "ptoas", + "--enable-insert-sync", + str(pto_path), + "-o", + str(cpp_path), + ], + check=True, + )