diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7d578d8b..041288ab 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,10 +47,10 @@ jobs: env: RELEASE_REPO: hw-native-sys/PTOAS - RELEASE_VER: 0.31 - RELEASE_TAG: v0.31 + RELEASE_VER: 0.37 + RELEASE_TAG: v0.37 CLI_DIR: /installers/ptoas-cli - PTOISA_COMMIT: 0af942568a4f2868673da0a35b0f5b64f27a20d5 + PTOISA_COMMIT: 933ad5d84c98377ca19f1de2e6616ba79136056a steps: - name: Install system packages diff --git a/docker/Dockerfile b/docker/Dockerfile index dc27a0ef..882f5082 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -15,8 +15,8 @@ RUN pip install --no-cache-dir \ pytest pybind11 nanobind setuptools wheel \ ipython jupyterlab matplotlib pandas -# This specific commit might not be found as it has been forced push over -ARG PTOISA_COMMIT=4e27a104f948e883e0bef44670252381bff794c5 +# This commit is dated 2024-05-12 +ARG PTOISA_COMMIT=933ad5d84c98377ca19f1de2e6616ba79136056a WORKDIR /sources RUN git clone --single-branch --branch master https://gitcode.com/cann/pto-isa.git \ && cd pto-isa && git checkout $PTOISA_COMMIT @@ -29,8 +29,8 @@ ARG CACHE_BURST=1 # ARG ARCH=x86_64 ARG ARCH=aarch64 ARG RELEASE_REPO=hw-native-sys/PTOAS -ARG RELEASE_VER=0.36 -ARG RELEASE_TAG=v${RELEASE_VER} +ARG RELEASE_VER=0.37 +ARG RELEASE_TAG=v0.37 ARG WHEEL_NAME=ptoas-${RELEASE_VER}-cp311-none-manylinux_2_34_${ARCH}.whl ARG CLI_TAR_NAME=ptoas-bin-${ARCH}.tar.gz diff --git a/examples/aot/flash_attention/140tflops/README.md b/examples/aot/flash_attention/140tflops/README.md new file mode 100644 index 00000000..98c64298 --- /dev/null +++ b/examples/aot/flash_attention/140tflops/README.md @@ -0,0 +1,82 @@ +# Flash Attention 140 TFLOP/s DSL Builders + +This directory has two PTODSL Flash Attention builders: + +- `fa_dsl_builder.py`: default `TILE_S1=256` builder. +- `fa_dsl_builder_tile512.py`: experimental `TILE_S1=512` builder. + +Both compile scripts write the same runtime artifact: + +```text +build_artifacts/fa_dsl.so +``` + +`run.py` always loads that file. Compile first, then run. + +## Build + +Default 256-tile kernel: + +```bash +bash compile.sh +``` + +Experimental 512-tile kernel: + +```bash +bash compile_tile512.sh +``` + +The 512-tile builder uses `TILE_S1=512` and `QK_PRELOAD=4`, so it requires +`S1 >= 2048`. The default `run.py` sweep skips `S1=1024` for this builder. + +## Run + +Run one or more specific sequence lengths: + +```bash +python run.py --s1-values 8192 +python run.py --s1-values 8192,131072 +``` + +Benchmark PTODSL perf only for one sequence length: + +```bash +python run.py --perf-mode 131072 +``` + +## Vector Barrier Removal Experiment + +Both compile scripts accept selected generated-C++ vector barrier removals: + +```bash +bash compile.sh --remove-vec-barriers line1,line2,... +bash compile_tile512.sh --remove-vec-barriers line1,line2,... +``` + +This only removes lines containing: + +```cpp +pipe_barrier(PIPE_V); +``` + +The patched C++ is emitted as: + +```text +build_artifacts/fa_dsl_patched.cpp +``` + +The compiled shared object is still: + +```text +build_artifacts/fa_dsl.so +``` + +### Known Useful 256-Tile Variant + +Use: + +```bash +bash compile.sh --remove-vec-barriers 1264,1267,1272,1275,1279,1282,1311,1313,1316,1320,1322,1325,1328,1330,1333,1362,1364,1367,1371,1373,1376,1379,1381,1384,1390 +python run.py --perf-mode 131072 +``` diff --git a/examples/aot/flash_attention/140tflops/caller.cpp b/examples/aot/flash_attention/140tflops/caller.cpp index d17a0ec8..0dfe6530 100644 --- a/examples/aot/flash_attention/140tflops/caller.cpp +++ b/examples/aot/flash_attention/140tflops/caller.cpp @@ -25,7 +25,7 @@ extern "C" void call_kernel( (void)fftsLen; call_both<<>>( - (__gm__ int64_t *)fftsAddr, + (__gm__ uint64_t *)fftsAddr, (__gm__ float *)gmSlotBuffer, (__gm__ half *)gmSlotBuffer, (__gm__ half *)q, diff --git a/examples/aot/flash_attention/140tflops/compile.sh b/examples/aot/flash_attention/140tflops/compile.sh index 4e5a6a6b..80f41ab0 100755 --- a/examples/aot/flash_attention/140tflops/compile.sh +++ b/examples/aot/flash_attention/140tflops/compile.sh @@ -2,36 +2,23 @@ set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/compile_common.sh" + ARTIFACT_DIR="${SCRIPT_DIR}/build_artifacts" PTO_LIB_PATH="${PTO_LIB_PATH:-/sources/pto-isa}" NPU_ARCH="${NPU_ARCH:-dav-2201}" - +PTO_LEVEL="${PTO_LEVEL:-}" MLIR_PATH="${ARTIFACT_DIR}/fa_dsl.mlir" GENERATED_CPP="${ARTIFACT_DIR}/fa_dsl.cpp" +PATCHED_CPP="${ARTIFACT_DIR}/fa_dsl_patched.cpp" LIB_PATH="${ARTIFACT_DIR}/fa_dsl.so" -FA_DSL_BUILDER="${FA_DSL_BUILDER:-fa_dsl_builder.py}" -BUILDER_PATH="${SCRIPT_DIR}/${FA_DSL_BUILDER}" -PTOAS_SYNC_ARGS=(--enable-insert-sync) +RUNTIME_BUILDER_PATH="${ARTIFACT_DIR}/fa_dsl_runtime_builder.py" +BUILDER_PATH="${SCRIPT_DIR}/fa_dsl_builder.py" -if [[ $# -gt 1 ]]; then - echo "Usage: $0 [--manual-sync]" >&2 - exit 2 -fi - -if [[ $# -eq 1 ]]; then - case "$1" in - --manual-sync) - PTOAS_SYNC_ARGS=() - ;; - *) - echo "Usage: $0 [--manual-sync]" >&2 - exit 2 - ;; - esac -fi +parse_common_compile_args "$@" mkdir -p "${ARTIFACT_DIR}" -rm -f "${MLIR_PATH}" "${GENERATED_CPP}" "${LIB_PATH}" +rm -f "${MLIR_PATH}" "${GENERATED_CPP}" "${PATCHED_CPP}" "${LIB_PATH}" "${RUNTIME_BUILDER_PATH}" if [[ ! -f "${BUILDER_PATH}" ]]; then echo "Builder not found: ${BUILDER_PATH}" >&2 @@ -39,7 +26,15 @@ if [[ ! -f "${BUILDER_PATH}" ]]; then fi python "${BUILDER_PATH}" > "${MLIR_PATH}" -ptoas --pto-arch=a3 "${PTOAS_SYNC_ARGS[@]}" "${MLIR_PATH}" > "${GENERATED_CPP}" + +PTOAS_ARGS=(--pto-arch=a3) +if [[ -n "${PTO_LEVEL}" ]]; then + PTOAS_ARGS+=("--pto-level=${PTO_LEVEL}") +fi +PTOAS_ARGS+=("${PTOAS_SYNC_ARGS[@]}") + +ptoas "${PTOAS_ARGS[@]}" "${MLIR_PATH}" > "${GENERATED_CPP}" +maybe_patch_vec_barriers "${GENERATED_CPP}" "${PATCHED_CPP}" "${REMOVE_VEC_BARRIER_LINES}" bisheng \ -I"${PTO_LIB_PATH}/include" \ @@ -54,9 +49,11 @@ bisheng \ -cce-enable-mix \ --npu-arch="${NPU_ARCH}" -DMEMORY_BASE \ -std=gnu++17 \ - -DKERNEL_CPP="\"${GENERATED_CPP}\"" \ + -DKERNEL_CPP="\"${PATCHED_CPP}\"" \ "${SCRIPT_DIR}/caller.cpp" \ -o "${LIB_PATH}" echo "Generated ${GENERATED_CPP}." echo "Built ${LIB_PATH}." +cp "${BUILDER_PATH}" "${RUNTIME_BUILDER_PATH}" +echo "Runtime builder ${RUNTIME_BUILDER_PATH}." diff --git a/examples/aot/flash_attention/140tflops/compile_common.sh b/examples/aot/flash_attention/140tflops/compile_common.sh new file mode 100644 index 00000000..5390ce50 --- /dev/null +++ b/examples/aot/flash_attention/140tflops/compile_common.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env bash + +parse_common_compile_args() { + PTOAS_SYNC_ARGS=(--enable-insert-sync) + REMOVE_VEC_BARRIER_LINES="" + + while [[ $# -gt 0 ]]; do + case "$1" in + --remove-vec-barriers) + if [[ $# -lt 2 || -z "$2" ]]; then + echo "--remove-vec-barriers requires a comma-separated line list" >&2 + exit 2 + fi + REMOVE_VEC_BARRIER_LINES="$2" + shift 2 + ;; + *) + echo "Usage: $0 [--remove-vec-barriers line1,line2,...]" >&2 + exit 2 + ;; + esac + done +} + +maybe_patch_vec_barriers() { + local src_cpp="$1" + local dst_cpp="$2" + local raw_lines="$3" + + if [[ -z "${raw_lines}" ]]; then + PATCHED_CPP="${src_cpp}" + return + fi + + python - "${src_cpp}" "${dst_cpp}" "${raw_lines}" <<'PY' +from pathlib import Path +import sys + +src = Path(sys.argv[1]) +dst = Path(sys.argv[2]) +remove_lines = {int(part.strip()) for part in sys.argv[3].split(",") if part.strip()} + +lines = src.read_text().splitlines() +patched = [] +for i, line in enumerate(lines, start=1): + if i in remove_lines and "pipe_barrier(PIPE_V);" in line: + patched.append(" /* removed PIPE_V barrier via --remove-vec-barriers */") + else: + patched.append(line) + +dst.write_text("\n".join(patched) + "\n") +print(f"Patched generated C++ -> {dst}") +PY + + PATCHED_CPP="${dst_cpp}" +} diff --git a/examples/aot/flash_attention/140tflops/compile_tile512.sh b/examples/aot/flash_attention/140tflops/compile_tile512.sh new file mode 100755 index 00000000..3c951e59 --- /dev/null +++ b/examples/aot/flash_attention/140tflops/compile_tile512.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/compile_common.sh" + +ARTIFACT_DIR="${SCRIPT_DIR}/build_artifacts" +PTO_LIB_PATH="${PTO_LIB_PATH:-/sources/pto-isa}" +NPU_ARCH="${NPU_ARCH:-dav-2201}" +PTO_LEVEL="${PTO_LEVEL:-}" + +MLIR_PATH="${ARTIFACT_DIR}/fa_dsl.mlir" +GENERATED_CPP="${ARTIFACT_DIR}/fa_dsl.cpp" +PATCHED_CPP="${ARTIFACT_DIR}/fa_dsl_patched.cpp" +LIB_PATH="${ARTIFACT_DIR}/fa_dsl.so" +RUNTIME_BUILDER_PATH="${ARTIFACT_DIR}/fa_dsl_runtime_builder.py" +BUILDER_PATH="${SCRIPT_DIR}/fa_dsl_builder_tile512.py" + +parse_common_compile_args "$@" + +mkdir -p "${ARTIFACT_DIR}" +rm -f "${MLIR_PATH}" "${GENERATED_CPP}" "${PATCHED_CPP}" "${LIB_PATH}" "${RUNTIME_BUILDER_PATH}" + +python "${BUILDER_PATH}" > "${MLIR_PATH}" + +PTOAS_ARGS=(--pto-arch=a3) +if [[ -n "${PTO_LEVEL}" ]]; then + PTOAS_ARGS+=("--pto-level=${PTO_LEVEL}") +fi +PTOAS_ARGS+=("${PTOAS_SYNC_ARGS[@]}") + +ptoas "${PTOAS_ARGS[@]}" "${MLIR_PATH}" > "${GENERATED_CPP}" +maybe_patch_vec_barriers "${GENERATED_CPP}" "${PATCHED_CPP}" "${REMOVE_VEC_BARRIER_LINES}" + +bisheng \ + -I"${PTO_LIB_PATH}/include" \ + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ + -xcce -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + -cce-enable-mix \ + --npu-arch="${NPU_ARCH}" -DMEMORY_BASE \ + -std=gnu++17 \ + -DKERNEL_CPP="\"${PATCHED_CPP}\"" \ + "${SCRIPT_DIR}/caller.cpp" \ + -o "${LIB_PATH}" + +cp "${BUILDER_PATH}" "${RUNTIME_BUILDER_PATH}" + +echo "Generated ${GENERATED_CPP}." +echo "Built ${LIB_PATH}." +echo "Runtime builder ${RUNTIME_BUILDER_PATH}." diff --git a/examples/aot/flash_attention/140tflops/fa_dsl_builder.py b/examples/aot/flash_attention/140tflops/fa_dsl_builder.py index 192b1c5b..644bf727 100644 --- a/examples/aot/flash_attention/140tflops/fa_dsl_builder.py +++ b/examples/aot/flash_attention/140tflops/fa_dsl_builder.py @@ -16,7 +16,6 @@ SPLIT_UP_DOWN = 1 SLOT_NUM = 8 -LOCAL_SLOT_NUM = 1 QK_PRELOAD = int(os.environ.get("FA_DSL_QK_PRELOAD", "3")) EXP_RING = int(os.environ.get("FA_DSL_EXP_RING", "3")) @@ -52,7 +51,6 @@ def meta_data(): q_sub_ty = pto.SubTensorType(shape=[CUBE_S0, HEAD], dtype=fp16) kt_sub_ty = pto.SubTensorType(shape=[HEAD, CUBE_S1], dtype=fp16) v_sub_ty = pto.SubTensorType(shape=[CUBE_S1, HEAD], dtype=fp16) - o_sub_half_ty = pto.SubTensorType(shape=[S0_HALF, HEAD], dtype=fp32) o_sub_vec_ty = pto.SubTensorType(shape=[VEC_ROWS, HEAD], dtype=fp32) qk_slot_part_ty = pto.SubTensorType(shape=[CUBE_S0, CUBE_S1], dtype=fp32) qk_vec_slot_part_ty = pto.SubTensorType(shape=[VEC_ROWS, TILE_S1], dtype=fp32) @@ -456,17 +454,8 @@ def update_softmax(qk0, qk1, exp_max_first, exp_max_second): ) cEXP_RING = const(EXP_RING) - - def dispatch_exp(tile_id, builder, idx=0): - if idx == EXP_RING - 1: - builder(exp_max_first_tiles[idx], exp_max_second_tiles[idx]) - return - with pto.if_context( - (tile_id % cEXP_RING) == const(idx), has_else=True - ) as branch: - builder(exp_max_first_tiles[idx], exp_max_second_tiles[idx]) - with branch.else_context(): - dispatch_exp(tile_id, builder, idx + 1) + if EXP_RING != 3: + raise ValueError("fa_dsl_builder.py fast path expects EXP_RING == 3") qk_entry = pto.declare_global(qk_vec_slot_ty) p_entry = pto.declare_global(p_vec_slot_ty) @@ -522,15 +511,29 @@ def compute_p_update(tile_id, exp_max_first, exp_max_second): def compute_p_update_dispatch(tile_id): pop_qk_slot() - dispatch_exp( - tile_id, - lambda exp_max_first, exp_max_second: update_softmax( + mod = tile_id % cEXP_RING + with pto.if_context(mod == c0, has_else=True) as branch0: + update_softmax( qk_first, qk_second, - exp_max_first, - exp_max_second, - ), - ) + exp_max_first_tiles[0], + exp_max_second_tiles[0], + ) + with branch0.else_context(): + with pto.if_context(mod == c1, has_else=True) as branch1: + update_softmax( + qk_first, + qk_second, + exp_max_first_tiles[1], + exp_max_second_tiles[1], + ) + with branch1.else_context(): + update_softmax( + qk_first, + qk_second, + exp_max_first_tiles[2], + exp_max_second_tiles[2], + ) push_p_slot() def pop_pv_slot(): @@ -559,24 +562,22 @@ def compute_gu_init(): tile.mov(recv_second, o_second) free_pv_slot() - def compute_gu_update(exp_max_first, exp_max_second): - pop_pv_slot() + def apply_gu_update(exp_max_first, exp_max_second): tile.row_expand_mul(o_first, exp_max_first, o_first) tile.add(o_first, recv_first, o_first) tile.row_expand_mul(o_second, exp_max_second, o_second) tile.add(o_second, recv_second, o_second) - free_pv_slot() def compute_gu_update_dispatch(tile_id): pop_pv_slot() - - def update_o(exp_max_first, exp_max_second): - tile.row_expand_mul(o_first, exp_max_first, o_first) - tile.add(o_first, recv_first, o_first) - tile.row_expand_mul(o_second, exp_max_second, o_second) - tile.add(o_second, recv_second, o_second) - - dispatch_exp(tile_id, update_o) + mod = tile_id % cEXP_RING + with pto.if_context(mod == c0, has_else=True) as branch0: + apply_gu_update(exp_max_first_tiles[0], exp_max_second_tiles[0]) + with branch0.else_context(): + with pto.if_context(mod == c1, has_else=True) as branch1: + apply_gu_update(exp_max_first_tiles[1], exp_max_second_tiles[1]) + with branch1.else_context(): + apply_gu_update(exp_max_first_tiles[2], exp_max_second_tiles[2]) free_pv_slot() def compute_gu(tile_id): @@ -585,14 +586,14 @@ def compute_gu(tile_id): tile.mov(recv_first, o_first) tile.mov(recv_second, o_second) with branch.else_context(): - - def update_o(exp_max_first, exp_max_second): - tile.row_expand_mul(o_first, exp_max_first, o_first) - tile.add(o_first, recv_first, o_first) - tile.row_expand_mul(o_second, exp_max_second, o_second) - tile.add(o_second, recv_second, o_second) - - dispatch_exp(tile_id, update_o) + mod = tile_id % cEXP_RING + with pto.if_context(mod == c0, has_else=True) as branch0: + apply_gu_update(exp_max_first_tiles[0], exp_max_second_tiles[0]) + with branch0.else_context(): + with pto.if_context(mod == c1, has_else=True) as branch1: + apply_gu_update(exp_max_first_tiles[1], exp_max_second_tiles[1]) + with branch1.else_context(): + apply_gu_update(exp_max_first_tiles[2], exp_max_second_tiles[2]) free_pv_slot() compute_p_init() diff --git a/examples/aot/flash_attention/140tflops/fa_dsl_builder_tile512.py b/examples/aot/flash_attention/140tflops/fa_dsl_builder_tile512.py new file mode 100644 index 00000000..d73ae38f --- /dev/null +++ b/examples/aot/flash_attention/140tflops/fa_dsl_builder_tile512.py @@ -0,0 +1,658 @@ +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +import math +import os + +const = s.const + +CUBE_S0 = 128 +S0_HALF = CUBE_S0 // 2 +HEAD = 128 +CUBE_S1 = 128 +TILE_S1 = 512 +SUBTILES = TILE_S1 // CUBE_S1 +VEC_ROWS = S0_HALF // SUBTILES + +SPLIT_UP_DOWN = 1 +SLOT_NUM = 8 +QK_PRELOAD = int(os.environ.get("FA_DSL_QK_PRELOAD", "4")) +EXP_RING = int(os.environ.get("FA_DSL_EXP_RING", "4")) + +SLOT_SIZE_QK = CUBE_S0 * TILE_S1 * 4 +SLOT_SIZE_PV = CUBE_S0 * HEAD * 4 +SLOT_SIZE_P = CUBE_S0 * TILE_S1 * 2 + +GM_BYTES_PER_BLOCK = (SLOT_SIZE_QK + SLOT_SIZE_P + SLOT_SIZE_PV) * SLOT_NUM +GM_ELEMS_PER_BLOCK = GM_BYTES_PER_BLOCK // 4 +GM_HALF_ELEMS_PER_BLOCK = GM_BYTES_PER_BLOCK // 2 +GM_QK_OFF_F32 = 0 +GM_P_OFF_F16 = (SLOT_SIZE_QK * SLOT_NUM) // 2 +GM_PV_OFF_F32 = ((SLOT_SIZE_QK + SLOT_SIZE_P) * SLOT_NUM) // 4 + + +def meta_data(): + fp16 = pto.float16 + fp32 = pto.float32 + ffts_ty = pto.ffts_type + ptr_fp16 = pto.PtrType(fp16) + ptr_fp32 = pto.PtrType(fp32) + i64 = pto.int64 + + qkv_tensor_ty = pto.TensorType(rank=2, dtype=fp16) + o_tensor_ty = pto.TensorType(rank=2, dtype=fp32) + qk_slot_ty = pto.TensorType(shape=[CUBE_S0, TILE_S1], dtype=fp32) + qk_vec_slot_ty = pto.TensorType(shape=[S0_HALF, TILE_S1], dtype=fp32) + p_slot_ty = pto.TensorType(shape=[CUBE_S0, TILE_S1], dtype=fp16) + p_vec_slot_ty = pto.TensorType(shape=[S0_HALF, TILE_S1], dtype=fp16) + pv_slot_ty = pto.TensorType(shape=[CUBE_S0, HEAD], dtype=fp32) + pv_vec_slot_ty = pto.TensorType(shape=[S0_HALF, HEAD], dtype=fp32) + + q_sub_ty = pto.SubTensorType(shape=[CUBE_S0, HEAD], dtype=fp16) + kt_sub_ty = pto.SubTensorType(shape=[HEAD, CUBE_S1], dtype=fp16) + v_sub_ty = pto.SubTensorType(shape=[CUBE_S1, HEAD], dtype=fp16) + o_sub_vec_ty = pto.SubTensorType(shape=[VEC_ROWS, HEAD], dtype=fp32) + qk_slot_part_ty = pto.SubTensorType(shape=[CUBE_S0, CUBE_S1], dtype=fp32) + qk_vec_slot_part_ty = pto.SubTensorType(shape=[VEC_ROWS, TILE_S1], dtype=fp32) + p_slot_part_ty = pto.SubTensorType(shape=[CUBE_S0, CUBE_S1], dtype=fp16) + p_vec_slot_part_ty = pto.SubTensorType(shape=[VEC_ROWS, TILE_S1], dtype=fp16) + pv_slot_part_ty = pto.SubTensorType(shape=[CUBE_S0, HEAD], dtype=fp32) + pv_vec_slot_part_ty = pto.SubTensorType(shape=[VEC_ROWS, HEAD], dtype=fp32) + + q_mat_ty = pto.TileBufType(shape=[CUBE_S0, HEAD], dtype=fp16, memory_space="MAT") + q_left_ty = pto.TileBufType(shape=[CUBE_S0, HEAD], dtype=fp16, memory_space="LEFT") + k_mat_ty = pto.TileBufType( + shape=[HEAD, CUBE_S1], + dtype=fp16, + memory_space="MAT", + config=pto.TileBufConfig(blayout="RowMajor", slayout="ColMajor"), + ) + k_right_ty = pto.TileBufType( + shape=[HEAD, CUBE_S1], dtype=fp16, memory_space="RIGHT" + ) + qk_acc_ty = pto.TileBufType( + shape=[CUBE_S0, CUBE_S1], dtype=fp32, memory_space="ACC" + ) + + p_recv_ty = pto.TileBufType( + shape=[CUBE_S0, CUBE_S1], dtype=fp16, memory_space="MAT" + ) + p_left_ty = pto.TileBufType( + shape=[CUBE_S0, CUBE_S1], dtype=fp16, memory_space="LEFT" + ) + v_mat_ty = pto.TileBufType(shape=[CUBE_S1, HEAD], dtype=fp16, memory_space="MAT") + v_right_ty = pto.TileBufType( + shape=[CUBE_S1, HEAD], dtype=fp16, memory_space="RIGHT" + ) + pv_acc_ty = pto.TileBufType(shape=[CUBE_S0, HEAD], dtype=fp32, memory_space="ACC") + + qk_vec_ty = pto.TileBufType( + shape=[VEC_ROWS, TILE_S1], dtype=fp32, memory_space="VEC" + ) + p_fp32_ty = pto.TileBufType( + shape=[VEC_ROWS, TILE_S1], dtype=fp32, memory_space="VEC" + ) + p_fp16_ty = pto.TileBufType( + shape=[VEC_ROWS, TILE_S1], dtype=fp16, memory_space="VEC" + ) + pv_vec_ty = pto.TileBufType(shape=[VEC_ROWS, HEAD], dtype=fp32, memory_space="VEC") + o_vec_ty = pto.TileBufType(shape=[VEC_ROWS, HEAD], dtype=fp32, memory_space="VEC") + tri_ty = pto.TileBufType(shape=[VEC_ROWS, TILE_S1], dtype=fp32, memory_space="VEC") + red_ty = pto.TileBufType( + shape=[VEC_ROWS, 1], + dtype=fp32, + memory_space="VEC", + config=pto.TileBufConfig(blayout="ColMajor", slayout="NoneBox"), + ) + red_row_ty = pto.TileBufType(shape=[1, VEC_ROWS], dtype=fp32, memory_space="VEC") + + return locals() + + +@to_ir_module(meta_data=meta_data, module=True) +def module(): + @pto.func(kernel="cube") + def cube_kernel( + gm_slot_buffer: "ptr_fp32", + gm_slot_buffer_h: "ptr_fp16", + gm_q: "ptr_fp16", + gm_k: "ptr_fp16", + gm_v: "ptr_fp16", + s0_i64: "i64", + s1_i64: "i64", + ) -> None: + c0 = const(0) + c1 = const(1) + cS0 = const(CUBE_S0) + cHEAD = const(HEAD) + cTILE = const(TILE_S1) + cCUBE_S1 = const(CUBE_S1) + cGM_BLOCK = const(GM_ELEMS_PER_BLOCK) + cGM_BLOCK_H = const(GM_HALF_ELEMS_PER_BLOCK) + + bid = s.index_cast(pto.get_block_idx()) + s0 = s.index_cast(s0_i64) + s1 = s.index_cast(s1_i64) + num_tiles_s1 = s1 // cTILE + q_row_off = bid * cS0 + tiles_this_block = num_tiles_s1 + + gm_blk = pto.add_ptr(gm_slot_buffer, bid * cGM_BLOCK) + gm_blk_h = pto.add_ptr(gm_slot_buffer_h, bid * cGM_BLOCK_H) + gm_qk = pto.add_ptr(gm_blk, const(GM_QK_OFF_F32)) + gm_p = pto.add_ptr(gm_blk_h, const(GM_P_OFF_F16)) + gm_pv = pto.add_ptr(gm_blk, const(GM_PV_OFF_F32)) + + qk_slot_desc = pto.as_tensor( + qk_slot_ty, + ptr=gm_qk, + shape=[cS0, cTILE], + strides=[cTILE, c1], + ) + qk_pipe = pto.initialize_l2g2l_pipe( + dir_mask=1, + slot_size=SLOT_SIZE_QK, + slot_num=SLOT_NUM, + gm_addr=qk_slot_desc, + flag_base=0, + ) + + p_slot_desc = pto.as_tensor( + p_slot_ty, + ptr=gm_p, + shape=[cS0, cTILE], + strides=[cTILE, c1], + ) + p_pipe = pto.initialize_l2g2l_pipe( + dir_mask=2, + slot_size=SLOT_SIZE_P, + slot_num=SLOT_NUM, + gm_addr=p_slot_desc, + flag_base=2, + ) + + pv_slot_desc = pto.as_tensor( + pv_slot_ty, + ptr=gm_pv, + shape=[cS0, cHEAD], + strides=[cHEAD, c1], + ) + pv_pipe = pto.initialize_l2g2l_pipe( + dir_mask=1, + slot_size=SLOT_SIZE_PV, + slot_num=SLOT_NUM, + gm_addr=pv_slot_desc, + flag_base=4, + ) + + tv_q = pto.as_tensor( + qkv_tensor_ty, ptr=gm_q, shape=[s0, cHEAD], strides=[cHEAD, c1] + ) + tv_k = pto.as_tensor( + qkv_tensor_ty, + ptr=gm_k, + shape=[cHEAD, s1], + strides=[c1, cHEAD], + layout="DN", + ) + tv_v = pto.as_tensor( + qkv_tensor_ty, ptr=gm_v, shape=[s1, cHEAD], strides=[cHEAD, c1] + ) + + q_mat = pto.alloc_tile(q_mat_ty) + q_left = pto.alloc_tile(q_left_ty) + k_mat = pto.alloc_tile(k_mat_ty) + k_right = pto.alloc_tile(k_right_ty) + qk_acc = pto.alloc_tile(qk_acc_ty) + p_recv = pto.alloc_tile(p_recv_ty) + p_left = pto.alloc_tile(p_left_ty) + v_mat = pto.alloc_tile(v_mat_ty) + v_right = pto.alloc_tile(v_right_ty) + pv_acc = pto.alloc_tile(pv_acc_ty) + + q_view = pto.slice_view( + q_sub_ty, source=tv_q, offsets=[q_row_off, c0], sizes=[cS0, cHEAD] + ) + pto.load(q_view, q_mat) + tile.mov(q_mat, q_left) + + qk_entry = pto.declare_global(qk_slot_ty) + p_entry = pto.declare_global(p_slot_ty) + pv_entry = pto.declare_global(pv_slot_ty) + + def compute_qk_sub(tile_id, sub): + tile_col_off = tile_id * cTILE + + k_col_off = tile_col_off + const(sub * CUBE_S1) + kt_view = pto.slice_view( + kt_sub_ty, + source=tv_k, + offsets=[c0, k_col_off], + sizes=[cHEAD, cCUBE_S1], + ) + pto.load(kt_view, k_mat) + tile.mov(k_mat, k_right) + tile.matmul(q_left, k_right, qk_acc) + qk_part = pto.slice_view( + qk_slot_part_ty, + source=qk_entry, + offsets=[c0, const(sub * CUBE_S1)], + sizes=[cS0, cCUBE_S1], + ) + pto.store(qk_acc, qk_part) + + def compute_qk(tile_id): + pto.talloc(qk_entry, qk_pipe, SPLIT_UP_DOWN) + for sub in range(SUBTILES): + compute_qk_sub(tile_id, sub) + pto.tpush(qk_entry, qk_pipe, SPLIT_UP_DOWN) + + def compute_pv_sub(tile_id, sub): + tile_col_off = tile_id * cTILE + + v_col_off = tile_col_off + const(sub * CUBE_S1) + v_view = pto.slice_view( + v_sub_ty, + source=tv_v, + offsets=[v_col_off, c0], + sizes=[cCUBE_S1, cHEAD], + ) + pto.load(v_view, v_mat) + p_part = pto.slice_view( + p_slot_part_ty, + source=p_entry, + offsets=[c0, const(sub * CUBE_S1)], + sizes=[cS0, cCUBE_S1], + ) + pto.load(p_part, p_recv) + tile.mov(p_recv, p_left) + tile.mov(v_mat, v_right) + if sub == 0: + tile.matmul(p_left, v_right, pv_acc) + else: + tile.matmul_acc(pv_acc, p_left, v_right, pv_acc) + + def push_pv(): + pto.tfree(p_pipe, SPLIT_UP_DOWN, entry=p_entry) + + pto.talloc(pv_entry, pv_pipe, SPLIT_UP_DOWN) + pv_part = pto.slice_view( + pv_slot_part_ty, + source=pv_entry, + offsets=[c0, c0], + sizes=[cS0, cHEAD], + ) + pto.store(pv_acc, pv_part) + pto.tpush(pv_entry, pv_pipe, SPLIT_UP_DOWN) + + def compute_pv(tile_id): + pto.tpop_into(p_entry, p_pipe, SPLIT_UP_DOWN) + for sub in range(SUBTILES): + compute_pv_sub(tile_id, sub) + push_pv() + + def compute_qk_pv_interleaved(next_tile, tile_id): + pto.tpop_into(p_entry, p_pipe, SPLIT_UP_DOWN) + for sub in range(SUBTILES): + compute_pv_sub(tile_id, sub) + if sub == 0: + pto.talloc(qk_entry, qk_pipe, SPLIT_UP_DOWN) + if sub == SUBTILES - 1: + push_pv() + compute_qk_sub(next_tile, sub) + if sub == SUBTILES - 1: + pto.tpush(qk_entry, qk_pipe, SPLIT_UP_DOWN) + + for preload in range(QK_PRELOAD): + compute_qk(const(preload)) + + cPRELOAD = const(QK_PRELOAD) + steady_end = tiles_this_block - cPRELOAD + for tile_id in pto.range(c0, steady_end, c1): + next_tile = tile_id + cPRELOAD + compute_qk_pv_interleaved(next_tile, tile_id) + + for drain in range(QK_PRELOAD): + tile_id = steady_end + const(drain) + compute_pv(tile_id) + + @pto.func(kernel="vector") + def vector_kernel( + gm_slot_buffer: "ptr_fp32", + gm_slot_buffer_h: "ptr_fp16", + gm_o: "ptr_fp32", + s0_i64: "i64", + s1_i64: "i64", + ) -> None: + c0 = const(0) + c1 = const(1) + cS0 = const(CUBE_S0) + cS0_HALF = const(S0_HALF) + cVEC_ROWS = const(VEC_ROWS) + cHEAD = const(HEAD) + cTILE = const(TILE_S1) + cCUBE_S1 = const(CUBE_S1) + cGM_BLOCK = const(GM_ELEMS_PER_BLOCK) + cGM_BLOCK_H = const(GM_HALF_ELEMS_PER_BLOCK) + + bid = s.index_cast(pto.get_block_idx()) + sbid = s.index_cast(pto.get_subblock_idx()) + s0 = s.index_cast(s0_i64) + s1 = s.index_cast(s1_i64) + num_tiles_s1 = s1 // cTILE + q_row_off = bid * cS0 + row_off_sb = sbid * cS0_HALF + q_row_off_sb = q_row_off + row_off_sb + tiles_this_block = num_tiles_s1 + + gm_blk = pto.add_ptr(gm_slot_buffer, bid * cGM_BLOCK) + gm_blk_h = pto.add_ptr(gm_slot_buffer_h, bid * cGM_BLOCK_H) + gm_qk = pto.add_ptr(gm_blk, const(GM_QK_OFF_F32)) + gm_p = pto.add_ptr(gm_blk_h, const(GM_P_OFF_F16)) + gm_pv = pto.add_ptr(gm_blk, const(GM_PV_OFF_F32)) + + qk_slot_desc = pto.as_tensor( + qk_vec_slot_ty, + ptr=gm_qk, + shape=[cS0_HALF, cTILE], + strides=[cTILE, c1], + ) + qk_pipe = pto.initialize_l2g2l_pipe( + dir_mask=1, + slot_size=SLOT_SIZE_QK, + slot_num=SLOT_NUM, + gm_addr=qk_slot_desc, + flag_base=0, + ) + + p_slot_desc = pto.as_tensor( + p_vec_slot_ty, + ptr=gm_p, + shape=[cS0_HALF, cTILE], + strides=[cTILE, c1], + ) + p_pipe = pto.initialize_l2g2l_pipe( + dir_mask=2, + slot_size=SLOT_SIZE_P, + slot_num=SLOT_NUM, + gm_addr=p_slot_desc, + flag_base=2, + ) + + pv_slot_desc = pto.as_tensor( + pv_vec_slot_ty, + ptr=gm_pv, + shape=[cS0_HALF, cHEAD], + strides=[cHEAD, c1], + ) + pv_pipe = pto.initialize_l2g2l_pipe( + dir_mask=1, + slot_size=SLOT_SIZE_PV, + slot_num=SLOT_NUM, + gm_addr=pv_slot_desc, + flag_base=4, + ) + + tv_o = pto.as_tensor( + o_tensor_ty, ptr=gm_o, shape=[s0, cHEAD], strides=[cHEAD, c1] + ) + + qk_tile = pto.alloc_tile(qk_vec_ty) + reduce_tmp = pto.alloc_tile(p_fp32_ty) + p_fp16 = pto.alloc_tile(p_fp16_ty) + recv_tile = pto.alloc_tile(pv_vec_ty) + o_tiles = [pto.alloc_tile(o_vec_ty) for _ in range(SUBTILES)] + global_max_tiles = [pto.alloc_tile(red_ty) for _ in range(SUBTILES)] + global_sum_tiles = [pto.alloc_tile(red_ty) for _ in range(SUBTILES)] + local_max = pto.alloc_tile(red_ty) + local_sum = pto.alloc_tile(red_ty) + exp_max_tiles = [ + [pto.alloc_tile(red_ty) for _ in range(EXP_RING)] for _ in range(SUBTILES) + ] + + scale = const(1.0 / math.sqrt(HEAD), s.float32) + + def init_softmax_slice(qk, global_max, global_sum): + tile.muls(qk, scale, qk) + tile.row_max(qk, reduce_tmp, global_max) + tile.row_expand_sub(qk, global_max, qk) + tile.exp(qk, qk) + tile.row_sum(qk, reduce_tmp, global_sum) + + def update_softmax_slice(qk, exp_max, global_max, global_sum): + tile.muls(qk, scale, qk) + tile.row_max(qk, reduce_tmp, local_max) + local_max_r = tile.reshape(red_row_ty, local_max) + exp_max_r = tile.reshape(red_row_ty, exp_max) + global_max_r = tile.reshape(red_row_ty, global_max) + global_sum_r = tile.reshape(red_row_ty, global_sum) + local_sum_r = tile.reshape(red_row_ty, local_sum) + + tile.max(local_max_r, global_max_r, local_max_r) + tile.sub(global_max_r, local_max_r, exp_max_r) + tile.exp(exp_max_r, exp_max_r) + tile.mov(local_max_r, global_max_r) + tile.mul(global_sum_r, exp_max_r, global_sum_r) + + tile.row_expand_sub(qk, local_max, qk) + tile.exp(qk, qk) + tile.row_sum(qk, reduce_tmp, local_sum) + tile.add(global_sum_r, local_sum_r, global_sum_r) + + def init_softmax_row(row_slice): + init_softmax_slice( + qk_tile, + global_max_tiles[row_slice], + global_sum_tiles[row_slice], + ) + + def update_softmax_row(row_slice, exp_max): + update_softmax_slice( + qk_tile, + exp_max, + global_max_tiles[row_slice], + global_sum_tiles[row_slice], + ) + + cEXP_RING = const(EXP_RING) + # if EXP_RING != 3: + # raise ValueError("fa_dsl_builder.py fast path expects EXP_RING == 3") + + qk_entry = pto.declare_global(qk_vec_slot_ty) + p_entry = pto.declare_global(p_vec_slot_ty) + pv_entry = pto.declare_global(pv_vec_slot_ty) + + def load_qk_row(row_slice): + qk_part = pto.slice_view( + qk_vec_slot_part_ty, + source=qk_entry, + offsets=[const(row_slice * VEC_ROWS), c0], + sizes=[cVEC_ROWS, cTILE], + ) + pto.load(qk_part, qk_tile) + + def store_p_row(row_slice): + tile.cvt(qk_tile, p_fp16, rmode="cast_rint") + p_part = pto.slice_view( + p_vec_slot_part_ty, + source=p_entry, + offsets=[const(row_slice * VEC_ROWS), c0], + sizes=[cVEC_ROWS, cTILE], + ) + pto.store(p_fp16, p_part) + + def compute_p_init(): + pto.tpop_into(qk_entry, qk_pipe, SPLIT_UP_DOWN) + pto.talloc(p_entry, p_pipe, SPLIT_UP_DOWN) + for row_slice in range(SUBTILES): + load_qk_row(row_slice) + init_softmax_row(row_slice) + store_p_row(row_slice) + pto.tfree(qk_pipe, SPLIT_UP_DOWN, entry=qk_entry) + pto.tpush(p_entry, p_pipe, SPLIT_UP_DOWN) + + def compute_p_update(tile_id, ring_idx): + pto.tpop_into(qk_entry, qk_pipe, SPLIT_UP_DOWN) + pto.talloc(p_entry, p_pipe, SPLIT_UP_DOWN) + for row_slice in range(SUBTILES): + load_qk_row(row_slice) + update_softmax_row(row_slice, exp_max_tiles[row_slice][ring_idx]) + store_p_row(row_slice) + pto.tfree(qk_pipe, SPLIT_UP_DOWN, entry=qk_entry) + pto.tpush(p_entry, p_pipe, SPLIT_UP_DOWN) + + def compute_p_update_dispatch(tile_id): + mod = tile_id % cEXP_RING + with pto.if_context(mod == c0, has_else=True) as branch0: + compute_p_update(tile_id, 0) + with branch0.else_context(): + with pto.if_context(mod == c1, has_else=True) as branch1: + compute_p_update(tile_id, 1) + with branch1.else_context(): + with pto.if_context(mod == const(2), has_else=True) as branch2: + compute_p_update(tile_id, 2) + with branch2.else_context(): + compute_p_update(tile_id, 3) + + def load_pv_row(row_slice): + pv_part = pto.slice_view( + pv_vec_slot_part_ty, + source=pv_entry, + offsets=[const(row_slice * VEC_ROWS), c0], + sizes=[cVEC_ROWS, cHEAD], + ) + pto.load(pv_part, recv_tile) + + def free_pv_slot(): + pto.tfree(pv_pipe, SPLIT_UP_DOWN, entry=pv_entry) + + def compute_gu_init(): + pto.tpop_into(pv_entry, pv_pipe, SPLIT_UP_DOWN) + for row_slice in range(SUBTILES): + load_pv_row(row_slice) + tile.mov(recv_tile, o_tiles[row_slice]) + free_pv_slot() + + def apply_gu_update_row(row_slice, exp_max): + tile.row_expand_mul(o_tiles[row_slice], exp_max, o_tiles[row_slice]) + tile.add(o_tiles[row_slice], recv_tile, o_tiles[row_slice]) + + def compute_gu_update_dispatch(tile_id): + pto.tpop_into(pv_entry, pv_pipe, SPLIT_UP_DOWN) + mod = tile_id % cEXP_RING + for row_slice in range(SUBTILES): + load_pv_row(row_slice) + with pto.if_context(mod == c0, has_else=True) as branch0: + apply_gu_update_row(row_slice, exp_max_tiles[row_slice][0]) + with branch0.else_context(): + with pto.if_context(mod == c1, has_else=True) as branch1: + apply_gu_update_row(row_slice, exp_max_tiles[row_slice][1]) + with branch1.else_context(): + with pto.if_context(mod == const(2), has_else=True) as branch2: + apply_gu_update_row(row_slice, exp_max_tiles[row_slice][2]) + with branch2.else_context(): + apply_gu_update_row(row_slice, exp_max_tiles[row_slice][3]) + free_pv_slot() + + def compute_gu(tile_id): + pto.tpop_into(pv_entry, pv_pipe, SPLIT_UP_DOWN) + with pto.if_context(tile_id == c0, has_else=True) as branch: + for row_slice in range(SUBTILES): + load_pv_row(row_slice) + tile.mov(recv_tile, o_tiles[row_slice]) + with branch.else_context(): + mod = tile_id % cEXP_RING + for row_slice in range(SUBTILES): + load_pv_row(row_slice) + with pto.if_context(mod == c0, has_else=True) as branch0: + apply_gu_update_row(row_slice, exp_max_tiles[row_slice][0]) + with branch0.else_context(): + with pto.if_context(mod == c1, has_else=True) as branch1: + apply_gu_update_row(row_slice, exp_max_tiles[row_slice][1]) + with branch1.else_context(): + with pto.if_context( + mod == const(2), has_else=True + ) as branch2: + apply_gu_update_row( + row_slice, exp_max_tiles[row_slice][2] + ) + with branch2.else_context(): + apply_gu_update_row( + row_slice, exp_max_tiles[row_slice][3] + ) + free_pv_slot() + + compute_p_init() + for preload in range(1, QK_PRELOAD): + compute_p_update( + const(preload), + preload % EXP_RING, + ) + + cPRELOAD = const(QK_PRELOAD) + steady_end = tiles_this_block - cPRELOAD + with pto.if_context(steady_end > c0): + compute_gu_init() + compute_p_update( + cPRELOAD, + QK_PRELOAD % EXP_RING, + ) + + for tile_id in pto.range(c1, steady_end, c1): + next_tile = tile_id + cPRELOAD + compute_gu_update_dispatch(tile_id) + compute_p_update_dispatch(next_tile) + + for drain in range(QK_PRELOAD): + tile_id = steady_end + const(drain) + compute_gu(tile_id) + + for row_slice in range(SUBTILES): + tile.row_expand_div( + o_tiles[row_slice], + global_sum_tiles[row_slice], + o_tiles[row_slice], + ) + o_view = pto.slice_view( + o_sub_vec_ty, + source=tv_o, + offsets=[q_row_off_sb + const(row_slice * VEC_ROWS), c0], + sizes=[cVEC_ROWS, cHEAD], + ) + pto.store(o_tiles[row_slice], o_view) + + @pto.func + def call_both( + ffts_addr: "ffts_ty", + gm_slot_buffer: "ptr_fp32", + gm_slot_buffer_h: "ptr_fp16", + gm_q: "ptr_fp16", + gm_k: "ptr_fp16", + gm_v: "ptr_fp16", + gm_o: "ptr_fp32", + s0_i64: "i64", + s1_i64: "i64", + ) -> None: + pto.set_ffts(ffts_addr) + pto.call( + cube_kernel, + gm_slot_buffer, + gm_slot_buffer_h, + gm_q, + gm_k, + gm_v, + s0_i64, + s1_i64, + ) + pto.call( + vector_kernel, + gm_slot_buffer, + gm_slot_buffer_h, + gm_o, + s0_i64, + s1_i64, + ) + + +if __name__ == "__main__": + print(module.operation.get_asm(print_generic_op_form=True)) diff --git a/examples/aot/flash_attention/140tflops/naive_tpush_dsl_plot.png b/examples/aot/flash_attention/140tflops/naive_tpush_dsl_plot.png new file mode 100644 index 00000000..fd51c28a Binary files /dev/null and b/examples/aot/flash_attention/140tflops/naive_tpush_dsl_plot.png differ diff --git a/examples/aot/flash_attention/140tflops/run.py b/examples/aot/flash_attention/140tflops/run.py index 84932ede..5a0235d5 100644 --- a/examples/aot/flash_attention/140tflops/run.py +++ b/examples/aot/flash_attention/140tflops/run.py @@ -21,7 +21,7 @@ import math import argparse import ctypes -import subprocess +import importlib.util from pathlib import Path import matplotlib.pyplot as plt @@ -111,12 +111,22 @@ def torch_to_ctypes(t: torch.Tensor) -> ctypes.c_void_p: def load_dsl_flash(lib_path: Path | None = None): if lib_path is None: lib_path = THIS_DIR / "build_artifacts" / "fa_dsl.so" - print("Compiling PTODSL flash kernel...") - subprocess.run(["bash", str(THIS_DIR / "compile.sh")], cwd=THIS_DIR, check=True) + print(f"Using default lib path: {lib_path}") + builder_path = THIS_DIR / "build_artifacts" / "fa_dsl_runtime_builder.py" if not lib_path.exists(): - raise FileNotFoundError(f"compile.sh did not create {lib_path}") + raise FileNotFoundError( + f"Missing {lib_path}; run compile.sh or compile_tile512.sh first" + ) + if not builder_path.exists(): + raise FileNotFoundError( + f"Missing {builder_path}; run compile.sh or compile_tile512.sh first" + ) - import fa_dsl_builder + spec = importlib.util.spec_from_file_location("fa_builder_runtime", builder_path) + if spec is None or spec.loader is None: + raise ImportError(f"Could not import builder from {builder_path}") + fa_dsl_builder = importlib.util.module_from_spec(spec) + spec.loader.exec_module(fa_dsl_builder) lib = ctypes.CDLL(str(lib_path)) lib.call_kernel.argtypes = [ @@ -147,6 +157,7 @@ def alloc_workspace(s0: int, s1: int, head: int, device): device=device, ) ws["o"] = torch.empty((s0, head), dtype=torch.float32, device=device) + ws["block_dim"] = block_dim def flash(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): if q.shape[1] != fa_dsl_builder.HEAD: @@ -176,18 +187,77 @@ def flash(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): ) return ws["o"] - return flash, fa_dsl_builder.TILE_S1 + def prepare_raw_runner(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + if q.shape[1] != fa_dsl_builder.HEAD: + raise ValueError(f"HEAD must be {fa_dsl_builder.HEAD}, got {q.shape[1]}") + if q.shape[0] % fa_dsl_builder.CUBE_S0 != 0: + raise ValueError( + f"S0 must be divisible by CUBE_S0={fa_dsl_builder.CUBE_S0}" + ) + if k.shape[0] % fa_dsl_builder.TILE_S1 != 0: + raise ValueError( + f"S1 must be divisible by TILE_S1={fa_dsl_builder.TILE_S1}" + ) + + alloc_workspace(q.shape[0], k.shape[0], q.shape[1], q.device) + block_dim = ws["block_dim"] + stream_ptr = torch.npu.current_stream()._as_parameter_ + gm_slot_ptr = torch_to_ctypes(ws["gm_slot"]) + q_ptr = torch_to_ctypes(q) + k_ptr = torch_to_ctypes(k) + v_ptr = torch_to_ctypes(v) + o_ptr = torch_to_ctypes(ws["o"]) + q_rows = q.shape[0] + k_rows = k.shape[0] + + def run_raw(): + lib.call_kernel( + block_dim, + stream_ptr, + gm_slot_ptr, + q_ptr, + k_ptr, + v_ptr, + o_ptr, + q_rows, + k_rows, + ) + return ws["o"] + + return run_raw + + return flash, prepare_raw_runner, fa_dsl_builder.TILE_S1, fa_dsl_builder.QK_PRELOAD + + +def parse_s1_values(raw: str) -> list[int]: + return [int(part.strip()) for part in raw.split(",") if part.strip()] -def test_flash(): +def test_flash(s1_values: list[int] | None = None, perf_mode: bool = False): s0, head = 128 * 24, 128 - s1_values = [1024, 2048, 4096, 8192, 16384, 32768, 64 * 1024, 128 * 1024] + default_s1_values = [1024, 2048, 4096, 8192, 16384, 32768, 64 * 1024, 128 * 1024] + user_provided_s1 = s1_values is not None + if s1_values is None: + s1_values = default_s1_values is_causal = False dtype = torch.float16 q2d = torch.randn((s0, head), dtype=dtype).npu() - flash, s1_tile = load_dsl_flash() + flash, prepare_raw_runner, s1_tile, qk_preload = load_dsl_flash() + min_s1 = s1_tile * qk_preload + invalid_s1 = [s1 for s1 in s1_values if s1 < min_s1] + if invalid_s1 and user_provided_s1: + raise ValueError( + f"S1 values {invalid_s1} are too small for TILE_S1={s1_tile} " + f"and QK_PRELOAD={qk_preload}; minimum S1 is {min_s1}" + ) + if invalid_s1: + print( + f"Skipping S1 values {invalid_s1}: TILE_S1={s1_tile}, " + f"QK_PRELOAD={qk_preload}, minimum S1={min_s1}" + ) + s1_values = [s1 for s1 in s1_values if s1 >= min_s1] run_flash = lambda q, k, v: flash(q, k, v) flash_ms_values = [] @@ -196,19 +266,49 @@ def test_flash(): flash_tflops_values = [] npu_tflops_values = [] ref_tflops_values = [] + cases = [ + {"s1": s1, "flops_total": attn_flops_matmul_softmax_scale(1, s0, s1, head)} + for s1 in s1_values + ] - for s1 in s1_values: - flops_total = attn_flops_matmul_softmax_scale(1, s0, s1, head) - - # ========================== - # Inputs - # ========================== + def make_kv(s1: int): + torch.manual_seed(SEED + s1) k2d = torch.randn((s1, head), dtype=dtype).npu() v2d = torch.randn((s1, head), dtype=dtype).npu() + return k2d, v2d + + if perf_mode: + if len(cases) != 1: + raise ValueError( + "--perf-mode requires exactly one S1 via --perf-mode S1 or --s1-values" + ) + case = cases[0] + s1 = case["s1"] + flops_total = case["flops_total"] + k2d, v2d = make_kv(s1) + run_flash_raw = prepare_raw_runner(q2d, k2d, v2d) + flash_ms = do_bench( + run_flash_raw, + warmup_iters=WARMUP, + benchmark_iters=NUM_ITERATIONS, + unit="ms", + ) + print("==== PTODSL perf mode ====") + print(f"S1 : {s1}") + print(f"Causal : {is_causal}") + print(f"GFLOPs total : {flops_total//10e9}") + print( + f"{'PTODSL flash kernel':<27}: {flash_ms:.3f} ms/iter " + f"({tflops(flops_total, flash_ms):.3f} TFLOP/s)" + ) + return + + print("==== Reference phase ====") + for case in cases: + s1 = case["s1"] + flops_total = case["flops_total"] + k2d, v2d = make_kv(s1) - # ========================== - # Benchmark reference ops - # ========================== ref_ms = do_bench( lambda: fa_reference(q2d, k2d, v2d, is_causal=is_causal), warmup_iters=WARMUP, @@ -221,26 +321,52 @@ def test_flash(): benchmark_iters=NUM_ITERATIONS, unit="ms", ) + o_ref = fa_reference(q2d, k2d, v2d, is_causal=is_causal).to(torch.float32) + o_npu = fused_attention(q2d, k2d, v2d, is_causal=is_causal).to(torch.float32) + + case["ref_ms"] = ref_ms + case["npu_ms"] = npu_ms + case["o_ref"] = o_ref + case["o_npu"] = o_npu + + ref_ms_values.append(ref_ms) + npu_ms_values.append(npu_ms) + ref_tflops_values.append(tflops(flops_total, ref_ms)) + npu_tflops_values.append(tflops(flops_total, npu_ms)) + + print(f"S1 : {s1}") + print(f"Causal : {is_causal}") + print(f"GFLOPs total : {flops_total//10e9}") + print( + f"npu_fused_infer_attention : {npu_ms:.3f} ms/iter " + f"({tflops(flops_total, npu_ms):.3f} TFLOP/s)" + ) + print( + f"torch reference : {ref_ms:.3f} ms/iter " + f"({tflops(flops_total, ref_ms):.3f} TFLOP/s)" + ) + print("") + del k2d, v2d + + print("==== PTODSL kernel phase ====") + for case in cases: + s1 = case["s1"] + flops_total = case["flops_total"] + k2d, v2d = make_kv(s1) + o_ref = case["o_ref"] + o_npu = case["o_npu"] + run_flash_raw = prepare_raw_runner(q2d, k2d, v2d) + flash_ms = do_bench( - lambda: run_flash(q2d, k2d, v2d), + run_flash_raw, warmup_iters=WARMUP, benchmark_iters=NUM_ITERATIONS, unit="ms", ) - flash_ms_values.append(flash_ms) - npu_ms_values.append(npu_ms) - ref_ms_values.append(ref_ms) flash_tflops_values.append(tflops(flops_total, flash_ms)) - npu_tflops_values.append(tflops(flops_total, npu_ms)) - ref_tflops_values.append(tflops(flops_total, ref_ms)) - # ========================== - # Correctness check - # ========================== o_out = run_flash(q2d, k2d, v2d) - o_ref = fa_reference(q2d, k2d, v2d, is_causal=is_causal).to(torch.float32) - o_npu = fused_attention(q2d, k2d, v2d, is_causal=is_causal).to(torch.float32) print(f"S1 : {s1}") print(f"Causal : {is_causal}") @@ -250,18 +376,19 @@ def test_flash(): f"({tflops(flops_total, flash_ms):.3f} TFLOP/s)" ) print( - f"npu_fused_infer_attention : {npu_ms:.3f} ms/iter " - f"({tflops(flops_total, npu_ms):.3f} TFLOP/s)" + f"npu_fused_infer_attention : {case['npu_ms']:.3f} ms/iter " + f"({tflops(flops_total, case['npu_ms']):.3f} TFLOP/s)" ) print( - f"torch reference : {ref_ms:.3f} ms/iter " - f"({tflops(flops_total, ref_ms):.3f} TFLOP/s)" + f"torch reference : {case['ref_ms']:.3f} ms/iter " + f"({tflops(flops_total, case['ref_ms']):.3f} TFLOP/s)" ) torch.testing.assert_close(o_out, o_ref, rtol=1e-3, atol=1e-3) print("vs torch reference: PASSED") torch.testing.assert_close(o_out, o_npu, rtol=1e-3, atol=1e-3) print("vs npu_fused_attention: PASSED") print("") + del k2d, v2d, o_out plot_path = Path(__file__).with_name("naive_tpush_dsl_plot.png") plt.figure(figsize=(8, 5)) @@ -284,4 +411,33 @@ def test_flash(): if __name__ == "__main__": - test_flash() + parser = argparse.ArgumentParser() + parser.add_argument( + "--s1-values", + default=None, + help="Comma-separated S1 values. Default: full sweep.", + ) + parser.add_argument( + "--perf-mode", + nargs="?", + const=True, + default=False, + metavar="S1", + help=( + "Benchmark only the PTODSL raw kernel. Optionally pass one S1 directly, " + "for example: --perf-mode 131072. The old form with --s1-values still works." + ), + ) + args = parser.parse_args() + perf_mode = args.perf_mode is not False + s1_values = parse_s1_values(args.s1_values) if args.s1_values else None + if args.perf_mode is not False and args.perf_mode is not True: + if s1_values is not None: + raise ValueError( + "Pass S1 either as --perf-mode S1 or --s1-values, not both" + ) + s1_values = [int(args.perf_mode)] + test_flash( + s1_values=s1_values, + perf_mode=perf_mode, + ) diff --git a/ptodsl/api/pto.py b/ptodsl/api/pto.py index f2d80e67..94075fe1 100644 --- a/ptodsl/api/pto.py +++ b/ptodsl/api/pto.py @@ -7,6 +7,7 @@ aic_initialize_pipe, aiv_initialize_pipe, as_tensor, + bitcast, call, set_ffts, sync_set, @@ -25,6 +26,8 @@ reserve_buffer, slice_view, store, + talloc_to_aic, + talloc_to_aiv, tfree_from_aic, tfree_from_aiv, tpop_from_aic, @@ -89,9 +92,13 @@ "alloc_tile", "declare_global", "declare_tile", + "declare_global", + "bitcast", "load_scalar", "load", "store", + "talloc_to_aic", + "talloc_to_aiv", "tpush_to_aiv", "tpush_to_aic", "talloc", @@ -100,6 +107,7 @@ "tfree_from_aic", "tfree_from_aiv", "tpush", + "talloc", "tpop_into", "tpop", "tfree", diff --git a/ptodsl/api/pto_general.py b/ptodsl/api/pto_general.py index b6bfe3cf..7121cb7b 100644 --- a/ptodsl/api/pto_general.py +++ b/ptodsl/api/pto_general.py @@ -179,19 +179,29 @@ def aic_initialize_pipe( dir_mask, slot_size, gm_slot_buffer=None, # only needed on a2/a3? - c2v_consumer_buf, - v2c_consumer_buf, + gm_slot_tensor=None, + c2v_consumer_buf=None, + v2c_consumer_buf=None, id=None, + local_slot_num=None, nosplit=None, ): + kwargs = {} + if c2v_consumer_buf is not None: + kwargs["c2v_consumer_buf"] = _unwrap(c2v_consumer_buf) + if v2c_consumer_buf is not None: + kwargs["v2c_consumer_buf"] = _unwrap(v2c_consumer_buf) + if gm_slot_buffer is not None: + kwargs["gm_slot_buffer"] = _unwrap(gm_slot_buffer) + if gm_slot_tensor is not None: + kwargs["gm_slot_tensor"] = _unwrap(gm_slot_tensor) return _pto.AicInitializePipeOp( dir_mask, slot_size, - c2v_consumer_buf=_unwrap(c2v_consumer_buf), - v2c_consumer_buf=_unwrap(v2c_consumer_buf), - gm_slot_buffer=_unwrap(gm_slot_buffer), id=id, + local_slot_num=local_slot_num, nosplit=nosplit, + **kwargs, ) @@ -206,22 +216,33 @@ def aiv_initialize_pipe( dir_mask, slot_size, gm_slot_buffer=None, # only needed on a2/a3 - c2v_consumer_buf, - v2c_consumer_buf, + gm_slot_tensor=None, + c2v_consumer_buf=None, + v2c_consumer_buf=None, id=None, + local_slot_num=None, nosplit=None, ): + kwargs = {} + if c2v_consumer_buf is not None: + kwargs["c2v_consumer_buf"] = _unwrap(c2v_consumer_buf) + if v2c_consumer_buf is not None: + kwargs["v2c_consumer_buf"] = _unwrap(v2c_consumer_buf) + if gm_slot_buffer is not None: + kwargs["gm_slot_buffer"] = _unwrap(gm_slot_buffer) + if gm_slot_tensor is not None: + kwargs["gm_slot_tensor"] = _unwrap(gm_slot_tensor) return _pto.AivInitializePipeOp( dir_mask, slot_size, - c2v_consumer_buf=_unwrap(c2v_consumer_buf), - v2c_consumer_buf=_unwrap(v2c_consumer_buf), - gm_slot_buffer=_unwrap(gm_slot_buffer), id=id, + local_slot_num=local_slot_num, nosplit=nosplit, + **kwargs, ) +@with_loc def initialize_l2g2l_pipe( *, dir_mask, @@ -315,6 +336,14 @@ def tpush_to_aic(tile, split, *, id=None): return _pto.TPushToAicOp(_unwrap(tile), split, id=id) +def talloc_to_aic(entry, split, *, id=None): + return _pto.TAllocToAicOp(_unwrap(entry), split, id=id).result + + +def talloc_to_aiv(entry, split, *, id=None): + return _pto.TAllocToAivOp(_unwrap(entry), split, id=id).result + + # %recv_tile = pto.tpop_from_aic {split = 0} -> !pto.tile_buf @with_loc def tpop_from_aic(tile_type, split, *, id=None): @@ -328,13 +357,24 @@ def tpop_from_aiv(tile_type, split, *, id=None): # pto.tfree_from_aic {split = 0} @with_loc -def tfree_from_aic(split, *, id=None): - return _pto.TFreeFromAicOp(split, id=id) +def tfree_from_aic(split, *, entry=None, id=None): + kwargs = {} + if entry is not None: + kwargs["entry"] = _unwrap(entry) + return _pto.TFreeFromAicOp(split, id=id, **kwargs) + + +@with_loc +def tfree_from_aiv(split, *, entry=None, id=None): + kwargs = {} + if entry is not None: + kwargs["entry"] = _unwrap(entry) + return _pto.TFreeFromAivOp(split, id=id, **kwargs) @with_loc -def tfree_from_aiv(split, *, id=None): - return _pto.TFreeFromAivOp(split, id=id) +def bitcast(result_type, src): + return _pto.BitcastOp(result_type, _unwrap(src)).result @with_loc @@ -387,6 +427,7 @@ def print(format, scalar): "alloc_tile", "declare_global", "declare_tile", + "declare_global", "reserve_buffer", "import_reserved_buffer", "aic_initialize_pipe", @@ -395,6 +436,9 @@ def print(format, scalar): "load_scalar", "load", "store", + "bitcast", + "talloc_to_aic", + "talloc_to_aiv", "tpush_to_aiv", "tpush_to_aic", "tpop_from_aic", diff --git a/ptodsl/compiler/ir.py b/ptodsl/compiler/ir.py index bab1b9c1..890221e3 100644 --- a/ptodsl/compiler/ir.py +++ b/ptodsl/compiler/ir.py @@ -1,3 +1,4 @@ +import os import inspect from mlir.dialects import func, pto as _pto @@ -195,7 +196,8 @@ def decorator(fn): else: _define(ir_module, ctx, meta_map, fn) - ir_module.operation.verify() + if os.environ.get("PTODSL_SKIP_VERIFY", "0") != "1": + ir_module.operation.verify() return ir_module return decorator