Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions kernels/python/flash_atten/caller.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
*
* Host shim for the pto-dsl flash-attention kernel. compile.sh injects
* -DKERNEL_CPP="\"build_artifacts/fa.cpp\""
* which makes this TU include the ptoas-generated kernel that defines
* `call_both`. The single exported symbol `call_kernel` is what run.py
* calls via ctypes.
*/

#ifndef KERNEL_CPP
#error "KERNEL_CPP must be defined at compile time (see compile.sh)."
#endif

#include <cstdint>

extern "C" int rtGetC2cCtrlAddr(uint64_t *ctrlAddr, uint32_t *ctrlLen);

#include KERNEL_CPP

extern "C" void call_kernel(uint32_t blockDim, void *stream, uint8_t *gmSlotBuffer, uint8_t *q, uint8_t *k, uint8_t *v,
uint8_t *o)
{
void *fftsAddr = nullptr;
uint32_t fftsLen = 0;
(void)rtGetC2cCtrlAddr(reinterpret_cast<uint64_t *>(&fftsAddr), &fftsLen);
(void)fftsLen;

call_both<<<blockDim, nullptr, stream>>>((__gm__ int64_t *)fftsAddr, (__gm__ float *)gmSlotBuffer,
(__gm__ half *)gmSlotBuffer, (__gm__ half *)q, (__gm__ half *)k,
(__gm__ half *)v, (__gm__ float *)o);
}
90 changes: 90 additions & 0 deletions kernels/python/flash_atten/compile.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#!/usr/bin/env bash
# Copyright (c) 2026 Huawei Technologies Co., Ltd.
# CANN Open Software License Agreement Version 2.0
#
# Build pto-dsl flash-attention .so variants. Each FA_TILES value N produces
# fa${TAG}.{mlir,cpp,so} under build_artifacts/ where TAG is "" for the
# default N=4 (S1=1024 at S1_TILE=256, same host-visible shape as
# manual case_float_H_128_S0_128_S1_1024) and "_N" otherwise.
#
# Usage:
# bash compile.sh # build the default S1=1024/2048/8192 (N=4,8,32)
# FA_TILES=4 bash compile.sh # build only S1=1024
# FA_TILES=4,8,32 bash compile.sh # build S1=1024 / 2048 / 8192 (same as default)
# PTO_LIB_PATH=/abs/pto-isa bash compile.sh # override include path

set -euo pipefail

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
ARTIFACT_DIR="${SCRIPT_DIR}/build_artifacts"
PTO_LIB_PATH="${PTO_LIB_PATH:-/sources/pto-isa}"
PTOAS="${PTOAS:-ptoas}"
BISHENG="${BISHENG:-bisheng}"

mkdir -p "${ARTIFACT_DIR}"

build_variant() {
local builder="$1"
local num_tiles="$2"
local out_basename="$3"

local mlir_path="${ARTIFACT_DIR}/${out_basename}.mlir"
local generated_cpp="${ARTIFACT_DIR}/${out_basename}.cpp"
local lib_path="${ARTIFACT_DIR}/${out_basename}.so"

echo "==> Building ${out_basename} (NUM_TILES=${num_tiles}) -> ${lib_path}"
rm -f "${mlir_path}" "${generated_cpp}" "${lib_path}"

FA_NUM_TILES="${num_tiles}" python "${SCRIPT_DIR}/kernels/${builder}" > "${mlir_path}"
"${PTOAS}" --pto-arch=a3 --enable-insert-sync "${mlir_path}" > "${generated_cpp}"
python3 - "${generated_cpp}" <<'PY'
import re
import sys
from pathlib import Path

path = Path(sys.argv[1])
text = path.read_text()

# Older ptoas builds ignore local_slot_num on gm_slot_tensor-based pipe init and
# emit address-based TPipe instantiations as LocalSlotNum=SlotNum=8. The manual
# flash-attention kernel uses LocalSlotNum=2; patch only when that old lowering
# shape is still present.
patched = re.sub(r"(TPipe<[^>\n]+,\s*8),\s*8,\s*false", r"\1, 2, false", text)
path.write_text(patched)
PY

"${BISHENG}" \
-I"${PTO_LIB_PATH}/include" \
-fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \
-Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \
-xcce -Xhost-start -Xhost-end \
-mllvm -cce-aicore-stack-size=0x8000 \
-mllvm -cce-aicore-function-stack-size=0x8000 \
-mllvm -cce-aicore-record-overflow=true \
-mllvm -cce-aicore-addr-transform \
-mllvm -cce-aicore-dcci-insert-for-scalar=false \
-cce-enable-mix \
--npu-arch=dav-2201 -DMEMORY_BASE \
-std=gnu++17 \
-DKERNEL_CPP="\"${generated_cpp}\"" \
"${SCRIPT_DIR}/caller.cpp" \
-o "${lib_path}"

echo " built ${lib_path}"
}

FA_TILES="${FA_TILES:-4,8,32}"

IFS=',' read -r -a tile_list <<< "${FA_TILES}"
for nt_raw in "${tile_list[@]}"; do
nt="$(echo "${nt_raw}" | tr -d '[:space:]')"
[[ -z "${nt}" ]] && continue
if [[ "${nt}" == "4" ]]; then
out="fa"
else
out="fa_${nt}"
fi
build_variant "fa_builder.py" "${nt}" "${out}"
done

echo "Done. Variants: ${FA_TILES}"
Loading
Loading