Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ jobs:

env:
RELEASE_REPO: hw-native-sys/PTOAS
RELEASE_VER: 0.31
RELEASE_TAG: v0.31
RELEASE_VER: 0.36
RELEASE_TAG: v0.36
CLI_DIR: /installers/ptoas-cli
PTOISA_COMMIT: 0af942568a4f2868673da0a35b0f5b64f27a20d5
PTOISA_COMMIT: 4e27a104f948e883e0bef44670252381bff794c5

steps:
- name: Install system packages
Expand Down
14 changes: 9 additions & 5 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,26 @@ RUN pip install --no-cache-dir \
pytest pybind11 nanobind setuptools wheel \
ipython jupyterlab matplotlib pandas

# certain operations need latest isa header, not CANN 8.5.0 default
# header on 2026/04/24
ARG PTOISA_COMMIT=0af942568a4f2868673da0a35b0f5b64f27a20d5
# For updated FA style
# https://gitcode.com/cann/pto-isa/commit/4e27a104f948e883e0bef44670252381bff794c5?ref=master
ARG PTOISA_COMMIT=4e27a104f948e883e0bef44670252381bff794c5
WORKDIR /sources
RUN git clone https://gitcode.com/cann/pto-isa.git \
RUN git clone https://gitcode.com/cann/pto-isa \
&& cd pto-isa && git checkout $PTOISA_COMMIT

ENV PTO_LIB_PATH=/sources/pto-isa

# cache above layers unrelated to ptoas version change

# change this ununsed arg if need to force rebuild later lines
ARG CACHE_BURST=1

# ARG ARCH=x86_64
ARG ARCH=aarch64
# https://github.com/hw-native-sys/PTOAS/releases/tag/v0.36
# include the split pipes https://github.com/hw-native-sys/PTOAS/pull/606
ARG RELEASE_REPO=hw-native-sys/PTOAS
ARG RELEASE_VER=0.31
ARG RELEASE_VER=0.36
ARG RELEASE_TAG=v${RELEASE_VER}
ARG WHEEL_NAME=ptoas-${RELEASE_VER}-cp311-none-manylinux_2_34_${ARCH}.whl
ARG CLI_TAR_NAME=ptoas-bin-${ARCH}.tar.gz
Expand Down
6 changes: 3 additions & 3 deletions docker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ Recommend using [Ascend Docker Runtime](https://gitcode.com/Ascend/mind-cluster/
Then, build and run docker image:

```bash
RELEASE_VER=0.29
RELEASE_VER=0.36
sudo docker build \
--build-arg RELEASE_VER=$RELEASE_VER \
. -t pto_dsl:$RELEASE_VER
. -t pto_dsl:fa_hack

# for specific arch (x86_64 vs aarch64)
sudo docker build \
Expand All @@ -30,7 +30,7 @@ sudo docker run --rm -it --ipc=host --privileged \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver:ro \
-v /etc/ascend_install.info:/etc/ascend_install.info:ro \
-v $HOME:/mounted_home -w /mounted_home \
pto_dsl:$RELEASE_VER /bin/bash
pto_dsl:fa_hack /bin/bash
```

## Appendix: NPU driver
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
_CV_FIFO_SIZE = 8 # CV_FIFO_SIZE
_CUBE_S0 = 128 # CUBE_S0
_SUPPORTED_TILE_S1 = (256, 512, 1024)
_DEFAULT_TILE_S1 = 256
_DEFAULT_TILE_S1 = 512
_MAX_TILE_S1 = max(_SUPPORTED_TILE_S1)


Expand Down
16 changes: 8 additions & 8 deletions examples/aot/flash_attention/cpp_ref/simplified/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def fused_attention(q, k, v, is_causal=False):
return out.squeeze(0)


def test_flash(tile_s1: int = 256, head: int = 128):
def test_flash(tile_s1: int = 512, head: int = 128):
s0 = 128 * 24
s1_values = [1024, 2048, 4096, 8192, 16384, 32768, 64 * 1024, 128 * 1024]
bad_s1 = [s1 for s1 in s1_values if s1 % tile_s1 != 0]
Expand Down Expand Up @@ -156,28 +156,28 @@ def test_flash(tile_s1: int = 256, head: int = 128):
print(f"Tile S1 : {tile_s1}")
print(f"FLOPs total : {flops_total}")
print(
f"JIT flash kernel : {flash_ms:.3f} ms/iter "
f"PTO custom FlashAttention : {flash_ms:.3f} ms/iter "
f"({tflops(flops_total, flash_ms):.3f} TFLOP/s)"
)
print(
f"npu_fused_infer_attention : {npu_ms:.3f} ms/iter "
f"({tflops(flops_total, npu_ms):.3f} TFLOP/s)"
)
print(
f"torch reference : {ref_ms:.3f} ms/iter "
f"PyTorch Eager Reference : {ref_ms:.3f} ms/iter "
f"({tflops(flops_total, ref_ms):.3f} TFLOP/s)"
)
torch.testing.assert_close(o_out, o_ref, rtol=1e-3, atol=1e-3)
print("vs torch reference: PASSED")
torch.testing.assert_close(o_out, o_npu, rtol=1e-3, atol=1e-3)
print("vs npu_fused_attention: PASSED")
print("vs npu_fused_infer_attention_score: PASSED")
print("")

plot_path = Path(__file__).with_name("fa_compile_and_run_s1_plot.png")
plt.figure(figsize=(8, 5))
plt.plot(s1_values, flash_tflops_values, marker="o", label="flash")
plt.plot(s1_values, ref_tflops_values, marker="o", label="ref")
plt.plot(s1_values, npu_tflops_values, marker="o", label="torch_npu")
plt.plot(s1_values, flash_tflops_values, marker="o", label="PTO custom FlashAttention")
plt.plot(s1_values, ref_tflops_values, marker="o", label="PyTorch Eager Reference")
plt.plot(s1_values, npu_tflops_values, marker="o", label="torch_npu.npu_fused_infer_attention_score")
plt.xscale("log", base=2)
plt.xticks(s1_values, [str(v) for v in s1_values])
plt.xlabel("S1")
Expand All @@ -195,7 +195,7 @@ def test_flash(tile_s1: int = 256, head: int = 128):

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tile-s1", type=int, choices=(256, 512, 1024), default=256)
parser.add_argument("--tile-s1", type=int, choices=(256, 512, 1024), default=512)
parser.add_argument("--head", type=int, choices=(32, 64, 128), default=128)
args = parser.parse_args()
test_flash(tile_s1=args.tile_s1, head=args.head)
11 changes: 11 additions & 0 deletions examples/aot/flash_attention/cpp_ref/split_pipe/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Try different tile sizes (512 gets highest TFLOPs for long sequence)

```bash
export PTODSL_TEST_DEVICE_ID=7

python ./run.py --tile-s1 512 # default
python ./run.py --tile-s1 256 # slower for long seq
python ./run.py --tile-s1 1024 # wrong result
```

Reference outputs in [./results](./results)
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/**
Copyright (c) 2026 Huawei Technologies Co., Ltd.
*/

#include <acl/acl.h>
#include <cstdint>
#include <cstdlib>

#include "fa_performance_kernel.h"
#include "generated_cases.h"
#include "runtime/rt.h"

extern "C" void call_kernel(void *stream, int headSize, int s0, int s1, int tile_s1, bool is_causal, uint8_t *q,
uint8_t *k, uint8_t *v, uint8_t *o_out, float *qk_tile_fifo, uint16_t *p_tile_fifo,
float *exp_max_ififo, float *pv_tile_fifo, float *global_sum_out, float *exp_max_out,
float *o_parts_out)
{
if (is_causal) {
return;
}

uint64_t ffts_val = 0;
uint32_t ffts_len = 0;
rtGetC2cCtrlAddr(&ffts_val, &ffts_len);
auto *ffts = reinterpret_cast<uint16_t *>(static_cast<uintptr_t>(ffts_val));

uint8_t *cv_comm_buf = nullptr;

#define LAUNCH_DISPATCH(S0_, HEAD_, S1_, CUBE_S0_, CUBE_S1_, TILE_S1_, QK_PRELOAD_, CAUSAL_MASK_) \
if (headSize == (HEAD_) && (s0) == (S0_) && (s1) == (S1_) && tile_s1 == (TILE_S1_)) { \
LaunchTFA<(S0_), (HEAD_), (S1_), (CUBE_S0_), (CUBE_S1_), (TILE_S1_), (QK_PRELOAD_), kFaCvFifoSize, false, \
(CAUSAL_MASK_), kFaCvFifoConsSyncPeriod>( \
ffts, reinterpret_cast<aclFloat16 *>(q), reinterpret_cast<aclFloat16 *>(k), \
reinterpret_cast<aclFloat16 *>(v), reinterpret_cast<aclFloat16 *>(p_tile_fifo), exp_max_ififo, \
global_sum_out, exp_max_out, reinterpret_cast<float *>(o_out), o_parts_out, qk_tile_fifo, pv_tile_fifo, \
reinterpret_cast<aclrtStream>(stream), cv_comm_buf); \
return; \
}

TFA_FOR_EACH_CASE(LAUNCH_DISPATCH);

#undef LAUNCH_DISPATCH
}
71 changes: 71 additions & 0 deletions examples/aot/flash_attention/cpp_ref/split_pipe/generated_cases.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#pragma once
// Auto-generated by scripts/generate_cases.py. Do not edit manually.
// clang-format off
#include <cstddef>

#define TFA_FOR_EACH_CASE(MACRO) \
MACRO(3072, 128, 1024, 128, 128, 256, 4, false) \
MACRO(3072, 128, 1024, 128, 128, 512, 4, false) \
MACRO(3072, 128, 1024, 128, 128, 1024, 4, false) \
MACRO(3072, 128, 2048, 128, 128, 256, 4, false) \
MACRO(3072, 128, 2048, 128, 128, 512, 4, false) \
MACRO(3072, 128, 2048, 128, 128, 1024, 4, false) \
MACRO(3072, 128, 4096, 128, 128, 256, 4, false) \
MACRO(3072, 128, 4096, 128, 128, 512, 4, false) \
MACRO(3072, 128, 4096, 128, 128, 1024, 4, false) \
MACRO(3072, 128, 8192, 128, 128, 256, 4, false) \
MACRO(3072, 128, 8192, 128, 128, 512, 4, false) \
MACRO(3072, 128, 8192, 128, 128, 1024, 4, false) \
MACRO(3072, 128, 16384, 128, 128, 256, 4, false) \
MACRO(3072, 128, 16384, 128, 128, 512, 4, false) \
MACRO(3072, 128, 16384, 128, 128, 1024, 4, false) \
MACRO(3072, 128, 32768, 128, 128, 256, 4, false) \
MACRO(3072, 128, 32768, 128, 128, 512, 4, false) \
MACRO(3072, 128, 32768, 128, 128, 1024, 4, false) \
MACRO(3072, 128, 65536, 128, 128, 256, 4, false) \
MACRO(3072, 128, 65536, 128, 128, 512, 4, false) \
MACRO(3072, 128, 65536, 128, 128, 1024, 4, false) \
MACRO(3072, 128, 131072, 128, 128, 256, 4, false) \
MACRO(3072, 128, 131072, 128, 128, 512, 4, false) \
MACRO(3072, 128, 131072, 128, 128, 1024, 4, false)

struct GeneratedTfaCase {
int s0;
int head_size;
int s1;
int cube_s0;
int cube_s1;
int tile_s1;
int qk_preload;
bool causal_mask;
const char *name;
};

static constexpr GeneratedTfaCase kGeneratedTfaCases[] = {
{3072, 128, 1024, 128, 128, 256, 4, false, "case_float_H_128_S0_3072_S1_1024"},
{3072, 128, 1024, 128, 128, 512, 4, false, "case_float_H_128_S0_3072_S1_1024"},
{3072, 128, 1024, 128, 128, 1024, 4, false, "case_float_H_128_S0_3072_S1_1024"},
{3072, 128, 2048, 128, 128, 256, 4, false, "case_float_H_128_S0_3072_S1_2048"},
{3072, 128, 2048, 128, 128, 512, 4, false, "case_float_H_128_S0_3072_S1_2048"},
{3072, 128, 2048, 128, 128, 1024, 4, false, "case_float_H_128_S0_3072_S1_2048"},
{3072, 128, 4096, 128, 128, 256, 4, false, "case_float_H_128_S0_3072_S1_4096"},
{3072, 128, 4096, 128, 128, 512, 4, false, "case_float_H_128_S0_3072_S1_4096"},
{3072, 128, 4096, 128, 128, 1024, 4, false, "case_float_H_128_S0_3072_S1_4096"},
{3072, 128, 8192, 128, 128, 256, 4, false, "case_float_H_128_S0_3072_S1_8192"},
{3072, 128, 8192, 128, 128, 512, 4, false, "case_float_H_128_S0_3072_S1_8192"},
{3072, 128, 8192, 128, 128, 1024, 4, false, "case_float_H_128_S0_3072_S1_8192"},
{3072, 128, 16384, 128, 128, 256, 4, false, "case_float_H_128_S0_3072_S1_16384"},
{3072, 128, 16384, 128, 128, 512, 4, false, "case_float_H_128_S0_3072_S1_16384"},
{3072, 128, 16384, 128, 128, 1024, 4, false, "case_float_H_128_S0_3072_S1_16384"},
{3072, 128, 32768, 128, 128, 256, 4, false, "case_float_H_128_S0_3072_S1_32768"},
{3072, 128, 32768, 128, 128, 512, 4, false, "case_float_H_128_S0_3072_S1_32768"},
{3072, 128, 32768, 128, 128, 1024, 4, false, "case_float_H_128_S0_3072_S1_32768"},
{3072, 128, 65536, 128, 128, 256, 4, false, "case_float_H_128_S0_3072_S1_65536"},
{3072, 128, 65536, 128, 128, 512, 4, false, "case_float_H_128_S0_3072_S1_65536"},
{3072, 128, 65536, 128, 128, 1024, 4, false, "case_float_H_128_S0_3072_S1_65536"},
{3072, 128, 131072, 128, 128, 256, 4, false, "case_float_H_128_S0_3072_S1_131072"},
{3072, 128, 131072, 128, 128, 512, 4, false, "case_float_H_128_S0_3072_S1_131072"},
{3072, 128, 131072, 128, 128, 1024, 4, false, "case_float_H_128_S0_3072_S1_131072"}
};
static constexpr std::size_t kGeneratedTfaCasesCount = sizeof(kGeneratedTfaCases) / sizeof(kGeneratedTfaCases[0]);
// clang-format on
Loading